forked from mystiq/dex
Merge pull request #602 from ericchiang/dev-add-garbage-collect-method-to-storage
dev branch: add garbage collect method to storage
This commit is contained in:
commit
5bec61d73f
20 changed files with 265 additions and 357 deletions
|
@ -9,6 +9,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/net/context"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
yaml "gopkg.in/yaml.v2"
|
yaml "gopkg.in/yaml.v2"
|
||||||
|
@ -124,7 +125,7 @@ func serve(cmd *cobra.Command, args []string) error {
|
||||||
EnablePasswordDB: c.EnablePasswordDB,
|
EnablePasswordDB: c.EnablePasswordDB,
|
||||||
}
|
}
|
||||||
|
|
||||||
serv, err := server.NewServer(serverConfig)
|
serv, err := server.NewServer(context.Background(), serverConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("initializing server: %v", err)
|
return fmt.Errorf("initializing server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -143,6 +143,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||||
s.renderError(w, http.StatusInternalServerError, err.Type, err.Description)
|
s.renderError(w, http.StatusInternalServerError, err.Type, err.Description)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
authReq.Expiry = s.now().Add(time.Minute * 30)
|
||||||
if err := s.storage.CreateAuthRequest(authReq); err != nil {
|
if err := s.storage.CreateAuthRequest(authReq); err != nil {
|
||||||
log.Printf("Failed to create authorization request: %v", err)
|
log.Printf("Failed to create authorization request: %v", err)
|
||||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||||
|
@ -342,7 +343,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) {
|
func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) {
|
||||||
if authReq.Expiry.After(s.now()) {
|
if s.now().After(authReq.Expiry) {
|
||||||
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request period has expired.")
|
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request period has expired.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -373,7 +374,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
|
||||||
Nonce: authReq.Nonce,
|
Nonce: authReq.Nonce,
|
||||||
Scopes: authReq.Scopes,
|
Scopes: authReq.Scopes,
|
||||||
Claims: authReq.Claims,
|
Claims: authReq.Claims,
|
||||||
Expiry: s.now().Add(time.Minute * 5),
|
Expiry: s.now().Add(time.Minute * 30),
|
||||||
RedirectURI: authReq.RedirectURI,
|
RedirectURI: authReq.RedirectURI,
|
||||||
}
|
}
|
||||||
if err := s.storage.CreateAuthCode(code); err != nil {
|
if err := s.storage.CreateAuthCode(code); err != nil {
|
||||||
|
|
|
@ -4,10 +4,15 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHandleHealth(t *testing.T) {
|
func TestHandleHealth(t *testing.T) {
|
||||||
httpServer, server := newTestServer(t, nil)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
httpServer, server := newTestServer(t, ctx, nil)
|
||||||
defer httpServer.Close()
|
defer httpServer.Close()
|
||||||
|
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
|
@ -56,40 +56,34 @@ type keyRotater struct {
|
||||||
storage.Storage
|
storage.Storage
|
||||||
|
|
||||||
strategy rotationStrategy
|
strategy rotationStrategy
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
now func() time.Time
|
now func() time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func storageWithKeyRotation(s storage.Storage, strategy rotationStrategy, now func() time.Time) storage.Storage {
|
// startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled.
|
||||||
if now == nil {
|
//
|
||||||
now = time.Now
|
// The method blocks until after the first attempt to rotate keys has completed. That way
|
||||||
}
|
// healthy storages will return from this call with valid keys.
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
func startKeyRotation(ctx context.Context, s storage.Storage, strategy rotationStrategy, now func() time.Time) {
|
||||||
rotater := keyRotater{s, strategy, cancel, now}
|
rotater := keyRotater{s, strategy, now}
|
||||||
|
|
||||||
// Try to rotate immediately so properly configured storages will return a
|
// Try to rotate immediately so properly configured storages will have keys.
|
||||||
// storage with keys.
|
|
||||||
if err := rotater.rotate(); err != nil {
|
if err := rotater.rotate(); err != nil {
|
||||||
log.Printf("failed to rotate keys: %v", err)
|
log.Printf("failed to rotate keys: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-time.After(time.Second * 30):
|
case <-time.After(strategy.period):
|
||||||
if err := rotater.rotate(); err != nil {
|
if err := rotater.rotate(); err != nil {
|
||||||
log.Printf("failed to rotate keys: %v", err)
|
log.Printf("failed to rotate keys: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
return rotater
|
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
func (k keyRotater) Close() error {
|
return
|
||||||
k.cancel()
|
|
||||||
return k.Storage.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k keyRotater) rotate() error {
|
func (k keyRotater) rotate() error {
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
@ -48,6 +49,8 @@ type Config struct {
|
||||||
RotateKeysAfter time.Duration // Defaults to 6 hours.
|
RotateKeysAfter time.Duration // Defaults to 6 hours.
|
||||||
IDTokensValidFor time.Duration // Defaults to 24 hours
|
IDTokensValidFor time.Duration // Defaults to 24 hours
|
||||||
|
|
||||||
|
GCFrequency time.Duration // Defaults to 5 minutes
|
||||||
|
|
||||||
// If specified, the server will use this function for determining time.
|
// If specified, the server will use this function for determining time.
|
||||||
Now func() time.Time
|
Now func() time.Time
|
||||||
|
|
||||||
|
@ -87,14 +90,14 @@ type Server struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer constructs a server from the provided config.
|
// NewServer constructs a server from the provided config.
|
||||||
func NewServer(c Config) (*Server, error) {
|
func NewServer(ctx context.Context, c Config) (*Server, error) {
|
||||||
return newServer(c, defaultRotationStrategy(
|
return newServer(ctx, c, defaultRotationStrategy(
|
||||||
value(c.RotateKeysAfter, 6*time.Hour),
|
value(c.RotateKeysAfter, 6*time.Hour),
|
||||||
value(c.IDTokensValidFor, 24*time.Hour),
|
value(c.IDTokensValidFor, 24*time.Hour),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
||||||
issuerURL, err := url.Parse(c.Issuer)
|
issuerURL, err := url.Parse(c.Issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("server: can't parse issuer URL")
|
return nil, fmt.Errorf("server: can't parse issuer URL")
|
||||||
|
@ -140,12 +143,7 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
||||||
s := &Server{
|
s := &Server{
|
||||||
issuerURL: *issuerURL,
|
issuerURL: *issuerURL,
|
||||||
connectors: make(map[string]Connector),
|
connectors: make(map[string]Connector),
|
||||||
storage: newKeyCacher(
|
storage: newKeyCacher(c.Storage, now),
|
||||||
storageWithKeyRotation(
|
|
||||||
c.Storage, rotationStrategy, now,
|
|
||||||
),
|
|
||||||
now,
|
|
||||||
),
|
|
||||||
supportedResponseTypes: supported,
|
supportedResponseTypes: supported,
|
||||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||||
skipApproval: c.SkipApprovalScreen,
|
skipApproval: c.SkipApprovalScreen,
|
||||||
|
@ -179,6 +177,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
||||||
handleFunc("/healthz", s.handleHealth)
|
handleFunc("/healthz", s.handleHealth)
|
||||||
s.mux = r
|
s.mux = r
|
||||||
|
|
||||||
|
startKeyRotation(ctx, c.Storage, rotationStrategy, now)
|
||||||
|
startGarbageCollection(ctx, c.Storage, value(c.GCFrequency, 5*time.Minute), now)
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -262,3 +263,21 @@ func (k *keyCacher) GetKeys() (storage.Keys, error) {
|
||||||
}
|
}
|
||||||
return storageKeys, nil
|
return storageKeys, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func startGarbageCollection(ctx context.Context, s storage.Storage, frequency time.Duration, now func() time.Time) {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(frequency):
|
||||||
|
if r, err := s.GarbageCollect(now()); err != nil {
|
||||||
|
log.Printf("garbage collection failed: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Printf("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -69,7 +69,7 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ
|
||||||
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
|
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
|
||||||
-----END RSA PRIVATE KEY-----`)
|
-----END RSA PRIVATE KEY-----`)
|
||||||
|
|
||||||
func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
|
func newTestServer(t *testing.T, ctx context.Context, updateConfig func(c *Config)) (*httptest.Server, *Server) {
|
||||||
var server *Server
|
var server *Server
|
||||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
server.ServeHTTP(w, r)
|
server.ServeHTTP(w, r)
|
||||||
|
@ -91,7 +91,7 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server
|
||||||
s.URL = config.Issuer
|
s.URL = config.Issuer
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil {
|
if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
|
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
|
||||||
|
@ -99,14 +99,16 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewTestServer(t *testing.T) {
|
func TestNewTestServer(t *testing.T) {
|
||||||
newTestServer(t, nil)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
newTestServer(t, ctx, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDiscovery(t *testing.T) {
|
func TestDiscovery(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
httpServer, _ := newTestServer(t, func(c *Config) {
|
httpServer, _ := newTestServer(t, ctx, func(c *Config) {
|
||||||
c.Issuer = c.Issuer + "/non-root-path"
|
c.Issuer = c.Issuer + "/non-root-path"
|
||||||
})
|
})
|
||||||
defer httpServer.Close()
|
defer httpServer.Close()
|
||||||
|
@ -255,7 +257,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
httpServer, s := newTestServer(t, func(c *Config) {
|
httpServer, s := newTestServer(t, ctx, func(c *Config) {
|
||||||
c.Issuer = c.Issuer + "/non-root-path"
|
c.Issuer = c.Issuer + "/non-root-path"
|
||||||
})
|
})
|
||||||
defer httpServer.Close()
|
defer httpServer.Close()
|
||||||
|
@ -368,7 +370,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
httpServer, s := newTestServer(t, func(c *Config) {
|
httpServer, s := newTestServer(t, ctx, func(c *Config) {
|
||||||
// Enable support for the implicit flow.
|
// Enable support for the implicit flow.
|
||||||
c.SupportedResponseTypes = []string{"code", "token"}
|
c.SupportedResponseTypes = []string{"code", "token"}
|
||||||
})
|
})
|
||||||
|
@ -498,7 +500,7 @@ func TestCrossClientScopes(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
httpServer, s := newTestServer(t, func(c *Config) {
|
httpServer, s := newTestServer(t, ctx, func(c *Config) {
|
||||||
c.Issuer = c.Issuer + "/non-root-path"
|
c.Issuer = c.Issuer + "/non-root-path"
|
||||||
})
|
})
|
||||||
defer httpServer.Close()
|
defer httpServer.Close()
|
||||||
|
|
|
@ -18,12 +18,10 @@ import (
|
||||||
// ensure that values being tested on never expire.
|
// ensure that values being tested on never expire.
|
||||||
var neverExpire = time.Now().UTC().Add(time.Hour * 24 * 365 * 100)
|
var neverExpire = time.Now().UTC().Add(time.Hour * 24 * 365 * 100)
|
||||||
|
|
||||||
// StorageFactory is a method for creating a new storage. The returned storage sould be initialized
|
// RunTests runs a set of conformance tests against a storage. newStorage should
|
||||||
// but shouldn't have any existing data in it.
|
// return an initialized but empty storage. The storage will be closed at the
|
||||||
type StorageFactory func() storage.Storage
|
// end of each test run.
|
||||||
|
func RunTests(t *testing.T, newStorage func() storage.Storage) {
|
||||||
// RunTestSuite runs a set of conformance tests against a storage.
|
|
||||||
func RunTestSuite(t *testing.T, sf StorageFactory) {
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
run func(t *testing.T, s storage.Storage)
|
run func(t *testing.T, s storage.Storage)
|
||||||
|
@ -33,10 +31,13 @@ func RunTestSuite(t *testing.T, sf StorageFactory) {
|
||||||
{"ClientCRUD", testClientCRUD},
|
{"ClientCRUD", testClientCRUD},
|
||||||
{"RefreshTokenCRUD", testRefreshTokenCRUD},
|
{"RefreshTokenCRUD", testRefreshTokenCRUD},
|
||||||
{"PasswordCRUD", testPasswordCRUD},
|
{"PasswordCRUD", testPasswordCRUD},
|
||||||
|
{"GarbageCollection", testGC},
|
||||||
}
|
}
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
test.run(t, sf())
|
s := newStorage()
|
||||||
|
test.run(t, s)
|
||||||
|
s.Close()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -276,3 +277,92 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
|
||||||
t.Errorf("after deleting password expected storage.ErrNotFound, got %v", err)
|
t.Errorf("after deleting password expected storage.ErrNotFound, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testGC(t *testing.T, s storage.Storage) {
|
||||||
|
n := time.Now()
|
||||||
|
c := storage.AuthCode{
|
||||||
|
ID: storage.NewID(),
|
||||||
|
ClientID: "foobar",
|
||||||
|
RedirectURI: "https://localhost:80/callback",
|
||||||
|
Nonce: "foobar",
|
||||||
|
Scopes: []string{"openid", "email"},
|
||||||
|
Expiry: n.Add(time.Second),
|
||||||
|
ConnectorID: "ldap",
|
||||||
|
ConnectorData: []byte(`{"some":"data"}`),
|
||||||
|
Claims: storage.Claims{
|
||||||
|
UserID: "1",
|
||||||
|
Username: "jane",
|
||||||
|
Email: "jane.doe@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
Groups: []string{"a", "b"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.CreateAuthCode(c); err != nil {
|
||||||
|
t.Fatalf("failed creating auth code: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := s.GarbageCollect(n); err != nil {
|
||||||
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := s.GetAuthCode(c.ID); err != nil {
|
||||||
|
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
|
||||||
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
|
} else if r.AuthCodes != 1 {
|
||||||
|
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := s.GetAuthCode(c.ID); err == nil {
|
||||||
|
t.Errorf("expected auth code to be GC'd")
|
||||||
|
} else if err != storage.ErrNotFound {
|
||||||
|
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
a := storage.AuthRequest{
|
||||||
|
ID: storage.NewID(),
|
||||||
|
ClientID: "foobar",
|
||||||
|
ResponseTypes: []string{"code"},
|
||||||
|
Scopes: []string{"openid", "email"},
|
||||||
|
RedirectURI: "https://localhost:80/callback",
|
||||||
|
Nonce: "foo",
|
||||||
|
State: "bar",
|
||||||
|
ForceApprovalPrompt: true,
|
||||||
|
LoggedIn: true,
|
||||||
|
Expiry: n,
|
||||||
|
ConnectorID: "ldap",
|
||||||
|
ConnectorData: []byte(`{"some":"data"}`),
|
||||||
|
Claims: storage.Claims{
|
||||||
|
UserID: "1",
|
||||||
|
Username: "jane",
|
||||||
|
Email: "jane.doe@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
Groups: []string{"a", "b"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.CreateAuthRequest(a); err != nil {
|
||||||
|
t.Fatalf("failed creating auth request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := s.GarbageCollect(n); err != nil {
|
||||||
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := s.GetAuthRequest(a.ID); err != nil {
|
||||||
|
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
|
||||||
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
|
} else if r.AuthRequests != 1 {
|
||||||
|
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := s.GetAuthRequest(a.ID); err == nil {
|
||||||
|
t.Errorf("expected auth code to be GC'd")
|
||||||
|
} else if err != storage.ErrNotFound {
|
||||||
|
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gtank/cryptopasta"
|
"github.com/gtank/cryptopasta"
|
||||||
"golang.org/x/net/context"
|
|
||||||
yaml "gopkg.in/yaml.v2"
|
yaml "gopkg.in/yaml.v2"
|
||||||
|
|
||||||
"github.com/coreos/dex/storage"
|
"github.com/coreos/dex/storage"
|
||||||
|
@ -35,9 +34,6 @@ type client struct {
|
||||||
|
|
||||||
now func() time.Time
|
now func() time.Time
|
||||||
|
|
||||||
// If not nil, the cancel function for stopping garbage colletion.
|
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
// BUG: currently each third party API group can only have one resource in it,
|
// BUG: currently each third party API group can only have one resource in it,
|
||||||
// so for each resource this storage uses, it need a unique API group.
|
// so for each resource this storage uses, it need a unique API group.
|
||||||
//
|
//
|
||||||
|
|
|
@ -1,58 +0,0 @@
|
||||||
package kubernetes
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
|
||||||
)
|
|
||||||
|
|
||||||
// gc begins the gc process for Kubernetes.
|
|
||||||
func (cli *client) gc(ctx context.Context, every time.Duration) {
|
|
||||||
handleErr := func(err error) { log.Println(err.Error()) }
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-time.After(every):
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(ericchiang): On failures, run garbage collection more often.
|
|
||||||
log.Println("kubernetes: running garbage collection")
|
|
||||||
cli.gcAuthRequests(handleErr)
|
|
||||||
cli.gcAuthCodes(handleErr)
|
|
||||||
log.Printf("kubernetes: garbage collection finished, next run at %s", cli.now().Add(every))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cli *client) gcAuthRequests(handleErr func(error)) {
|
|
||||||
var authRequests AuthRequestList
|
|
||||||
if err := cli.list(resourceAuthRequest, &authRequests); err != nil {
|
|
||||||
handleErr(fmt.Errorf("failed to list auth requests: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, authRequest := range authRequests.AuthRequests {
|
|
||||||
if cli.now().After(authRequest.Expiry) {
|
|
||||||
if err := cli.delete(resourceAuthRequest, authRequest.ObjectMeta.Name); err != nil {
|
|
||||||
handleErr(fmt.Errorf("failed to detele auth request: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cli *client) gcAuthCodes(handleErr func(error)) {
|
|
||||||
var authCodes AuthCodeList
|
|
||||||
if err := cli.list(resourceAuthCode, &authCodes); err != nil {
|
|
||||||
handleErr(fmt.Errorf("failed to list auth codes: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, authCode := range authCodes.AuthCodes {
|
|
||||||
if cli.now().After(authCode.Expiry) {
|
|
||||||
if err := cli.delete(resourceAuthCode, authCode.ObjectMeta.Name); err != nil {
|
|
||||||
handleErr(fmt.Errorf("failed to delete auth code: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,88 +0,0 @@
|
||||||
package kubernetes
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
func muster(t *testing.T) func(err error) {
|
|
||||||
return func(err error) {
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGCAuthRequests(t *testing.T) {
|
|
||||||
cli := loadClient(t)
|
|
||||||
must := muster(t)
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
cli.now = func() time.Time { return now }
|
|
||||||
|
|
||||||
expiredID := storage.NewID()
|
|
||||||
goodID := storage.NewID()
|
|
||||||
|
|
||||||
must(cli.CreateAuthRequest(storage.AuthRequest{
|
|
||||||
ID: expiredID,
|
|
||||||
Expiry: now.Add(-time.Second),
|
|
||||||
}))
|
|
||||||
|
|
||||||
must(cli.CreateAuthRequest(storage.AuthRequest{
|
|
||||||
ID: goodID,
|
|
||||||
Expiry: now.Add(time.Second),
|
|
||||||
}))
|
|
||||||
|
|
||||||
handleErr := func(err error) { t.Error(err.Error()) }
|
|
||||||
cli.gcAuthRequests(handleErr)
|
|
||||||
|
|
||||||
if _, err := cli.GetAuthRequest(goodID); err != nil {
|
|
||||||
t.Errorf("failed to get good auth ID: %v", err)
|
|
||||||
}
|
|
||||||
_, err := cli.GetAuthRequest(expiredID)
|
|
||||||
switch {
|
|
||||||
case err == nil:
|
|
||||||
t.Errorf("gc did not remove expired auth request")
|
|
||||||
case err == storage.ErrNotFound:
|
|
||||||
default:
|
|
||||||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGCAuthCodes(t *testing.T) {
|
|
||||||
cli := loadClient(t)
|
|
||||||
must := muster(t)
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
cli.now = func() time.Time { return now }
|
|
||||||
|
|
||||||
expiredID := storage.NewID()
|
|
||||||
goodID := storage.NewID()
|
|
||||||
|
|
||||||
must(cli.CreateAuthCode(storage.AuthCode{
|
|
||||||
ID: expiredID,
|
|
||||||
Expiry: now.Add(-time.Second),
|
|
||||||
}))
|
|
||||||
|
|
||||||
must(cli.CreateAuthCode(storage.AuthCode{
|
|
||||||
ID: goodID,
|
|
||||||
Expiry: now.Add(time.Second),
|
|
||||||
}))
|
|
||||||
|
|
||||||
handleErr := func(err error) { t.Error(err.Error()) }
|
|
||||||
cli.gcAuthCodes(handleErr)
|
|
||||||
|
|
||||||
if _, err := cli.GetAuthCode(goodID); err != nil {
|
|
||||||
t.Errorf("failed to get good auth ID: %v", err)
|
|
||||||
}
|
|
||||||
_, err := cli.GetAuthCode(expiredID)
|
|
||||||
switch {
|
|
||||||
case err == nil:
|
|
||||||
t.Errorf("gc did not remove expired auth request")
|
|
||||||
case err == storage.ErrNotFound:
|
|
||||||
default:
|
|
||||||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -3,12 +3,12 @@ package kubernetes
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
homedir "github.com/mitchellh/go-homedir"
|
homedir "github.com/mitchellh/go-homedir"
|
||||||
"golang.org/x/net/context"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/storage"
|
"github.com/coreos/dex/storage"
|
||||||
"github.com/coreos/dex/storage/kubernetes/k8sapi"
|
"github.com/coreos/dex/storage/kubernetes/k8sapi"
|
||||||
|
@ -46,14 +46,6 @@ func (c *Config) Open() (storage.Storage, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// start up garbage collection
|
|
||||||
gcFrequency := c.GCFrequency
|
|
||||||
if gcFrequency == 0 {
|
|
||||||
gcFrequency = 600
|
|
||||||
}
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
cli.cancel = cancel
|
|
||||||
go cli.gc(ctx, time.Duration(gcFrequency)*time.Second)
|
|
||||||
return cli, nil
|
return cli, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -93,9 +85,6 @@ func (c *Config) open() (*client, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cli *client) Close() error {
|
func (cli *client) Close() error {
|
||||||
if cli.cancel != nil {
|
|
||||||
cli.cancel()
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -291,3 +280,40 @@ func (cli *client) UpdateAuthRequest(id string, updater func(a storage.AuthReque
|
||||||
newReq.ObjectMeta = req.ObjectMeta
|
newReq.ObjectMeta = req.ObjectMeta
|
||||||
return cli.put(resourceAuthRequest, id, newReq)
|
return cli.put(resourceAuthRequest, id, newReq)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
|
||||||
|
var authRequests AuthRequestList
|
||||||
|
if err := cli.list(resourceAuthRequest, &authRequests); err != nil {
|
||||||
|
return result, fmt.Errorf("failed to list auth requests: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var delErr error
|
||||||
|
for _, authRequest := range authRequests.AuthRequests {
|
||||||
|
if now.After(authRequest.Expiry) {
|
||||||
|
if err := cli.delete(resourceAuthRequest, authRequest.ObjectMeta.Name); err != nil {
|
||||||
|
log.Printf("failed to delete auth request: %v", err)
|
||||||
|
delErr = fmt.Errorf("failed to delete auth request: %v", err)
|
||||||
|
}
|
||||||
|
result.AuthRequests++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if delErr != nil {
|
||||||
|
return result, delErr
|
||||||
|
}
|
||||||
|
|
||||||
|
var authCodes AuthCodeList
|
||||||
|
if err := cli.list(resourceAuthCode, &authCodes); err != nil {
|
||||||
|
return result, fmt.Errorf("failed to list auth codes: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, authCode := range authCodes.AuthCodes {
|
||||||
|
if now.After(authCode.Expiry) {
|
||||||
|
if err := cli.delete(resourceAuthCode, authCode.ObjectMeta.Name); err != nil {
|
||||||
|
log.Printf("failed to delete auth code %v", err)
|
||||||
|
delErr = fmt.Errorf("failed to delete auth code: %v", err)
|
||||||
|
}
|
||||||
|
result.AuthCodes++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, delErr
|
||||||
|
}
|
||||||
|
|
|
@ -74,7 +74,7 @@ func TestURLFor(t *testing.T) {
|
||||||
|
|
||||||
func TestStorage(t *testing.T) {
|
func TestStorage(t *testing.T) {
|
||||||
client := loadClient(t)
|
client := loadClient(t)
|
||||||
conformance.RunTestSuite(t, func() storage.Storage {
|
conformance.RunTests(t, func() storage.Storage {
|
||||||
for _, resource := range []string{
|
for _, resource := range []string{
|
||||||
resourceAuthCode,
|
resourceAuthCode,
|
||||||
resourceAuthRequest,
|
resourceAuthRequest,
|
||||||
|
|
|
@ -4,6 +4,7 @@ package memory
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/dex/storage"
|
"github.com/coreos/dex/storage"
|
||||||
)
|
)
|
||||||
|
@ -51,6 +52,24 @@ func (s *memStorage) tx(f func()) {
|
||||||
|
|
||||||
func (s *memStorage) Close() error { return nil }
|
func (s *memStorage) Close() error { return nil }
|
||||||
|
|
||||||
|
func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
|
||||||
|
s.tx(func() {
|
||||||
|
for id, a := range s.authCodes {
|
||||||
|
if now.After(a.Expiry) {
|
||||||
|
delete(s.authCodes, id)
|
||||||
|
result.AuthCodes++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for id, a := range s.authReqs {
|
||||||
|
if now.After(a.Expiry) {
|
||||||
|
delete(s.authReqs, id)
|
||||||
|
result.AuthRequests++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *memStorage) CreateClient(c storage.Client) (err error) {
|
func (s *memStorage) CreateClient(c storage.Client) (err error) {
|
||||||
s.tx(func() {
|
s.tx(func() {
|
||||||
if _, ok := s.clients[c.ID]; ok {
|
if _, ok := s.clients[c.ID]; ok {
|
||||||
|
@ -240,29 +259,6 @@ func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *memStorage) ClaimCode(id string) (err error) {
|
|
||||||
s.tx(func() {
|
|
||||||
if _, ok := s.authCodes[id]; !ok {
|
|
||||||
err = storage.ErrNotFound
|
|
||||||
return
|
|
||||||
}
|
|
||||||
delete(s.authCodes, id)
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *memStorage) ClaimRefresh(refreshToken string) (token storage.RefreshToken, err error) {
|
|
||||||
s.tx(func() {
|
|
||||||
var ok bool
|
|
||||||
if token, ok = s.refreshTokens[refreshToken]; !ok {
|
|
||||||
err = storage.ErrNotFound
|
|
||||||
return
|
|
||||||
}
|
|
||||||
delete(s.refreshTokens, refreshToken)
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *memStorage) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) (err error) {
|
func (s *memStorage) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) (err error) {
|
||||||
s.tx(func() {
|
s.tx(func() {
|
||||||
client, ok := s.clients[id]
|
client, ok := s.clients[id]
|
||||||
|
|
|
@ -7,5 +7,5 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestStorage(t *testing.T) {
|
func TestStorage(t *testing.T) {
|
||||||
conformance.RunTestSuite(t, New)
|
conformance.RunTests(t, New)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/storage"
|
"github.com/coreos/dex/storage"
|
||||||
)
|
)
|
||||||
|
@ -22,7 +21,7 @@ func (s *SQLite3) Open() (storage.Storage, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return withGC(conn, time.Now), nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLite3) open() (*conn, error) {
|
func (s *SQLite3) open() (*conn, error) {
|
||||||
|
@ -76,7 +75,7 @@ func (p *Postgres) Open() (storage.Storage, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return withGC(conn, time.Now), nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Postgres) open() (*conn, error) {
|
func (p *Postgres) open() (*conn, error) {
|
||||||
|
|
|
@ -54,7 +54,7 @@ func TestSQLite3(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
withTimeout(time.Second*10, func() {
|
withTimeout(time.Second*10, func() {
|
||||||
conformance.RunTestSuite(t, newStorage)
|
conformance.RunTests(t, newStorage)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,19 +72,24 @@ func TestPostgres(t *testing.T) {
|
||||||
},
|
},
|
||||||
ConnectionTimeout: 5,
|
ConnectionTimeout: 5,
|
||||||
}
|
}
|
||||||
conn, err := p.open()
|
|
||||||
if err != nil {
|
// t.Fatal has a bad habbit of not actually printing the error
|
||||||
t.Fatal(err)
|
fatal := func(i interface{}) {
|
||||||
|
fmt.Fprintln(os.Stdout, i)
|
||||||
|
t.Fatal(i)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
newStorage := func() storage.Storage {
|
newStorage := func() storage.Storage {
|
||||||
|
conn, err := p.open()
|
||||||
|
if err != nil {
|
||||||
|
fatal(err)
|
||||||
|
}
|
||||||
if err := cleanDB(conn); err != nil {
|
if err := cleanDB(conn); err != nil {
|
||||||
t.Fatal(err)
|
fatal(err)
|
||||||
}
|
}
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
withTimeout(time.Minute*1, func() {
|
withTimeout(time.Minute*1, func() {
|
||||||
conformance.RunTestSuite(t, newStorage)
|
conformance.RunTests(t, newStorage)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/dex/storage"
|
"github.com/coreos/dex/storage"
|
||||||
)
|
)
|
||||||
|
@ -83,6 +84,25 @@ type scanner interface {
|
||||||
Scan(dest ...interface{}) error
|
Scan(dest ...interface{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
|
||||||
|
r, err := c.Exec(`delete from auth_request where expiry < $1`, now)
|
||||||
|
if err != nil {
|
||||||
|
return result, fmt.Errorf("gc auth_request: %v", err)
|
||||||
|
}
|
||||||
|
if n, err := r.RowsAffected(); err == nil {
|
||||||
|
result.AuthRequests = n
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err = c.Exec(`delete from auth_code where expiry < $1`, now)
|
||||||
|
if err != nil {
|
||||||
|
return result, fmt.Errorf("gc auth_code: %v", err)
|
||||||
|
}
|
||||||
|
if n, err := r.RowsAffected(); err == nil {
|
||||||
|
result.AuthCodes = n
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
||||||
_, err := c.Exec(`
|
_, err := c.Exec(`
|
||||||
insert into auth_request (
|
insert into auth_request (
|
||||||
|
|
|
@ -1,53 +0,0 @@
|
||||||
package sql
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
type gc struct {
|
|
||||||
now func() time.Time
|
|
||||||
conn *conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (gc gc) run() error {
|
|
||||||
for _, table := range []string{"auth_request", "auth_code"} {
|
|
||||||
_, err := gc.conn.Exec(`delete from `+table+` where expiry < $1`, gc.now())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("gc %s: %v", table, err)
|
|
||||||
}
|
|
||||||
// TODO(ericchiang): when we have levelled logging print how many rows were gc'd
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type withCancel struct {
|
|
||||||
storage.Storage
|
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w withCancel) Close() error {
|
|
||||||
w.cancel()
|
|
||||||
return w.Storage.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func withGC(conn *conn, now func() time.Time) storage.Storage {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
run := (gc{now, conn}).run
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-time.After(time.Second * 30):
|
|
||||||
if err := run(); err != nil {
|
|
||||||
log.Printf("gc failed: %v", err)
|
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return withCancel{conn, cancel}
|
|
||||||
}
|
|
|
@ -1,53 +0,0 @@
|
||||||
package sql
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coreos/dex/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGC(t *testing.T) {
|
|
||||||
// TODO(ericchiang): Add a GarbageCollect method to the storage interface so
|
|
||||||
// we can write conformance tests instead of directly testing each implementation.
|
|
||||||
s := &SQLite3{":memory:"}
|
|
||||||
conn, err := s.open()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
clock := time.Now()
|
|
||||||
now := func() time.Time { return clock }
|
|
||||||
|
|
||||||
runGC := (gc{now, conn}).run
|
|
||||||
|
|
||||||
a := storage.AuthRequest{
|
|
||||||
ID: storage.NewID(),
|
|
||||||
Expiry: now().Add(time.Second),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := conn.CreateAuthRequest(a); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := runGC(); err != nil {
|
|
||||||
t.Errorf("gc failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := conn.GetAuthRequest(a.ID); err != nil {
|
|
||||||
t.Errorf("failed to get auth request after gc: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
clock = clock.Add(time.Minute)
|
|
||||||
|
|
||||||
if err := runGC(); err != nil {
|
|
||||||
t.Errorf("gc failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := conn.GetAuthRequest(a.ID); err == nil {
|
|
||||||
t.Errorf("expected error after gc'ing auth request: %v", err)
|
|
||||||
} else if err != storage.ErrNotFound {
|
|
||||||
t.Errorf("expected error storage.NotFound got: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -38,6 +38,12 @@ func NewID() string {
|
||||||
return strings.TrimRight(encoding.EncodeToString(buff), "=")
|
return strings.TrimRight(encoding.EncodeToString(buff), "=")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GCResult returns the number of objects deleted by garbage collection.
|
||||||
|
type GCResult struct {
|
||||||
|
AuthRequests int64
|
||||||
|
AuthCodes int64
|
||||||
|
}
|
||||||
|
|
||||||
// Storage is the storage interface used by the server. Implementations, at minimum
|
// Storage is the storage interface used by the server. Implementations, at minimum
|
||||||
// require compare-and-swap atomic actions.
|
// require compare-and-swap atomic actions.
|
||||||
//
|
//
|
||||||
|
@ -80,8 +86,8 @@ type Storage interface {
|
||||||
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
|
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
|
||||||
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
||||||
|
|
||||||
// TODO(ericchiang): Add a GarbageCollect(now time.Time) method so conformance tests
|
// GarbageCollect deletes all expired AuthCodes and AuthRequests.
|
||||||
// can test implementations.
|
GarbageCollect(now time.Time) (GCResult, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client represents an OAuth2 client.
|
// Client represents an OAuth2 client.
|
||||||
|
|
Loading…
Reference in a new issue