server: remove boilerplate setup code

Use the test fixture setup stuff in testutil instead.
This commit is contained in:
Bobby Rullo 2016-06-06 13:03:07 -07:00
parent 8d1a6f2324
commit ad1d5ab253
2 changed files with 110 additions and 412 deletions

View file

@ -10,7 +10,6 @@ import (
"time" "time"
"github.com/coreos/dex/client" "github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/db" "github.com/coreos/dex/db"
"github.com/coreos/dex/refresh/refreshtest" "github.com/coreos/dex/refresh/refreshtest"
"github.com/coreos/dex/session/manager" "github.com/coreos/dex/session/manager"
@ -22,7 +21,6 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
) )
var clientTestSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
var validRedirURL = url.URL{ var validRedirURL = url.URL{
Scheme: "http", Scheme: "http",
Host: "client.example.com", Host: "client.example.com",
@ -185,110 +183,42 @@ func TestServerNewSession(t *testing.T) {
} }
func TestServerLogin(t *testing.T) { func TestServerLogin(t *testing.T) {
ci := client.Client{ f, err := makeTestFixtures()
Credentials: oidc.ClientCredentials{
ID: testClientID,
Secret: clientTestSecret,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{
Scheme: "http",
Host: "client.example.com",
Path: "/callback",
},
},
},
}
dbm := db.NewMemDB()
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), []client.Client{ci}, clientmanager.ManagerOptions{})
if err != nil { if err != nil {
t.Fatalf("Failed to create client identity manager: %v", err) t.Fatalf("error making test fixtures: %v", err)
} }
km := &StaticKeyManager{ sm := f.sessionManager
signer: &StaticSigner{sig: []byte("beer"), err: nil}, sessionID, err := sm.NewSession("IDPC-1", testClientID, "bogus", testRedirectURL, "", false, []string{"openid"})
}
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) ident := oidc.Identity{ID: testUserRemoteID1, Name: "elroy", Email: testUserEmail1}
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
userRepo, err := makeNewUserRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km,
SessionManager: sm,
ClientRepo: clientRepo,
ClientManager: clientManager,
UserRepo: userRepo,
}
ident := oidc.Identity{ID: "YYY", Name: "elroy", Email: "elroy@example.com"}
key, err := sm.NewSessionKey(sessionID) key, err := sm.NewSessionKey(sessionID)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
redirectURL, err := srv.Login(ident, key) if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
redirectURL, err := f.srv.Login(ident, key)
if err != nil { if err != nil {
t.Fatalf("Unexpected err from Server.Login: %v", err) t.Fatalf("Unexpected err from Server.Login: %v", err)
} }
wantRedirectURL := "http://client.example.com/callback?code=fakecode&state=bogus" wantRedirectURL := "http://client.example.com/callback?code=code-3&state=bogus"
if wantRedirectURL != redirectURL { if wantRedirectURL != redirectURL {
t.Fatalf("Unexpected redirectURL: want=%q, got=%q", wantRedirectURL, redirectURL) t.Fatalf("Unexpected redirectURL: want=%q, got=%q", wantRedirectURL, redirectURL)
} }
} }
func TestServerLoginUnrecognizedSessionKey(t *testing.T) { func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
clients := []client.Client{ f, err := makeTestFixtures()
client.Client{
Credentials: oidc.ClientCredentials{
ID: testClientID, Secret: clientTestSecret,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
},
},
}
dbm := db.NewMemDB()
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
t.Fatalf("Failed to create client identity manager: %v", err) t.Fatalf("error making test fixtures: %v", err)
}
km := &StaticKeyManager{
signer: &StaticSigner{sig: nil, err: errors.New("fail")},
}
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km,
SessionManager: sm,
ClientRepo: clientRepo,
ClientManager: clientManager,
} }
ident := oidc.Identity{ID: "YYY", Name: "elroy", Email: "elroy@example.com"} ident := oidc.Identity{ID: testUserRemoteID1, Name: "elroy", Email: testUserEmail1}
code, err := srv.Login(ident, testClientID) code, err := f.srv.Login(ident, testClientID)
if err == nil { if err == nil {
t.Fatalf("Expected non-nil error") t.Fatalf("Expected non-nil error")
} }
@ -299,47 +229,12 @@ func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
} }
func TestServerLoginDisabledUser(t *testing.T) { func TestServerLoginDisabledUser(t *testing.T) {
ci := client.Client{ f, err := makeTestFixtures()
Credentials: oidc.ClientCredentials{
ID: testClientID,
Secret: clientTestSecret,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
},
}
clients := []client.Client{ci}
dbm := db.NewMemDB()
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
t.Fatalf("Failed to create client identity manager: %v", err) t.Fatalf("error making test fixtures: %v", err)
}
km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil},
} }
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())) err = f.userRepo.Create(nil, user.User{
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
userRepo, err := makeNewUserRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = userRepo.Create(nil, user.User{
ID: "disabled-1", ID: "disabled-1",
Email: "disabled@example.com", Email: "disabled@example.com",
Disabled: true, Disabled: true,
@ -348,79 +243,29 @@ func TestServerLoginDisabledUser(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
err = userRepo.AddRemoteIdentity(nil, "disabled-1", user.RemoteIdentity{ err = f.userRepo.AddRemoteIdentity(nil, "disabled-1", user.RemoteIdentity{
ConnectorID: "test_connector_id", ConnectorID: "test_connector_id",
ID: "disabled-connector-id", ID: "disabled-connector-id",
}) })
srv := &Server{ sessionID, err := f.sessionManager.NewSession("test_connector_id", testClientID, "bogus", testRedirectURL, "", false, []string{"openid"})
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km,
SessionManager: sm,
ClientRepo: clientRepo,
ClientManager: clientManager,
UserRepo: userRepo,
}
ident := oidc.Identity{ID: "disabled-connector-id", Name: "elroy", Email: "elroy@example.com"}
key, err := sm.NewSessionKey(sessionID)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
_, err = srv.Login(ident, key) ident := oidc.Identity{ID: "disabled-connector-id", Name: "elroy", Email: "elroy@example.com"}
key, err := f.sessionManager.NewSessionKey(sessionID)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
_, err = f.srv.Login(ident, key)
if err == nil { if err == nil {
t.Errorf("disabled user was allowed to log in") t.Errorf("disabled user was allowed to log in")
} }
} }
func TestServerCodeToken(t *testing.T) { func TestServerCodeToken(t *testing.T) {
ci := client.Client{
Credentials: oidc.ClientCredentials{
ID: testClientID,
Secret: clientTestSecret,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
},
}
clients := []client.Client{ci}
dbm := db.NewMemDB()
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
t.Fatalf("Failed to create client identity manager: %v", err)
}
km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil},
}
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
userRepo, err := makeNewUserRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km,
SessionManager: sm,
ClientRepo: clientRepo,
ClientManager: clientManager,
UserRepo: userRepo,
RefreshTokenRepo: refreshTokenRepo,
}
tests := []struct { tests := []struct {
scope []string scope []string
refreshToken string refreshToken string
@ -440,7 +285,14 @@ func TestServerCodeToken(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, tt.scope) f, err := makeTestFixtures()
if err != nil {
t.Fatalf("error making test fixtures: %v", err)
}
f.srv.RefreshTokenRepo = refreshtest.NewTestRefreshTokenRepo()
sm := f.sessionManager
sessionID, err := sm.NewSession("bogus_idpc", testClientID, "bogus", url.URL{}, "", false, tt.scope)
if err != nil { if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err) t.Fatalf("case %d: unexpected error: %v", i, err)
} }
@ -449,7 +301,7 @@ func TestServerCodeToken(t *testing.T) {
t.Fatalf("case %d: unexpected error: %v", i, err) t.Fatalf("case %d: unexpected error: %v", i, err)
} }
_, err = sm.AttachUser(sessionID, "testid-1") _, err = sm.AttachUser(sessionID, testUserID1)
if err != nil { if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err) t.Fatalf("case %d: unexpected error: %v", i, err)
} }
@ -459,7 +311,11 @@ func TestServerCodeToken(t *testing.T) {
t.Fatalf("case %d: unexpected error: %v", i, err) t.Fatalf("case %d: unexpected error: %v", i, err)
} }
jwt, token, err := srv.CodeToken(ci.Credentials, key) jwt, token, err := f.srv.CodeToken(
oidc.ClientCredentials{
ID: testClientID,
Secret: clientTestSecret,
}, key)
if err != nil { if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err) t.Fatalf("case %d: unexpected error: %v", i, err)
} }
@ -473,45 +329,13 @@ func TestServerCodeToken(t *testing.T) {
} }
func TestServerTokenUnrecognizedKey(t *testing.T) { func TestServerTokenUnrecognizedKey(t *testing.T) {
ci := client.Client{ f, err := makeTestFixtures()
Credentials: oidc.ClientCredentials{
ID: testClientID,
Secret: clientTestSecret,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
},
}
clients := []client.Client{ci}
dbm := db.NewMemDB()
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil { if err != nil {
t.Fatalf("Failed to create client identity manager: %v", err) t.Fatalf("error making test fixtures: %v", err)
} }
km := &StaticKeyManager{ sm := f.sessionManager
signer: &StaticSigner{sig: []byte("beer"), err: nil},
}
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &Server{ sessionID, err := sm.NewSession("connector_id", testClientID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"})
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km,
SessionManager: sm,
ClientRepo: clientRepo,
ClientManager: clientManager,
}
sessionID, err := sm.NewSession("connector_id", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"})
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
@ -521,7 +345,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
jwt, token, err := srv.CodeToken(ci.Credentials, "foo") jwt, token, err := f.srv.CodeToken(testClientCredentials, "foo")
if err == nil { if err == nil {
t.Fatalf("Expected non-nil error") t.Fatalf("Expected non-nil error")
} }
@ -534,12 +358,8 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
} }
func TestServerTokenFail(t *testing.T) { func TestServerTokenFail(t *testing.T) {
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
keyFixture := "goodkey" keyFixture := "goodkey"
ccFixture := oidc.ClientCredentials{
ID: testClientID,
Secret: clientTestSecret,
}
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
tests := []struct { tests := []struct {
@ -555,7 +375,7 @@ func TestServerTokenFail(t *testing.T) {
// NOTE(ericchiang): This test assumes that the database ID of the first // NOTE(ericchiang): This test assumes that the database ID of the first
// refresh token will be "1". // refresh token will be "1".
signer: signerFixture, signer: signerFixture,
argCC: ccFixture, argCC: testClientCredentials,
argKey: keyFixture, argKey: keyFixture,
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
@ -564,7 +384,7 @@ func TestServerTokenFail(t *testing.T) {
// no 'offline_access' in 'scope', should get empty refresh token // no 'offline_access' in 'scope', should get empty refresh token
{ {
signer: signerFixture, signer: signerFixture,
argCC: ccFixture, argCC: testClientCredentials,
argKey: keyFixture, argKey: keyFixture,
scope: []string{"openid"}, scope: []string{"openid"},
}, },
@ -572,7 +392,7 @@ func TestServerTokenFail(t *testing.T) {
// unrecognized key // unrecognized key
{ {
signer: signerFixture, signer: signerFixture,
argCC: ccFixture, argCC: testClientCredentials,
argKey: "foo", argKey: "foo",
err: oauth2.NewError(oauth2.ErrorInvalidGrant), err: oauth2.NewError(oauth2.ErrorInvalidGrant),
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
@ -590,7 +410,7 @@ func TestServerTokenFail(t *testing.T) {
// signing operation fails // signing operation fails
{ {
signer: &StaticSigner{sig: nil, err: errors.New("fail")}, signer: &StaticSigner{sig: nil, err: errors.New("fail")},
argCC: ccFixture, argCC: testClientCredentials,
argKey: keyFixture, argKey: keyFixture,
err: oauth2.NewError(oauth2.ErrorServerError), err: oauth2.NewError(oauth2.ErrorServerError),
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
@ -598,10 +418,19 @@ func TestServerTokenFail(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sm.GenerateCode = func() (string, error) { return keyFixture, nil }
sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope) f, err := makeTestFixtures()
if err != nil {
t.Fatalf("error making test fixtures: %v", err)
}
sm := f.sessionManager
sm.GenerateCode = func() (string, error) { return keyFixture, nil }
f.srv.RefreshTokenRepo = refreshtest.NewTestRefreshTokenRepo()
f.srv.KeyManager = &StaticKeyManager{
signer: tt.signer,
}
sessionID, err := sm.NewSession(testConnectorID1, testClientID, "bogus", url.URL{}, "", false, tt.scope)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
@ -611,60 +440,17 @@ func TestServerTokenFail(t *testing.T) {
t.Errorf("case %d: unexpected error: %v", i, err) t.Errorf("case %d: unexpected error: %v", i, err)
continue continue
} }
km := &StaticKeyManager{ _, err = sm.AttachUser(sessionID, testUserID1)
signer: tt.signer,
}
clients := []client.Client{
client.Client{
Credentials: ccFixture,
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
validRedirURL,
},
},
},
}
dbm := db.NewMemDB()
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
t.Fatalf("Failed to create client identity manager: %v", err)
}
_, err = sm.AttachUser(sessionID, "testid-1")
if err != nil { if err != nil {
t.Fatalf("case %d: unexpected error: %v", i, err) t.Fatalf("case %d: unexpected error: %v", i, err)
} }
userRepo, err := makeNewUserRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
srv := &Server{
IssuerURL: issuerURL,
KeyManager: km,
SessionManager: sm,
ClientRepo: clientRepo,
ClientManager: clientManager,
UserRepo: userRepo,
RefreshTokenRepo: refreshTokenRepo,
}
_, err = sm.NewSessionKey(sessionID) _, err = sm.NewSessionKey(sessionID)
if err != nil { if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
jwt, token, err := srv.CodeToken(tt.argCC, tt.argKey) jwt, token, err := f.srv.CodeToken(tt.argCC, tt.argKey)
if token != tt.refreshToken { if token != tt.refreshToken {
fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token) fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token)
t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
@ -683,18 +469,7 @@ func TestServerTokenFail(t *testing.T) {
} }
func TestServerRefreshToken(t *testing.T) { func TestServerRefreshToken(t *testing.T) {
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
clientA := client.Client{
Credentials: oidc.ClientCredentials{
ID: testClientID,
Secret: clientTestSecret,
},
Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{
url.URL{Scheme: "https", Host: "client.example.com", Path: "one/two/three"},
},
},
}
clientB := client.Client{ clientB := client.Client{
Credentials: oidc.ClientCredentials{ Credentials: oidc.ClientCredentials{
ID: "example2.com", ID: "example2.com",
@ -706,7 +481,6 @@ func TestServerRefreshToken(t *testing.T) {
}, },
}, },
} }
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
// NOTE(ericchiang): These tests assume that the database ID of the first // NOTE(ericchiang): These tests assume that the database ID of the first
@ -721,39 +495,39 @@ func TestServerRefreshToken(t *testing.T) {
// Everything is good. // Everything is good.
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
clientA.Credentials.ID, testClientID,
clientA.Credentials, testClientCredentials,
signerFixture, signerFixture,
nil, nil,
}, },
// Invalid refresh token(malformatted). // Invalid refresh token(malformatted).
{ {
"invalid-token", "invalid-token",
clientA.Credentials.ID, testClientID,
clientA.Credentials, testClientCredentials,
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidRequest), oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
// Invalid refresh token(invalid payload content). // Invalid refresh token(invalid payload content).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
clientA.Credentials.ID, testClientID,
clientA.Credentials, testClientCredentials,
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidRequest), oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
// Invalid refresh token(invalid ID content). // Invalid refresh token(invalid ID content).
{ {
fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
clientA.Credentials.ID, testClientID,
clientA.Credentials, testClientCredentials,
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidRequest), oauth2.NewError(oauth2.ErrorInvalidRequest),
}, },
// Invalid client(client is not associated with the token). // Invalid client(client is not associated with the token).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
clientA.Credentials.ID, testClientID,
clientB.Credentials, clientB.Credentials,
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), oauth2.NewError(oauth2.ErrorInvalidClient),
@ -761,7 +535,7 @@ func TestServerRefreshToken(t *testing.T) {
// Invalid client(no client ID). // Invalid client(no client ID).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
clientA.Credentials.ID, testClientID,
oidc.ClientCredentials{ID: "", Secret: "aaa"}, oidc.ClientCredentials{ID: "", Secret: "aaa"},
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), oauth2.NewError(oauth2.ErrorInvalidClient),
@ -769,7 +543,7 @@ func TestServerRefreshToken(t *testing.T) {
// Invalid client(no such client). // Invalid client(no such client).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
clientA.Credentials.ID, testClientID,
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), oauth2.NewError(oauth2.ErrorInvalidClient),
@ -777,7 +551,7 @@ func TestServerRefreshToken(t *testing.T) {
// Invalid client(no secrets). // Invalid client(no secrets).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
clientA.Credentials.ID, testClientID,
oidc.ClientCredentials{ID: testClientID}, oidc.ClientCredentials{ID: testClientID},
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), oauth2.NewError(oauth2.ErrorInvalidClient),
@ -785,7 +559,7 @@ func TestServerRefreshToken(t *testing.T) {
// Invalid client(invalid secret). // Invalid client(invalid secret).
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
clientA.Credentials.ID, testClientID,
oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"}, oidc.ClientCredentials{ID: "bad-id", Secret: "bad-secret"},
signerFixture, signerFixture,
oauth2.NewError(oauth2.ErrorInvalidClient), oauth2.NewError(oauth2.ErrorInvalidClient),
@ -793,8 +567,8 @@ func TestServerRefreshToken(t *testing.T) {
// Signing operation fails. // Signing operation fails.
{ {
fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))), fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
clientA.Credentials.ID, testClientID,
clientA.Credentials, testClientCredentials,
&StaticSigner{sig: nil, err: errors.New("fail")}, &StaticSigner{sig: nil, err: errors.New("fail")},
oauth2.NewError(oauth2.ErrorServerError), oauth2.NewError(oauth2.ErrorServerError),
}, },
@ -804,45 +578,22 @@ func TestServerRefreshToken(t *testing.T) {
km := &StaticKeyManager{ km := &StaticKeyManager{
signer: tt.signer, signer: tt.signer,
} }
f, err := makeTestFixtures()
clients := []client.Client{ if err != nil {
clientA, t.Fatalf("error making test fixtures: %v", err)
clientB, }
f.srv.RefreshTokenRepo = refreshtest.NewTestRefreshTokenRepo()
f.srv.KeyManager = km
_, err = f.clientRepo.New(nil, clientB)
if err != nil {
t.Errorf("case %d: error creating other client: %v", i, err)
} }
clientIDGenerator := func(hostport string) (string, error) { if _, err := f.srv.RefreshTokenRepo.Create(testUserID1, tt.clientID); err != nil {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
dbm := db.NewMemDB()
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
t.Fatalf("Failed to create client identity manager: %v", err)
}
userRepo, err := makeNewUserRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo() jwt, err := f.srv.RefreshToken(tt.creds, tt.token)
srv := &Server{
IssuerURL: issuerURL,
KeyManager: km,
ClientRepo: clientRepo,
ClientManager: clientManager,
UserRepo: userRepo,
RefreshTokenRepo: refreshTokenRepo,
}
if _, err := refreshTokenRepo.Create("testid-1", tt.clientID); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
jwt, err := srv.RefreshToken(tt.creds, tt.token)
if !reflect.DeepEqual(err, tt.err) { if !reflect.DeepEqual(err, tt.err) {
t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err) t.Errorf("Case %d: expect: %v, got: %v", i, tt.err, err)
} }
@ -855,71 +606,9 @@ func TestServerRefreshToken(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("Case %d: unexpected error: %v", i, err) t.Errorf("Case %d: unexpected error: %v", i, err)
} }
if claims["iss"] != issuerURL.String() || claims["sub"] != "testid-1" || claims["aud"] != testClientID { if claims["iss"] != testIssuerURL.String() || claims["sub"] != testUserID1 || claims["aud"] != testClientID {
t.Errorf("Case %d: invalid claims: %v", i, claims) t.Errorf("Case %d: invalid claims: %v", i, claims)
} }
} }
} }
// Test that we should return error when user cannot be found after
// verifying the token.
km := &StaticKeyManager{
signer: signerFixture,
}
clients := []client.Client{
clientA,
clientB,
}
clientIDGenerator := func(hostport string) (string, error) {
return hostport, nil
}
secGen := func() ([]byte, error) {
return []byte("secret"), nil
}
dbm := db.NewMemDB()
clientRepo := db.NewClientRepo(dbm)
clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbm), clients, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen})
if err != nil {
t.Fatalf("Failed to create client identity manager: %v", err)
}
userRepo, err := makeNewUserRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Create a user that will be removed later.
if err := userRepo.Create(nil, user.User{
ID: "testid-2",
Email: "test-2@example.com",
}); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
srv := &Server{
IssuerURL: issuerURL,
KeyManager: km,
ClientRepo: clientRepo,
ClientManager: clientManager,
UserRepo: userRepo,
RefreshTokenRepo: refreshTokenRepo,
}
if _, err := refreshTokenRepo.Create("testid-2", clientA.Credentials.ID); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
// Recreate the user repo to remove the user we created.
userRepo, err = makeNewUserRepo()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
srv.UserRepo = userRepo
_, err = srv.RefreshToken(clientA.Credentials, fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
}
} }

