From b4c47910e40b11b1a6fe57764746eb70a0bb68ba Mon Sep 17 00:00:00 2001 From: Simon HEGE Date: Thu, 29 Dec 2016 10:25:16 +0100 Subject: [PATCH 1/2] Allow CORS on discovery endpoint --- cmd/dex/config.go | 9 ++++--- cmd/dex/serve.go | 22 +++++++++------- glide.lock | 2 ++ glide.yaml | 2 ++ server/handlers.go | 14 +++++++--- server/handlers_test.go | 58 +++++++++++++++++++++++++++++++++++++++++ server/server.go | 31 +++++++++++++++------- 7 files changed, 112 insertions(+), 26 deletions(-) diff --git a/cmd/dex/config.go b/cmd/dex/config.go index a2a653a7..40a4fe61 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -98,10 +98,11 @@ type OAuth2 struct { // Web is the config format for the HTTP server. type Web struct { - HTTP string `json:"http"` - HTTPS string `json:"https"` - TLSCert string `json:"tlsCert"` - TLSKey string `json:"tlsKey"` + HTTP string `json:"http"` + HTTPS string `json:"https"` + TLSCert string `json:"tlsCert"` + TLSKey string `json:"tlsKey"` + DiscoveryAllowedOrigins []string `json:"discoveryAllowedOrigins"` } // GRPC is the config for the gRPC API. diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 640b122a..2dbc7fe0 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -179,20 +179,24 @@ func serve(cmd *cobra.Command, args []string) error { if c.OAuth2.SkipApprovalScreen { logger.Infof("config skipping approval screen") } + if len(c.Web.DiscoveryAllowedOrigins) > 0 { + logger.Infof("config discovery allowed origins: %s", c.Web.DiscoveryAllowedOrigins) + } // explicitly convert to UTC. now := func() time.Time { return time.Now().UTC() } serverConfig := server.Config{ - SupportedResponseTypes: c.OAuth2.ResponseTypes, - SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, - Issuer: c.Issuer, - Connectors: connectors, - Storage: s, - Web: c.Frontend, - EnablePasswordDB: c.EnablePasswordDB, - Logger: logger, - Now: now, + SupportedResponseTypes: c.OAuth2.ResponseTypes, + SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, + DiscoveryAllowedOrigins: c.Web.DiscoveryAllowedOrigins, + Issuer: c.Issuer, + Connectors: connectors, + Storage: s, + Web: c.Frontend, + EnablePasswordDB: c.EnablePasswordDB, + Logger: logger, + Now: now, } if c.Expiry.SigningKeys != "" { signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys) diff --git a/glide.lock b/glide.lock index a4741bda..95a9ab21 100644 --- a/glide.lock +++ b/glide.lock @@ -18,6 +18,8 @@ imports: - protoc-gen-go - name: github.com/gorilla/context version: aed02d124ae4a0e94fea4541c8effd05bf0c8296 +- name: github.com/gorilla/handlers + version: 3a5767ca75ece5f7f1440b1d16975247f8d8b221 - name: github.com/gorilla/mux version: 9fa818a44c2bf1396a17f9d5a3c0f6dd39d2ff8e - name: github.com/gtank/cryptopasta diff --git a/glide.yaml b/glide.yaml index 650bd8bf..05d6ec42 100644 --- a/glide.yaml +++ b/glide.yaml @@ -53,6 +53,8 @@ import: version: 9fa818a44c2bf1396a17f9d5a3c0f6dd39d2ff8e - package: github.com/gorilla/context version: aed02d124ae4a0e94fea4541c8effd05bf0c8296 +- package: github.com/gorilla/handlers + version: 3a5767ca75ece5f7f1440b1d16975247f8d8b221 # Package with a bunch of sane crypto defaults. Consider just # copy the code (as recommended by the repo itself) instead diff --git a/server/handlers.go b/server/handlers.go index 832c262b..c962265f 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/gorilla/handlers" "github.com/gorilla/mux" jose "gopkg.in/square/go-jose.v2" @@ -101,7 +102,7 @@ type discovery struct { Claims []string `json:"claims_supported"` } -func (s *Server) discoveryHandler() (http.HandlerFunc, error) { +func (s *Server) discoveryHandler() (http.Handler, error) { d := discovery{ Issuer: s.issuerURL.String(), Auth: s.absURL("/auth"), @@ -127,11 +128,18 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) { return nil, fmt.Errorf("failed to marshal discovery data: %v", err) } - return func(w http.ResponseWriter, r *http.Request) { + var discoveryHandler http.Handler + discoveryHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Length", strconv.Itoa(len(data))) w.Write(data) - }, nil + }) + if len(s.discoveryAllowedOrigins) > 0 { + corsOption := handlers.AllowedOrigins(s.discoveryAllowedOrigins) + discoveryHandler = handlers.CORS(corsOption)(discoveryHandler) + } + + return discoveryHandler, nil } // handleAuthorization handles the OAuth2 auth endpoint. diff --git a/server/handlers_test.go b/server/handlers_test.go index 233af279..6470a54c 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -22,3 +22,61 @@ func TestHandleHealth(t *testing.T) { } } + +var discoveryHandlerCORSTests = []struct { + DiscoveryAllowedOrigins []string + Origin string + ResponseAllowOrigin string //The expected response: same as Origin in case of valid CORS flow +}{ + {nil, "http://foo.example", ""}, //Default behavior: cross origin requests not allowed + {[]string{}, "http://foo.example", ""}, + {[]string{"http://foo.example"}, "http://foo.example", "http://foo.example"}, + {[]string{"http://bar.example", "http://foo.example"}, "http://foo.example", "http://foo.example"}, + {[]string{"*"}, "http://foo.example", "http://foo.example"}, + {[]string{"http://bar.example"}, "http://foo.example", ""}, +} + +func TestDiscoveryHandlerCORS(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for _, testcase := range discoveryHandlerCORSTests { + + httpServer, server := newTestServer(ctx, t, func(c *Config) { + c.DiscoveryAllowedOrigins = testcase.DiscoveryAllowedOrigins + }) + defer httpServer.Close() + + discoveryHandler, err := server.discoveryHandler() + if err != nil { + t.Fatalf("failed to get discovery handler: %v", err) + } + + //Perform preflight request + rrPreflight := httptest.NewRecorder() + reqPreflight := httptest.NewRequest("OPTIONS", "/.well-kown/openid-configuration", nil) + reqPreflight.Header.Set("Origin", testcase.Origin) + reqPreflight.Header.Set("Access-Control-Request-Method", "GET") + discoveryHandler.ServeHTTP(rrPreflight, reqPreflight) + if rrPreflight.Code != http.StatusOK { + t.Errorf("expected 200 got %d", rrPreflight.Code) + } + headerAccessControlPreflight := rrPreflight.HeaderMap.Get("Access-Control-Allow-Origin") + if headerAccessControlPreflight != testcase.ResponseAllowOrigin { + t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControlPreflight) + } + + //Perform request + rr := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/.well-kown/openid-configuration", nil) + req.Header.Set("Origin", testcase.Origin) + discoveryHandler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("expected 200 got %d", rr.Code) + } + headerAccessControl := rr.HeaderMap.Get("Access-Control-Allow-Origin") + if headerAccessControl != testcase.ResponseAllowOrigin { + t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControl) + } + } +} diff --git a/server/server.go b/server/server.go index 0c292d13..535e21be 100644 --- a/server/server.go +++ b/server/server.go @@ -42,6 +42,11 @@ type Config struct { // flow. If no response types are supplied this value defaults to "code". SupportedResponseTypes []string + // List of allowed origins for CORS requests on discovery endpoint. + // If none are indicated, CORS requests are disabled. Passing in "*" will allow any + // domain. + DiscoveryAllowedOrigins []string + // If enabled, the server won't prompt the user to approve authorization requests. // Logging in implies approval. SkipApprovalScreen bool @@ -111,6 +116,8 @@ type Server struct { supportedResponseTypes map[string]bool + discoveryAllowedOrigins []string + now func() time.Time idTokensValidFor time.Duration @@ -178,15 +185,16 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) } s := &Server{ - issuerURL: *issuerURL, - connectors: make(map[string]Connector), - storage: newKeyCacher(c.Storage, now), - supportedResponseTypes: supported, - idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), - skipApproval: c.SkipApprovalScreen, - now: now, - templates: tmpls, - logger: c.Logger, + issuerURL: *issuerURL, + connectors: make(map[string]Connector), + storage: newKeyCacher(c.Storage, now), + supportedResponseTypes: supported, + discoveryAllowedOrigins: c.DiscoveryAllowedOrigins, + idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), + skipApproval: c.SkipApprovalScreen, + now: now, + templates: tmpls, + logger: c.Logger, } for _, conn := range c.Connectors { @@ -197,6 +205,9 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleFunc := func(p string, h http.HandlerFunc) { r.HandleFunc(path.Join(issuerURL.Path, p), h) } + handle := func(p string, h http.Handler) { + r.Handle(path.Join(issuerURL.Path, p), h) + } handlePrefix := func(p string, h http.Handler) { prefix := path.Join(issuerURL.Path, p) r.PathPrefix(prefix).Handler(http.StripPrefix(prefix, h)) @@ -207,7 +218,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) if err != nil { return nil, err } - handleFunc("/.well-known/openid-configuration", discoveryHandler) + handle("/.well-known/openid-configuration", discoveryHandler) // TODO(ericchiang): rate limit certain paths based on IP. handleFunc("/token", s.handleToken) From 6cbf7125e0c370e07186cc3bcea6e5c54b5d3099 Mon Sep 17 00:00:00 2001 From: Simon HEGE Date: Thu, 29 Dec 2016 10:25:43 +0100 Subject: [PATCH 2/2] vendor: revendor --- glide.lock | 4 +- vendor/github.com/gorilla/handlers/LICENSE | 22 + .../github.com/gorilla/handlers/canonical.go | 74 ++++ .../github.com/gorilla/handlers/compress.go | 148 +++++++ vendor/github.com/gorilla/handlers/cors.go | 317 ++++++++++++++ vendor/github.com/gorilla/handlers/doc.go | 9 + .../github.com/gorilla/handlers/handlers.go | 403 ++++++++++++++++++ .../gorilla/handlers/proxy_headers.go | 120 ++++++ .../github.com/gorilla/handlers/recovery.go | 91 ++++ 9 files changed, 1186 insertions(+), 2 deletions(-) create mode 100644 vendor/github.com/gorilla/handlers/LICENSE create mode 100644 vendor/github.com/gorilla/handlers/canonical.go create mode 100644 vendor/github.com/gorilla/handlers/compress.go create mode 100644 vendor/github.com/gorilla/handlers/cors.go create mode 100644 vendor/github.com/gorilla/handlers/doc.go create mode 100644 vendor/github.com/gorilla/handlers/handlers.go create mode 100644 vendor/github.com/gorilla/handlers/proxy_headers.go create mode 100644 vendor/github.com/gorilla/handlers/recovery.go diff --git a/glide.lock b/glide.lock index 95a9ab21..81ab6445 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ -hash: 22c01c4265c210fbf3cd2d55e9451b924a3301e35053c82ebe35171ab7286c83 -updated: 2017-01-06T15:38:29.812891187-08:00 +hash: 4d7d84f09a330d27458fb821ae7ada243cfa825808dc7ab116db28a08f9166a2 +updated: 2017-01-08T19:23:40.352046548+01:00 imports: - name: github.com/cockroachdb/cockroach-go version: 31611c0501c812f437d4861d87d117053967c955 diff --git a/vendor/github.com/gorilla/handlers/LICENSE b/vendor/github.com/gorilla/handlers/LICENSE new file mode 100644 index 00000000..66ea3c8a --- /dev/null +++ b/vendor/github.com/gorilla/handlers/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2013 The Gorilla Handlers Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/gorilla/handlers/canonical.go b/vendor/github.com/gorilla/handlers/canonical.go new file mode 100644 index 00000000..8437fefc --- /dev/null +++ b/vendor/github.com/gorilla/handlers/canonical.go @@ -0,0 +1,74 @@ +package handlers + +import ( + "net/http" + "net/url" + "strings" +) + +type canonical struct { + h http.Handler + domain string + code int +} + +// CanonicalHost is HTTP middleware that re-directs requests to the canonical +// domain. It accepts a domain and a status code (e.g. 301 or 302) and +// re-directs clients to this domain. The existing request path is maintained. +// +// Note: If the provided domain is considered invalid by url.Parse or otherwise +// returns an empty scheme or host, clients are not re-directed. +// +// Example: +// +// r := mux.NewRouter() +// canonical := handlers.CanonicalHost("http://www.gorillatoolkit.org", 302) +// r.HandleFunc("/route", YourHandler) +// +// log.Fatal(http.ListenAndServe(":7000", canonical(r))) +// +func CanonicalHost(domain string, code int) func(h http.Handler) http.Handler { + fn := func(h http.Handler) http.Handler { + return canonical{h, domain, code} + } + + return fn +} + +func (c canonical) ServeHTTP(w http.ResponseWriter, r *http.Request) { + dest, err := url.Parse(c.domain) + if err != nil { + // Call the next handler if the provided domain fails to parse. + c.h.ServeHTTP(w, r) + return + } + + if dest.Scheme == "" || dest.Host == "" { + // Call the next handler if the scheme or host are empty. + // Note that url.Parse won't fail on in this case. + c.h.ServeHTTP(w, r) + return + } + + if !strings.EqualFold(cleanHost(r.Host), dest.Host) { + // Re-build the destination URL + dest := dest.Scheme + "://" + dest.Host + r.URL.Path + if r.URL.RawQuery != "" { + dest += "?" + r.URL.RawQuery + } + http.Redirect(w, r, dest, c.code) + return + } + + c.h.ServeHTTP(w, r) +} + +// cleanHost cleans invalid Host headers by stripping anything after '/' or ' '. +// This is backported from Go 1.5 (in response to issue #11206) and attempts to +// mitigate malformed Host headers that do not match the format in RFC7230. +func cleanHost(in string) string { + if i := strings.IndexAny(in, " /"); i != -1 { + return in[:i] + } + return in +} diff --git a/vendor/github.com/gorilla/handlers/compress.go b/vendor/github.com/gorilla/handlers/compress.go new file mode 100644 index 00000000..e8345d79 --- /dev/null +++ b/vendor/github.com/gorilla/handlers/compress.go @@ -0,0 +1,148 @@ +// Copyright 2013 The Gorilla Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package handlers + +import ( + "compress/flate" + "compress/gzip" + "io" + "net/http" + "strings" +) + +type compressResponseWriter struct { + io.Writer + http.ResponseWriter + http.Hijacker + http.Flusher + http.CloseNotifier +} + +func (w *compressResponseWriter) WriteHeader(c int) { + w.ResponseWriter.Header().Del("Content-Length") + w.ResponseWriter.WriteHeader(c) +} + +func (w *compressResponseWriter) Header() http.Header { + return w.ResponseWriter.Header() +} + +func (w *compressResponseWriter) Write(b []byte) (int, error) { + h := w.ResponseWriter.Header() + if h.Get("Content-Type") == "" { + h.Set("Content-Type", http.DetectContentType(b)) + } + h.Del("Content-Length") + + return w.Writer.Write(b) +} + +type flusher interface { + Flush() error +} + +func (w *compressResponseWriter) Flush() { + // Flush compressed data if compressor supports it. + if f, ok := w.Writer.(flusher); ok { + f.Flush() + } + // Flush HTTP response. + if w.Flusher != nil { + w.Flusher.Flush() + } +} + +// CompressHandler gzip compresses HTTP responses for clients that support it +// via the 'Accept-Encoding' header. +// +// Compressing TLS traffic may leak the page contents to an attacker if the +// page contains user input: http://security.stackexchange.com/a/102015/12208 +func CompressHandler(h http.Handler) http.Handler { + return CompressHandlerLevel(h, gzip.DefaultCompression) +} + +// CompressHandlerLevel gzip compresses HTTP responses with specified compression level +// for clients that support it via the 'Accept-Encoding' header. +// +// The compression level should be gzip.DefaultCompression, gzip.NoCompression, +// or any integer value between gzip.BestSpeed and gzip.BestCompression inclusive. +// gzip.DefaultCompression is used in case of invalid compression level. +func CompressHandlerLevel(h http.Handler, level int) http.Handler { + if level < gzip.DefaultCompression || level > gzip.BestCompression { + level = gzip.DefaultCompression + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + L: + for _, enc := range strings.Split(r.Header.Get("Accept-Encoding"), ",") { + switch strings.TrimSpace(enc) { + case "gzip": + w.Header().Set("Content-Encoding", "gzip") + w.Header().Add("Vary", "Accept-Encoding") + + gw, _ := gzip.NewWriterLevel(w, level) + defer gw.Close() + + h, hok := w.(http.Hijacker) + if !hok { /* w is not Hijacker... oh well... */ + h = nil + } + + f, fok := w.(http.Flusher) + if !fok { + f = nil + } + + cn, cnok := w.(http.CloseNotifier) + if !cnok { + cn = nil + } + + w = &compressResponseWriter{ + Writer: gw, + ResponseWriter: w, + Hijacker: h, + Flusher: f, + CloseNotifier: cn, + } + + break L + case "deflate": + w.Header().Set("Content-Encoding", "deflate") + w.Header().Add("Vary", "Accept-Encoding") + + fw, _ := flate.NewWriter(w, level) + defer fw.Close() + + h, hok := w.(http.Hijacker) + if !hok { /* w is not Hijacker... oh well... */ + h = nil + } + + f, fok := w.(http.Flusher) + if !fok { + f = nil + } + + cn, cnok := w.(http.CloseNotifier) + if !cnok { + cn = nil + } + + w = &compressResponseWriter{ + Writer: fw, + ResponseWriter: w, + Hijacker: h, + Flusher: f, + CloseNotifier: cn, + } + + break L + } + } + + h.ServeHTTP(w, r) + }) +} diff --git a/vendor/github.com/gorilla/handlers/cors.go b/vendor/github.com/gorilla/handlers/cors.go new file mode 100644 index 00000000..1f92d1ad --- /dev/null +++ b/vendor/github.com/gorilla/handlers/cors.go @@ -0,0 +1,317 @@ +package handlers + +import ( + "net/http" + "strconv" + "strings" +) + +// CORSOption represents a functional option for configuring the CORS middleware. +type CORSOption func(*cors) error + +type cors struct { + h http.Handler + allowedHeaders []string + allowedMethods []string + allowedOrigins []string + allowedOriginValidator OriginValidator + exposedHeaders []string + maxAge int + ignoreOptions bool + allowCredentials bool +} + +// OriginValidator takes an origin string and returns whether or not that origin is allowed. +type OriginValidator func(string) bool + +var ( + defaultCorsMethods = []string{"GET", "HEAD", "POST"} + defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"} + // (WebKit/Safari v9 sends the Origin header by default in AJAX requests) +) + +const ( + corsOptionMethod string = "OPTIONS" + corsAllowOriginHeader string = "Access-Control-Allow-Origin" + corsExposeHeadersHeader string = "Access-Control-Expose-Headers" + corsMaxAgeHeader string = "Access-Control-Max-Age" + corsAllowMethodsHeader string = "Access-Control-Allow-Methods" + corsAllowHeadersHeader string = "Access-Control-Allow-Headers" + corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials" + corsRequestMethodHeader string = "Access-Control-Request-Method" + corsRequestHeadersHeader string = "Access-Control-Request-Headers" + corsOriginHeader string = "Origin" + corsVaryHeader string = "Vary" + corsOriginMatchAll string = "*" +) + +func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get(corsOriginHeader) + if !ch.isOriginAllowed(origin) { + ch.h.ServeHTTP(w, r) + return + } + + if r.Method == corsOptionMethod { + if ch.ignoreOptions { + ch.h.ServeHTTP(w, r) + return + } + + if _, ok := r.Header[corsRequestMethodHeader]; !ok { + w.WriteHeader(http.StatusBadRequest) + return + } + + method := r.Header.Get(corsRequestMethodHeader) + if !ch.isMatch(method, ch.allowedMethods) { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",") + allowedHeaders := []string{} + for _, v := range requestHeaders { + canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) + if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) { + continue + } + + if !ch.isMatch(canonicalHeader, ch.allowedHeaders) { + w.WriteHeader(http.StatusForbidden) + return + } + + allowedHeaders = append(allowedHeaders, canonicalHeader) + } + + if len(allowedHeaders) > 0 { + w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ",")) + } + + if ch.maxAge > 0 { + w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge)) + } + + if !ch.isMatch(method, defaultCorsMethods) { + w.Header().Set(corsAllowMethodsHeader, method) + } + } else { + if len(ch.exposedHeaders) > 0 { + w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ",")) + } + } + + if ch.allowCredentials { + w.Header().Set(corsAllowCredentialsHeader, "true") + } + + if len(ch.allowedOrigins) > 1 { + w.Header().Set(corsVaryHeader, corsOriginHeader) + } + + w.Header().Set(corsAllowOriginHeader, origin) + + if r.Method == corsOptionMethod { + return + } + ch.h.ServeHTTP(w, r) +} + +// CORS provides Cross-Origin Resource Sharing middleware. +// Example: +// +// import ( +// "net/http" +// +// "github.com/gorilla/handlers" +// "github.com/gorilla/mux" +// ) +// +// func main() { +// r := mux.NewRouter() +// r.HandleFunc("/users", UserEndpoint) +// r.HandleFunc("/projects", ProjectEndpoint) +// +// // Apply the CORS middleware to our top-level router, with the defaults. +// http.ListenAndServe(":8000", handlers.CORS()(r)) +// } +// +func CORS(opts ...CORSOption) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + ch := parseCORSOptions(opts...) + ch.h = h + return ch + } +} + +func parseCORSOptions(opts ...CORSOption) *cors { + ch := &cors{ + allowedMethods: defaultCorsMethods, + allowedHeaders: defaultCorsHeaders, + allowedOrigins: []string{corsOriginMatchAll}, + } + + for _, option := range opts { + option(ch) + } + + return ch +} + +// +// Functional options for configuring CORS. +// + +// AllowedHeaders adds the provided headers to the list of allowed headers in a +// CORS request. +// This is an append operation so the headers Accept, Accept-Language, +// and Content-Language are always allowed. +// Content-Type must be explicitly declared if accepting Content-Types other than +// application/x-www-form-urlencoded, multipart/form-data, or text/plain. +func AllowedHeaders(headers []string) CORSOption { + return func(ch *cors) error { + for _, v := range headers { + normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) + if normalizedHeader == "" { + continue + } + + if !ch.isMatch(normalizedHeader, ch.allowedHeaders) { + ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader) + } + } + + return nil + } +} + +// AllowedMethods can be used to explicitly allow methods in the +// Access-Control-Allow-Methods header. +// This is a replacement operation so you must also +// pass GET, HEAD, and POST if you wish to support those methods. +func AllowedMethods(methods []string) CORSOption { + return func(ch *cors) error { + ch.allowedMethods = []string{} + for _, v := range methods { + normalizedMethod := strings.ToUpper(strings.TrimSpace(v)) + if normalizedMethod == "" { + continue + } + + if !ch.isMatch(normalizedMethod, ch.allowedMethods) { + ch.allowedMethods = append(ch.allowedMethods, normalizedMethod) + } + } + + return nil + } +} + +// AllowedOrigins sets the allowed origins for CORS requests, as used in the +// 'Allow-Access-Control-Origin' HTTP header. +// Note: Passing in a []string{"*"} will allow any domain. +func AllowedOrigins(origins []string) CORSOption { + return func(ch *cors) error { + for _, v := range origins { + if v == corsOriginMatchAll { + ch.allowedOrigins = []string{corsOriginMatchAll} + return nil + } + } + + ch.allowedOrigins = origins + return nil + } +} + +// AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the +// 'Allow-Access-Control-Origin' HTTP header. +func AllowedOriginValidator(fn OriginValidator) CORSOption { + return func(ch *cors) error { + ch.allowedOriginValidator = fn + return nil + } +} + +// ExposeHeaders can be used to specify headers that are available +// and will not be stripped out by the user-agent. +func ExposedHeaders(headers []string) CORSOption { + return func(ch *cors) error { + ch.exposedHeaders = []string{} + for _, v := range headers { + normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v)) + if normalizedHeader == "" { + continue + } + + if !ch.isMatch(normalizedHeader, ch.exposedHeaders) { + ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader) + } + } + + return nil + } +} + +// MaxAge determines the maximum age (in seconds) between preflight requests. A +// maximum of 10 minutes is allowed. An age above this value will default to 10 +// minutes. +func MaxAge(age int) CORSOption { + return func(ch *cors) error { + // Maximum of 10 minutes. + if age > 600 { + age = 600 + } + + ch.maxAge = age + return nil + } +} + +// IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead +// passing them through to the next handler. This is useful when your application +// or framework has a pre-existing mechanism for responding to OPTIONS requests. +func IgnoreOptions() CORSOption { + return func(ch *cors) error { + ch.ignoreOptions = true + return nil + } +} + +// AllowCredentials can be used to specify that the user agent may pass +// authentication details along with the request. +func AllowCredentials() CORSOption { + return func(ch *cors) error { + ch.allowCredentials = true + return nil + } +} + +func (ch *cors) isOriginAllowed(origin string) bool { + if origin == "" { + return false + } + + if ch.allowedOriginValidator != nil { + return ch.allowedOriginValidator(origin) + } + + for _, allowedOrigin := range ch.allowedOrigins { + if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll { + return true + } + } + + return false +} + +func (ch *cors) isMatch(needle string, haystack []string) bool { + for _, v := range haystack { + if v == needle { + return true + } + } + + return false +} diff --git a/vendor/github.com/gorilla/handlers/doc.go b/vendor/github.com/gorilla/handlers/doc.go new file mode 100644 index 00000000..944e5a8a --- /dev/null +++ b/vendor/github.com/gorilla/handlers/doc.go @@ -0,0 +1,9 @@ +/* +Package handlers is a collection of handlers (aka "HTTP middleware") for use +with Go's net/http package (or any framework supporting http.Handler). + +The package includes handlers for logging in standardised formats, compressing +HTTP responses, validating content types and other useful tools for manipulating +requests and responses. +*/ +package handlers diff --git a/vendor/github.com/gorilla/handlers/handlers.go b/vendor/github.com/gorilla/handlers/handlers.go new file mode 100644 index 00000000..9544d2f0 --- /dev/null +++ b/vendor/github.com/gorilla/handlers/handlers.go @@ -0,0 +1,403 @@ +// Copyright 2013 The Gorilla Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package handlers + +import ( + "bufio" + "fmt" + "io" + "net" + "net/http" + "net/url" + "sort" + "strconv" + "strings" + "time" + "unicode/utf8" +) + +// MethodHandler is an http.Handler that dispatches to a handler whose key in the +// MethodHandler's map matches the name of the HTTP request's method, eg: GET +// +// If the request's method is OPTIONS and OPTIONS is not a key in the map then +// the handler responds with a status of 200 and sets the Allow header to a +// comma-separated list of available methods. +// +// If the request's method doesn't match any of its keys the handler responds +// with a status of HTTP 405 "Method Not Allowed" and sets the Allow header to a +// comma-separated list of available methods. +type MethodHandler map[string]http.Handler + +func (h MethodHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if handler, ok := h[req.Method]; ok { + handler.ServeHTTP(w, req) + } else { + allow := []string{} + for k := range h { + allow = append(allow, k) + } + sort.Strings(allow) + w.Header().Set("Allow", strings.Join(allow, ", ")) + if req.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + } else { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + } +} + +// loggingHandler is the http.Handler implementation for LoggingHandlerTo and its +// friends +type loggingHandler struct { + writer io.Writer + handler http.Handler +} + +// combinedLoggingHandler is the http.Handler implementation for LoggingHandlerTo +// and its friends +type combinedLoggingHandler struct { + writer io.Writer + handler http.Handler +} + +func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + t := time.Now() + logger := makeLogger(w) + url := *req.URL + h.handler.ServeHTTP(logger, req) + writeLog(h.writer, req, url, t, logger.Status(), logger.Size()) +} + +func (h combinedLoggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + t := time.Now() + logger := makeLogger(w) + url := *req.URL + h.handler.ServeHTTP(logger, req) + writeCombinedLog(h.writer, req, url, t, logger.Status(), logger.Size()) +} + +func makeLogger(w http.ResponseWriter) loggingResponseWriter { + var logger loggingResponseWriter = &responseLogger{w: w} + if _, ok := w.(http.Hijacker); ok { + logger = &hijackLogger{responseLogger{w: w}} + } + h, ok1 := logger.(http.Hijacker) + c, ok2 := w.(http.CloseNotifier) + if ok1 && ok2 { + return hijackCloseNotifier{logger, h, c} + } + if ok2 { + return &closeNotifyWriter{logger, c} + } + return logger +} + +type loggingResponseWriter interface { + http.ResponseWriter + http.Flusher + Status() int + Size() int +} + +// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP +// status code and body size +type responseLogger struct { + w http.ResponseWriter + status int + size int +} + +func (l *responseLogger) Header() http.Header { + return l.w.Header() +} + +func (l *responseLogger) Write(b []byte) (int, error) { + if l.status == 0 { + // The status will be StatusOK if WriteHeader has not been called yet + l.status = http.StatusOK + } + size, err := l.w.Write(b) + l.size += size + return size, err +} + +func (l *responseLogger) WriteHeader(s int) { + l.w.WriteHeader(s) + l.status = s +} + +func (l *responseLogger) Status() int { + return l.status +} + +func (l *responseLogger) Size() int { + return l.size +} + +func (l *responseLogger) Flush() { + f, ok := l.w.(http.Flusher) + if ok { + f.Flush() + } +} + +type hijackLogger struct { + responseLogger +} + +func (l *hijackLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h := l.responseLogger.w.(http.Hijacker) + conn, rw, err := h.Hijack() + if err == nil && l.responseLogger.status == 0 { + // The status will be StatusSwitchingProtocols if there was no error and + // WriteHeader has not been called yet + l.responseLogger.status = http.StatusSwitchingProtocols + } + return conn, rw, err +} + +type closeNotifyWriter struct { + loggingResponseWriter + http.CloseNotifier +} + +type hijackCloseNotifier struct { + loggingResponseWriter + http.Hijacker + http.CloseNotifier +} + +const lowerhex = "0123456789abcdef" + +func appendQuoted(buf []byte, s string) []byte { + var runeTmp [utf8.UTFMax]byte + for width := 0; len(s) > 0; s = s[width:] { + r := rune(s[0]) + width = 1 + if r >= utf8.RuneSelf { + r, width = utf8.DecodeRuneInString(s) + } + if width == 1 && r == utf8.RuneError { + buf = append(buf, `\x`...) + buf = append(buf, lowerhex[s[0]>>4]) + buf = append(buf, lowerhex[s[0]&0xF]) + continue + } + if r == rune('"') || r == '\\' { // always backslashed + buf = append(buf, '\\') + buf = append(buf, byte(r)) + continue + } + if strconv.IsPrint(r) { + n := utf8.EncodeRune(runeTmp[:], r) + buf = append(buf, runeTmp[:n]...) + continue + } + switch r { + case '\a': + buf = append(buf, `\a`...) + case '\b': + buf = append(buf, `\b`...) + case '\f': + buf = append(buf, `\f`...) + case '\n': + buf = append(buf, `\n`...) + case '\r': + buf = append(buf, `\r`...) + case '\t': + buf = append(buf, `\t`...) + case '\v': + buf = append(buf, `\v`...) + default: + switch { + case r < ' ': + buf = append(buf, `\x`...) + buf = append(buf, lowerhex[s[0]>>4]) + buf = append(buf, lowerhex[s[0]&0xF]) + case r > utf8.MaxRune: + r = 0xFFFD + fallthrough + case r < 0x10000: + buf = append(buf, `\u`...) + for s := 12; s >= 0; s -= 4 { + buf = append(buf, lowerhex[r>>uint(s)&0xF]) + } + default: + buf = append(buf, `\U`...) + for s := 28; s >= 0; s -= 4 { + buf = append(buf, lowerhex[r>>uint(s)&0xF]) + } + } + } + } + return buf + +} + +// buildCommonLogLine builds a log entry for req in Apache Common Log Format. +// ts is the timestamp with which the entry should be logged. +// status and size are used to provide the response HTTP status and size. +func buildCommonLogLine(req *http.Request, url url.URL, ts time.Time, status int, size int) []byte { + username := "-" + if url.User != nil { + if name := url.User.Username(); name != "" { + username = name + } + } + + host, _, err := net.SplitHostPort(req.RemoteAddr) + + if err != nil { + host = req.RemoteAddr + } + + uri := req.RequestURI + + // Requests using the CONNECT method over HTTP/2.0 must use + // the authority field (aka r.Host) to identify the target. + // Refer: https://httpwg.github.io/specs/rfc7540.html#CONNECT + if req.ProtoMajor == 2 && req.Method == "CONNECT" { + uri = req.Host + } + if uri == "" { + uri = url.RequestURI() + } + + buf := make([]byte, 0, 3*(len(host)+len(username)+len(req.Method)+len(uri)+len(req.Proto)+50)/2) + buf = append(buf, host...) + buf = append(buf, " - "...) + buf = append(buf, username...) + buf = append(buf, " ["...) + buf = append(buf, ts.Format("02/Jan/2006:15:04:05 -0700")...) + buf = append(buf, `] "`...) + buf = append(buf, req.Method...) + buf = append(buf, " "...) + buf = appendQuoted(buf, uri) + buf = append(buf, " "...) + buf = append(buf, req.Proto...) + buf = append(buf, `" `...) + buf = append(buf, strconv.Itoa(status)...) + buf = append(buf, " "...) + buf = append(buf, strconv.Itoa(size)...) + return buf +} + +// writeLog writes a log entry for req to w in Apache Common Log Format. +// ts is the timestamp with which the entry should be logged. +// status and size are used to provide the response HTTP status and size. +func writeLog(w io.Writer, req *http.Request, url url.URL, ts time.Time, status, size int) { + buf := buildCommonLogLine(req, url, ts, status, size) + buf = append(buf, '\n') + w.Write(buf) +} + +// writeCombinedLog writes a log entry for req to w in Apache Combined Log Format. +// ts is the timestamp with which the entry should be logged. +// status and size are used to provide the response HTTP status and size. +func writeCombinedLog(w io.Writer, req *http.Request, url url.URL, ts time.Time, status, size int) { + buf := buildCommonLogLine(req, url, ts, status, size) + buf = append(buf, ` "`...) + buf = appendQuoted(buf, req.Referer()) + buf = append(buf, `" "`...) + buf = appendQuoted(buf, req.UserAgent()) + buf = append(buf, '"', '\n') + w.Write(buf) +} + +// CombinedLoggingHandler return a http.Handler that wraps h and logs requests to out in +// Apache Combined Log Format. +// +// See http://httpd.apache.org/docs/2.2/logs.html#combined for a description of this format. +// +// LoggingHandler always sets the ident field of the log to - +func CombinedLoggingHandler(out io.Writer, h http.Handler) http.Handler { + return combinedLoggingHandler{out, h} +} + +// LoggingHandler return a http.Handler that wraps h and logs requests to out in +// Apache Common Log Format (CLF). +// +// See http://httpd.apache.org/docs/2.2/logs.html#common for a description of this format. +// +// LoggingHandler always sets the ident field of the log to - +// +// Example: +// +// r := mux.NewRouter() +// r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { +// w.Write([]byte("This is a catch-all route")) +// }) +// loggedRouter := handlers.LoggingHandler(os.Stdout, r) +// http.ListenAndServe(":1123", loggedRouter) +// +func LoggingHandler(out io.Writer, h http.Handler) http.Handler { + return loggingHandler{out, h} +} + +// isContentType validates the Content-Type header matches the supplied +// contentType. That is, its type and subtype match. +func isContentType(h http.Header, contentType string) bool { + ct := h.Get("Content-Type") + if i := strings.IndexRune(ct, ';'); i != -1 { + ct = ct[0:i] + } + return ct == contentType +} + +// ContentTypeHandler wraps and returns a http.Handler, validating the request +// content type is compatible with the contentTypes list. It writes a HTTP 415 +// error if that fails. +// +// Only PUT, POST, and PATCH requests are considered. +func ContentTypeHandler(h http.Handler, contentTypes ...string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !(r.Method == "PUT" || r.Method == "POST" || r.Method == "PATCH") { + h.ServeHTTP(w, r) + return + } + + for _, ct := range contentTypes { + if isContentType(r.Header, ct) { + h.ServeHTTP(w, r) + return + } + } + http.Error(w, fmt.Sprintf("Unsupported content type %q; expected one of %q", r.Header.Get("Content-Type"), contentTypes), http.StatusUnsupportedMediaType) + }) +} + +const ( + // HTTPMethodOverrideHeader is a commonly used + // http header to override a request method. + HTTPMethodOverrideHeader = "X-HTTP-Method-Override" + // HTTPMethodOverrideFormKey is a commonly used + // HTML form key to override a request method. + HTTPMethodOverrideFormKey = "_method" +) + +// HTTPMethodOverrideHandler wraps and returns a http.Handler which checks for +// the X-HTTP-Method-Override header or the _method form key, and overrides (if +// valid) request.Method with its value. +// +// This is especially useful for HTTP clients that don't support many http verbs. +// It isn't secure to override e.g a GET to a POST, so only POST requests are +// considered. Likewise, the override method can only be a "write" method: PUT, +// PATCH or DELETE. +// +// Form method takes precedence over header method. +func HTTPMethodOverrideHandler(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "POST" { + om := r.FormValue(HTTPMethodOverrideFormKey) + if om == "" { + om = r.Header.Get(HTTPMethodOverrideHeader) + } + if om == "PUT" || om == "PATCH" || om == "DELETE" { + r.Method = om + } + } + h.ServeHTTP(w, r) + }) +} diff --git a/vendor/github.com/gorilla/handlers/proxy_headers.go b/vendor/github.com/gorilla/handlers/proxy_headers.go new file mode 100644 index 00000000..0be750fd --- /dev/null +++ b/vendor/github.com/gorilla/handlers/proxy_headers.go @@ -0,0 +1,120 @@ +package handlers + +import ( + "net/http" + "regexp" + "strings" +) + +var ( + // De-facto standard header keys. + xForwardedFor = http.CanonicalHeaderKey("X-Forwarded-For") + xForwardedHost = http.CanonicalHeaderKey("X-Forwarded-Host") + xForwardedProto = http.CanonicalHeaderKey("X-Forwarded-Proto") + xForwardedScheme = http.CanonicalHeaderKey("X-Forwarded-Scheme") + xRealIP = http.CanonicalHeaderKey("X-Real-IP") +) + +var ( + // RFC7239 defines a new "Forwarded: " header designed to replace the + // existing use of X-Forwarded-* headers. + // e.g. Forwarded: for=192.0.2.60;proto=https;by=203.0.113.43 + forwarded = http.CanonicalHeaderKey("Forwarded") + // Allows for a sub-match of the first value after 'for=' to the next + // comma, semi-colon or space. The match is case-insensitive. + forRegex = regexp.MustCompile(`(?i)(?:for=)([^(;|,| )]+)`) + // Allows for a sub-match for the first instance of scheme (http|https) + // prefixed by 'proto='. The match is case-insensitive. + protoRegex = regexp.MustCompile(`(?i)(?:proto=)(https|http)`) +) + +// ProxyHeaders inspects common reverse proxy headers and sets the corresponding +// fields in the HTTP request struct. These are X-Forwarded-For and X-Real-IP +// for the remote (client) IP address, X-Forwarded-Proto or X-Forwarded-Scheme +// for the scheme (http|https) and the RFC7239 Forwarded header, which may +// include both client IPs and schemes. +// +// NOTE: This middleware should only be used when behind a reverse +// proxy like nginx, HAProxy or Apache. Reverse proxies that don't (or are +// configured not to) strip these headers from client requests, or where these +// headers are accepted "as is" from a remote client (e.g. when Go is not behind +// a proxy), can manifest as a vulnerability if your application uses these +// headers for validating the 'trustworthiness' of a request. +func ProxyHeaders(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + // Set the remote IP with the value passed from the proxy. + if fwd := getIP(r); fwd != "" { + r.RemoteAddr = fwd + } + + // Set the scheme (proto) with the value passed from the proxy. + if scheme := getScheme(r); scheme != "" { + r.URL.Scheme = scheme + } + // Set the host with the value passed by the proxy + if r.Header.Get(xForwardedHost) != "" { + r.Host = r.Header.Get(xForwardedHost) + } + // Call the next handler in the chain. + h.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} + +// getIP retrieves the IP from the X-Forwarded-For, X-Real-IP and RFC7239 +// Forwarded headers (in that order). +func getIP(r *http.Request) string { + var addr string + + if fwd := r.Header.Get(xForwardedFor); fwd != "" { + // Only grab the first (client) address. Note that '192.168.0.1, + // 10.1.1.1' is a valid key for X-Forwarded-For where addresses after + // the first may represent forwarding proxies earlier in the chain. + s := strings.Index(fwd, ", ") + if s == -1 { + s = len(fwd) + } + addr = fwd[:s] + } else if fwd := r.Header.Get(xRealIP); fwd != "" { + // X-Real-IP should only contain one IP address (the client making the + // request). + addr = fwd + } else if fwd := r.Header.Get(forwarded); fwd != "" { + // match should contain at least two elements if the protocol was + // specified in the Forwarded header. The first element will always be + // the 'for=' capture, which we ignore. In the case of multiple IP + // addresses (for=8.8.8.8, 8.8.4.4,172.16.1.20 is valid) we only + // extract the first, which should be the client IP. + if match := forRegex.FindStringSubmatch(fwd); len(match) > 1 { + // IPv6 addresses in Forwarded headers are quoted-strings. We strip + // these quotes. + addr = strings.Trim(match[1], `"`) + } + } + + return addr +} + +// getScheme retrieves the scheme from the X-Forwarded-Proto and RFC7239 +// Forwarded headers (in that order). +func getScheme(r *http.Request) string { + var scheme string + + // Retrieve the scheme from X-Forwarded-Proto. + if proto := r.Header.Get(xForwardedProto); proto != "" { + scheme = strings.ToLower(proto) + } else if proto = r.Header.Get(xForwardedScheme); proto != "" { + scheme = strings.ToLower(proto) + } else if proto = r.Header.Get(forwarded); proto != "" { + // match should contain at least two elements if the protocol was + // specified in the Forwarded header. The first element will always be + // the 'proto=' capture, which we ignore. In the case of multiple proto + // parameters (invalid) we only extract the first. + if match := protoRegex.FindStringSubmatch(proto); len(match) > 1 { + scheme = strings.ToLower(match[1]) + } + } + + return scheme +} diff --git a/vendor/github.com/gorilla/handlers/recovery.go b/vendor/github.com/gorilla/handlers/recovery.go new file mode 100644 index 00000000..b1be9dc8 --- /dev/null +++ b/vendor/github.com/gorilla/handlers/recovery.go @@ -0,0 +1,91 @@ +package handlers + +import ( + "log" + "net/http" + "runtime/debug" +) + +// RecoveryHandlerLogger is an interface used by the recovering handler to print logs. +type RecoveryHandlerLogger interface { + Println(...interface{}) +} + +type recoveryHandler struct { + handler http.Handler + logger RecoveryHandlerLogger + printStack bool +} + +// RecoveryOption provides a functional approach to define +// configuration for a handler; such as setting the logging +// whether or not to print strack traces on panic. +type RecoveryOption func(http.Handler) + +func parseRecoveryOptions(h http.Handler, opts ...RecoveryOption) http.Handler { + for _, option := range opts { + option(h) + } + + return h +} + +// RecoveryHandler is HTTP middleware that recovers from a panic, +// logs the panic, writes http.StatusInternalServerError, and +// continues to the next handler. +// +// Example: +// +// r := mux.NewRouter() +// r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { +// panic("Unexpected error!") +// }) +// +// http.ListenAndServe(":1123", handlers.RecoveryHandler()(r)) +func RecoveryHandler(opts ...RecoveryOption) func(h http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + r := &recoveryHandler{handler: h} + return parseRecoveryOptions(r, opts...) + } +} + +// RecoveryLogger is a functional option to override +// the default logger +func RecoveryLogger(logger RecoveryHandlerLogger) RecoveryOption { + return func(h http.Handler) { + r := h.(*recoveryHandler) + r.logger = logger + } +} + +// PrintRecoveryStack is a functional option to enable +// or disable printing stack traces on panic. +func PrintRecoveryStack(print bool) RecoveryOption { + return func(h http.Handler) { + r := h.(*recoveryHandler) + r.printStack = print + } +} + +func (h recoveryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + defer func() { + if err := recover(); err != nil { + w.WriteHeader(http.StatusInternalServerError) + h.log(err) + } + }() + + h.handler.ServeHTTP(w, req) +} + +func (h recoveryHandler) log(v ...interface{}) { + if h.logger != nil { + h.logger.Println(v...) + } else { + log.Println(v...) + } + + if h.printStack { + debug.PrintStack() + } +}