Allow CORS on keys and token endpoints

This commit is contained in:
Simon HEGE 2017-01-14 10:18:48 +01:00
parent ca7d2b8f9e
commit 415a68f977
5 changed files with 43 additions and 106 deletions

View file

@ -99,11 +99,11 @@ type OAuth2 struct {
// Web is the config format for the HTTP server. // Web is the config format for the HTTP server.
type Web struct { type Web struct {
HTTP string `json:"http"` HTTP string `json:"http"`
HTTPS string `json:"https"` HTTPS string `json:"https"`
TLSCert string `json:"tlsCert"` TLSCert string `json:"tlsCert"`
TLSKey string `json:"tlsKey"` TLSKey string `json:"tlsKey"`
DiscoveryAllowedOrigins []string `json:"discoveryAllowedOrigins"` AllowedOrigins []string `json:"allowedOrigins"`
} }
// GRPC is the config for the gRPC API. // GRPC is the config for the gRPC API.

View file

@ -179,24 +179,24 @@ func serve(cmd *cobra.Command, args []string) error {
if c.OAuth2.SkipApprovalScreen { if c.OAuth2.SkipApprovalScreen {
logger.Infof("config skipping approval screen") logger.Infof("config skipping approval screen")
} }
if len(c.Web.DiscoveryAllowedOrigins) > 0 { if len(c.Web.AllowedOrigins) > 0 {
logger.Infof("config discovery allowed origins: %s", c.Web.DiscoveryAllowedOrigins) logger.Infof("config allowed origins: %s", c.Web.AllowedOrigins)
} }
// explicitly convert to UTC. // explicitly convert to UTC.
now := func() time.Time { return time.Now().UTC() } now := func() time.Time { return time.Now().UTC() }
serverConfig := server.Config{ serverConfig := server.Config{
SupportedResponseTypes: c.OAuth2.ResponseTypes, SupportedResponseTypes: c.OAuth2.ResponseTypes,
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
DiscoveryAllowedOrigins: c.Web.DiscoveryAllowedOrigins, AllowedOrigins: c.Web.AllowedOrigins,
Issuer: c.Issuer, Issuer: c.Issuer,
Connectors: connectors, Connectors: connectors,
Storage: s, Storage: s,
Web: c.Frontend, Web: c.Frontend,
EnablePasswordDB: c.EnablePasswordDB, EnablePasswordDB: c.EnablePasswordDB,
Logger: logger, Logger: logger,
Now: now, Now: now,
} }
if c.Expiry.SigningKeys != "" { if c.Expiry.SigningKeys != "" {
signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys) signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys)

View file

@ -12,7 +12,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/gorilla/handlers"
"github.com/gorilla/mux" "github.com/gorilla/mux"
jose "gopkg.in/square/go-jose.v2" jose "gopkg.in/square/go-jose.v2"
@ -104,7 +103,7 @@ type discovery struct {
Claims []string `json:"claims_supported"` Claims []string `json:"claims_supported"`
} }
func (s *Server) discoveryHandler() (http.Handler, error) { func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
d := discovery{ d := discovery{
Issuer: s.issuerURL.String(), Issuer: s.issuerURL.String(),
Auth: s.absURL("/auth"), Auth: s.absURL("/auth"),
@ -130,18 +129,11 @@ func (s *Server) discoveryHandler() (http.Handler, error) {
return nil, fmt.Errorf("failed to marshal discovery data: %v", err) return nil, fmt.Errorf("failed to marshal discovery data: %v", err)
} }
var discoveryHandler http.Handler return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
discoveryHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(data))) w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.Write(data) w.Write(data)
}) }), nil
if len(s.discoveryAllowedOrigins) > 0 {
corsOption := handlers.AllowedOrigins(s.discoveryAllowedOrigins)
discoveryHandler = handlers.CORS(corsOption)(discoveryHandler)
}
return discoveryHandler, nil
} }
// handleAuthorization handles the OAuth2 auth endpoint. // handleAuthorization handles the OAuth2 auth endpoint.

View file

@ -22,61 +22,3 @@ func TestHandleHealth(t *testing.T) {
} }
} }
var discoveryHandlerCORSTests = []struct {
DiscoveryAllowedOrigins []string
Origin string
ResponseAllowOrigin string //The expected response: same as Origin in case of valid CORS flow
}{
{nil, "http://foo.example", ""}, //Default behavior: cross origin requests not allowed
{[]string{}, "http://foo.example", ""},
{[]string{"http://foo.example"}, "http://foo.example", "http://foo.example"},
{[]string{"http://bar.example", "http://foo.example"}, "http://foo.example", "http://foo.example"},
{[]string{"*"}, "http://foo.example", "http://foo.example"},
{[]string{"http://bar.example"}, "http://foo.example", ""},
}
func TestDiscoveryHandlerCORS(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, testcase := range discoveryHandlerCORSTests {
httpServer, server := newTestServer(ctx, t, func(c *Config) {
c.DiscoveryAllowedOrigins = testcase.DiscoveryAllowedOrigins
})
defer httpServer.Close()
discoveryHandler, err := server.discoveryHandler()
if err != nil {
t.Fatalf("failed to get discovery handler: %v", err)
}
//Perform preflight request
rrPreflight := httptest.NewRecorder()
reqPreflight := httptest.NewRequest("OPTIONS", "/.well-kown/openid-configuration", nil)
reqPreflight.Header.Set("Origin", testcase.Origin)
reqPreflight.Header.Set("Access-Control-Request-Method", "GET")
discoveryHandler.ServeHTTP(rrPreflight, reqPreflight)
if rrPreflight.Code != http.StatusOK {
t.Errorf("expected 200 got %d", rrPreflight.Code)
}
headerAccessControlPreflight := rrPreflight.HeaderMap.Get("Access-Control-Allow-Origin")
if headerAccessControlPreflight != testcase.ResponseAllowOrigin {
t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControlPreflight)
}
//Perform request
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/.well-kown/openid-configuration", nil)
req.Header.Set("Origin", testcase.Origin)
discoveryHandler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200 got %d", rr.Code)
}
headerAccessControl := rr.HeaderMap.Get("Access-Control-Allow-Origin")
if headerAccessControl != testcase.ResponseAllowOrigin {
t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControl)
}
}
}

