diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 19a87fc3..74a1b091 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -99,11 +99,11 @@ type OAuth2 struct { // Web is the config format for the HTTP server. type Web struct { - HTTP string `json:"http"` - HTTPS string `json:"https"` - TLSCert string `json:"tlsCert"` - TLSKey string `json:"tlsKey"` - DiscoveryAllowedOrigins []string `json:"discoveryAllowedOrigins"` + HTTP string `json:"http"` + HTTPS string `json:"https"` + TLSCert string `json:"tlsCert"` + TLSKey string `json:"tlsKey"` + AllowedOrigins []string `json:"allowedOrigins"` } // GRPC is the config for the gRPC API. diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 2dbc7fe0..f08e65ee 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -179,24 +179,24 @@ func serve(cmd *cobra.Command, args []string) error { if c.OAuth2.SkipApprovalScreen { logger.Infof("config skipping approval screen") } - if len(c.Web.DiscoveryAllowedOrigins) > 0 { - logger.Infof("config discovery allowed origins: %s", c.Web.DiscoveryAllowedOrigins) + if len(c.Web.AllowedOrigins) > 0 { + logger.Infof("config allowed origins: %s", c.Web.AllowedOrigins) } // explicitly convert to UTC. now := func() time.Time { return time.Now().UTC() } serverConfig := server.Config{ - SupportedResponseTypes: c.OAuth2.ResponseTypes, - SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, - DiscoveryAllowedOrigins: c.Web.DiscoveryAllowedOrigins, - Issuer: c.Issuer, - Connectors: connectors, - Storage: s, - Web: c.Frontend, - EnablePasswordDB: c.EnablePasswordDB, - Logger: logger, - Now: now, + SupportedResponseTypes: c.OAuth2.ResponseTypes, + SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, + AllowedOrigins: c.Web.AllowedOrigins, + Issuer: c.Issuer, + Connectors: connectors, + Storage: s, + Web: c.Frontend, + EnablePasswordDB: c.EnablePasswordDB, + Logger: logger, + Now: now, } if c.Expiry.SigningKeys != "" { signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys) diff --git a/server/handlers.go b/server/handlers.go index 0d38121b..7b9f3a94 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -12,7 +12,6 @@ import ( "strings" "time" - "github.com/gorilla/handlers" "github.com/gorilla/mux" jose "gopkg.in/square/go-jose.v2" @@ -104,7 +103,7 @@ type discovery struct { Claims []string `json:"claims_supported"` } -func (s *Server) discoveryHandler() (http.Handler, error) { +func (s *Server) discoveryHandler() (http.HandlerFunc, error) { d := discovery{ Issuer: s.issuerURL.String(), Auth: s.absURL("/auth"), @@ -130,18 +129,11 @@ func (s *Server) discoveryHandler() (http.Handler, error) { return nil, fmt.Errorf("failed to marshal discovery data: %v", err) } - var discoveryHandler http.Handler - discoveryHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Length", strconv.Itoa(len(data))) w.Write(data) - }) - if len(s.discoveryAllowedOrigins) > 0 { - corsOption := handlers.AllowedOrigins(s.discoveryAllowedOrigins) - discoveryHandler = handlers.CORS(corsOption)(discoveryHandler) - } - - return discoveryHandler, nil + }), nil } // handleAuthorization handles the OAuth2 auth endpoint. diff --git a/server/handlers_test.go b/server/handlers_test.go index 6470a54c..233af279 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -22,61 +22,3 @@ func TestHandleHealth(t *testing.T) { } } - -var discoveryHandlerCORSTests = []struct { - DiscoveryAllowedOrigins []string - Origin string - ResponseAllowOrigin string //The expected response: same as Origin in case of valid CORS flow -}{ - {nil, "http://foo.example", ""}, //Default behavior: cross origin requests not allowed - {[]string{}, "http://foo.example", ""}, - {[]string{"http://foo.example"}, "http://foo.example", "http://foo.example"}, - {[]string{"http://bar.example", "http://foo.example"}, "http://foo.example", "http://foo.example"}, - {[]string{"*"}, "http://foo.example", "http://foo.example"}, - {[]string{"http://bar.example"}, "http://foo.example", ""}, -} - -func TestDiscoveryHandlerCORS(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - for _, testcase := range discoveryHandlerCORSTests { - - httpServer, server := newTestServer(ctx, t, func(c *Config) { - c.DiscoveryAllowedOrigins = testcase.DiscoveryAllowedOrigins - }) - defer httpServer.Close() - - discoveryHandler, err := server.discoveryHandler() - if err != nil { - t.Fatalf("failed to get discovery handler: %v", err) - } - - //Perform preflight request - rrPreflight := httptest.NewRecorder() - reqPreflight := httptest.NewRequest("OPTIONS", "/.well-kown/openid-configuration", nil) - reqPreflight.Header.Set("Origin", testcase.Origin) - reqPreflight.Header.Set("Access-Control-Request-Method", "GET") - discoveryHandler.ServeHTTP(rrPreflight, reqPreflight) - if rrPreflight.Code != http.StatusOK { - t.Errorf("expected 200 got %d", rrPreflight.Code) - } - headerAccessControlPreflight := rrPreflight.HeaderMap.Get("Access-Control-Allow-Origin") - if headerAccessControlPreflight != testcase.ResponseAllowOrigin { - t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControlPreflight) - } - - //Perform request - rr := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/.well-kown/openid-configuration", nil) - req.Header.Set("Origin", testcase.Origin) - discoveryHandler.ServeHTTP(rr, req) - if rr.Code != http.StatusOK { - t.Errorf("expected 200 got %d", rr.Code) - } - headerAccessControl := rr.HeaderMap.Get("Access-Control-Allow-Origin") - if headerAccessControl != testcase.ResponseAllowOrigin { - t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControl) - } - } -} diff --git a/server/server.go b/server/server.go index c0b194a6..012802f2 100644 --- a/server/server.go +++ b/server/server.go @@ -13,6 +13,7 @@ import ( "golang.org/x/net/context" "github.com/Sirupsen/logrus" + "github.com/gorilla/handlers" "github.com/gorilla/mux" "github.com/coreos/dex/connector" @@ -42,10 +43,10 @@ type Config struct { // flow. If no response types are supplied this value defaults to "code". SupportedResponseTypes []string - // List of allowed origins for CORS requests on discovery endpoint. + // List of allowed origins for CORS requests on discovery, token and keys endpoint. // If none are indicated, CORS requests are disabled. Passing in "*" will allow any // domain. - DiscoveryAllowedOrigins []string + AllowedOrigins []string // If enabled, the server won't prompt the user to approve authorization requests. // Logging in implies approval. @@ -116,8 +117,6 @@ type Server struct { supportedResponseTypes map[string]bool - discoveryAllowedOrigins []string - now func() time.Time idTokensValidFor time.Duration @@ -185,16 +184,15 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) } s := &Server{ - issuerURL: *issuerURL, - connectors: make(map[string]Connector), - storage: newKeyCacher(c.Storage, now), - supportedResponseTypes: supported, - discoveryAllowedOrigins: c.DiscoveryAllowedOrigins, - idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), - skipApproval: c.SkipApprovalScreen, - now: now, - templates: tmpls, - logger: c.Logger, + issuerURL: *issuerURL, + connectors: make(map[string]Connector), + storage: newKeyCacher(c.Storage, now), + supportedResponseTypes: supported, + idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), + skipApproval: c.SkipApprovalScreen, + now: now, + templates: tmpls, + logger: c.Logger, } for _, conn := range c.Connectors { @@ -205,24 +203,29 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleFunc := func(p string, h http.HandlerFunc) { r.HandleFunc(path.Join(issuerURL.Path, p), h) } - handle := func(p string, h http.Handler) { - r.Handle(path.Join(issuerURL.Path, p), h) - } handlePrefix := func(p string, h http.Handler) { prefix := path.Join(issuerURL.Path, p) r.PathPrefix(prefix).Handler(http.StripPrefix(prefix, h)) } + handleWithCORS := func(p string, h http.HandlerFunc) { + var handler http.Handler = h + if len(c.AllowedOrigins) > 0 { + corsOption := handlers.AllowedOrigins(c.AllowedOrigins) + handler = handlers.CORS(corsOption)(handler) + } + r.Handle(path.Join(issuerURL.Path, p), handler) + } r.NotFoundHandler = http.HandlerFunc(http.NotFound) discoveryHandler, err := s.discoveryHandler() if err != nil { return nil, err } - handle("/.well-known/openid-configuration", discoveryHandler) + handleWithCORS("/.well-known/openid-configuration", discoveryHandler) // TODO(ericchiang): rate limit certain paths based on IP. - handleFunc("/token", s.handleToken) - handleFunc("/keys", s.handlePublicKeys) + handleWithCORS("/token", s.handleToken) + handleWithCORS("/keys", s.handlePublicKeys) handleFunc("/auth", s.handleAuthorization) handleFunc("/auth/{connector}", s.handleConnectorLogin) handleFunc("/callback", s.handleConnectorCallback)