View file

@ -26,21 +26,33 @@ const (
) )
var ( var (
testUserID1 = "ID-1"
testUserEmail1 = "Email-1@example.com"
testUserRemoteID1 = "RID-1"
testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"} testIssuerURL = url.URL{Scheme: "http", Host: "server.example.com"}
testClientID = "client.example.com" testClientID = "client.example.com"
clientTestSecret = base64.URLEncoding.EncodeToString([]byte("secret"))
testClientCredentials = oidc.ClientCredentials{
ID: testClientID,
Secret: clientTestSecret,
}
testConnectorID1 = "IDPC-1"
testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"} testRedirectURL = url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}
testUsers = []user.UserWithRemoteIdentities{ testUsers = []user.UserWithRemoteIdentities{
{ {
User: user.User{ User: user.User{
ID: "ID-1", ID: testUserID1,
Email: "Email-1@example.com", Email: testUserEmail1,
}, },
RemoteIdentities: []user.RemoteIdentity{ RemoteIdentities: []user.RemoteIdentity{
{ {
ConnectorID: "IDPC-1", ConnectorID: testConnectorID1,
ID: "RID-1", ID: testUserRemoteID1,
}, },
}, },
}, },
@ -140,10 +152,7 @@ func makeTestFixtures() (*testFixtures, error) {
clients := []client.Client{ clients := []client.Client{
client.Client{ client.Client{
Credentials: oidc.ClientCredentials{ Credentials: testClientCredentials,
ID: testClientID,
Secret: base64.URLEncoding.EncodeToString([]byte("secret")),
},
Metadata: oidc.ClientMetadata{ Metadata: oidc.ClientMetadata{
RedirectURIs: []url.URL{ RedirectURIs: []url.URL{
testRedirectURL, testRedirectURL,