diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 7afb8851..6164dabb 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -14,6 +14,9 @@ import ( "syscall" "time" + gosundheit "github.com/AppsFlyer/go-sundheit" + "github.com/AppsFlyer/go-sundheit/checks" + gosundheithttp "github.com/AppsFlyer/go-sundheit/http" "github.com/ghodss/yaml" grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus" "github.com/oklog/run" @@ -272,6 +275,8 @@ func runServe(options serveOptions) error { // explicitly convert to UTC. now := func() time.Time { return time.Now().UTC() } + healthChecker := gosundheit.New() + serverConfig := server.Config{ SupportedResponseTypes: c.OAuth2.ResponseTypes, SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, @@ -284,6 +289,7 @@ func runServe(options serveOptions) error { Logger: logger, Now: now, PrometheusRegistry: prometheusRegistry, + HealthChecker: healthChecker, } if c.Expiry.SigningKeys != "" { signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys) @@ -322,12 +328,33 @@ func runServe(options serveOptions) error { return fmt.Errorf("failed to initialize server: %v", err) } - telemetryServ := http.NewServeMux() - telemetryServ.Handle("/metrics", promhttp.HandlerFor(prometheusRegistry, promhttp.HandlerOpts{})) + telemetryRouter := http.NewServeMux() + telemetryRouter.Handle("/metrics", promhttp.HandlerFor(prometheusRegistry, promhttp.HandlerOpts{})) + + // Configure health checker + { + handler := gosundheithttp.HandleHealthJSON(healthChecker) + telemetryRouter.Handle("/healthz", handler) + + // Kubernetes style health checks + telemetryRouter.HandleFunc("/healthz/live", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("ok")) + }) + telemetryRouter.Handle("/healthz/ready", handler) + } + + healthChecker.RegisterCheck(&gosundheit.Config{ + Check: &checks.CustomCheck{ + CheckName: "storage", + CheckFunc: storage.NewCustomHealthCheckFunc(serverConfig.Storage, serverConfig.Now), + }, + ExecutionPeriod: 15 * time.Second, + InitiallyPassing: true, + }) var gr run.Group if c.Telemetry.HTTP != "" { - telemetrySrv := &http.Server{Addr: c.Telemetry.HTTP, Handler: telemetryServ} + telemetrySrv := &http.Server{Addr: c.Telemetry.HTTP, Handler: telemetryRouter} defer telemetrySrv.Close() if err := listenAndShutdownGracefully(logger, &gr, telemetrySrv, "http/telemetry"); err != nil { diff --git a/go.mod b/go.mod index 2b279879..b08c64d4 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/dexidp/dex go 1.15 require ( + github.com/AppsFlyer/go-sundheit v0.3.1 github.com/Microsoft/hcsshim v0.8.14 // indirect github.com/beevik/etree v1.1.0 github.com/coreos/go-oidc/v3 v3.0.0 diff --git a/go.sum b/go.sum index 1147f6e5..9affe62c 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,8 @@ cloud.google.com/go/firestore v1.1.0/go.mod h1:ulACoGHTpvq5r8rxGJ4ddJZBZqakUQqCl cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= cloud.google.com/go/storage v1.0.0/go.mod h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +github.com/AppsFlyer/go-sundheit v0.3.1 h1:Zqnr3wV3WQmXonc234k9XZAoV2KHUHw3osR5k2iHQZE= +github.com/AppsFlyer/go-sundheit v0.3.1/go.mod h1:iZ8zWMS7idcvmqewf5mEymWWgoOiG/0WD4+aeh+heX4= github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78 h1:w+iIsaOQNcT7OZ575w+acHgRric5iCyQh+xv+KJ4HB8= github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78/go.mod h1:LmzpDX56iTiv29bbRTIsUNlaFfuhWRQBWjQdVyAevI8= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= @@ -102,6 +104,8 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7 github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= +github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= @@ -339,6 +343,8 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= diff --git a/server/handlers.go b/server/handlers.go index 348700df..eb65f490 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -1,7 +1,6 @@ package server import ( - "context" "crypto/sha256" "encoding/base64" "encoding/json" @@ -13,7 +12,6 @@ import ( "sort" "strconv" "strings" - "sync" "time" "github.com/coreos/go-oidc/v3/oidc" @@ -30,90 +28,6 @@ const ( CodeChallengeMethodS256 = "S256" ) -// newHealthChecker returns the healthz handler. The handler runs until the -// provided context is canceled. -func (s *Server) newHealthChecker(ctx context.Context) http.Handler { - h := &healthChecker{s: s} - - // Perform one health check synchronously so the returned handler returns - // valid data immediately. - h.runHealthCheck() - - go func() { - for { - select { - case <-ctx.Done(): - return - case <-time.After(time.Second * 15): - } - h.runHealthCheck() - } - }() - return h -} - -// healthChecker periodically performs health checks on server dependencies. -// Currently, it only checks that the storage layer is available. -type healthChecker struct { - s *Server - - // Result of the last health check: any error and the amount of time it took - // to query the storage. - mu sync.RWMutex - // Guarded by the mutex - err error - passed time.Duration -} - -// runHealthCheck performs a single health check and makes the result available -// for any clients performing and HTTP request against the healthChecker. -func (h *healthChecker) runHealthCheck() { - t := h.s.now() - err := checkStorageHealth(h.s.storage, h.s.now) - passed := h.s.now().Sub(t) - if err != nil { - h.s.logger.Errorf("Storage health check failed: %v", err) - } - - // Make sure to only hold the mutex to access the fields, and not while - // we're querying the storage object. - h.mu.Lock() - h.err = err - h.passed = passed - h.mu.Unlock() -} - -func checkStorageHealth(s storage.Storage, now func() time.Time) error { - a := storage.AuthRequest{ - ID: storage.NewID(), - ClientID: storage.NewID(), - - // Set a short expiry so if the delete fails this will be cleaned up quickly by garbage collection. - Expiry: now().Add(time.Minute), - } - - if err := s.CreateAuthRequest(a); err != nil { - return fmt.Errorf("create auth request: %v", err) - } - if err := s.DeleteAuthRequest(a.ID); err != nil { - return fmt.Errorf("delete auth request: %v", err) - } - return nil -} - -func (h *healthChecker) ServeHTTP(w http.ResponseWriter, r *http.Request) { - h.mu.RLock() - err := h.err - t := h.passed - h.mu.RUnlock() - - if err != nil { - h.s.renderError(r, w, http.StatusInternalServerError, "Health check failed.") - return - } - fmt.Fprintf(w, "Health check passed in %s", t) -} - func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) { // TODO(ericchiang): Cache this. keys, err := s.storage.GetKeys() diff --git a/server/handlers_test.go b/server/handlers_test.go index 4ca182f2..8ad59d94 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + gosundheit "github.com/AppsFlyer/go-sundheit" + "github.com/AppsFlyer/go-sundheit/checks" "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/mux" "github.com/stretchr/testify/require" @@ -33,20 +35,23 @@ func TestHandleHealth(t *testing.T) { } } -type badStorage struct { - storage.Storage -} - -func (b *badStorage) CreateAuthRequest(r storage.AuthRequest) error { - return errors.New("storage unavailable") -} - func TestHandleHealthFailure(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() httpServer, server := newTestServer(ctx, t, func(c *Config) { - c.Storage = &badStorage{c.Storage} + c.HealthChecker = gosundheit.New() + + c.HealthChecker.RegisterCheck(&gosundheit.Config{ + Check: &checks.CustomCheck{ + CheckName: "fail", + CheckFunc: func() (details interface{}, err error) { + return nil, errors.New("error") + }, + }, + InitiallyPassing: false, + ExecutionPeriod: 1 * time.Second, + }) }) defer httpServer.Close() diff --git a/server/server.go b/server/server.go index 6fd4d8b7..a79b7cfd 100644 --- a/server/server.go +++ b/server/server.go @@ -15,6 +15,7 @@ import ( "sync/atomic" "time" + gosundheit "github.com/AppsFlyer/go-sundheit" "github.com/felixge/httpsnoop" "github.com/gorilla/handlers" "github.com/gorilla/mux" @@ -93,6 +94,8 @@ type Config struct { Logger log.Logger PrometheusRegistry *prometheus.Registry + + HealthChecker gosundheit.Health } // WebConfig holds the server's frontend templates and asset configuration. @@ -333,7 +336,13 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) // "authproxy" connector. handleFunc("/callback/{connector}", s.handleConnectorCallback) handleFunc("/approval", s.handleApproval) - handle("/healthz", s.newHealthChecker(ctx)) + handle("/healthz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !c.HealthChecker.IsHealthy() { + s.renderError(r, w, http.StatusInternalServerError, "Health check failed.") + return + } + fmt.Fprintf(w, "Health check passed") + })) handlePrefix("/static", static) handlePrefix("/theme", theme) s.mux = r diff --git a/server/server_test.go b/server/server_test.go index 3a918434..87ca6c17 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + gosundheit "github.com/AppsFlyer/go-sundheit" "github.com/coreos/go-oidc/v3/oidc" "github.com/kylelemons/godebug/pretty" "github.com/prometheus/client_golang/prometheus" @@ -96,6 +97,7 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi }, Logger: logger, PrometheusRegistry: prometheus.NewRegistry(), + HealthChecker: gosundheit.New(), } if updateConfig != nil { updateConfig(&config) diff --git a/storage/health.go b/storage/health.go new file mode 100644 index 00000000..5df0373d --- /dev/null +++ b/storage/health.go @@ -0,0 +1,29 @@ +package storage + +import ( + "fmt" + "time" +) + +// NewCustomHealthCheckFunc returns a new health check function. +func NewCustomHealthCheckFunc(s Storage, now func() time.Time) func() (details interface{}, err error) { + return func() (details interface{}, err error) { + a := AuthRequest{ + ID: NewID(), + ClientID: NewID(), + + // Set a short expiry so if the delete fails this will be cleaned up quickly by garbage collection. + Expiry: now().Add(time.Minute), + } + + if err := s.CreateAuthRequest(a); err != nil { + return nil, fmt.Errorf("create auth request: %v", err) + } + + if err := s.DeleteAuthRequest(a.ID); err != nil { + return nil, fmt.Errorf("delete auth request: %v", err) + } + + return nil, nil + } +}