forked from mystiq/dex
server: check scope in requests.
Require 'openid' in scope for all requests. Require 'offline_access' for returning refresh token.
This commit is contained in:
parent
066fd859ec
commit
93a0830ae0
8 changed files with 147 additions and 59 deletions
|
@ -196,7 +196,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
|
|||
|
||||
// this will actually happen due to some interaction between the
|
||||
// end-user and a remote identity provider
|
||||
sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, nil)
|
||||
sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
|
|
@ -330,6 +330,34 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
|
|||
return
|
||||
}
|
||||
|
||||
// Check scopes.
|
||||
var scopes []string
|
||||
foundOpenIDScope := false
|
||||
for _, scope := range acr.Scope {
|
||||
switch scope {
|
||||
case "openid":
|
||||
foundOpenIDScope = true
|
||||
scopes = append(scopes, scope)
|
||||
case "offline_access":
|
||||
// According to the spec, for offline_access scope, the client must
|
||||
// use a response_type value that would result in an Authorization Code.
|
||||
// Currently oauth2.ResponseTypeCode is the only supported response type,
|
||||
// and it's been checked above, so we don't need to check it again here.
|
||||
//
|
||||
// TODO(yifan): Verify that 'consent' should be in 'prompt'.
|
||||
scopes = append(scopes, scope)
|
||||
default:
|
||||
// Pass all other scopes.
|
||||
scopes = append(scopes, scope)
|
||||
}
|
||||
}
|
||||
|
||||
if !foundOpenIDScope {
|
||||
log.Errorf("Invalid auth request: missing 'openid' in 'scope'")
|
||||
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
|
||||
return
|
||||
}
|
||||
|
||||
nonce := q.Get("nonce")
|
||||
|
||||
key, err := srv.NewSession(connectorID, acr.ClientID, acr.State, redirectURL, nonce, register, acr.Scope)
|
||||
|
|
|
@ -102,6 +102,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
"response_type": []string{"code"},
|
||||
"client_id": []string{"XXX"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
wantCode: http.StatusTemporaryRedirect,
|
||||
wantLocation: "http://fake.example.com",
|
||||
|
@ -114,6 +115,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
"redirect_uri": []string{"http://client.example.com/callback"},
|
||||
"client_id": []string{"XXX"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
wantCode: http.StatusTemporaryRedirect,
|
||||
wantLocation: "http://fake.example.com",
|
||||
|
@ -126,6 +128,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
"redirect_uri": []string{"http://unrecognized.example.com/callback"},
|
||||
"client_id": []string{"XXX"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
|
@ -137,6 +140,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
"redirect_uri": []string{"http://client.example.com/callback"},
|
||||
"client_id": []string{"YYY"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
|
@ -147,10 +151,22 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
|||
"response_type": []string{"token"},
|
||||
"client_id": []string{"XXX"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
wantCode: http.StatusTemporaryRedirect,
|
||||
wantLocation: "http://client.example.com/callback?error=unsupported_response_type&state=",
|
||||
},
|
||||
|
||||
// no 'openid' in scope
|
||||
{
|
||||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"redirect_uri": []string{"http://client.example.com/callback"},
|
||||
"client_id": []string{"XXX"},
|
||||
"connector_id": []string{"fake"},
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
|
@ -211,6 +227,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
|||
"redirect_uri": []string{"http://foo.example.com/callback"},
|
||||
"client_id": []string{"XXX"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
wantCode: http.StatusTemporaryRedirect,
|
||||
wantLocation: "http://fake.example.com",
|
||||
|
@ -223,6 +240,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
|||
"redirect_uri": []string{"http://bar.example.com/callback"},
|
||||
"client_id": []string{"XXX"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
wantCode: http.StatusTemporaryRedirect,
|
||||
wantLocation: "http://fake.example.com",
|
||||
|
@ -235,6 +253,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
|||
"redirect_uri": []string{"http://unrecognized.example.com/callback"},
|
||||
"client_id": []string{"XXX"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
|
@ -245,6 +264,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
|
|||
"response_type": []string{"code"},
|
||||
"client_id": []string{"XXX"},
|
||||
"connector_id": []string{"fake"},
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
wantCode: http.StatusBadRequest,
|
||||
},
|
||||
|
|
|
@ -245,7 +245,7 @@ func TestSendResetPasswordEmailHandler(t *testing.T) {
|
|||
t.Fatalf("case %d: could not make test fixtures: %v", i, err)
|
||||
}
|
||||
|
||||
_, err = f.srv.NewSession("local", "XXX", "", f.redirectURL, "", true, nil)
|
||||
_, err = f.srv.NewSession("local", "XXX", "", f.redirectURL, "", true, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: could not create new session: %v", i, err)
|
||||
}
|
||||
|
|
|
@ -197,7 +197,7 @@ func TestHandleRegister(t *testing.T) {
|
|||
t.Fatalf("case %d: could not make test fixtures: %v", i, err)
|
||||
}
|
||||
|
||||
key, err := f.srv.NewSession(tt.connID, "XXX", "", f.redirectURL, "", true, nil)
|
||||
key, err := f.srv.NewSession(tt.connID, "XXX", "", f.redirectURL, "", true, []string{"openid"})
|
||||
t.Logf("case %d: key for NewSession: %v", i, key)
|
||||
|
||||
if tt.attachRemote {
|
||||
|
|
|
@ -422,20 +422,26 @@ func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jo
|
|||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||
}
|
||||
|
||||
log.Infof("Session %s token sent: clientID=%s", sessionID, creds.ID)
|
||||
// Generate refresh token when 'scope' contains 'offline_access'.
|
||||
var refreshToken string
|
||||
|
||||
// Generate refresh token.
|
||||
//
|
||||
// TODO(yifan): Return refresh token only when 'access_type == offline',
|
||||
// or 'scope' == 'offline_access'.
|
||||
refreshToken, err := s.RefreshTokenRepo.Create(ses.UserID, creds.ID)
|
||||
switch err {
|
||||
case nil:
|
||||
break
|
||||
default:
|
||||
log.Errorf("Failed to generate refresh token: %v", err)
|
||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||
for _, scope := range ses.Scope {
|
||||
if scope == "offline_access" {
|
||||
log.Infof("Session %s requests offline access, will generate refresh token", sessionID)
|
||||
|
||||
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID)
|
||||
switch err {
|
||||
case nil:
|
||||
break
|
||||
default:
|
||||
log.Errorf("Failed to generate refresh token: %v", err)
|
||||
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("Session %s token sent: clientID=%s", sessionID, creds.ID)
|
||||
return jwt, refreshToken, nil
|
||||
}
|
||||
|
||||
|
@ -487,7 +493,7 @@ func (s *Server) RefreshToken(creds oidc.ClientCredentials, token string) (*jose
|
|||
return nil, oauth2.NewError(oauth2.ErrorServerError)
|
||||
}
|
||||
|
||||
log.Infof("Token refreshed sent: clientID=%s", creds.ID)
|
||||
log.Infof("New token sent: clientID=%s", creds.ID)
|
||||
|
||||
return jwt, nil
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
@ -139,7 +140,7 @@ func TestServerNewSession(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
key, err := srv.NewSession("bogus_idpc", ci.Credentials.ID, state, ci.Metadata.RedirectURLs[0], nonce, false, nil)
|
||||
key, err := srv.NewSession("bogus_idpc", ci.Credentials.ID, state, ci.Metadata.RedirectURLs[0], nonce, false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
@ -195,7 +196,7 @@ func TestServerLogin(t *testing.T) {
|
|||
|
||||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
|
||||
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURLs[0], "", false, nil)
|
||||
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURLs[0], "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
@ -292,34 +293,52 @@ func TestServerCodeToken(t *testing.T) {
|
|||
RefreshTokenRepo: refreshTokenRepo,
|
||||
}
|
||||
|
||||
sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
_, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
tests := []struct {
|
||||
scope []string
|
||||
refreshToken string
|
||||
}{
|
||||
// No 'offline_access' in scope, should get empty refresh token.
|
||||
{
|
||||
scope: []string{"openid"},
|
||||
refreshToken: "",
|
||||
},
|
||||
// Have 'offline_access' in scope, should get non-empty refresh token.
|
||||
{
|
||||
scope: []string{"openid", "offline_access"},
|
||||
refreshToken: "0/refresh-1",
|
||||
},
|
||||
}
|
||||
|
||||
_, err = sm.AttachUser(sessionID, "testid-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
for i, tt := range tests {
|
||||
sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, tt.scope)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
_, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{})
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
key, err := sm.NewSessionKey(sessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
_, err = sm.AttachUser(sessionID, "testid-1")
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
jwt, token, err := srv.CodeToken(ci.Credentials, key)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if jwt == nil {
|
||||
t.Fatalf("Expected non-nil jwt")
|
||||
}
|
||||
if token == "" {
|
||||
t.Fatalf("Expected non-empty refresh token")
|
||||
key, err := sm.NewSessionKey(sessionID)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
jwt, token, err := srv.CodeToken(ci.Credentials, key)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if jwt == nil {
|
||||
t.Fatalf("case %d: expect non-nil jwt", i)
|
||||
}
|
||||
if token != tt.refreshToken {
|
||||
t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -343,7 +362,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
|
|||
ClientIdentityRepo: ciRepo,
|
||||
}
|
||||
|
||||
sessionID, err := sm.NewSession("connector_id", ci.Credentials.ID, "bogus", url.URL{}, "", false, nil)
|
||||
sessionID, err := sm.NewSession("connector_id", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
@ -375,16 +394,28 @@ func TestServerTokenFail(t *testing.T) {
|
|||
signerFixture := &StaticSigner{sig: []byte("beer"), err: nil}
|
||||
|
||||
tests := []struct {
|
||||
signer jose.Signer
|
||||
argCC oidc.ClientCredentials
|
||||
argKey string
|
||||
err string
|
||||
signer jose.Signer
|
||||
argCC oidc.ClientCredentials
|
||||
argKey string
|
||||
err string
|
||||
scope []string
|
||||
refreshToken string
|
||||
}{
|
||||
// control test case to make sure fixtures check out
|
||||
{
|
||||
signer: signerFixture,
|
||||
argCC: ccFixture,
|
||||
argKey: keyFixture,
|
||||
scope: []string{"openid", "offline_access"},
|
||||
refreshToken: "0/refresh-1",
|
||||
},
|
||||
|
||||
// no 'offline_access' in 'scope', should get empty refresh token
|
||||
{
|
||||
signer: signerFixture,
|
||||
argCC: ccFixture,
|
||||
argKey: keyFixture,
|
||||
scope: []string{"openid"},
|
||||
},
|
||||
|
||||
// unrecognized key
|
||||
|
@ -393,6 +424,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
argCC: ccFixture,
|
||||
argKey: "foo",
|
||||
err: oauth2.ErrorInvalidGrant,
|
||||
scope: []string{"openid", "offline_access"},
|
||||
},
|
||||
|
||||
// unrecognized client
|
||||
|
@ -401,6 +433,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
argCC: oidc.ClientCredentials{ID: "YYY"},
|
||||
argKey: keyFixture,
|
||||
err: oauth2.ErrorInvalidClient,
|
||||
scope: []string{"openid", "offline_access"},
|
||||
},
|
||||
|
||||
// signing operation fails
|
||||
|
@ -409,6 +442,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
argCC: ccFixture,
|
||||
argKey: keyFixture,
|
||||
err: oauth2.ErrorServerError,
|
||||
scope: []string{"openid", "offline_access"},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -416,7 +450,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
|
||||
sm.GenerateCode = func() (string, error) { return keyFixture, nil }
|
||||
|
||||
sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, nil)
|
||||
sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
@ -435,7 +469,7 @@ func TestServerTokenFail(t *testing.T) {
|
|||
|
||||
_, err = sm.AttachUser(sessionID, "testid-1")
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: Unexpected error: %v", i, err)
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
userRepo, err := makeNewUserRepo()
|
||||
|
@ -463,22 +497,22 @@ func TestServerTokenFail(t *testing.T) {
|
|||
}
|
||||
|
||||
jwt, token, err := srv.CodeToken(tt.argCC, tt.argKey)
|
||||
if token != tt.refreshToken {
|
||||
fmt.Printf("case %d: expect refresh token %q, got %q\n", i, tt.refreshToken, token)
|
||||
t.Fatalf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
|
||||
panic("")
|
||||
}
|
||||
if tt.err == "" {
|
||||
if err != nil {
|
||||
t.Errorf("case %d: got non-nil error: %v", i, err)
|
||||
} else if jwt == nil {
|
||||
t.Errorf("case %d: got nil JWT", i)
|
||||
} else if token == "" {
|
||||
t.Errorf("case %d: got empty refresh token", i)
|
||||
}
|
||||
|
||||
} else {
|
||||
if err.Error() != tt.err {
|
||||
t.Errorf("case %d: want err %q, got %q", i, tt.err, err.Error())
|
||||
} else if jwt != nil {
|
||||
t.Errorf("case %d: got non-nil JWT", i)
|
||||
} else if token != "" {
|
||||
t.Errorf("case %d: got non-empty refresh token", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ func staticGenerateCodeFunc(code string) GenerateCodeFunc {
|
|||
func TestSessionManagerNewSession(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sm.GenerateCode = staticGenerateCodeFunc("boo")
|
||||
got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, nil)
|
||||
got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ func TestSessionManagerNewSession(t *testing.T) {
|
|||
|
||||
func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, nil)
|
||||
sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
@ -44,7 +44,7 @@ func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
|
|||
|
||||
func TestSessionManagerExchangeKey(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, nil)
|
||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
|
|||
|
||||
func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, nil)
|
||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
@ -95,7 +95,7 @@ func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
|
|||
|
||||
func TestSessionManagerKill(t *testing.T) {
|
||||
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
|
||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, nil)
|
||||
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue