Merge pull request #602 from ericchiang/dev-add-garbage-collect-method-to-storage

dev branch: add garbage collect method to storage
This commit is contained in:
Eric Chiang 2016-10-12 22:08:53 -07:00 committed by GitHub
commit 5bec61d73f
20 changed files with 265 additions and 357 deletions

View file

@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
@ -124,7 +125,7 @@ func serve(cmd *cobra.Command, args []string) error {
EnablePasswordDB: c.EnablePasswordDB, EnablePasswordDB: c.EnablePasswordDB,
} }
serv, err := server.NewServer(serverConfig) serv, err := server.NewServer(context.Background(), serverConfig)
if err != nil { if err != nil {
return fmt.Errorf("initializing server: %v", err) return fmt.Errorf("initializing server: %v", err)
} }

View file

@ -143,6 +143,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
s.renderError(w, http.StatusInternalServerError, err.Type, err.Description) s.renderError(w, http.StatusInternalServerError, err.Type, err.Description)
return return
} }
authReq.Expiry = s.now().Add(time.Minute * 30)
if err := s.storage.CreateAuthRequest(authReq); err != nil { if err := s.storage.CreateAuthRequest(authReq); err != nil {
log.Printf("Failed to create authorization request: %v", err) log.Printf("Failed to create authorization request: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "") s.renderError(w, http.StatusInternalServerError, errServerError, "")
@ -342,7 +343,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
} }
func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) { func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) {
if authReq.Expiry.After(s.now()) { if s.now().After(authReq.Expiry) {
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request period has expired.") s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request period has expired.")
return return
} }
@ -373,7 +374,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
Nonce: authReq.Nonce, Nonce: authReq.Nonce,
Scopes: authReq.Scopes, Scopes: authReq.Scopes,
Claims: authReq.Claims, Claims: authReq.Claims,
Expiry: s.now().Add(time.Minute * 5), Expiry: s.now().Add(time.Minute * 30),
RedirectURI: authReq.RedirectURI, RedirectURI: authReq.RedirectURI,
} }
if err := s.storage.CreateAuthCode(code); err != nil { if err := s.storage.CreateAuthCode(code); err != nil {

View file

@ -4,10 +4,15 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"golang.org/x/net/context"
) )
func TestHandleHealth(t *testing.T) { func TestHandleHealth(t *testing.T) {
httpServer, server := newTestServer(t, nil) ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, server := newTestServer(t, ctx, nil)
defer httpServer.Close() defer httpServer.Close()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()

View file

@ -56,40 +56,34 @@ type keyRotater struct {
storage.Storage storage.Storage
strategy rotationStrategy strategy rotationStrategy
cancel context.CancelFunc now func() time.Time
now func() time.Time
} }
func storageWithKeyRotation(s storage.Storage, strategy rotationStrategy, now func() time.Time) storage.Storage { // startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled.
if now == nil { //
now = time.Now // The method blocks until after the first attempt to rotate keys has completed. That way
} // healthy storages will return from this call with valid keys.
ctx, cancel := context.WithCancel(context.Background()) func startKeyRotation(ctx context.Context, s storage.Storage, strategy rotationStrategy, now func() time.Time) {
rotater := keyRotater{s, strategy, cancel, now} rotater := keyRotater{s, strategy, now}
// Try to rotate immediately so properly configured storages will return a // Try to rotate immediately so properly configured storages will have keys.
// storage with keys.
if err := rotater.rotate(); err != nil { if err := rotater.rotate(); err != nil {
log.Printf("failed to rotate keys: %v", err) log.Printf("failed to rotate keys: %v", err)
} }
go func() { go func() {
select { for {
case <-ctx.Done(): select {
return case <-ctx.Done():
case <-time.After(time.Second * 30): return
if err := rotater.rotate(); err != nil { case <-time.After(strategy.period):
log.Printf("failed to rotate keys: %v", err) if err := rotater.rotate(); err != nil {
log.Printf("failed to rotate keys: %v", err)
}
} }
} }
}() }()
return rotater return
}
func (k keyRotater) Close() error {
k.cancel()
return k.Storage.Close()
} }
func (k keyRotater) rotate() error { func (k keyRotater) rotate() error {

View file

@ -11,6 +11,7 @@ import (
"time" "time"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.org/x/net/context"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -48,6 +49,8 @@ type Config struct {
RotateKeysAfter time.Duration // Defaults to 6 hours. RotateKeysAfter time.Duration // Defaults to 6 hours.
IDTokensValidFor time.Duration // Defaults to 24 hours IDTokensValidFor time.Duration // Defaults to 24 hours
GCFrequency time.Duration // Defaults to 5 minutes
// If specified, the server will use this function for determining time. // If specified, the server will use this function for determining time.
Now func() time.Time Now func() time.Time
@ -87,14 +90,14 @@ type Server struct {
} }
// NewServer constructs a server from the provided config. // NewServer constructs a server from the provided config.
func NewServer(c Config) (*Server, error) { func NewServer(ctx context.Context, c Config) (*Server, error) {
return newServer(c, defaultRotationStrategy( return newServer(ctx, c, defaultRotationStrategy(
value(c.RotateKeysAfter, 6*time.Hour), value(c.RotateKeysAfter, 6*time.Hour),
value(c.IDTokensValidFor, 24*time.Hour), value(c.IDTokensValidFor, 24*time.Hour),
)) ))
} }
func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) (*Server, error) {
issuerURL, err := url.Parse(c.Issuer) issuerURL, err := url.Parse(c.Issuer)
if err != nil { if err != nil {
return nil, fmt.Errorf("server: can't parse issuer URL") return nil, fmt.Errorf("server: can't parse issuer URL")
@ -138,14 +141,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
} }
s := &Server{ s := &Server{
issuerURL: *issuerURL, issuerURL: *issuerURL,
connectors: make(map[string]Connector), connectors: make(map[string]Connector),
storage: newKeyCacher( storage: newKeyCacher(c.Storage, now),
storageWithKeyRotation(
c.Storage, rotationStrategy, now,
),
now,
),
supportedResponseTypes: supported, supportedResponseTypes: supported,
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
skipApproval: c.SkipApprovalScreen, skipApproval: c.SkipApprovalScreen,
@ -179,6 +177,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
handleFunc("/healthz", s.handleHealth) handleFunc("/healthz", s.handleHealth)
s.mux = r s.mux = r
startKeyRotation(ctx, c.Storage, rotationStrategy, now)
startGarbageCollection(ctx, c.Storage, value(c.GCFrequency, 5*time.Minute), now)
return s, nil return s, nil
} }
@ -262,3 +263,21 @@ func (k *keyCacher) GetKeys() (storage.Keys, error) {
} }
return storageKeys, nil return storageKeys, nil
} }
func startGarbageCollection(ctx context.Context, s storage.Storage, frequency time.Duration, now func() time.Time) {
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(frequency):
if r, err := s.GarbageCollect(now()); err != nil {
log.Printf("garbage collection failed: %v", err)
} else {
log.Printf("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes)
}
}
}
}()
return
}

