diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 17a63613..69ea7932 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -9,6 +9,7 @@ import ( "net/http" "github.com/spf13/cobra" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/credentials" yaml "gopkg.in/yaml.v2" @@ -124,7 +125,7 @@ func serve(cmd *cobra.Command, args []string) error { EnablePasswordDB: c.EnablePasswordDB, } - serv, err := server.NewServer(serverConfig) + serv, err := server.NewServer(context.Background(), serverConfig) if err != nil { return fmt.Errorf("initializing server: %v", err) } diff --git a/server/handlers_test.go b/server/handlers_test.go index ce058022..d9f91888 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -4,10 +4,15 @@ import ( "net/http" "net/http/httptest" "testing" + + "golang.org/x/net/context" ) func TestHandleHealth(t *testing.T) { - httpServer, server := newTestServer(t, nil) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + httpServer, server := newTestServer(t, ctx, nil) defer httpServer.Close() rr := httptest.NewRecorder() diff --git a/server/rotation.go b/server/rotation.go index 37924486..a725deb2 100644 --- a/server/rotation.go +++ b/server/rotation.go @@ -56,40 +56,34 @@ type keyRotater struct { storage.Storage strategy rotationStrategy - cancel context.CancelFunc - - now func() time.Time + now func() time.Time } -func storageWithKeyRotation(s storage.Storage, strategy rotationStrategy, now func() time.Time) storage.Storage { - if now == nil { - now = time.Now - } - ctx, cancel := context.WithCancel(context.Background()) - rotater := keyRotater{s, strategy, cancel, now} +// startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled. +// +// The method blocks until after the first attempt to rotate keys has completed. That way +// healthy storages will return from this call with valid keys. +func startKeyRotation(ctx context.Context, s storage.Storage, strategy rotationStrategy, now func() time.Time) { + rotater := keyRotater{s, strategy, now} - // Try to rotate immediately so properly configured storages will return a - // storage with keys. + // Try to rotate immediately so properly configured storages will have keys. if err := rotater.rotate(); err != nil { log.Printf("failed to rotate keys: %v", err) } go func() { - select { - case <-ctx.Done(): - return - case <-time.After(time.Second * 30): - if err := rotater.rotate(); err != nil { - log.Printf("failed to rotate keys: %v", err) + for { + select { + case <-ctx.Done(): + return + case <-time.After(strategy.period): + if err := rotater.rotate(); err != nil { + log.Printf("failed to rotate keys: %v", err) + } } } }() - return rotater -} - -func (k keyRotater) Close() error { - k.cancel() - return k.Storage.Close() + return } func (k keyRotater) rotate() error { diff --git a/server/server.go b/server/server.go index 904d826e..e6825cb0 100644 --- a/server/server.go +++ b/server/server.go @@ -11,6 +11,7 @@ import ( "time" "golang.org/x/crypto/bcrypt" + "golang.org/x/net/context" "github.com/gorilla/mux" @@ -48,6 +49,8 @@ type Config struct { RotateKeysAfter time.Duration // Defaults to 6 hours. IDTokensValidFor time.Duration // Defaults to 24 hours + GCFrequency time.Duration // Defaults to 5 minutes + // If specified, the server will use this function for determining time. Now func() time.Time @@ -87,14 +90,14 @@ type Server struct { } // NewServer constructs a server from the provided config. -func NewServer(c Config) (*Server, error) { - return newServer(c, defaultRotationStrategy( +func NewServer(ctx context.Context, c Config) (*Server, error) { + return newServer(ctx, c, defaultRotationStrategy( value(c.RotateKeysAfter, 6*time.Hour), value(c.IDTokensValidFor, 24*time.Hour), )) } -func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { +func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) (*Server, error) { issuerURL, err := url.Parse(c.Issuer) if err != nil { return nil, fmt.Errorf("server: can't parse issuer URL") @@ -138,14 +141,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { } s := &Server{ - issuerURL: *issuerURL, - connectors: make(map[string]Connector), - storage: newKeyCacher( - storageWithKeyRotation( - c.Storage, rotationStrategy, now, - ), - now, - ), + issuerURL: *issuerURL, + connectors: make(map[string]Connector), + storage: newKeyCacher(c.Storage, now), supportedResponseTypes: supported, idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), skipApproval: c.SkipApprovalScreen, @@ -179,6 +177,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { handleFunc("/healthz", s.handleHealth) s.mux = r + startKeyRotation(ctx, c.Storage, rotationStrategy, now) + startGarbageCollection(ctx, c.Storage, value(c.GCFrequency, 5*time.Minute), now) + return s, nil } @@ -262,3 +263,21 @@ func (k *keyCacher) GetKeys() (storage.Keys, error) { } return storageKeys, nil } + +func startGarbageCollection(ctx context.Context, s storage.Storage, frequency time.Duration, now func() time.Time) { + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(frequency): + if r, err := s.GarbageCollect(now()); err != nil { + log.Printf("garbage collection failed: %v", err) + } else { + log.Printf("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes) + } + } + } + }() + return +} diff --git a/server/server_test.go b/server/server_test.go index 13f46ac1..fb383f00 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -69,7 +69,7 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo= -----END RSA PRIVATE KEY-----`) -func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) { +func newTestServer(t *testing.T, ctx context.Context, updateConfig func(c *Config)) (*httptest.Server, *Server) { var server *Server s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server.ServeHTTP(w, r) @@ -91,7 +91,7 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server s.URL = config.Issuer var err error - if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil { + if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil { t.Fatal(err) } server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. @@ -99,14 +99,16 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server } func TestNewTestServer(t *testing.T) { - newTestServer(t, nil) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + newTestServer(t, ctx, nil) } func TestDiscovery(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - httpServer, _ := newTestServer(t, func(c *Config) { + httpServer, _ := newTestServer(t, ctx, func(c *Config) { c.Issuer = c.Issuer + "/non-root-path" }) defer httpServer.Close() @@ -227,7 +229,7 @@ func TestOAuth2CodeFlow(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - httpServer, s := newTestServer(t, func(c *Config) { + httpServer, s := newTestServer(t, ctx, func(c *Config) { c.Issuer = c.Issuer + "/non-root-path" }) defer httpServer.Close() @@ -340,7 +342,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - httpServer, s := newTestServer(t, func(c *Config) { + httpServer, s := newTestServer(t, ctx, func(c *Config) { // Enable support for the implicit flow. c.SupportedResponseTypes = []string{"code", "token"} }) @@ -470,7 +472,7 @@ func TestCrossClientScopes(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - httpServer, s := newTestServer(t, func(c *Config) { + httpServer, s := newTestServer(t, ctx, func(c *Config) { c.Issuer = c.Issuer + "/non-root-path" }) defer httpServer.Close()