Merge pull request #38 from coreos/oidc-connector

Add OpenID Connect connector
This commit is contained in:
Eric Chiang 2016-08-08 11:52:11 -07:00 committed by GitHub
commit b31dedc2b6
14 changed files with 388 additions and 115 deletions

View File

@ -7,6 +7,7 @@ import (
"github.com/coreos/poke/connector/github" "github.com/coreos/poke/connector/github"
"github.com/coreos/poke/connector/ldap" "github.com/coreos/poke/connector/ldap"
"github.com/coreos/poke/connector/mock" "github.com/coreos/poke/connector/mock"
"github.com/coreos/poke/connector/oidc"
"github.com/coreos/poke/storage" "github.com/coreos/poke/storage"
"github.com/coreos/poke/storage/kubernetes" "github.com/coreos/poke/storage/kubernetes"
"github.com/coreos/poke/storage/memory" "github.com/coreos/poke/storage/memory"
@ -100,33 +101,34 @@ func (c *Connector) UnmarshalYAML(unmarshal func(interface{}) error) error {
c.Name = connectorMetadata.Name c.Name = connectorMetadata.Name
c.ID = connectorMetadata.ID c.ID = connectorMetadata.ID
var err error
switch c.Type { switch c.Type {
case "mock": case "mock":
var config struct { var config struct {
Config mock.Config `yaml:"config"` Config mock.Config `yaml:"config"`
} }
if err := unmarshal(&config); err != nil { err = unmarshal(&config)
return err
}
c.Config = &config.Config c.Config = &config.Config
case "ldap": case "ldap":
var config struct { var config struct {
Config ldap.Config `yaml:"config"` Config ldap.Config `yaml:"config"`
} }
if err := unmarshal(&config); err != nil { err = unmarshal(&config)
return err
}
c.Config = &config.Config c.Config = &config.Config
case "github": case "github":
var config struct { var config struct {
Config github.Config `yaml:"config"` Config github.Config `yaml:"config"`
} }
if err := unmarshal(&config); err != nil { err = unmarshal(&config)
return err c.Config = &config.Config
case "oidc":
var config struct {
Config oidc.Config `yaml:"config"`
} }
err = unmarshal(&config)
c.Config = &config.Config c.Config = &config.Config
default: default:
return fmt.Errorf("unknown connector type %q", c.Type) return fmt.Errorf("unknown connector type %q", c.Type)
} }
return nil return err
} }

View File

