package repo import ( "encoding/base64" "fmt" "net/url" "sort" "testing" "time" "github.com/coreos/go-oidc/oidc" "github.com/kylelemons/godebug/pretty" "github.com/coreos/dex/client" "github.com/coreos/dex/db" "github.com/coreos/dex/refresh" "github.com/coreos/dex/user" ) var ( testRefreshClientID = "client1" testRefreshClientID2 = "client2" testRefreshConnectorID = "IDPC-1" testRefreshClients = []client.LoadableClient{ { Client: client.Client{ Credentials: oidc.ClientCredentials{ ID: testRefreshClientID, Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")), }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ url.URL{Scheme: "https", Host: "client1.example.com", Path: "/callback"}, }, }, }, }, { Client: client.Client{ Credentials: oidc.ClientCredentials{ ID: testRefreshClientID2, Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")), }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ url.URL{Scheme: "https", Host: "client2.example.com", Path: "/callback"}, }, }, }, }, } testRefreshUserID = "user1" testRefreshUsers = []user.UserWithRemoteIdentities{ { User: user.User{ ID: testRefreshUserID, Email: "Email-1@example.com", CreatedAt: time.Now().Truncate(time.Second), }, RemoteIdentities: []user.RemoteIdentity{ { ConnectorID: testRefreshConnectorID, ID: "RID-1", }, }, }, } ) func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients []client.LoadableClient) refresh.RefreshTokenRepo { dbMap := connect(t) if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil { t.Fatalf("Unable to add users: %v", err) } if _, err := db.NewClientRepoFromClients(dbMap, clients); err != nil { t.Fatalf("Unable to add clients: %v", err) } return db.NewRefreshTokenRepo(dbMap) } func TestRefreshTokenRepoCreateVerify(t *testing.T) { tests := []struct { createScopes []string verifyClientID string wantVerifyErr bool }{ { createScopes: []string{"openid", "profile"}, verifyClientID: testRefreshClientID, }, { createScopes: []string{}, verifyClientID: testRefreshClientID, }, { createScopes: []string{"openid", "profile"}, verifyClientID: "not-a-client", wantVerifyErr: true, }, } for i, tt := range tests { repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) tok, err := repo.Create(testRefreshUserID, testRefreshClientID, testRefreshConnectorID, tt.createScopes) if err != nil { t.Fatalf("case %d: failed to create refresh token: %v", i, err) } tokUserID, gotConnectorID, gotScopes, err := repo.Verify(tt.verifyClientID, tok) if tt.wantVerifyErr { if err == nil { t.Errorf("case %d: want non-nil error.", i) } continue } if diff := pretty.Compare(tt.createScopes, gotScopes); diff != "" { t.Errorf("case %d: Compare(want, got): %v", i, diff) } if err != nil { t.Errorf("case %d: Could not verify token: %v", i, err) } else if tokUserID != testRefreshUserID { t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i, testRefreshUserID, tokUserID) } if gotConnectorID != testRefreshConnectorID { t.Errorf("case %d: wanted connector_id=%q got=%q", i, testRefreshConnectorID, gotConnectorID) } } } // buildRefreshToken combines the token ID and token payload to create a new token. // used in the tests to created a refresh token. func buildRefreshToken(tokenID int64, tokenPayload []byte) string { return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload)) } func TestRefreshRepoVerifyInvalidTokens(t *testing.T) { r := db.NewRefreshTokenRepo(connect(t)) token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope) if err != nil { t.Fatalf("Unexpected error: %v", err) } badTokenPayload, err := refresh.DefaultRefreshTokenGenerator() if err != nil { t.Fatalf("Unexpected error: %v", err) } tokenWithBadID := "404" + token[1:] tokenWithBadPayload := buildRefreshToken(1, badTokenPayload) tests := []struct { token string creds oidc.ClientCredentials err error expected string }{ { "invalid-token-format", oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, refresh.ErrorInvalidToken, "", }, { "b/invalid-base64-encoded-format", oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, refresh.ErrorInvalidToken, "", }, { "1/invalid-base64-encoded-format", oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, refresh.ErrorInvalidToken, "", }, { token + "corrupted-token-payload", oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, refresh.ErrorInvalidToken, "", }, { // The token's ID content is invalid. tokenWithBadID, oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, refresh.ErrorInvalidToken, "", }, { // The token's payload content is invalid. tokenWithBadPayload, oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, refresh.ErrorInvalidToken, "", }, { token, oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"}, refresh.ErrorInvalidClientID, "", }, { token, oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, nil, "user-foo", }, } for i, tt := range tests { result, _, _, err := r.Verify(tt.creds.ID, tt.token) if err != tt.err { t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) } if result != tt.expected { t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result) } } } func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) { tests := []struct { clientIDs []string }{ {clientIDs: []string{"client1", "client2"}}, {clientIDs: []string{"client1"}}, {clientIDs: []string{}}, } for i, tt := range tests { repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) for _, clientID := range tt.clientIDs { _, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"}) if err != nil { t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err) } } clients, err := repo.ClientsWithRefreshTokens(testRefreshUserID) if err != nil { t.Fatalf("case %d: unexpected error fetching clients %q", i, err) } var clientIDs []string for _, client := range clients { clientIDs = append(clientIDs, client.Credentials.ID) } sort.Strings(clientIDs) if diff := pretty.Compare(clientIDs, tt.clientIDs); diff != "" { t.Errorf("case %d: Compare(want, got): %v", i, diff) } } } func TestRefreshTokenRepoRevokeForClient(t *testing.T) { tests := []struct { createIDs []string revokeID string }{ { createIDs: []string{"client1", "client2"}, revokeID: "client1", }, { createIDs: []string{"client2"}, revokeID: "client1", }, { createIDs: []string{"client1"}, revokeID: "client1", }, { createIDs: []string{}, revokeID: "oops", }, } for i, tt := range tests { repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) for _, clientID := range tt.createIDs { _, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"}) if err != nil { t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err) } if err := repo.RevokeTokensForClient(testRefreshUserID, tt.revokeID); err != nil { t.Fatalf("case %d: couldn't revoke refresh token(s): %v", i, err) } } var wantIDs []string for _, id := range tt.createIDs { if id != tt.revokeID { wantIDs = append(wantIDs, id) } } clients, err := repo.ClientsWithRefreshTokens(testRefreshUserID) if err != nil { t.Fatalf("case %d: unexpected error fetching clients %q", i, err) } var gotIDs []string for _, client := range clients { gotIDs = append(gotIDs, client.Credentials.ID) } sort.Strings(gotIDs) if diff := pretty.Compare(wantIDs, gotIDs); diff != "" { t.Errorf("case %d: Compare(wantIDs, gotIDs): %v", i, diff) } } } func TestRefreshRepoRevoke(t *testing.T) { r := db.NewRefreshTokenRepo(connect(t)) token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope) if err != nil { t.Fatalf("Unexpected error: %v", err) } badTokenPayload, err := refresh.DefaultRefreshTokenGenerator() if err != nil { t.Fatalf("Unexpected error: %v", err) } tokenWithBadID := "404" + token[1:] tokenWithBadPayload := buildRefreshToken(1, badTokenPayload) tests := []struct { token string userID string err error }{ { "invalid-token-format", "user-foo", refresh.ErrorInvalidToken, }, { "1/invalid-base64-encoded-format", "user-foo", refresh.ErrorInvalidToken, }, { token + "corrupted-token-payload", "user-foo", refresh.ErrorInvalidToken, }, { // The token's ID is invalid. tokenWithBadID, "user-foo", refresh.ErrorInvalidToken, }, { // The token's payload is invalid. tokenWithBadPayload, "user-foo", refresh.ErrorInvalidToken, }, { token, "invalid-user", refresh.ErrorInvalidUserID, }, { token, "user-foo", nil, }, } for i, tt := range tests { if err := r.Revoke(tt.userID, tt.token); err != tt.err { t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) } } }