View file

@ -13,6 +13,7 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"github.com/Sirupsen/logrus" "github.com/Sirupsen/logrus"
"github.com/gorilla/handlers"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
@ -42,10 +43,10 @@ type Config struct {
// flow. If no response types are supplied this value defaults to "code". // flow. If no response types are supplied this value defaults to "code".
SupportedResponseTypes []string SupportedResponseTypes []string
// List of allowed origins for CORS requests on discovery endpoint. // List of allowed origins for CORS requests on discovery, token and keys endpoint.
// If none are indicated, CORS requests are disabled. Passing in "*" will allow any // If none are indicated, CORS requests are disabled. Passing in "*" will allow any
// domain. // domain.
DiscoveryAllowedOrigins []string AllowedOrigins []string
// If enabled, the server won't prompt the user to approve authorization requests. // If enabled, the server won't prompt the user to approve authorization requests.
// Logging in implies approval. // Logging in implies approval.
@ -116,8 +117,6 @@ type Server struct {
supportedResponseTypes map[string]bool supportedResponseTypes map[string]bool
discoveryAllowedOrigins []string
now func() time.Time now func() time.Time
idTokensValidFor time.Duration idTokensValidFor time.Duration
@ -185,16 +184,15 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
} }
s := &Server{ s := &Server{
issuerURL: *issuerURL, issuerURL: *issuerURL,
connectors: make(map[string]Connector), connectors: make(map[string]Connector),
storage: newKeyCacher(c.Storage, now), storage: newKeyCacher(c.Storage, now),
supportedResponseTypes: supported, supportedResponseTypes: supported,
discoveryAllowedOrigins: c.DiscoveryAllowedOrigins, idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), skipApproval: c.SkipApprovalScreen,
skipApproval: c.SkipApprovalScreen, now: now,
now: now, templates: tmpls,
templates: tmpls, logger: c.Logger,
logger: c.Logger,
} }
for _, conn := range c.Connectors { for _, conn := range c.Connectors {
@ -205,24 +203,29 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
handleFunc := func(p string, h http.HandlerFunc) { handleFunc := func(p string, h http.HandlerFunc) {
r.HandleFunc(path.Join(issuerURL.Path, p), h) r.HandleFunc(path.Join(issuerURL.Path, p), h)
} }
handle := func(p string, h http.Handler) {
r.Handle(path.Join(issuerURL.Path, p), h)
}
handlePrefix := func(p string, h http.Handler) { handlePrefix := func(p string, h http.Handler) {
prefix := path.Join(issuerURL.Path, p) prefix := path.Join(issuerURL.Path, p)
r.PathPrefix(prefix).Handler(http.StripPrefix(prefix, h)) r.PathPrefix(prefix).Handler(http.StripPrefix(prefix, h))
} }
handleWithCORS := func(p string, h http.HandlerFunc) {
var handler http.Handler = h
if len(c.AllowedOrigins) > 0 {
corsOption := handlers.AllowedOrigins(c.AllowedOrigins)
handler = handlers.CORS(corsOption)(handler)
}
r.Handle(path.Join(issuerURL.Path, p), handler)
}
r.NotFoundHandler = http.HandlerFunc(http.NotFound) r.NotFoundHandler = http.HandlerFunc(http.NotFound)
discoveryHandler, err := s.discoveryHandler() discoveryHandler, err := s.discoveryHandler()
if err != nil { if err != nil {
return nil, err return nil, err
} }
handle("/.well-known/openid-configuration", discoveryHandler) handleWithCORS("/.well-known/openid-configuration", discoveryHandler)
// TODO(ericchiang): rate limit certain paths based on IP. // TODO(ericchiang): rate limit certain paths based on IP.
handleFunc("/token", s.handleToken) handleWithCORS("/token", s.handleToken)
handleFunc("/keys", s.handlePublicKeys) handleWithCORS("/keys", s.handlePublicKeys)
handleFunc("/auth", s.handleAuthorization) handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin) handleFunc("/auth/{connector}", s.handleConnectorLogin)
handleFunc("/callback", s.handleConnectorCallback) handleFunc("/callback", s.handleConnectorCallback)