@ -1,2 +1,133 @@
// Package oidc implements logging in through OpenID Connect providers. // Package oidc implements logging in through OpenID Connect providers.
package oidc package oidc
import (
"errors"
"fmt"
"net/http"
"os"
"github.com/ericchiang/oidc"
"golang.org/x/net/context"
"golang.org/x/oauth2"
"github.com/coreos/poke/connector"
)
// Config holds configuration options for OpenID Connect logins.
type Config struct {
Issuer string `yaml:"issuer"`
ClientID string `yaml:"clientID"`
ClientSecret string `yaml:"clientSecret"`
RedirectURI string `yaml:"redirectURI"`
Scopes []string `yaml:"scopes"` // defaults to "profile" and "email"
}
// Open returns a connector which can be used to login users through an upstream
// OpenID Connect provider.
func (c *Config) Open() (conn connector.Connector, err error) {
ctx, cancel := context.WithCancel(context.Background())
provider, err := oidc.NewProvider(ctx, c.Issuer)
if err != nil {
cancel()
return nil, fmt.Errorf("failed to get provider: %v", err)
}
scopes := []string{oidc.ScopeOpenID}
if len(c.Scopes) > 0 {
scopes = append(scopes, c.Scopes...)
} else {
scopes = append(scopes, "profile", "email")
}
clientID := os.ExpandEnv(c.ClientID)
return &oidcConnector{
redirectURI: c.RedirectURI,
oauth2Config: &oauth2.Config{
ClientID: clientID,
ClientSecret: os.ExpandEnv(c.ClientSecret),
Endpoint: provider.Endpoint(),
Scopes: scopes,
RedirectURL: c.RedirectURI,
},
verifier: provider.NewVerifier(ctx,
oidc.VerifyExpiry(),
oidc.VerifyAudience(clientID),
),
}, nil
}
var (
_ connector.CallbackConnector = (*oidcConnector)(nil)
)
type oidcConnector struct {
redirectURI string
oauth2Config *oauth2.Config
verifier *oidc.IDTokenVerifier
ctx context.Context
cancel context.CancelFunc
}
func (c *oidcConnector) Close() error {
c.cancel()
return nil
}
func (c *oidcConnector) LoginURL(callbackURL, state string) (string, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL did not match the URL in the config")
}
return c.oauth2Config.AuthCodeURL(state), nil
}
type oauth2Error struct {
error string
errorDescription string
}
func (e *oauth2Error) Error() string {
if e.errorDescription == "" {
return e.error
}
return e.error + ": " + e.errorDescription
}
func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, state string, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, "", &oauth2Error{errType, q.Get("error_description")}
}
token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code"))
if err != nil {
return identity, "", fmt.Errorf("oidc: failed to get token: %v", err)
}
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return identity, "", errors.New("oidc: no id_token in token response")
}
idToken, err := c.verifier.Verify(rawIDToken)
if err != nil {
return identity, "", fmt.Errorf("oidc: failed to verify ID Token: %v", err)
}
var claims struct {
Username string `json:"name"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}
if err := idToken.Claims(&claims); err != nil {
return identity, "", fmt.Errorf("oidc: failed to decode claims: %v", err)
}
identity = connector.Identity{
UserID: idToken.Subject,
Username: claims.Username,
Email: claims.Email,
EmailVerified: claims.EmailVerified,
}
return identity, q.Get("state"), nil
}

View File

@ -18,6 +18,14 @@ connectors:
clientSecret: "$GITHUB_CLIENT_SECRET" clientSecret: "$GITHUB_CLIENT_SECRET"
redirectURI: http://127.0.0.1:5556/callback/github redirectURI: http://127.0.0.1:5556/callback/github
org: kubernetes org: kubernetes
- type: oidc
id: google
name: Google Account
config:
issuer: https://accounts.google.com
clientID: "$GOOGLE_OAUTH2_CLIENT_ID"
clientSecret: "$GOOGLE_OAUTH2_CLIENT_SECRET"
redirectURI: http://127.0.0.1:5556/callback/google
staticClients: staticClients:
- id: example-app - id: example-app

6
glide.lock generated
View File

@ -1,8 +1,8 @@
hash: 4442a097b81856345ae5f80101ad1a692a0b4e5d9b7627f5ad09cd20926122f4 hash: 2af4a276277d2ab2ba9de9b0fd67ab7d6b70c07f4171a9efb225f30306d6f3eb
updated: 2016-08-05T09:58:15.61704222-07:00 updated: 2016-08-08T11:20:44.300140564-07:00
imports: imports:
- name: github.com/ericchiang/oidc - name: github.com/ericchiang/oidc
version: 69fec81d167d815f4f455c741b2a94ffaf547ed2 version: 1907f0e61549f9081f26bdf269f11603496c9dee
- name: github.com/golang/protobuf - name: github.com/golang/protobuf
version: 874264fbbb43f4d91e999fecb4b40143ed611400 version: 874264fbbb43f4d91e999fecb4b40143ed611400
subpackages: subpackages:

View File

@ -43,7 +43,7 @@ import:
- bcrypt - bcrypt
- package: github.com/ericchiang/oidc - package: github.com/ericchiang/oidc
version: 69fec81d167d815f4f455c741b2a94ffaf547ed2 version: 1907f0e61549f9081f26bdf269f11603496c9dee
- package: github.com/pquerna/cachecontrol - package: github.com/pquerna/cachecontrol
version: c97913dcbd76de40b051a9b4cd827f7eaeb7a868 version: c97913dcbd76de40b051a9b4cd827f7eaeb7a868
- package: gopkg.in/square/go-jose.v1 - package: gopkg.in/square/go-jose.v1

13
vendor/github.com/ericchiang/oidc/.travis.yml generated vendored Normal file
View File

@ -0,0 +1,13 @@
language: go
go:
- 1.5.4
- 1.6.3
- tip
notifications:
email: false
matrix:
allow_failures:
- go: tip

View File

@ -58,6 +58,13 @@ Or the provider can be used to verify and inspect the OpenID Connect
verifier := provider.NewVerifier(ctx) verifier := provider.NewVerifier(ctx)
``` ```
The verifier itself can be constructed with addition checks, such as verifing a
token was issued for a specific client or hasn't expired.
```go
verifier := provier.NewVerifier(ctx, oidc.VerifyAudience(clientID), oidc.VerifyExpiry())
```
The returned verifier can be used to ensure the ID Token (a JWT) is signed by the provider. The returned verifier can be used to ensure the ID Token (a JWT) is signed by the provider.
```go ```go
@ -78,19 +85,19 @@ func handleOAuth2Callback(w http.ResponseWriter, r *http.Request) {
} }
// Verify that the ID Token is signed by the provider. // Verify that the ID Token is signed by the provider.
payload, err := verifier.Verify(rawIDToken) idToken, err := verifier.Verify(rawIDToken)
if err != nil { if err != nil {
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
return return
} }
// Unmarshal ID Token for expected custom claims. // Unmarshal ID Token for expected custom claims.
var idToken struct { var claims struct {
Email string `json:"email"` Email string `json:"email"`
EmailVerified bool `json:"email_verified"` EmailVerified bool `json:"email_verified"`
} }
if err := json.Unmarshal(payload, &idToken); err != nil { if err := idToken.Claims(&claims); err != nil {
http.Error(w, "Failed to unmarshal ID Token: "+err.Error(), http.StatusInternalServerError) http.Error(w, "Failed to unmarshal ID Token claims: "+err.Error(), http.StatusInternalServerError)
return return
} }

View File

@ -65,19 +65,19 @@ including verifying the JWT signature. It then returns the payload.
} }
// Verify that the ID Token is signed by the provider. // Verify that the ID Token is signed by the provider.
payload, err := verifier.Verify(rawIDToken) idToken, err := verifier.Verify(rawIDToken)
if err != nil { if err != nil {
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
return return
} }
// Unmarshal ID Token for expected custom claims. // Unmarshal ID Token for expected custom claims.
var idToken struct { var claims struct {
Email string `json:"email"` Email string `json:"email"`
EmailVerified bool `json:"email_verified"` EmailVerified bool `json:"email_verified"`
} }
if err := json.Unmarshal(payload, &idToken); err != nil { if err := idToken.Claims(&claims); err != nil {
http.Error(w, "Failed to unmarshal ID Token: "+err.Error(), http.StatusInternalServerError) http.Error(w, "Failed to unmarshal ID Token custom claims: "+err.Error(), http.StatusInternalServerError)
return return
} }
@ -123,7 +123,7 @@ The nonce enabled verifier can then be used to verify the nonce while unpacking
} }
// Verify that the ID Token is signed by the provider and verify the nonce. // Verify that the ID Token is signed by the provider and verify the nonce.
payload, err := nonceEnabledVerifier.Verify(rawIDToken) idToken, err := nonceEnabledVerifier.Verify(rawIDToken)
if err != nil { if err != nil {
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
return return

View File

@ -59,8 +59,7 @@ func main() {
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError) http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
return return
} }
log.Println(rawIDToken) idToken, err := verifier.Verify(rawIDToken)
idTokenPayload, err := verifier.Verify(rawIDToken)
if err != nil { if err != nil {
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
return return
@ -68,11 +67,15 @@ func main() {
oauth2Token.AccessToken = "*REDACTED*" oauth2Token.AccessToken = "*REDACTED*"
rawMessage := json.RawMessage(idTokenPayload)
resp := struct { resp := struct {
OAuth2Token *oauth2.Token OAuth2Token *oauth2.Token
IDTokenClaims *json.RawMessage // ID Token payload is just JSON. IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
}{oauth2Token, &rawMessage} }{oauth2Token, new(json.RawMessage)}
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
data, err := json.MarshalIndent(resp, "", " ") data, err := json.MarshalIndent(resp, "", " ")
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)

View File

@ -69,23 +69,28 @@ func main() {
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
return return
} }
rawIDToken, ok := oauth2Token.Extra("id_token").(string) rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok { if !ok {
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError) http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
return return
} }
// Verify the ID Token signature and nonce. // Verify the ID Token signature and nonce.
idTokenPayload, err := nonceEnabledVerifier.Verify(rawIDToken) idToken, err := nonceEnabledVerifier.Verify(rawIDToken)
if err != nil { if err != nil {
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
return return
} }
rawMessage := json.RawMessage(idTokenPayload)
resp := struct { resp := struct {
OAuth2Token *oauth2.Token OAuth2Token *oauth2.Token
IDToken *json.RawMessage // ID Token payload is just JSON. IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
}{oauth2Token, &rawMessage} }{oauth2Token, new(json.RawMessage)}
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
data, err := json.MarshalIndent(resp, "", " ") data, err := json.MarshalIndent(resp, "", " ")
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)

View File

@ -1,9 +1,7 @@
package oidc package oidc
import ( import (
"encoding/json"
"errors" "errors"
"fmt"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@ -29,13 +27,7 @@ type nonceVerifier struct {
nonceSource NonceSource nonceSource NonceSource
} }
func (n nonceVerifier) verifyIDTokenPayload(payload []byte) error { func (n nonceVerifier) verifyIDToken(token *IDToken) error {
var token struct {
Nonce string `json:"nonce"`
}
if err := json.Unmarshal(payload, &token); err != nil {
return fmt.Errorf("oidc: failed to unmarshal nonce: %v", err)
}
if token.Nonce == "" { if token.Nonce == "" {
return errors.New("oidc: no nonce present in ID Token") return errors.New("oidc: no nonce present in ID Token")
} }

View File

@ -15,9 +15,9 @@ import (
var ( var (
// ErrTokenExpired indicates that a token parsed by a verifier has expired. // ErrTokenExpired indicates that a token parsed by a verifier has expired.
ErrTokenExpired = errors.New("ID Token expired") ErrTokenExpired = errors.New("oidc: ID Token expired")
// ErrNotSupported indicates that the requested optional OpenID Connect endpoint is not supported by the provider. // ErrNotSupported indicates that the requested optional OpenID Connect endpoint is not supported by the provider.
ErrNotSupported = errors.New("endpoint not supported") ErrNotSupported = errors.New("oidc: endpoint not supported")
) )
const ( const (
@ -44,8 +44,8 @@ type Provider struct {
JWKSURL string `json:"jwks_uri"` JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"` UserInfoURL string `json:"userinfo_endpoint"`
// Optionally contains extra claims. // Raw claims returned by the server.
raw map[string]interface{} rawClaims []byte
} }
// NewProvider uses the OpenID Connect disovery mechanism to construct a Provider. // NewProvider uses the OpenID Connect disovery mechanism to construct a Provider.
@ -67,20 +67,19 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
if err := json.Unmarshal(body, &p); err != nil { if err := json.Unmarshal(body, &p); err != nil {
return nil, fmt.Errorf("oidc: failed to decode provider discovery object: %v", err) return nil, fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
} }
// raw claims do not get error checks p.rawClaims = body
json.Unmarshal(body, &p.raw)
if p.Issuer != issuer { if p.Issuer != issuer {
return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuer, p.Issuer) return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuer, p.Issuer)
} }
return &p, nil return &p, nil
} }
// Extra returns additional fields returned by the server during discovery. // Claims returns additional fields returned by the server during discovery.
func (p *Provider) Extra(key string) interface{} { func (p *Provider) Claims(v interface{}) error {
if p.raw != nil { if p.rawClaims == nil {
return p.raw[key] return errors.New("oidc: claims not set")
} }
return nil return json.Unmarshal(p.rawClaims, v)
} }
// Endpoint returns the OAuth2 auth and token endpoints for the given provider. // Endpoint returns the OAuth2 auth and token endpoints for the given provider.
@ -95,16 +94,15 @@ type UserInfo struct {
Email string `json:"email"` Email string `json:"email"`
EmailVerified bool `json:"email_verified"` EmailVerified bool `json:"email_verified"`
// Optionally contains extra claims. claims []byte
raw map[string]interface{}
} }
// Extra returns additional claims returned by the server. // Claims unmarshals the raw JSON object claims into the provided object.
func (u *UserInfo) Extra(key string) interface{} { func (u *UserInfo) Claims(v interface{}) error {
if u.raw != nil { if u.claims == nil {
return u.raw[key] return errors.New("oidc: claims not set")
} }
return nil return json.Unmarshal(u.claims, v)
} }
// UserInfo uses the token source to query the provider's user info endpoint. // UserInfo uses the token source to query the provider's user info endpoint.
@ -130,11 +128,101 @@ func (p *Provider) UserInfo(ctx context.Context, tokenSource oauth2.TokenSource)
if err := json.Unmarshal(body, &userInfo); err != nil { if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("oidc: failed to decode userinfo: %v", err) return nil, fmt.Errorf("oidc: failed to decode userinfo: %v", err)
} }
// raw claims do not get error checks userInfo.claims = body
json.Unmarshal(body, &userInfo.raw)
return &userInfo, nil return &userInfo, nil
} }
// IDToken is an OpenID Connect extension that provides a predictable representation
// of an authorization event.
//
// The ID Token only holds fields OpenID Connect requires. To access additional
// claims returned by the server, use the Claims method.
//
// idToken, err := idTokenVerifier.Verify(rawIDToken)
// if err != nil {
// // handle error
// }
// var claims struct {
// Email string `json:"email"`
// EmailVerified bool `json:"email_verified"`
// }
// if err := idToken.Claims(&claims); err != nil {
// // handle error
// }
//
type IDToken struct {
// The URL of the server which issued this token. This will always be the same
// as the URL used for initial discovery.
Issuer string
// The client, or set of clients, that this token is issued for.
Audience []string
// A unique string which identifies the end user.
Subject string
IssuedAt time.Time
Expiry time.Time
Nonce string
claims []byte
}
// Claims unmarshals the raw JSON payload of the ID Token into a provided struct.
func (i *IDToken) Claims(v interface{}) error {
if i.claims == nil {
return errors.New("oidc: claims not set")
}
return json.Unmarshal(i.claims, v)
}
type audience []string
func (a *audience) UnmarshalJSON(b []byte) error {
var s string
if json.Unmarshal(b, &s) == nil {
*a = audience{s}
return nil
}
var auds []string
if err := json.Unmarshal(b, &auds); err != nil {
return err
}
*a = audience(auds)
return nil
}
type jsonTime time.Time
func (j *jsonTime) UnmarshalJSON(b []byte) error {
var n json.Number
if err := json.Unmarshal(b, &n); err != nil {
return err
}
var unix int64
if t, err := n.Int64(); err == nil {
unix = t
} else {
f, err := n.Float64()
if err != nil {
return err
}
unix = int64(f)
}
*j = jsonTime(time.Unix(unix, 0))
return nil
}
type idToken struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Audience audience `json:"aud"`
Expiry jsonTime `json:"exp"`
IssuedAt jsonTime `json:"iat"`
Nonce string `json:"nonce"`
}
// IDTokenVerifier provides verification for ID Tokens. // IDTokenVerifier provides verification for ID Tokens.
type IDTokenVerifier struct { type IDTokenVerifier struct {
issuer string issuer string
@ -143,31 +231,34 @@ type IDTokenVerifier struct {
} }
// Verify parse the raw ID Token, verifies it's been signed by the provider, preforms // Verify parse the raw ID Token, verifies it's been signed by the provider, preforms
// additional verification, such as checking the expiration, and returns the claims. // additional verification, and returns the claims.
func (v *IDTokenVerifier) Verify(rawIDToken string) (payload []byte, err error) { func (v *IDTokenVerifier) Verify(rawIDToken string) (*IDToken, error) {
payload, err = v.keySet.verifyJWT(rawIDToken) payload, err := v.keySet.verifyJWT(rawIDToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var token struct { var token idToken
Exp float64 `json:"exp"` // JSON numbers are always float64s.
Issuer string `json:"iss"`
}
if err := json.Unmarshal(payload, &token); err != nil { if err := json.Unmarshal(payload, &token); err != nil {
return nil, fmt.Errorf("oidc: failed to unmarshal claims: %v", err) return nil, fmt.Errorf("oidc: failed to unmarshal claims: %v", err)
} }
if v.issuer != token.Issuer { if v.issuer != token.Issuer {
return nil, fmt.Errorf("oidc: iss field did not match provider issuer") return nil, fmt.Errorf("oidc: iss field did not match provider issuer")
} }
if time.Unix(int64(token.Exp), 0).Before(time.Now().Round(time.Second)) { t := &IDToken{
return nil, ErrTokenExpired Issuer: token.Issuer,
Subject: token.Subject,
Audience: []string(token.Audience),
Expiry: time.Time(token.Expiry),
IssuedAt: time.Time(token.Expiry),
Nonce: token.Nonce,
claims: payload,
} }
for _, option := range v.options { for _, option := range v.options {
if err := option.verifyIDTokenPayload(payload); err != nil { if err := option.verifyIDToken(t); err != nil {
return nil, err return nil, err
} }
} }
return payload, nil return t, nil
} }
// NewVerifier returns an IDTokenVerifier that uses the provider's key set to verify JWTs. // NewVerifier returns an IDTokenVerifier that uses the provider's key set to verify JWTs.
@ -184,7 +275,7 @@ func (p *Provider) NewVerifier(ctx context.Context, options ...VerificationOptio
// VerificationOption is an option provided to Provider.NewVerifier. // VerificationOption is an option provided to Provider.NewVerifier.
type VerificationOption interface { type VerificationOption interface {
verifyIDTokenPayload(raw []byte) error verifyIDToken(token *IDToken) error
} }
// VerifyAudience ensures that an ID Token was issued for the specific client. // VerifyAudience ensures that an ID Token was issued for the specific client.
@ -199,25 +290,8 @@ type clientVerifier struct {
clientID string clientID string
} }
func (c clientVerifier) verifyIDTokenPayload(payload []byte) error { func (c clientVerifier) verifyIDToken(token *IDToken) error {
var token struct { for _, aud := range token.Audience {
Aud string `json:"aud"`
}
if err := json.Unmarshal(payload, &token); err == nil {
if token.Aud != c.clientID {
return errors.New("oidc: id token aud field did not match client_id")
}
return nil
}
// Aud can optionally be an array of strings
var token2 struct {
Aud []string `json:"aud"`
}
if err := json.Unmarshal(payload, &token2); err != nil {
return fmt.Errorf("oidc: failed to unmarshal aud claim: %v", err)
}
for _, aud := range token2.Aud {
if aud == c.clientID { if aud == c.clientID {
return nil return nil
} }
@ -225,6 +299,22 @@ func (c clientVerifier) verifyIDTokenPayload(payload []byte) error {
return errors.New("oidc: id token aud field did not match client_id") return errors.New("oidc: id token aud field did not match client_id")
} }
// VerifyExpiry ensures that an ID Token has not expired.
func VerifyExpiry() VerificationOption {
return expiryVerifier{time.Now}
}
type expiryVerifier struct {
now func() time.Time
}
func (e expiryVerifier) verifyIDToken(token *IDToken) error {
if e.now().After(token.Expiry) {
return ErrTokenExpired
}
return nil
}
// This method is internal to golang.org/x/oauth2. Just copy it. // This method is internal to golang.org/x/oauth2. Just copy it.
func contextClient(ctx context.Context) *http.Client { func contextClient(ctx context.Context) *http.Client {
if ctx != nil { if ctx != nil {

View File

@ -1,49 +1,40 @@
package oidc package oidc
import "testing" import (
"encoding/json"
"reflect"
"testing"
)
func TestClientVerifier(t *testing.T) { func TestClientVerifier(t *testing.T) {
tests := []struct { tests := []struct {
clientID string clientID string
payload string aud []string
wantErr bool wantErr bool
}{ }{
{ {
clientID: "1", clientID: "1",
payload: `{"aud":"1"}`, aud: []string{"1"},
}, },
{ {
clientID: "1", clientID: "1",
payload: `{"aud":"2"}`, aud: []string{"2"},
wantErr: true, wantErr: true,
}, },
{ {
clientID: "1", clientID: "1",
payload: `{"aud":["1"]}`, aud: []string{"2", "1"},
},
{
clientID: "1",
payload: `{"aud":["1", "2"]}`,
}, },
{ {
clientID: "3", clientID: "3",
payload: `{"aud":["1", "2"]}`, aud: []string{"1", "2"},
wantErr: true,
},
{
clientID: "3",
payload: `{"aud":}`, // invalid JSON
wantErr: true,
},
{
clientID: "1",
payload: `{}`,
wantErr: true, wantErr: true,
}, },
} }
for i, tc := range tests { for i, tc := range tests {
err := (clientVerifier{tc.clientID}).verifyIDTokenPayload([]byte(tc.payload)) token := IDToken{Audience: tc.aud}
err := (clientVerifier{tc.clientID}).verifyIDToken(&token)
if err != nil && !tc.wantErr { if err != nil && !tc.wantErr {
t.Errorf("case %d: %v", i) t.Errorf("case %d: %v", i)
} }
@ -52,3 +43,34 @@ func TestClientVerifier(t *testing.T) {
} }
} }
} }
func TestUnmarshalAudience(t *testing.T) {
tests := []struct {
data string
want audience
wantErr bool
}{
{`"foo"`, audience{"foo"}, false},
{`["foo","bar"]`, audience{"foo", "bar"}, false},
{"foo", nil, true}, // invalid JSON
}
for _, tc := range tests {
var a audience
if err := json.Unmarshal([]byte(tc.data), &a); err != nil {
if !tc.wantErr {
t.Errorf("failed to unmarshal %q: %v", tc.data, err)
}
continue
}
if tc.wantErr {
t.Errorf("did not expected to be able to unmarshal %q", tc.data)
continue
}
if !reflect.DeepEqual(tc.want, a) {
t.Errorf("from %q expected %q got %q", tc.data, tc.want, a)
}
}
}

View File

@ -3,7 +3,6 @@ package main
import ( import (
"crypto/rand" "crypto/rand"
"encoding/gob" "encoding/gob"
"encoding/json"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@ -192,7 +191,7 @@ func handleCallback(w http.ResponseWriter, r *http.Request) {
return httpError(http.StatusInternalServerError, "Authentication failed") return httpError(http.StatusInternalServerError, "Authentication failed")
} }
payload, err := verifier.Verify(rawIDToken) idToken, err := verifier.Verify(rawIDToken)
if err != nil { if err != nil {
log.Printf("Failed to verify token: %v", err) log.Printf("Failed to verify token: %v", err)
return httpError(http.StatusInternalServerError, "Authentication failed") return httpError(http.StatusInternalServerError, "Authentication failed")
@ -201,7 +200,8 @@ func handleCallback(w http.ResponseWriter, r *http.Request) {
Email string `json:"email"` Email string `json:"email"`
EmailVerified bool `json:"email_verified"` EmailVerified bool `json:"email_verified"`
} }
if err := json.Unmarshal(payload, &claims); err != nil {
if err := idToken.Claims(&claims); err != nil {
log.Printf("Failed to decode claims: %v", err) log.Printf("Failed to decode claims: %v", err)
return httpError(http.StatusInternalServerError, "Authentication failed") return httpError(http.StatusInternalServerError, "Authentication failed")
} }