diff --git a/integration/oidc_test.go b/integration/oidc_test.go index 97a728b0..51bf0288 100644 --- a/integration/oidc_test.go +++ b/integration/oidc_test.go @@ -196,7 +196,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) { // this will actually happen due to some interaction between the // end-user and a remote identity provider - sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, nil) + sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } diff --git a/server/http.go b/server/http.go index 073f8b19..29cac285 100644 --- a/server/http.go +++ b/server/http.go @@ -330,6 +330,34 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T return } + // Check scopes. + var scopes []string + foundOpenIDScope := false + for _, scope := range acr.Scope { + switch scope { + case "openid": + foundOpenIDScope = true + scopes = append(scopes, scope) + case "offline_access": + // According to the spec, for offline_access scope, the client must + // use a response_type value that would result in an Authorization Code. + // Currently oauth2.ResponseTypeCode is the only supported response type, + // and it's been checked above, so we don't need to check it again here. + // + // TODO(yifan): Verify that 'consent' should be in 'prompt'. + scopes = append(scopes, scope) + default: + // Pass all other scopes. + scopes = append(scopes, scope) + } + } + + if !foundOpenIDScope { + log.Errorf("Invalid auth request: missing 'openid' in 'scope'") + writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State) + return + } + nonce := q.Get("nonce") key, err := srv.NewSession(connectorID, acr.ClientID, acr.State, redirectURL, nonce, register, acr.Scope) diff --git a/server/http_test.go b/server/http_test.go index 3018dcfb..4f812204 100644 --- a/server/http_test.go +++ b/server/http_test.go @@ -102,6 +102,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { "response_type": []string{"code"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, + "scope": []string{"openid"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://fake.example.com", @@ -114,6 +115,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { "redirect_uri": []string{"http://client.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, + "scope": []string{"openid"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://fake.example.com", @@ -126,6 +128,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { "redirect_uri": []string{"http://unrecognized.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, + "scope": []string{"openid"}, }, wantCode: http.StatusBadRequest, }, @@ -137,6 +140,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { "redirect_uri": []string{"http://client.example.com/callback"}, "client_id": []string{"YYY"}, "connector_id": []string{"fake"}, + "scope": []string{"openid"}, }, wantCode: http.StatusBadRequest, }, @@ -147,10 +151,22 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { "response_type": []string{"token"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, + "scope": []string{"openid"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://client.example.com/callback?error=unsupported_response_type&state=", }, + + // no 'openid' in scope + { + query: url.Values{ + "response_type": []string{"code"}, + "redirect_uri": []string{"http://client.example.com/callback"}, + "client_id": []string{"XXX"}, + "connector_id": []string{"fake"}, + }, + wantCode: http.StatusBadRequest, + }, } for i, tt := range tests { @@ -211,6 +227,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) { "redirect_uri": []string{"http://foo.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, + "scope": []string{"openid"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://fake.example.com", @@ -223,6 +240,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) { "redirect_uri": []string{"http://bar.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, + "scope": []string{"openid"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://fake.example.com", @@ -235,6 +253,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) { "redirect_uri": []string{"http://unrecognized.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, + "scope": []string{"openid"}, }, wantCode: http.StatusBadRequest, }, @@ -245,6 +264,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) { "response_type": []string{"code"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, + "scope": []string{"openid"}, }, wantCode: http.StatusBadRequest, }, diff --git a/server/password_test.go b/server/password_test.go index 4e344dc2..ff14dfdd 100644 --- a/server/password_test.go +++ b/server/password_test.go @@ -245,7 +245,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) { t.Fatalf("case %d: could not make test fixtures: %v", i, err) } - _, err = f.srv.NewSession("local", "XXX", "", f.redirectURL, "", true, nil) + _, err = f.srv.NewSession("local", "XXX", "", f.redirectURL, "", true, []string{"openid"}) if err != nil { t.Fatalf("case %d: could not create new session: %v", i, err) } diff --git a/server/register_test.go b/server/register_test.go index 7b517b17..fe2e5756 100644 --- a/server/register_test.go +++ b/server/register_test.go @@ -197,7 +197,7 @@ func TestHandleRegister(t *testing.T) { t.Fatalf("case %d: could not make test fixtures: %v", i, err) } - key, err := f.srv.NewSession(tt.connID, "XXX", "", f.redirectURL, "", true, nil) + key, err := f.srv.NewSession(tt.connID, "XXX", "", f.redirectURL, "", true, []string{"openid"}) t.Logf("case %d: key for NewSession: %v", i, key) if tt.attachRemote { diff --git a/server/server.go b/server/server.go index ac9a3531..67bbb517 100644 --- a/server/server.go +++ b/server/server.go @@ -422,20 +422,26 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo return nil, "", oauth2.NewError(oauth2.ErrorServerError) } - log.Infof("Session %s token sent: clientID=%s", sessionID, creds.ID) + // Generate refresh token when 'scope' contains 'offline_access'. + var refreshToken string - // Generate refresh token. - // - // TODO(yifan): Return refresh token only when 'access_type == offline', - // or 'scope' == 'offline_access'. - refreshToken, err := s.RefreshTokenRepo.Create(ses.UserID, creds.ID) - switch err { - case nil: - break - default: - log.Errorf("Failed to generate refresh token: %v", err) - return nil, "", oauth2.NewError(oauth2.ErrorServerError) + for _, scope := range ses.Scope { + if scope == "offline_access" { + log.Infof("Session %s requests offline access, will generate refresh token", sessionID) + + refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID) + switch err { + case nil: + break + default: + log.Errorf("Failed to generate refresh token: %v", err) + return nil, "", oauth2.NewError(oauth2.ErrorServerError) + } + break + } } + + log.Infof("Session %s token sent: clientID=%s", sessionID, creds.ID) return jwt, refreshToken, nil } @@ -487,7 +493,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose return nil, oauth2.NewError(oauth2.ErrorServerError) } - log.Infof("Token refreshed sent: clientID=%s", creds.ID) + log.Infof("New token sent: clientID=%s", creds.ID) return jwt, nil } diff --git a/server/server_test.go b/server/server_test.go index 37c397d0..49e36846 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "errors" + "fmt" "net/url" "reflect" "testing" @@ -139,7 +140,7 @@ func TestServerNewSession(t *testing.T) { }, } - key, err := srv.NewSession("bogus_idpc", ci.Credentials.ID, state, ci.Metadata.RedirectURLs[0], nonce, false, nil) + key, err := srv.NewSession("bogus_idpc", ci.Credentials.ID, state, ci.Metadata.RedirectURLs[0], nonce, false, []string{"openid"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -195,7 +196,7 @@ func TestServerLogin(t *testing.T) { sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm.GenerateCode = staticGenerateCodeFunc("fakecode") - sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURLs[0], "", false, nil) + sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURLs[0], "", false, []string{"openid"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -292,34 +293,52 @@ func TestServerCodeToken(t *testing.T) { RefreshTokenRepo: refreshTokenRepo, } - sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, nil) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - _, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{}) - if err != nil { - t.Fatalf("Unexpected error: %v", err) + tests := []struct { + scope []string + refreshToken string + }{ + // No 'offline_access' in scope, should get empty refresh token. + { + scope: []string{"openid"}, + refreshToken: "", + }, + // Have 'offline_access' in scope, should get non-empty refresh token. + { + scope: []string{"openid", "offline_access"}, + refreshToken: "0/refresh-1", + }, } - _, err = sm.AttachUser(sessionID, "testid-1") - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + for i, tt := range tests { + sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, tt.scope) + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + _, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{}) + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } - key, err := sm.NewSessionKey(sessionID) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + _, err = sm.AttachUser(sessionID, "testid-1") + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } - jwt, token, err := srv.CodeToken(ci.Credentials, key) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - if jwt == nil { - t.Fatalf("Expected non-nil jwt") - } - if token == "" { - t.Fatalf("Expected non-empty refresh token") + key, err := sm.NewSessionKey(sessionID) + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + + jwt, token, err := srv.CodeToken(ci.Credentials, key) + if err != nil { + t.Fatalf("case %d: unexpected error: %v", i, err) + } + if jwt == nil { + t.Fatalf("case %d: expect non-nil jwt", i) + } + if token != tt.refreshToken { + t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) + } } } @@ -343,7 +362,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) { ClientIdentityRepo: ciRepo, } - sessionID, err := sm.NewSession("connector_id", ci.Credentials.ID, "bogus", url.URL{}, "", false, nil) + sessionID, err := sm.NewSession("connector_id", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -375,16 +394,28 @@ func TestServerTokenFail(t *testing.T) { signerFixture := &StaticSigner{sig: []byte("beer"), err: nil} tests := []struct { - signer jose.Signer - argCC oidc.ClientCredentials - argKey string - err string + signer jose.Signer + argCC oidc.ClientCredentials + argKey string + err string + scope []string + refreshToken string }{ // control test case to make sure fixtures check out + { + signer: signerFixture, + argCC: ccFixture, + argKey: keyFixture, + scope: []string{"openid", "offline_access"}, + refreshToken: "0/refresh-1", + }, + + // no 'offline_access' in 'scope', should get empty refresh token { signer: signerFixture, argCC: ccFixture, argKey: keyFixture, + scope: []string{"openid"}, }, // unrecognized key @@ -393,6 +424,7 @@ func TestServerTokenFail(t *testing.T) { argCC: ccFixture, argKey: "foo", err: oauth2.ErrorInvalidGrant, + scope: []string{"openid", "offline_access"}, }, // unrecognized client @@ -401,6 +433,7 @@ func TestServerTokenFail(t *testing.T) { argCC: oidc.ClientCredentials{ID: "YYY"}, argKey: keyFixture, err: oauth2.ErrorInvalidClient, + scope: []string{"openid", "offline_access"}, }, // signing operation fails @@ -409,6 +442,7 @@ func TestServerTokenFail(t *testing.T) { argCC: ccFixture, argKey: keyFixture, err: oauth2.ErrorServerError, + scope: []string{"openid", "offline_access"}, }, } @@ -416,7 +450,7 @@ func TestServerTokenFail(t *testing.T) { sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()) sm.GenerateCode = func() (string, error) { return keyFixture, nil } - sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, nil) + sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -435,7 +469,7 @@ func TestServerTokenFail(t *testing.T) { _, err = sm.AttachUser(sessionID, "testid-1") if err != nil { - t.Fatalf("case %d: Unexpected error: %v", i, err) + t.Fatalf("case %d: unexpected error: %v", i, err) } userRepo, err := makeNewUserRepo() @@ -463,22 +497,22 @@ func TestServerTokenFail(t *testing.T) { } jwt, token, err := srv.CodeToken(tt.argCC, tt.argKey) + if token != tt.refreshToken { + fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token) + t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token) + panic("") + } if tt.err == "" { if err != nil { t.Errorf("case %d: got non-nil error: %v", i, err) } else if jwt == nil { t.Errorf("case %d: got nil JWT", i) - } else if token == "" { - t.Errorf("case %d: got empty refresh token", i) } - } else { if err.Error() != tt.err { t.Errorf("case %d: want err %q, got %q", i, tt.err, err.Error()) } else if jwt != nil { t.Errorf("case %d: got non-nil JWT", i) - } else if token != "" { - t.Errorf("case %d: got non-empty refresh token", i) } } } diff --git a/session/manager_test.go b/session/manager_test.go index 3f8ad8d2..4e925ec1 100644 --- a/session/manager_test.go +++ b/session/manager_test.go @@ -16,7 +16,7 @@ func staticGenerateCodeFunc(code string) GenerateCodeFunc { func TestSessionManagerNewSession(t *testing.T) { sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) sm.GenerateCode = staticGenerateCodeFunc("boo") - got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, nil) + got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -27,7 +27,7 @@ func TestSessionManagerNewSession(t *testing.T) { func TestSessionAttachRemoteIdentityTwice(t *testing.T) { sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) - sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, nil) + sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -44,7 +44,7 @@ func TestSessionAttachRemoteIdentityTwice(t *testing.T) { func TestSessionManagerExchangeKey(t *testing.T) { sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) - sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, nil) + sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -80,7 +80,7 @@ func TestSessionManagerGetSessionInStateNoExist(t *testing.T) { func TestSessionManagerGetSessionInStateWrongState(t *testing.T) { sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) - sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, nil) + sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -95,7 +95,7 @@ func TestSessionManagerGetSessionInStateWrongState(t *testing.T) { func TestSessionManagerKill(t *testing.T) { sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo()) - sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, nil) + sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"}) if err != nil { t.Fatalf("Unexpected error: %v", err) }