server: add tests for refreshing with explicit scopes

This commit is contained in:
Eric Chiang 2016-10-10 11:02:27 -07:00
parent 8518c30123
commit ac6e419d48
3 changed files with 183 additions and 97 deletions

View file

@ -538,20 +538,25 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
scopes := refresh.Scopes scopes := refresh.Scopes
if scope != "" { if scope != "" {
requestedScopes := strings.Split(scope, " ") requestedScopes := strings.Split(scope, " ")
contains := func() bool { var unauthorizedScopes []string
Loop:
for _, s := range requestedScopes { for _, s := range requestedScopes {
contains := func() bool {
for _, scope := range refresh.Scopes { for _, scope := range refresh.Scopes {
if s == scope { if s == scope {
continue Loop return true
} }
} }
return false return false
}
return true
}() }()
if !contains { if !contains {
tokenErr(w, errInvalidRequest, "Requested scopes did not contain authorized scopes.", http.StatusBadRequest) unauthorizedScopes = append(unauthorizedScopes, s)
}
}
if len(unauthorizedScopes) > 0 {
msg := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes)
tokenErr(w, errInvalidRequest, msg, http.StatusBadRequest)
return return
} }
scopes = requestedScopes scopes = requestedScopes

View file

@ -52,6 +52,7 @@ func tokenErr(w http.ResponseWriter, typ, description string, statusCode int) {
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(body))) w.Header().Set("Content-Length", strconv.Itoa(len(body)))
w.WriteHeader(statusCode)
w.Write(body) w.Write(body)
} }

View file

@ -131,6 +131,99 @@ func TestDiscovery(t *testing.T) {
} }
func TestOAuth2CodeFlow(t *testing.T) { func TestOAuth2CodeFlow(t *testing.T) {
clientID := "testclient"
clientSecret := "testclientsecret"
requestedScopes := []string{oidc.ScopeOpenID, "email", "offline_access"}
tests := []struct {
name string
handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token) error
}{
{
name: "verify ID Token",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
idToken, ok := token.Extra("id_token").(string)
if !ok {
return fmt.Errorf("no id token found")
}
if _, err := p.NewVerifier(ctx).Verify(idToken); err != nil {
return fmt.Errorf("failed to verify id token: %v", err)
}
return nil
},
},
{
name: "refresh token",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
// have to use time.Now because the OAuth2 package uses it.
token.Expiry = time.Now().Add(time.Second * -10)
if token.Valid() {
return errors.New("token shouldn't be valid")
}
newToken, err := config.TokenSource(ctx, token).Token()
if err != nil {
return fmt.Errorf("failed to refresh token: %v", err)
}
if token.RefreshToken == newToken.RefreshToken {
return fmt.Errorf("old refresh token was the same as the new token %q", token.RefreshToken)
}
return nil
},
},
{
name: "refresh with explicit scopes",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
v := url.Values{}
v.Add("client_id", clientID)
v.Add("client_secret", clientSecret)
v.Add("grant_type", "refresh_token")
v.Add("refresh_token", token.RefreshToken)
v.Add("scope", strings.Join(requestedScopes, " "))
resp, err := http.PostForm(p.TokenURL, v)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
dump, err := httputil.DumpResponse(resp, true)
if err != nil {
panic(err)
}
return fmt.Errorf("unexpected response: %s", dump)
}
return nil
},
},
{
name: "refresh with unauthorized scopes",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
v := url.Values{}
v.Add("client_id", clientID)
v.Add("client_secret", clientSecret)
v.Add("grant_type", "refresh_token")
v.Add("refresh_token", token.RefreshToken)
// Request a scope that wasn't requestd initially.
v.Add("scope", strings.Join(append(requestedScopes, "profile"), " "))
resp, err := http.PostForm(p.TokenURL, v)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
dump, err := httputil.DumpResponse(resp, true)
if err != nil {
panic(err)
}
return fmt.Errorf("unexpected response: %s", dump)
}
return nil
},
},
}
for _, tc := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -176,27 +269,12 @@ func TestOAuth2CodeFlow(t *testing.T) {
t.Errorf("failed to exchange code for token: %v", err) t.Errorf("failed to exchange code for token: %v", err)
return return
} }
idToken, ok := token.Extra("id_token").(string) err = tc.handleToken(ctx, p, oauth2Config, token)
if !ok {
t.Errorf("no id token found: %v", err)
return
}
// TODO(ericchiang): validate id token
_ = idToken
token.Expiry = time.Now().Add(time.Second * -10)
if token.Valid() {
t.Errorf("token shouldn't be valid")
}
newToken, err := oauth2Config.TokenSource(ctx, token).Token()
if err != nil { if err != nil {
t.Errorf("failed to refresh token: %v", err) t.Errorf("%s: %v", tc.name, err)
}
return return
}
if token.RefreshToken == newToken.RefreshToken {
t.Errorf("old refresh token was the same as the new token %q", token.RefreshToken)
}
} }
if gotState := q.Get("state"); gotState != state { if gotState := q.Get("state"); gotState != state {
t.Errorf("state did not match, want=%q got=%q", state, gotState) t.Errorf("state did not match, want=%q got=%q", state, gotState)
@ -211,8 +289,8 @@ func TestOAuth2CodeFlow(t *testing.T) {
redirectURL := oauth2Server.URL + "/callback" redirectURL := oauth2Server.URL + "/callback"
client := storage.Client{ client := storage.Client{
ID: "testclient", ID: clientID,
Secret: "testclientsecret", Secret: clientSecret,
RedirectURIs: []string{redirectURL}, RedirectURIs: []string{redirectURL},
} }
if err := s.storage.CreateClient(client); err != nil { if err := s.storage.CreateClient(client); err != nil {
@ -223,7 +301,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
ClientID: client.ID, ClientID: client.ID,
ClientSecret: client.Secret, ClientSecret: client.Secret,
Endpoint: p.Endpoint(), Endpoint: p.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "profile", "email", "offline_access"}, Scopes: requestedScopes,
RedirectURL: redirectURL, RedirectURL: redirectURL,
} }
@ -237,6 +315,8 @@ func TestOAuth2CodeFlow(t *testing.T) {
if respDump, err = httputil.DumpResponse(resp, true); err != nil { if respDump, err = httputil.DumpResponse(resp, true); err != nil {
t.Fatal(err) t.Fatal(err)
} }
}()
}
} }
type nonceSource struct { type nonceSource struct {