diff --git a/db/user.go b/db/user.go index 6863883a..303911ea 100644 --- a/db/user.go +++ b/db/user.go @@ -112,8 +112,12 @@ func (r *userRepo) Disable(tx repo.Transaction, userID string, disable bool) err return err } - if ct, err := result.RowsAffected(); err == nil && ct == 0 { - return user.ErrorInvalidID + ct, err := result.RowsAffected() + switch { + case err != nil: + return err + case ct == 0: + return user.ErrorNotFound } return nil diff --git a/functional/repo/user_repo_test.go b/functional/repo/user_repo_test.go index 5a1ac934..f99ce80b 100644 --- a/functional/repo/user_repo_test.go +++ b/functional/repo/user_repo_test.go @@ -35,6 +35,7 @@ var ( ID: "ID-2", Email: "Email-2@example.com", CreatedAt: time.Now(), + Disabled: true, }, RemoteIdentities: []user.RemoteIdentity{ { @@ -232,6 +233,61 @@ func TestUpdateUser(t *testing.T) { } } +func TestDisableUser(t *testing.T) { + tests := []struct { + id string + disable bool + err error + }{ + { + id: "ID-1", + }, + { + id: "ID-1", + disable: true, + }, + { + id: "ID-2", + }, + { + id: "ID-2", + disable: true, + }, + { + id: "NO SUCH ID", + err: user.ErrorNotFound, + }, + { + id: "NO SUCH ID", + err: user.ErrorNotFound, + disable: true, + }, + { + id: "", + err: user.ErrorInvalidID, + }, + } + + for i, tt := range tests { + repo := makeTestUserRepo() + err := repo.Disable(nil, tt.id, tt.disable) + switch { + case err != tt.err: + t.Errorf("case %d: want=%q, got=%q", i, tt.err, err) + case tt.err == nil: + gotUser, err := repo.Get(nil, tt.id) + if err != nil { + t.Fatalf("case %d: want nil err, got %q", i, err) + } + + if gotUser.Disabled != tt.disable { + t.Errorf("case %d: disabled status want=%v got=%v", + i, tt.disable, gotUser.Disabled) + } + } + } +} + func TestAttachRemoteIdentity(t *testing.T) { tests := []struct { id string diff --git a/user/user.go b/user/user.go index 9fecb835..e771e667 100644 --- a/user/user.go +++ b/user/user.go @@ -257,6 +257,9 @@ func (r *memUserRepo) Update(_ repo.Transaction, user User) error { } func (r *memUserRepo) Disable(_ repo.Transaction, id string, disable bool) error { + if id == "" { + return ErrorInvalidID + } user, ok := r.usersByID[id] if !ok { return ErrorNotFound