View file

@ -69,7 +69,7 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo= Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
-----END RSA PRIVATE KEY-----`) -----END RSA PRIVATE KEY-----`)
func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) { func newTestServer(t *testing.T, ctx context.Context, updateConfig func(c *Config)) (*httptest.Server, *Server) {
var server *Server var server *Server
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.ServeHTTP(w, r) server.ServeHTTP(w, r)
@ -91,7 +91,7 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server
s.URL = config.Issuer s.URL = config.Issuer
var err error var err error
if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil { if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
@ -99,14 +99,16 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server
} }
func TestNewTestServer(t *testing.T) { func TestNewTestServer(t *testing.T) {
newTestServer(t, nil) ctx, cancel := context.WithCancel(context.Background())
defer cancel()
newTestServer(t, ctx, nil)
} }
func TestDiscovery(t *testing.T) { func TestDiscovery(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
httpServer, _ := newTestServer(t, func(c *Config) { httpServer, _ := newTestServer(t, ctx, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path" c.Issuer = c.Issuer + "/non-root-path"
}) })
defer httpServer.Close() defer httpServer.Close()
@ -255,7 +257,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
httpServer, s := newTestServer(t, func(c *Config) { httpServer, s := newTestServer(t, ctx, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path" c.Issuer = c.Issuer + "/non-root-path"
}) })
defer httpServer.Close() defer httpServer.Close()
@ -368,7 +370,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
httpServer, s := newTestServer(t, func(c *Config) { httpServer, s := newTestServer(t, ctx, func(c *Config) {
// Enable support for the implicit flow. // Enable support for the implicit flow.
c.SupportedResponseTypes = []string{"code", "token"} c.SupportedResponseTypes = []string{"code", "token"}
}) })
@ -498,7 +500,7 @@ func TestCrossClientScopes(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
httpServer, s := newTestServer(t, func(c *Config) { httpServer, s := newTestServer(t, ctx, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path" c.Issuer = c.Issuer + "/non-root-path"
}) })
defer httpServer.Close() defer httpServer.Close()

View file

@ -18,12 +18,10 @@ import (
// ensure that values being tested on never expire. // ensure that values being tested on never expire.
var neverExpire = time.Now().UTC().Add(time.Hour * 24 * 365 * 100) var neverExpire = time.Now().UTC().Add(time.Hour * 24 * 365 * 100)
// StorageFactory is a method for creating a new storage. The returned storage sould be initialized // RunTests runs a set of conformance tests against a storage. newStorage should
// but shouldn't have any existing data in it. // return an initialized but empty storage. The storage will be closed at the
type StorageFactory func() storage.Storage // end of each test run.
func RunTests(t *testing.T, newStorage func() storage.Storage) {
// RunTestSuite runs a set of conformance tests against a storage.
func RunTestSuite(t *testing.T, sf StorageFactory) {
tests := []struct { tests := []struct {
name string name string
run func(t *testing.T, s storage.Storage) run func(t *testing.T, s storage.Storage)
@ -33,10 +31,13 @@ func RunTestSuite(t *testing.T, sf StorageFactory) {
{"ClientCRUD", testClientCRUD}, {"ClientCRUD", testClientCRUD},
{"RefreshTokenCRUD", testRefreshTokenCRUD}, {"RefreshTokenCRUD", testRefreshTokenCRUD},
{"PasswordCRUD", testPasswordCRUD}, {"PasswordCRUD", testPasswordCRUD},
{"GarbageCollection", testGC},
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
test.run(t, sf()) s := newStorage()
test.run(t, s)
s.Close()
}) })
} }
} }
@ -276,3 +277,92 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
t.Errorf("after deleting password expected storage.ErrNotFound, got %v", err) t.Errorf("after deleting password expected storage.ErrNotFound, got %v", err)
} }
} }
func testGC(t *testing.T, s storage.Storage) {
n := time.Now()
c := storage.AuthCode{
ID: storage.NewID(),
ClientID: "foobar",
RedirectURI: "https://localhost:80/callback",
Nonce: "foobar",
Scopes: []string{"openid", "email"},
Expiry: n.Add(time.Second),
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
}
if err := s.CreateAuthCode(c); err != nil {
t.Fatalf("failed creating auth code: %v", err)
}
if _, err := s.GarbageCollect(n); err != nil {
t.Errorf("garbage collection failed: %v", err)
}
if _, err := s.GetAuthCode(c.ID); err != nil {
t.Errorf("expected to be able to get auth code after GC: %v", err)
}
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.AuthCodes != 1 {
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
}
if _, err := s.GetAuthCode(c.ID); err == nil {
t.Errorf("expected auth code to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
a := storage.AuthRequest{
ID: storage.NewID(),
ClientID: "foobar",
ResponseTypes: []string{"code"},
Scopes: []string{"openid", "email"},
RedirectURI: "https://localhost:80/callback",
Nonce: "foo",
State: "bar",
ForceApprovalPrompt: true,
LoggedIn: true,
Expiry: n,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
}
if err := s.CreateAuthRequest(a); err != nil {
t.Fatalf("failed creating auth request: %v", err)
}
if _, err := s.GarbageCollect(n); err != nil {
t.Errorf("garbage collection failed: %v", err)
}
if _, err := s.GetAuthRequest(a.ID); err != nil {
t.Errorf("expected to be able to get auth code after GC: %v", err)
}
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.AuthRequests != 1 {
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
}
if _, err := s.GetAuthRequest(a.ID); err == nil {
t.Errorf("expected auth code to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
}

View file

@ -20,7 +20,6 @@ import (
"time" "time"
"github.com/gtank/cryptopasta" "github.com/gtank/cryptopasta"
"golang.org/x/net/context"
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
"github.com/coreos/dex/storage" "github.com/coreos/dex/storage"
@ -35,9 +34,6 @@ type client struct {
now func() time.Time now func() time.Time
// If not nil, the cancel function for stopping garbage colletion.
cancel context.CancelFunc
// BUG: currently each third party API group can only have one resource in it, // BUG: currently each third party API group can only have one resource in it,
// so for each resource this storage uses, it need a unique API group. // so for each resource this storage uses, it need a unique API group.
// //

View file

@ -1,58 +0,0 @@
package kubernetes
import (
"fmt"
"log"
"time"
"golang.org/x/net/context"
)
// gc begins the gc process for Kubernetes.
func (cli *client) gc(ctx context.Context, every time.Duration) {
handleErr := func(err error) { log.Println(err.Error()) }
for {
select {
case <-ctx.Done():
return
case <-time.After(every):
}
// TODO(ericchiang): On failures, run garbage collection more often.
log.Println("kubernetes: running garbage collection")
cli.gcAuthRequests(handleErr)
cli.gcAuthCodes(handleErr)
log.Printf("kubernetes: garbage collection finished, next run at %s", cli.now().Add(every))
}
}
func (cli *client) gcAuthRequests(handleErr func(error)) {
var authRequests AuthRequestList
if err := cli.list(resourceAuthRequest, &authRequests); err != nil {
handleErr(fmt.Errorf("failed to list auth requests: %v", err))
return
}
for _, authRequest := range authRequests.AuthRequests {
if cli.now().After(authRequest.Expiry) {
if err := cli.delete(resourceAuthRequest, authRequest.ObjectMeta.Name); err != nil {
handleErr(fmt.Errorf("failed to detele auth request: %v", err))
}
}
}
}
func (cli *client) gcAuthCodes(handleErr func(error)) {
var authCodes AuthCodeList
if err := cli.list(resourceAuthCode, &authCodes); err != nil {
handleErr(fmt.Errorf("failed to list auth codes: %v", err))
return
}
for _, authCode := range authCodes.AuthCodes {
if cli.now().After(authCode.Expiry) {
if err := cli.delete(resourceAuthCode, authCode.ObjectMeta.Name); err != nil {
handleErr(fmt.Errorf("failed to delete auth code: %v", err))
}
}
}
}

View file

@ -1,88 +0,0 @@
package kubernetes
import (
"testing"
"time"
"github.com/coreos/dex/storage"
)
func muster(t *testing.T) func(err error) {
return func(err error) {
if err != nil {
t.Fatal(err)
}
}
}
func TestGCAuthRequests(t *testing.T) {
cli := loadClient(t)
must := muster(t)
now := time.Now()
cli.now = func() time.Time { return now }
expiredID := storage.NewID()
goodID := storage.NewID()
must(cli.CreateAuthRequest(storage.AuthRequest{
ID: expiredID,
Expiry: now.Add(-time.Second),
}))
must(cli.CreateAuthRequest(storage.AuthRequest{
ID: goodID,
Expiry: now.Add(time.Second),
}))
handleErr := func(err error) { t.Error(err.Error()) }
cli.gcAuthRequests(handleErr)
if _, err := cli.GetAuthRequest(goodID); err != nil {
t.Errorf("failed to get good auth ID: %v", err)
}
_, err := cli.GetAuthRequest(expiredID)
switch {
case err == nil:
t.Errorf("gc did not remove expired auth request")
case err == storage.ErrNotFound:
default:
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
}
func TestGCAuthCodes(t *testing.T) {
cli := loadClient(t)
must := muster(t)
now := time.Now()
cli.now = func() time.Time { return now }
expiredID := storage.NewID()
goodID := storage.NewID()
must(cli.CreateAuthCode(storage.AuthCode{
ID: expiredID,
Expiry: now.Add(-time.Second),
}))
must(cli.CreateAuthCode(storage.AuthCode{
ID: goodID,
Expiry: now.Add(time.Second),
}))
handleErr := func(err error) { t.Error(err.Error()) }
cli.gcAuthCodes(handleErr)
if _, err := cli.GetAuthCode(goodID); err != nil {
t.Errorf("failed to get good auth ID: %v", err)
}
_, err := cli.GetAuthCode(expiredID)
switch {
case err == nil:
t.Errorf("gc did not remove expired auth request")
case err == storage.ErrNotFound:
default:
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
}

View file

@ -3,12 +3,12 @@ package kubernetes
import ( import (
"errors" "errors"
"fmt" "fmt"
"log"
"os" "os"
"path/filepath" "path/filepath"
"time" "time"
homedir "github.com/mitchellh/go-homedir" homedir "github.com/mitchellh/go-homedir"
"golang.org/x/net/context"
"github.com/coreos/dex/storage" "github.com/coreos/dex/storage"
"github.com/coreos/dex/storage/kubernetes/k8sapi" "github.com/coreos/dex/storage/kubernetes/k8sapi"
@ -46,14 +46,6 @@ func (c *Config) Open() (storage.Storage, error) {
return nil, err return nil, err
} }
// start up garbage collection
gcFrequency := c.GCFrequency
if gcFrequency == 0 {
gcFrequency = 600
}
ctx, cancel := context.WithCancel(context.Background())
cli.cancel = cancel
go cli.gc(ctx, time.Duration(gcFrequency)*time.Second)
return cli, nil return cli, nil
} }
@ -93,9 +85,6 @@ func (c *Config) open() (*client, error) {
} }
func (cli *client) Close() error { func (cli *client) Close() error {
if cli.cancel != nil {
cli.cancel()
}
return nil return nil
} }
@ -291,3 +280,40 @@ func (cli *client) UpdateAuthRequest(id string, updater func(a storage.AuthReque
newReq.ObjectMeta = req.ObjectMeta newReq.ObjectMeta = req.ObjectMeta
return cli.put(resourceAuthRequest, id, newReq) return cli.put(resourceAuthRequest, id, newReq)
} }
func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
var authRequests AuthRequestList
if err := cli.list(resourceAuthRequest, &authRequests); err != nil {
return result, fmt.Errorf("failed to list auth requests: %v", err)
}
var delErr error
for _, authRequest := range authRequests.AuthRequests {
if now.After(authRequest.Expiry) {
if err := cli.delete(resourceAuthRequest, authRequest.ObjectMeta.Name); err != nil {
log.Printf("failed to delete auth request: %v", err)
delErr = fmt.Errorf("failed to delete auth request: %v", err)
}
result.AuthRequests++
}
}
if delErr != nil {
return result, delErr
}
var authCodes AuthCodeList
if err := cli.list(resourceAuthCode, &authCodes); err != nil {
return result, fmt.Errorf("failed to list auth codes: %v", err)
}
for _, authCode := range authCodes.AuthCodes {
if now.After(authCode.Expiry) {
if err := cli.delete(resourceAuthCode, authCode.ObjectMeta.Name); err != nil {
log.Printf("failed to delete auth code %v", err)
delErr = fmt.Errorf("failed to delete auth code: %v", err)
}
result.AuthCodes++
}
}
return result, delErr
}

View file

@ -74,7 +74,7 @@ func TestURLFor(t *testing.T) {
func TestStorage(t *testing.T) { func TestStorage(t *testing.T) {
client := loadClient(t) client := loadClient(t)
conformance.RunTestSuite(t, func() storage.Storage { conformance.RunTests(t, func() storage.Storage {
for _, resource := range []string{ for _, resource := range []string{
resourceAuthCode, resourceAuthCode,
resourceAuthRequest, resourceAuthRequest,

View file

@ -4,6 +4,7 @@ package memory
import ( import (
"strings" "strings"
"sync" "sync"
"time"
"github.com/coreos/dex/storage" "github.com/coreos/dex/storage"
) )
@ -51,6 +52,24 @@ func (s *memStorage) tx(f func()) {
func (s *memStorage) Close() error { return nil } func (s *memStorage) Close() error { return nil }
func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
s.tx(func() {
for id, a := range s.authCodes {
if now.After(a.Expiry) {
delete(s.authCodes, id)
result.AuthCodes++
}
}
for id, a := range s.authReqs {
if now.After(a.Expiry) {
delete(s.authReqs, id)
result.AuthRequests++
}
}
})
return result, nil
}
func (s *memStorage) CreateClient(c storage.Client) (err error) { func (s *memStorage) CreateClient(c storage.Client) (err error) {
s.tx(func() { s.tx(func() {
if _, ok := s.clients[c.ID]; ok { if _, ok := s.clients[c.ID]; ok {
@ -240,29 +259,6 @@ func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) {
return return
} }
func (s *memStorage) ClaimCode(id string) (err error) {
s.tx(func() {
if _, ok := s.authCodes[id]; !ok {
err = storage.ErrNotFound
return
}
delete(s.authCodes, id)
})
return
}
func (s *memStorage) ClaimRefresh(refreshToken string) (token storage.RefreshToken, err error) {
s.tx(func() {
var ok bool
if token, ok = s.refreshTokens[refreshToken]; !ok {
err = storage.ErrNotFound
return
}
delete(s.refreshTokens, refreshToken)
})
return
}
func (s *memStorage) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) (err error) { func (s *memStorage) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) (err error) {
s.tx(func() { s.tx(func() {
client, ok := s.clients[id] client, ok := s.clients[id]

View file

@ -7,5 +7,5 @@ import (
) )
func TestStorage(t *testing.T) { func TestStorage(t *testing.T) {
conformance.RunTestSuite(t, New) conformance.RunTests(t, New)
} }

View file

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"strconv" "strconv"
"time"
"github.com/coreos/dex/storage" "github.com/coreos/dex/storage"
) )
@ -22,7 +21,7 @@ func (s *SQLite3) Open() (storage.Storage, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return withGC(conn, time.Now), nil return conn, nil
} }
func (s *SQLite3) open() (*conn, error) { func (s *SQLite3) open() (*conn, error) {
@ -76,7 +75,7 @@ func (p *Postgres) Open() (storage.Storage, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return withGC(conn, time.Now), nil return conn, nil
} }
func (p *Postgres) open() (*conn, error) { func (p *Postgres) open() (*conn, error) {

View file

@ -54,7 +54,7 @@ func TestSQLite3(t *testing.T) {
} }
withTimeout(time.Second*10, func() { withTimeout(time.Second*10, func() {
conformance.RunTestSuite(t, newStorage) conformance.RunTests(t, newStorage)
}) })
} }
@ -72,19 +72,24 @@ func TestPostgres(t *testing.T) {
}, },
ConnectionTimeout: 5, ConnectionTimeout: 5,
} }
conn, err := p.open()
if err != nil { // t.Fatal has a bad habbit of not actually printing the error
t.Fatal(err) fatal := func(i interface{}) {
fmt.Fprintln(os.Stdout, i)
t.Fatal(i)
} }
defer conn.Close()
newStorage := func() storage.Storage { newStorage := func() storage.Storage {
conn, err := p.open()
if err != nil {
fatal(err)
}
if err := cleanDB(conn); err != nil { if err := cleanDB(conn); err != nil {
t.Fatal(err) fatal(err)
} }
return conn return conn
} }
withTimeout(time.Minute*1, func() { withTimeout(time.Minute*1, func() {
conformance.RunTestSuite(t, newStorage) conformance.RunTests(t, newStorage)
}) })
} }

View file

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/coreos/dex/storage" "github.com/coreos/dex/storage"
) )
@ -83,6 +84,25 @@ type scanner interface {
Scan(dest ...interface{}) error Scan(dest ...interface{}) error
} }
func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
r, err := c.Exec(`delete from auth_request where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc auth_request: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.AuthRequests = n
}
r, err = c.Exec(`delete from auth_code where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc auth_code: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.AuthCodes = n
}
return
}
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
_, err := c.Exec(` _, err := c.Exec(`
insert into auth_request ( insert into auth_request (

View file

@ -1,53 +0,0 @@
package sql
import (
"context"
"fmt"
"log"
"time"
"github.com/coreos/dex/storage"
)
type gc struct {
now func() time.Time
conn *conn
}
func (gc gc) run() error {
for _, table := range []string{"auth_request", "auth_code"} {
_, err := gc.conn.Exec(`delete from `+table+` where expiry < $1`, gc.now())
if err != nil {
return fmt.Errorf("gc %s: %v", table, err)
}
// TODO(ericchiang): when we have levelled logging print how many rows were gc'd
}
return nil
}
type withCancel struct {
storage.Storage
cancel context.CancelFunc
}
func (w withCancel) Close() error {
w.cancel()
return w.Storage.Close()
}
func withGC(conn *conn, now func() time.Time) storage.Storage {
ctx, cancel := context.WithCancel(context.Background())
run := (gc{now, conn}).run
go func() {
for {
select {
case <-time.After(time.Second * 30):
if err := run(); err != nil {
log.Printf("gc failed: %v", err)
}
case <-ctx.Done():
}
}
}()
return withCancel{conn, cancel}
}

View file

@ -1,53 +0,0 @@
package sql
import (
"testing"
"time"
"github.com/coreos/dex/storage"
)
func TestGC(t *testing.T) {
// TODO(ericchiang): Add a GarbageCollect method to the storage interface so
// we can write conformance tests instead of directly testing each implementation.
s := &SQLite3{":memory:"}
conn, err := s.open()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
clock := time.Now()
now := func() time.Time { return clock }
runGC := (gc{now, conn}).run
a := storage.AuthRequest{
ID: storage.NewID(),
Expiry: now().Add(time.Second),
}
if err := conn.CreateAuthRequest(a); err != nil {
t.Fatal(err)
}
if err := runGC(); err != nil {
t.Errorf("gc failed: %v", err)
}
if _, err := conn.GetAuthRequest(a.ID); err != nil {
t.Errorf("failed to get auth request after gc: %v", err)
}
clock = clock.Add(time.Minute)
if err := runGC(); err != nil {
t.Errorf("gc failed: %v", err)
}
if _, err := conn.GetAuthRequest(a.ID); err == nil {
t.Errorf("expected error after gc'ing auth request: %v", err)
} else if err != storage.ErrNotFound {
t.Errorf("expected error storage.NotFound got: %v", err)
}
}

View file

@ -38,6 +38,12 @@ func NewID() string {
return strings.TrimRight(encoding.EncodeToString(buff), "=") return strings.TrimRight(encoding.EncodeToString(buff), "=")
} }
// GCResult returns the number of objects deleted by garbage collection.
type GCResult struct {
AuthRequests int64
AuthCodes int64
}
// Storage is the storage interface used by the server. Implementations, at minimum // Storage is the storage interface used by the server. Implementations, at minimum
// require compare-and-swap atomic actions. // require compare-and-swap atomic actions.
// //
@ -80,8 +86,8 @@ type Storage interface {
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
UpdatePassword(email string, updater func(p Password) (Password, error)) error UpdatePassword(email string, updater func(p Password) (Password, error)) error
// TODO(ericchiang): Add a GarbageCollect(now time.Time) method so conformance tests // GarbageCollect deletes all expired AuthCodes and AuthRequests.
// can test implementations. GarbageCollect(now time.Time) (GCResult, error)
} }
// Client represents an OAuth2 client. // Client represents an OAuth2 client.