diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 17a63613..69ea7932 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -9,6 +9,7 @@ import ( "net/http" "github.com/spf13/cobra" + "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/credentials" yaml "gopkg.in/yaml.v2" @@ -124,7 +125,7 @@ func serve(cmd *cobra.Command, args []string) error { EnablePasswordDB: c.EnablePasswordDB, } - serv, err := server.NewServer(serverConfig) + serv, err := server.NewServer(context.Background(), serverConfig) if err != nil { return fmt.Errorf("initializing server: %v", err) } diff --git a/server/handlers.go b/server/handlers.go index ff9444fe..7c3d649d 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -143,6 +143,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { s.renderError(w, http.StatusInternalServerError, err.Type, err.Description) return } + authReq.Expiry = s.now().Add(time.Minute * 30) if err := s.storage.CreateAuthRequest(authReq); err != nil { log.Printf("Failed to create authorization request: %v", err) s.renderError(w, http.StatusInternalServerError, errServerError, "") @@ -342,7 +343,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { } func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) { - if authReq.Expiry.After(s.now()) { + if s.now().After(authReq.Expiry) { s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request period has expired.") return } @@ -373,7 +374,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe Nonce: authReq.Nonce, Scopes: authReq.Scopes, Claims: authReq.Claims, - Expiry: s.now().Add(time.Minute * 5), + Expiry: s.now().Add(time.Minute * 30), RedirectURI: authReq.RedirectURI, } if err := s.storage.CreateAuthCode(code); err != nil { diff --git a/server/handlers_test.go b/server/handlers_test.go index ce058022..d9f91888 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -4,10 +4,15 @@ import ( "net/http" "net/http/httptest" "testing" + + "golang.org/x/net/context" ) func TestHandleHealth(t *testing.T) { - httpServer, server := newTestServer(t, nil) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + httpServer, server := newTestServer(t, ctx, nil) defer httpServer.Close() rr := httptest.NewRecorder() diff --git a/server/rotation.go b/server/rotation.go index 37924486..a725deb2 100644 --- a/server/rotation.go +++ b/server/rotation.go @@ -56,40 +56,34 @@ type keyRotater struct { storage.Storage strategy rotationStrategy - cancel context.CancelFunc - - now func() time.Time + now func() time.Time } -func storageWithKeyRotation(s storage.Storage, strategy rotationStrategy, now func() time.Time) storage.Storage { - if now == nil { - now = time.Now - } - ctx, cancel := context.WithCancel(context.Background()) - rotater := keyRotater{s, strategy, cancel, now} +// startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled. +// +// The method blocks until after the first attempt to rotate keys has completed. That way +// healthy storages will return from this call with valid keys. +func startKeyRotation(ctx context.Context, s storage.Storage, strategy rotationStrategy, now func() time.Time) { + rotater := keyRotater{s, strategy, now} - // Try to rotate immediately so properly configured storages will return a - // storage with keys. + // Try to rotate immediately so properly configured storages will have keys. if err := rotater.rotate(); err != nil { log.Printf("failed to rotate keys: %v", err) } go func() { - select { - case <-ctx.Done(): - return - case <-time.After(time.Second * 30): - if err := rotater.rotate(); err != nil { - log.Printf("failed to rotate keys: %v", err) + for { + select { + case <-ctx.Done(): + return + case <-time.After(strategy.period): + if err := rotater.rotate(); err != nil { + log.Printf("failed to rotate keys: %v", err) + } } } }() - return rotater -} - -func (k keyRotater) Close() error { - k.cancel() - return k.Storage.Close() + return } func (k keyRotater) rotate() error { diff --git a/server/server.go b/server/server.go index 904d826e..e6825cb0 100644 --- a/server/server.go +++ b/server/server.go @@ -11,6 +11,7 @@ import ( "time" "golang.org/x/crypto/bcrypt" + "golang.org/x/net/context" "github.com/gorilla/mux" @@ -48,6 +49,8 @@ type Config struct { RotateKeysAfter time.Duration // Defaults to 6 hours. IDTokensValidFor time.Duration // Defaults to 24 hours + GCFrequency time.Duration // Defaults to 5 minutes + // If specified, the server will use this function for determining time. Now func() time.Time @@ -87,14 +90,14 @@ type Server struct { } // NewServer constructs a server from the provided config. -func NewServer(c Config) (*Server, error) { - return newServer(c, defaultRotationStrategy( +func NewServer(ctx context.Context, c Config) (*Server, error) { + return newServer(ctx, c, defaultRotationStrategy( value(c.RotateKeysAfter, 6*time.Hour), value(c.IDTokensValidFor, 24*time.Hour), )) } -func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { +func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) (*Server, error) { issuerURL, err := url.Parse(c.Issuer) if err != nil { return nil, fmt.Errorf("server: can't parse issuer URL") @@ -138,14 +141,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { } s := &Server{ - issuerURL: *issuerURL, - connectors: make(map[string]Connector), - storage: newKeyCacher( - storageWithKeyRotation( - c.Storage, rotationStrategy, now, - ), - now, - ), + issuerURL: *issuerURL, + connectors: make(map[string]Connector), + storage: newKeyCacher(c.Storage, now), supportedResponseTypes: supported, idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), skipApproval: c.SkipApprovalScreen, @@ -179,6 +177,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { handleFunc("/healthz", s.handleHealth) s.mux = r + startKeyRotation(ctx, c.Storage, rotationStrategy, now) + startGarbageCollection(ctx, c.Storage, value(c.GCFrequency, 5*time.Minute), now) + return s, nil } @@ -262,3 +263,21 @@ func (k *keyCacher) GetKeys() (storage.Keys, error) { } return storageKeys, nil } + +func startGarbageCollection(ctx context.Context, s storage.Storage, frequency time.Duration, now func() time.Time) { + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(frequency): + if r, err := s.GarbageCollect(now()); err != nil { + log.Printf("garbage collection failed: %v", err) + } else { + log.Printf("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes) + } + } + } + }() + return +} diff --git a/server/server_test.go b/server/server_test.go index 35bcf8a3..62fe38e6 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -69,7 +69,7 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo= -----END RSA PRIVATE KEY-----`) -func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) { +func newTestServer(t *testing.T, ctx context.Context, updateConfig func(c *Config)) (*httptest.Server, *Server) { var server *Server s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server.ServeHTTP(w, r) @@ -91,7 +91,7 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server s.URL = config.Issuer var err error - if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil { + if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil { t.Fatal(err) } server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. @@ -99,14 +99,16 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server } func TestNewTestServer(t *testing.T) { - newTestServer(t, nil) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + newTestServer(t, ctx, nil) } func TestDiscovery(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - httpServer, _ := newTestServer(t, func(c *Config) { + httpServer, _ := newTestServer(t, ctx, func(c *Config) { c.Issuer = c.Issuer + "/non-root-path" }) defer httpServer.Close() @@ -255,7 +257,7 @@ func TestOAuth2CodeFlow(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - httpServer, s := newTestServer(t, func(c *Config) { + httpServer, s := newTestServer(t, ctx, func(c *Config) { c.Issuer = c.Issuer + "/non-root-path" }) defer httpServer.Close() @@ -368,7 +370,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - httpServer, s := newTestServer(t, func(c *Config) { + httpServer, s := newTestServer(t, ctx, func(c *Config) { // Enable support for the implicit flow. c.SupportedResponseTypes = []string{"code", "token"} }) @@ -498,7 +500,7 @@ func TestCrossClientScopes(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - httpServer, s := newTestServer(t, func(c *Config) { + httpServer, s := newTestServer(t, ctx, func(c *Config) { c.Issuer = c.Issuer + "/non-root-path" }) defer httpServer.Close() diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index a2458680..dcc765c9 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -18,12 +18,10 @@ import ( // ensure that values being tested on never expire. var neverExpire = time.Now().UTC().Add(time.Hour * 24 * 365 * 100) -// StorageFactory is a method for creating a new storage. The returned storage sould be initialized -// but shouldn't have any existing data in it. -type StorageFactory func() storage.Storage - -// RunTestSuite runs a set of conformance tests against a storage. -func RunTestSuite(t *testing.T, sf StorageFactory) { +// RunTests runs a set of conformance tests against a storage. newStorage should +// return an initialized but empty storage. The storage will be closed at the +// end of each test run. +func RunTests(t *testing.T, newStorage func() storage.Storage) { tests := []struct { name string run func(t *testing.T, s storage.Storage) @@ -33,10 +31,13 @@ func RunTestSuite(t *testing.T, sf StorageFactory) { {"ClientCRUD", testClientCRUD}, {"RefreshTokenCRUD", testRefreshTokenCRUD}, {"PasswordCRUD", testPasswordCRUD}, + {"GarbageCollection", testGC}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - test.run(t, sf()) + s := newStorage() + test.run(t, s) + s.Close() }) } } @@ -276,3 +277,92 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { t.Errorf("after deleting password expected storage.ErrNotFound, got %v", err) } } + +func testGC(t *testing.T, s storage.Storage) { + n := time.Now() + c := storage.AuthCode{ + ID: storage.NewID(), + ClientID: "foobar", + RedirectURI: "https://localhost:80/callback", + Nonce: "foobar", + Scopes: []string{"openid", "email"}, + Expiry: n.Add(time.Second), + ConnectorID: "ldap", + ConnectorData: []byte(`{"some":"data"}`), + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + } + + if err := s.CreateAuthCode(c); err != nil { + t.Fatalf("failed creating auth code: %v", err) + } + + if _, err := s.GarbageCollect(n); err != nil { + t.Errorf("garbage collection failed: %v", err) + } + if _, err := s.GetAuthCode(c.ID); err != nil { + t.Errorf("expected to be able to get auth code after GC: %v", err) + } + + if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil { + t.Errorf("garbage collection failed: %v", err) + } else if r.AuthCodes != 1 { + t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes) + } + + if _, err := s.GetAuthCode(c.ID); err == nil { + t.Errorf("expected auth code to be GC'd") + } else if err != storage.ErrNotFound { + t.Errorf("expected storage.ErrNotFound, got %v", err) + } + + a := storage.AuthRequest{ + ID: storage.NewID(), + ClientID: "foobar", + ResponseTypes: []string{"code"}, + Scopes: []string{"openid", "email"}, + RedirectURI: "https://localhost:80/callback", + Nonce: "foo", + State: "bar", + ForceApprovalPrompt: true, + LoggedIn: true, + Expiry: n, + ConnectorID: "ldap", + ConnectorData: []byte(`{"some":"data"}`), + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + } + + if err := s.CreateAuthRequest(a); err != nil { + t.Fatalf("failed creating auth request: %v", err) + } + + if _, err := s.GarbageCollect(n); err != nil { + t.Errorf("garbage collection failed: %v", err) + } + if _, err := s.GetAuthRequest(a.ID); err != nil { + t.Errorf("expected to be able to get auth code after GC: %v", err) + } + + if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil { + t.Errorf("garbage collection failed: %v", err) + } else if r.AuthRequests != 1 { + t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests) + } + + if _, err := s.GetAuthRequest(a.ID); err == nil { + t.Errorf("expected auth code to be GC'd") + } else if err != storage.ErrNotFound { + t.Errorf("expected storage.ErrNotFound, got %v", err) + } +} diff --git a/storage/kubernetes/client.go b/storage/kubernetes/client.go index ec703214..2c0910a7 100644 --- a/storage/kubernetes/client.go +++ b/storage/kubernetes/client.go @@ -20,7 +20,6 @@ import ( "time" "github.com/gtank/cryptopasta" - "golang.org/x/net/context" yaml "gopkg.in/yaml.v2" "github.com/coreos/dex/storage" @@ -35,9 +34,6 @@ type client struct { now func() time.Time - // If not nil, the cancel function for stopping garbage colletion. - cancel context.CancelFunc - // BUG: currently each third party API group can only have one resource in it, // so for each resource this storage uses, it need a unique API group. // diff --git a/storage/kubernetes/garbage_collection.go b/storage/kubernetes/garbage_collection.go deleted file mode 100644 index b58b0c89..00000000 --- a/storage/kubernetes/garbage_collection.go +++ /dev/null @@ -1,58 +0,0 @@ -package kubernetes - -import ( - "fmt" - "log" - "time" - - "golang.org/x/net/context" -) - -// gc begins the gc process for Kubernetes. -func (cli *client) gc(ctx context.Context, every time.Duration) { - handleErr := func(err error) { log.Println(err.Error()) } - - for { - select { - case <-ctx.Done(): - return - case <-time.After(every): - } - - // TODO(ericchiang): On failures, run garbage collection more often. - log.Println("kubernetes: running garbage collection") - cli.gcAuthRequests(handleErr) - cli.gcAuthCodes(handleErr) - log.Printf("kubernetes: garbage collection finished, next run at %s", cli.now().Add(every)) - } -} - -func (cli *client) gcAuthRequests(handleErr func(error)) { - var authRequests AuthRequestList - if err := cli.list(resourceAuthRequest, &authRequests); err != nil { - handleErr(fmt.Errorf("failed to list auth requests: %v", err)) - return - } - for _, authRequest := range authRequests.AuthRequests { - if cli.now().After(authRequest.Expiry) { - if err := cli.delete(resourceAuthRequest, authRequest.ObjectMeta.Name); err != nil { - handleErr(fmt.Errorf("failed to detele auth request: %v", err)) - } - } - } -} - -func (cli *client) gcAuthCodes(handleErr func(error)) { - var authCodes AuthCodeList - if err := cli.list(resourceAuthCode, &authCodes); err != nil { - handleErr(fmt.Errorf("failed to list auth codes: %v", err)) - return - } - for _, authCode := range authCodes.AuthCodes { - if cli.now().After(authCode.Expiry) { - if err := cli.delete(resourceAuthCode, authCode.ObjectMeta.Name); err != nil { - handleErr(fmt.Errorf("failed to delete auth code: %v", err)) - } - } - } -} diff --git a/storage/kubernetes/garbage_collection_test.go b/storage/kubernetes/garbage_collection_test.go deleted file mode 100644 index bb725683..00000000 --- a/storage/kubernetes/garbage_collection_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package kubernetes - -import ( - "testing" - "time" - - "github.com/coreos/dex/storage" -) - -func muster(t *testing.T) func(err error) { - return func(err error) { - if err != nil { - t.Fatal(err) - } - } -} - -func TestGCAuthRequests(t *testing.T) { - cli := loadClient(t) - must := muster(t) - - now := time.Now() - cli.now = func() time.Time { return now } - - expiredID := storage.NewID() - goodID := storage.NewID() - - must(cli.CreateAuthRequest(storage.AuthRequest{ - ID: expiredID, - Expiry: now.Add(-time.Second), - })) - - must(cli.CreateAuthRequest(storage.AuthRequest{ - ID: goodID, - Expiry: now.Add(time.Second), - })) - - handleErr := func(err error) { t.Error(err.Error()) } - cli.gcAuthRequests(handleErr) - - if _, err := cli.GetAuthRequest(goodID); err != nil { - t.Errorf("failed to get good auth ID: %v", err) - } - _, err := cli.GetAuthRequest(expiredID) - switch { - case err == nil: - t.Errorf("gc did not remove expired auth request") - case err == storage.ErrNotFound: - default: - t.Errorf("expected storage.ErrNotFound, got %v", err) - } -} - -func TestGCAuthCodes(t *testing.T) { - cli := loadClient(t) - must := muster(t) - - now := time.Now() - cli.now = func() time.Time { return now } - - expiredID := storage.NewID() - goodID := storage.NewID() - - must(cli.CreateAuthCode(storage.AuthCode{ - ID: expiredID, - Expiry: now.Add(-time.Second), - })) - - must(cli.CreateAuthCode(storage.AuthCode{ - ID: goodID, - Expiry: now.Add(time.Second), - })) - - handleErr := func(err error) { t.Error(err.Error()) } - cli.gcAuthCodes(handleErr) - - if _, err := cli.GetAuthCode(goodID); err != nil { - t.Errorf("failed to get good auth ID: %v", err) - } - _, err := cli.GetAuthCode(expiredID) - switch { - case err == nil: - t.Errorf("gc did not remove expired auth request") - case err == storage.ErrNotFound: - default: - t.Errorf("expected storage.ErrNotFound, got %v", err) - } -} diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index 44920f6b..178a90db 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -3,12 +3,12 @@ package kubernetes import ( "errors" "fmt" + "log" "os" "path/filepath" "time" homedir "github.com/mitchellh/go-homedir" - "golang.org/x/net/context" "github.com/coreos/dex/storage" "github.com/coreos/dex/storage/kubernetes/k8sapi" @@ -46,14 +46,6 @@ func (c *Config) Open() (storage.Storage, error) { return nil, err } - // start up garbage collection - gcFrequency := c.GCFrequency - if gcFrequency == 0 { - gcFrequency = 600 - } - ctx, cancel := context.WithCancel(context.Background()) - cli.cancel = cancel - go cli.gc(ctx, time.Duration(gcFrequency)*time.Second) return cli, nil } @@ -93,9 +85,6 @@ func (c *Config) open() (*client, error) { } func (cli *client) Close() error { - if cli.cancel != nil { - cli.cancel() - } return nil } @@ -291,3 +280,40 @@ func (cli *client) UpdateAuthRequest(id string, updater func(a storage.AuthReque newReq.ObjectMeta = req.ObjectMeta return cli.put(resourceAuthRequest, id, newReq) } + +func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err error) { + var authRequests AuthRequestList + if err := cli.list(resourceAuthRequest, &authRequests); err != nil { + return result, fmt.Errorf("failed to list auth requests: %v", err) + } + + var delErr error + for _, authRequest := range authRequests.AuthRequests { + if now.After(authRequest.Expiry) { + if err := cli.delete(resourceAuthRequest, authRequest.ObjectMeta.Name); err != nil { + log.Printf("failed to delete auth request: %v", err) + delErr = fmt.Errorf("failed to delete auth request: %v", err) + } + result.AuthRequests++ + } + } + if delErr != nil { + return result, delErr + } + + var authCodes AuthCodeList + if err := cli.list(resourceAuthCode, &authCodes); err != nil { + return result, fmt.Errorf("failed to list auth codes: %v", err) + } + + for _, authCode := range authCodes.AuthCodes { + if now.After(authCode.Expiry) { + if err := cli.delete(resourceAuthCode, authCode.ObjectMeta.Name); err != nil { + log.Printf("failed to delete auth code %v", err) + delErr = fmt.Errorf("failed to delete auth code: %v", err) + } + result.AuthCodes++ + } + } + return result, delErr +} diff --git a/storage/kubernetes/storage_test.go b/storage/kubernetes/storage_test.go index f41b01b1..043a1e9f 100644 --- a/storage/kubernetes/storage_test.go +++ b/storage/kubernetes/storage_test.go @@ -74,7 +74,7 @@ func TestURLFor(t *testing.T) { func TestStorage(t *testing.T) { client := loadClient(t) - conformance.RunTestSuite(t, func() storage.Storage { + conformance.RunTests(t, func() storage.Storage { for _, resource := range []string{ resourceAuthCode, resourceAuthRequest, diff --git a/storage/memory/memory.go b/storage/memory/memory.go index b85d68df..df88b442 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -4,6 +4,7 @@ package memory import ( "strings" "sync" + "time" "github.com/coreos/dex/storage" ) @@ -51,6 +52,24 @@ func (s *memStorage) tx(f func()) { func (s *memStorage) Close() error { return nil } +func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err error) { + s.tx(func() { + for id, a := range s.authCodes { + if now.After(a.Expiry) { + delete(s.authCodes, id) + result.AuthCodes++ + } + } + for id, a := range s.authReqs { + if now.After(a.Expiry) { + delete(s.authReqs, id) + result.AuthRequests++ + } + } + }) + return result, nil +} + func (s *memStorage) CreateClient(c storage.Client) (err error) { s.tx(func() { if _, ok := s.clients[c.ID]; ok { @@ -240,29 +259,6 @@ func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) { return } -func (s *memStorage) ClaimCode(id string) (err error) { - s.tx(func() { - if _, ok := s.authCodes[id]; !ok { - err = storage.ErrNotFound - return - } - delete(s.authCodes, id) - }) - return -} - -func (s *memStorage) ClaimRefresh(refreshToken string) (token storage.RefreshToken, err error) { - s.tx(func() { - var ok bool - if token, ok = s.refreshTokens[refreshToken]; !ok { - err = storage.ErrNotFound - return - } - delete(s.refreshTokens, refreshToken) - }) - return -} - func (s *memStorage) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) (err error) { s.tx(func() { client, ok := s.clients[id] diff --git a/storage/memory/memory_test.go b/storage/memory/memory_test.go index 56c5f93d..faa76283 100644 --- a/storage/memory/memory_test.go +++ b/storage/memory/memory_test.go @@ -7,5 +7,5 @@ import ( ) func TestStorage(t *testing.T) { - conformance.RunTestSuite(t, New) + conformance.RunTests(t, New) } diff --git a/storage/sql/config.go b/storage/sql/config.go index 4bff016b..f8cdf248 100644 --- a/storage/sql/config.go +++ b/storage/sql/config.go @@ -5,7 +5,6 @@ import ( "fmt" "net/url" "strconv" - "time" "github.com/coreos/dex/storage" ) @@ -22,7 +21,7 @@ func (s *SQLite3) Open() (storage.Storage, error) { if err != nil { return nil, err } - return withGC(conn, time.Now), nil + return conn, nil } func (s *SQLite3) open() (*conn, error) { @@ -76,7 +75,7 @@ func (p *Postgres) Open() (storage.Storage, error) { if err != nil { return nil, err } - return withGC(conn, time.Now), nil + return conn, nil } func (p *Postgres) open() (*conn, error) { diff --git a/storage/sql/config_test.go b/storage/sql/config_test.go index 55a9d69c..169e89ae 100644 --- a/storage/sql/config_test.go +++ b/storage/sql/config_test.go @@ -54,7 +54,7 @@ func TestSQLite3(t *testing.T) { } withTimeout(time.Second*10, func() { - conformance.RunTestSuite(t, newStorage) + conformance.RunTests(t, newStorage) }) } @@ -72,19 +72,24 @@ func TestPostgres(t *testing.T) { }, ConnectionTimeout: 5, } - conn, err := p.open() - if err != nil { - t.Fatal(err) + + // t.Fatal has a bad habbit of not actually printing the error + fatal := func(i interface{}) { + fmt.Fprintln(os.Stdout, i) + t.Fatal(i) } - defer conn.Close() newStorage := func() storage.Storage { + conn, err := p.open() + if err != nil { + fatal(err) + } if err := cleanDB(conn); err != nil { - t.Fatal(err) + fatal(err) } return conn } withTimeout(time.Minute*1, func() { - conformance.RunTestSuite(t, newStorage) + conformance.RunTests(t, newStorage) }) } diff --git a/storage/sql/crud.go b/storage/sql/crud.go index ca941f7c..0f8858de 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "strings" + "time" "github.com/coreos/dex/storage" ) @@ -83,6 +84,25 @@ type scanner interface { Scan(dest ...interface{}) error } +func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) { + r, err := c.Exec(`delete from auth_request where expiry < $1`, now) + if err != nil { + return result, fmt.Errorf("gc auth_request: %v", err) + } + if n, err := r.RowsAffected(); err == nil { + result.AuthRequests = n + } + + r, err = c.Exec(`delete from auth_code where expiry < $1`, now) + if err != nil { + return result, fmt.Errorf("gc auth_code: %v", err) + } + if n, err := r.RowsAffected(); err == nil { + result.AuthCodes = n + } + return +} + func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { _, err := c.Exec(` insert into auth_request ( diff --git a/storage/sql/gc.go b/storage/sql/gc.go deleted file mode 100644 index 1636c087..00000000 --- a/storage/sql/gc.go +++ /dev/null @@ -1,53 +0,0 @@ -package sql - -import ( - "context" - "fmt" - "log" - "time" - - "github.com/coreos/dex/storage" -) - -type gc struct { - now func() time.Time - conn *conn -} - -func (gc gc) run() error { - for _, table := range []string{"auth_request", "auth_code"} { - _, err := gc.conn.Exec(`delete from `+table+` where expiry < $1`, gc.now()) - if err != nil { - return fmt.Errorf("gc %s: %v", table, err) - } - // TODO(ericchiang): when we have levelled logging print how many rows were gc'd - } - return nil -} - -type withCancel struct { - storage.Storage - cancel context.CancelFunc -} - -func (w withCancel) Close() error { - w.cancel() - return w.Storage.Close() -} - -func withGC(conn *conn, now func() time.Time) storage.Storage { - ctx, cancel := context.WithCancel(context.Background()) - run := (gc{now, conn}).run - go func() { - for { - select { - case <-time.After(time.Second * 30): - if err := run(); err != nil { - log.Printf("gc failed: %v", err) - } - case <-ctx.Done(): - } - } - }() - return withCancel{conn, cancel} -} diff --git a/storage/sql/gc_test.go b/storage/sql/gc_test.go deleted file mode 100644 index ad6097e0..00000000 --- a/storage/sql/gc_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package sql - -import ( - "testing" - "time" - - "github.com/coreos/dex/storage" -) - -func TestGC(t *testing.T) { - // TODO(ericchiang): Add a GarbageCollect method to the storage interface so - // we can write conformance tests instead of directly testing each implementation. - s := &SQLite3{":memory:"} - conn, err := s.open() - if err != nil { - t.Fatal(err) - } - defer conn.Close() - - clock := time.Now() - now := func() time.Time { return clock } - - runGC := (gc{now, conn}).run - - a := storage.AuthRequest{ - ID: storage.NewID(), - Expiry: now().Add(time.Second), - } - - if err := conn.CreateAuthRequest(a); err != nil { - t.Fatal(err) - } - - if err := runGC(); err != nil { - t.Errorf("gc failed: %v", err) - } - - if _, err := conn.GetAuthRequest(a.ID); err != nil { - t.Errorf("failed to get auth request after gc: %v", err) - } - - clock = clock.Add(time.Minute) - - if err := runGC(); err != nil { - t.Errorf("gc failed: %v", err) - } - - if _, err := conn.GetAuthRequest(a.ID); err == nil { - t.Errorf("expected error after gc'ing auth request: %v", err) - } else if err != storage.ErrNotFound { - t.Errorf("expected error storage.NotFound got: %v", err) - } -} diff --git a/storage/storage.go b/storage/storage.go index 78c162f9..032ba4ed 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -38,6 +38,12 @@ func NewID() string { return strings.TrimRight(encoding.EncodeToString(buff), "=") } +// GCResult returns the number of objects deleted by garbage collection. +type GCResult struct { + AuthRequests int64 + AuthCodes int64 +} + // Storage is the storage interface used by the server. Implementations, at minimum // require compare-and-swap atomic actions. // @@ -80,8 +86,8 @@ type Storage interface { UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error UpdatePassword(email string, updater func(p Password) (Password, error)) error - // TODO(ericchiang): Add a GarbageCollect(now time.Time) method so conformance tests - // can test implementations. + // GarbageCollect deletes all expired AuthCodes and AuthRequests. + GarbageCollect(now time.Time) (GCResult, error) } // Client represents an OAuth2 client.