forked from mystiq/dex
Merge pull request #602 from ericchiang/dev-add-garbage-collect-method-to-storage
dev branch: add garbage collect method to storage
This commit is contained in:
commit
5bec61d73f
20 changed files with 265 additions and 357 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
|
@ -124,7 +125,7 @@ func serve(cmd *cobra.Command, args []string) error {
|
|||
EnablePasswordDB: c.EnablePasswordDB,
|
||||
}
|
||||
|
||||
serv, err := server.NewServer(serverConfig)
|
||||
serv, err := server.NewServer(context.Background(), serverConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initializing server: %v", err)
|
||||
}
|
||||
|
|
|
@ -143,6 +143,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
|
|||
s.renderError(w, http.StatusInternalServerError, err.Type, err.Description)
|
||||
return
|
||||
}
|
||||
authReq.Expiry = s.now().Add(time.Minute * 30)
|
||||
if err := s.storage.CreateAuthRequest(authReq); err != nil {
|
||||
log.Printf("Failed to create authorization request: %v", err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
|
@ -342,7 +343,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) {
|
||||
if authReq.Expiry.After(s.now()) {
|
||||
if s.now().After(authReq.Expiry) {
|
||||
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request period has expired.")
|
||||
return
|
||||
}
|
||||
|
@ -373,7 +374,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
|
|||
Nonce: authReq.Nonce,
|
||||
Scopes: authReq.Scopes,
|
||||
Claims: authReq.Claims,
|
||||
Expiry: s.now().Add(time.Minute * 5),
|
||||
Expiry: s.now().Add(time.Minute * 30),
|
||||
RedirectURI: authReq.RedirectURI,
|
||||
}
|
||||
if err := s.storage.CreateAuthCode(code); err != nil {
|
||||
|
|
|
@ -4,10 +4,15 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestHandleHealth(t *testing.T) {
|
||||
httpServer, server := newTestServer(t, nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
httpServer, server := newTestServer(t, ctx, nil)
|
||||
defer httpServer.Close()
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
|
|
@ -56,40 +56,34 @@ type keyRotater struct {
|
|||
storage.Storage
|
||||
|
||||
strategy rotationStrategy
|
||||
cancel context.CancelFunc
|
||||
|
||||
now func() time.Time
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
func storageWithKeyRotation(s storage.Storage, strategy rotationStrategy, now func() time.Time) storage.Storage {
|
||||
if now == nil {
|
||||
now = time.Now
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
rotater := keyRotater{s, strategy, cancel, now}
|
||||
// startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled.
|
||||
//
|
||||
// The method blocks until after the first attempt to rotate keys has completed. That way
|
||||
// healthy storages will return from this call with valid keys.
|
||||
func startKeyRotation(ctx context.Context, s storage.Storage, strategy rotationStrategy, now func() time.Time) {
|
||||
rotater := keyRotater{s, strategy, now}
|
||||
|
||||
// Try to rotate immediately so properly configured storages will return a
|
||||
// storage with keys.
|
||||
// Try to rotate immediately so properly configured storages will have keys.
|
||||
if err := rotater.rotate(); err != nil {
|
||||
log.Printf("failed to rotate keys: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(time.Second * 30):
|
||||
if err := rotater.rotate(); err != nil {
|
||||
log.Printf("failed to rotate keys: %v", err)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(strategy.period):
|
||||
if err := rotater.rotate(); err != nil {
|
||||
log.Printf("failed to rotate keys: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return rotater
|
||||
}
|
||||
|
||||
func (k keyRotater) Close() error {
|
||||
k.cancel()
|
||||
return k.Storage.Close()
|
||||
return
|
||||
}
|
||||
|
||||
func (k keyRotater) rotate() error {
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
|
@ -48,6 +49,8 @@ type Config struct {
|
|||
RotateKeysAfter time.Duration // Defaults to 6 hours.
|
||||
IDTokensValidFor time.Duration // Defaults to 24 hours
|
||||
|
||||
GCFrequency time.Duration // Defaults to 5 minutes
|
||||
|
||||
// If specified, the server will use this function for determining time.
|
||||
Now func() time.Time
|
||||
|
||||
|
@ -87,14 +90,14 @@ type Server struct {
|
|||
}
|
||||
|
||||
// NewServer constructs a server from the provided config.
|
||||
func NewServer(c Config) (*Server, error) {
|
||||
return newServer(c, defaultRotationStrategy(
|
||||
func NewServer(ctx context.Context, c Config) (*Server, error) {
|
||||
return newServer(ctx, c, defaultRotationStrategy(
|
||||
value(c.RotateKeysAfter, 6*time.Hour),
|
||||
value(c.IDTokensValidFor, 24*time.Hour),
|
||||
))
|
||||
}
|
||||
|
||||
func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
||||
func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
||||
issuerURL, err := url.Parse(c.Issuer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("server: can't parse issuer URL")
|
||||
|
@ -138,14 +141,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
|||
}
|
||||
|
||||
s := &Server{
|
||||
issuerURL: *issuerURL,
|
||||
connectors: make(map[string]Connector),
|
||||
storage: newKeyCacher(
|
||||
storageWithKeyRotation(
|
||||
c.Storage, rotationStrategy, now,
|
||||
),
|
||||
now,
|
||||
),
|
||||
issuerURL: *issuerURL,
|
||||
connectors: make(map[string]Connector),
|
||||
storage: newKeyCacher(c.Storage, now),
|
||||
supportedResponseTypes: supported,
|
||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||
skipApproval: c.SkipApprovalScreen,
|
||||
|
@ -179,6 +177,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
|||
handleFunc("/healthz", s.handleHealth)
|
||||
s.mux = r
|
||||
|
||||
startKeyRotation(ctx, c.Storage, rotationStrategy, now)
|
||||
startGarbageCollection(ctx, c.Storage, value(c.GCFrequency, 5*time.Minute), now)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
|
@ -262,3 +263,21 @@ func (k *keyCacher) GetKeys() (storage.Keys, error) {
|
|||
}
|
||||
return storageKeys, nil
|
||||
}
|
||||
|
||||
func startGarbageCollection(ctx context.Context, s storage.Storage, frequency time.Duration, now func() time.Time) {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(frequency):
|
||||
if r, err := s.GarbageCollect(now()); err != nil {
|
||||
log.Printf("garbage collection failed: %v", err)
|
||||
} else {
|
||||
log.Printf("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -69,7 +69,7 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ
|
|||
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
|
||||
func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
|
||||
func newTestServer(t *testing.T, ctx context.Context, updateConfig func(c *Config)) (*httptest.Server, *Server) {
|
||||
var server *Server
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server.ServeHTTP(w, r)
|
||||
|
@ -91,7 +91,7 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server
|
|||
s.URL = config.Issuer
|
||||
|
||||
var err error
|
||||
if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil {
|
||||
if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
|
||||
|
@ -99,14 +99,16 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server
|
|||
}
|
||||
|
||||
func TestNewTestServer(t *testing.T) {
|
||||
newTestServer(t, nil)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
newTestServer(t, ctx, nil)
|
||||
}
|
||||
|
||||
func TestDiscovery(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
httpServer, _ := newTestServer(t, func(c *Config) {
|
||||
httpServer, _ := newTestServer(t, ctx, func(c *Config) {
|
||||
c.Issuer = c.Issuer + "/non-root-path"
|
||||
})
|
||||
defer httpServer.Close()
|
||||
|
@ -255,7 +257,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
httpServer, s := newTestServer(t, func(c *Config) {
|
||||
httpServer, s := newTestServer(t, ctx, func(c *Config) {
|
||||
c.Issuer = c.Issuer + "/non-root-path"
|
||||
})
|
||||
defer httpServer.Close()
|
||||
|
@ -368,7 +370,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
httpServer, s := newTestServer(t, func(c *Config) {
|
||||
httpServer, s := newTestServer(t, ctx, func(c *Config) {
|
||||
// Enable support for the implicit flow.
|
||||
c.SupportedResponseTypes = []string{"code", "token"}
|
||||
})
|
||||
|
@ -498,7 +500,7 @@ func TestCrossClientScopes(t *testing.T) {
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
httpServer, s := newTestServer(t, func(c *Config) {
|
||||
httpServer, s := newTestServer(t, ctx, func(c *Config) {
|
||||
c.Issuer = c.Issuer + "/non-root-path"
|
||||
})
|
||||
defer httpServer.Close()
|
||||
|
|
|
@ -18,12 +18,10 @@ import (
|
|||
// ensure that values being tested on never expire.
|
||||
var neverExpire = time.Now().UTC().Add(time.Hour * 24 * 365 * 100)
|
||||
|
||||
// StorageFactory is a method for creating a new storage. The returned storage sould be initialized
|
||||
// but shouldn't have any existing data in it.
|
||||
type StorageFactory func() storage.Storage
|
||||
|
||||
// RunTestSuite runs a set of conformance tests against a storage.
|
||||
func RunTestSuite(t *testing.T, sf StorageFactory) {
|
||||
// RunTests runs a set of conformance tests against a storage. newStorage should
|
||||
// return an initialized but empty storage. The storage will be closed at the
|
||||
// end of each test run.
|
||||
func RunTests(t *testing.T, newStorage func() storage.Storage) {
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(t *testing.T, s storage.Storage)
|
||||
|
@ -33,10 +31,13 @@ func RunTestSuite(t *testing.T, sf StorageFactory) {
|
|||
{"ClientCRUD", testClientCRUD},
|
||||
{"RefreshTokenCRUD", testRefreshTokenCRUD},
|
||||
{"PasswordCRUD", testPasswordCRUD},
|
||||
{"GarbageCollection", testGC},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
test.run(t, sf())
|
||||
s := newStorage()
|
||||
test.run(t, s)
|
||||
s.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -276,3 +277,92 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
|
|||
t.Errorf("after deleting password expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testGC(t *testing.T, s storage.Storage) {
|
||||
n := time.Now()
|
||||
c := storage.AuthCode{
|
||||
ID: storage.NewID(),
|
||||
ClientID: "foobar",
|
||||
RedirectURI: "https://localhost:80/callback",
|
||||
Nonce: "foobar",
|
||||
Scopes: []string{"openid", "email"},
|
||||
Expiry: n.Add(time.Second),
|
||||
ConnectorID: "ldap",
|
||||
ConnectorData: []byte(`{"some":"data"}`),
|
||||
Claims: storage.Claims{
|
||||
UserID: "1",
|
||||
Username: "jane",
|
||||
Email: "jane.doe@example.com",
|
||||
EmailVerified: true,
|
||||
Groups: []string{"a", "b"},
|
||||
},
|
||||
}
|
||||
|
||||
if err := s.CreateAuthCode(c); err != nil {
|
||||
t.Fatalf("failed creating auth code: %v", err)
|
||||
}
|
||||
|
||||
if _, err := s.GarbageCollect(n); err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
}
|
||||
if _, err := s.GetAuthCode(c.ID); err != nil {
|
||||
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
||||
}
|
||||
|
||||
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
} else if r.AuthCodes != 1 {
|
||||
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
|
||||
}
|
||||
|
||||
if _, err := s.GetAuthCode(c.ID); err == nil {
|
||||
t.Errorf("expected auth code to be GC'd")
|
||||
} else if err != storage.ErrNotFound {
|
||||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
|
||||
a := storage.AuthRequest{
|
||||
ID: storage.NewID(),
|
||||
ClientID: "foobar",
|
||||
ResponseTypes: []string{"code"},
|
||||
Scopes: []string{"openid", "email"},
|
||||
RedirectURI: "https://localhost:80/callback",
|
||||
Nonce: "foo",
|
||||
State: "bar",
|
||||
ForceApprovalPrompt: true,
|
||||
LoggedIn: true,
|
||||
Expiry: n,
|
||||
ConnectorID: "ldap",
|
||||
ConnectorData: []byte(`{"some":"data"}`),
|
||||
Claims: storage.Claims{
|
||||
UserID: "1",
|
||||
Username: "jane",
|
||||
Email: "jane.doe@example.com",
|
||||
EmailVerified: true,
|
||||
Groups: []string{"a", "b"},
|
||||
},
|
||||
}
|
||||
|
||||
if err := s.CreateAuthRequest(a); err != nil {
|
||||
t.Fatalf("failed creating auth request: %v", err)
|
||||
}
|
||||
|
||||
if _, err := s.GarbageCollect(n); err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
}
|
||||
if _, err := s.GetAuthRequest(a.ID); err != nil {
|
||||
t.Errorf("expected to be able to get auth code after GC: %v", err)
|
||||
}
|
||||
|
||||
if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
} else if r.AuthRequests != 1 {
|
||||
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
|
||||
}
|
||||
|
||||
if _, err := s.GetAuthRequest(a.ID); err == nil {
|
||||
t.Errorf("expected auth code to be GC'd")
|
||||
} else if err != storage.ErrNotFound {
|
||||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/gtank/cryptopasta"
|
||||
"golang.org/x/net/context"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
|
@ -35,9 +34,6 @@ type client struct {
|
|||
|
||||
now func() time.Time
|
||||
|
||||
// If not nil, the cancel function for stopping garbage colletion.
|
||||
cancel context.CancelFunc
|
||||
|
||||
// BUG: currently each third party API group can only have one resource in it,
|
||||
// so for each resource this storage uses, it need a unique API group.
|
||||
//
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
package kubernetes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// gc begins the gc process for Kubernetes.
|
||||
func (cli *client) gc(ctx context.Context, every time.Duration) {
|
||||
handleErr := func(err error) { log.Println(err.Error()) }
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(every):
|
||||
}
|
||||
|
||||
// TODO(ericchiang): On failures, run garbage collection more often.
|
||||
log.Println("kubernetes: running garbage collection")
|
||||
cli.gcAuthRequests(handleErr)
|
||||
cli.gcAuthCodes(handleErr)
|
||||
log.Printf("kubernetes: garbage collection finished, next run at %s", cli.now().Add(every))
|
||||
}
|
||||
}
|
||||
|
||||
func (cli *client) gcAuthRequests(handleErr func(error)) {
|
||||
var authRequests AuthRequestList
|
||||
if err := cli.list(resourceAuthRequest, &authRequests); err != nil {
|
||||
handleErr(fmt.Errorf("failed to list auth requests: %v", err))
|
||||
return
|
||||
}
|
||||
for _, authRequest := range authRequests.AuthRequests {
|
||||
if cli.now().After(authRequest.Expiry) {
|
||||
if err := cli.delete(resourceAuthRequest, authRequest.ObjectMeta.Name); err != nil {
|
||||
handleErr(fmt.Errorf("failed to detele auth request: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cli *client) gcAuthCodes(handleErr func(error)) {
|
||||
var authCodes AuthCodeList
|
||||
if err := cli.list(resourceAuthCode, &authCodes); err != nil {
|
||||
handleErr(fmt.Errorf("failed to list auth codes: %v", err))
|
||||
return
|
||||
}
|
||||
for _, authCode := range authCodes.AuthCodes {
|
||||
if cli.now().After(authCode.Expiry) {
|
||||
if err := cli.delete(resourceAuthCode, authCode.ObjectMeta.Name); err != nil {
|
||||
handleErr(fmt.Errorf("failed to delete auth code: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,88 +0,0 @@
|
|||
package kubernetes
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
)
|
||||
|
||||
func muster(t *testing.T) func(err error) {
|
||||
return func(err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGCAuthRequests(t *testing.T) {
|
||||
cli := loadClient(t)
|
||||
must := muster(t)
|
||||
|
||||
now := time.Now()
|
||||
cli.now = func() time.Time { return now }
|
||||
|
||||
expiredID := storage.NewID()
|
||||
goodID := storage.NewID()
|
||||
|
||||
must(cli.CreateAuthRequest(storage.AuthRequest{
|
||||
ID: expiredID,
|
||||
Expiry: now.Add(-time.Second),
|
||||
}))
|
||||
|
||||
must(cli.CreateAuthRequest(storage.AuthRequest{
|
||||
ID: goodID,
|
||||
Expiry: now.Add(time.Second),
|
||||
}))
|
||||
|
||||
handleErr := func(err error) { t.Error(err.Error()) }
|
||||
cli.gcAuthRequests(handleErr)
|
||||
|
||||
if _, err := cli.GetAuthRequest(goodID); err != nil {
|
||||
t.Errorf("failed to get good auth ID: %v", err)
|
||||
}
|
||||
_, err := cli.GetAuthRequest(expiredID)
|
||||
switch {
|
||||
case err == nil:
|
||||
t.Errorf("gc did not remove expired auth request")
|
||||
case err == storage.ErrNotFound:
|
||||
default:
|
||||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGCAuthCodes(t *testing.T) {
|
||||
cli := loadClient(t)
|
||||
must := muster(t)
|
||||
|
||||
now := time.Now()
|
||||
cli.now = func() time.Time { return now }
|
||||
|
||||
expiredID := storage.NewID()
|
||||
goodID := storage.NewID()
|
||||
|
||||
must(cli.CreateAuthCode(storage.AuthCode{
|
||||
ID: expiredID,
|
||||
Expiry: now.Add(-time.Second),
|
||||
}))
|
||||
|
||||
must(cli.CreateAuthCode(storage.AuthCode{
|
||||
ID: goodID,
|
||||
Expiry: now.Add(time.Second),
|
||||
}))
|
||||
|
||||
handleErr := func(err error) { t.Error(err.Error()) }
|
||||
cli.gcAuthCodes(handleErr)
|
||||
|
||||
if _, err := cli.GetAuthCode(goodID); err != nil {
|
||||
t.Errorf("failed to get good auth ID: %v", err)
|
||||
}
|
||||
_, err := cli.GetAuthCode(expiredID)
|
||||
switch {
|
||||
case err == nil:
|
||||
t.Errorf("gc did not remove expired auth request")
|
||||
case err == storage.ErrNotFound:
|
||||
default:
|
||||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
|
@ -3,12 +3,12 @@ package kubernetes
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
homedir "github.com/mitchellh/go-homedir"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
"github.com/coreos/dex/storage/kubernetes/k8sapi"
|
||||
|
@ -46,14 +46,6 @@ func (c *Config) Open() (storage.Storage, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// start up garbage collection
|
||||
gcFrequency := c.GCFrequency
|
||||
if gcFrequency == 0 {
|
||||
gcFrequency = 600
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cli.cancel = cancel
|
||||
go cli.gc(ctx, time.Duration(gcFrequency)*time.Second)
|
||||
return cli, nil
|
||||
}
|
||||
|
||||
|
@ -93,9 +85,6 @@ func (c *Config) open() (*client, error) {
|
|||
}
|
||||
|
||||
func (cli *client) Close() error {
|
||||
if cli.cancel != nil {
|
||||
cli.cancel()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -291,3 +280,40 @@ func (cli *client) UpdateAuthRequest(id string, updater func(a storage.AuthReque
|
|||
newReq.ObjectMeta = req.ObjectMeta
|
||||
return cli.put(resourceAuthRequest, id, newReq)
|
||||
}
|
||||
|
||||
func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
|
||||
var authRequests AuthRequestList
|
||||
if err := cli.list(resourceAuthRequest, &authRequests); err != nil {
|
||||
return result, fmt.Errorf("failed to list auth requests: %v", err)
|
||||
}
|
||||
|
||||
var delErr error
|
||||
for _, authRequest := range authRequests.AuthRequests {
|
||||
if now.After(authRequest.Expiry) {
|
||||
if err := cli.delete(resourceAuthRequest, authRequest.ObjectMeta.Name); err != nil {
|
||||
log.Printf("failed to delete auth request: %v", err)
|
||||
delErr = fmt.Errorf("failed to delete auth request: %v", err)
|
||||
}
|
||||
result.AuthRequests++
|
||||
}
|
||||
}
|
||||
if delErr != nil {
|
||||
return result, delErr
|
||||
}
|
||||
|
||||
var authCodes AuthCodeList
|
||||
if err := cli.list(resourceAuthCode, &authCodes); err != nil {
|
||||
return result, fmt.Errorf("failed to list auth codes: %v", err)
|
||||
}
|
||||
|
||||
for _, authCode := range authCodes.AuthCodes {
|
||||
if now.After(authCode.Expiry) {
|
||||
if err := cli.delete(resourceAuthCode, authCode.ObjectMeta.Name); err != nil {
|
||||
log.Printf("failed to delete auth code %v", err)
|
||||
delErr = fmt.Errorf("failed to delete auth code: %v", err)
|
||||
}
|
||||
result.AuthCodes++
|
||||
}
|
||||
}
|
||||
return result, delErr
|
||||
}
|
||||
|
|
|
@ -74,7 +74,7 @@ func TestURLFor(t *testing.T) {
|
|||
|
||||
func TestStorage(t *testing.T) {
|
||||
client := loadClient(t)
|
||||
conformance.RunTestSuite(t, func() storage.Storage {
|
||||
conformance.RunTests(t, func() storage.Storage {
|
||||
for _, resource := range []string{
|
||||
resourceAuthCode,
|
||||
resourceAuthRequest,
|
||||
|
|
|
@ -4,6 +4,7 @@ package memory
|
|||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
)
|
||||
|
@ -51,6 +52,24 @@ func (s *memStorage) tx(f func()) {
|
|||
|
||||
func (s *memStorage) Close() error { return nil }
|
||||
|
||||
func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
|
||||
s.tx(func() {
|
||||
for id, a := range s.authCodes {
|
||||
if now.After(a.Expiry) {
|
||||
delete(s.authCodes, id)
|
||||
result.AuthCodes++
|
||||
}
|
||||
}
|
||||
for id, a := range s.authReqs {
|
||||
if now.After(a.Expiry) {
|
||||
delete(s.authReqs, id)
|
||||
result.AuthRequests++
|
||||
}
|
||||
}
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *memStorage) CreateClient(c storage.Client) (err error) {
|
||||
s.tx(func() {
|
||||
if _, ok := s.clients[c.ID]; ok {
|
||||
|
@ -240,29 +259,6 @@ func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) ClaimCode(id string) (err error) {
|
||||
s.tx(func() {
|
||||
if _, ok := s.authCodes[id]; !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
delete(s.authCodes, id)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) ClaimRefresh(refreshToken string) (token storage.RefreshToken, err error) {
|
||||
s.tx(func() {
|
||||
var ok bool
|
||||
if token, ok = s.refreshTokens[refreshToken]; !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
delete(s.refreshTokens, refreshToken)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) (err error) {
|
||||
s.tx(func() {
|
||||
client, ok := s.clients[id]
|
||||
|
|
|
@ -7,5 +7,5 @@ import (
|
|||
)
|
||||
|
||||
func TestStorage(t *testing.T) {
|
||||
conformance.RunTestSuite(t, New)
|
||||
conformance.RunTests(t, New)
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
)
|
||||
|
@ -22,7 +21,7 @@ func (s *SQLite3) Open() (storage.Storage, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return withGC(conn, time.Now), nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (s *SQLite3) open() (*conn, error) {
|
||||
|
@ -76,7 +75,7 @@ func (p *Postgres) Open() (storage.Storage, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return withGC(conn, time.Now), nil
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (p *Postgres) open() (*conn, error) {
|
||||
|
|
|
@ -54,7 +54,7 @@ func TestSQLite3(t *testing.T) {
|
|||
}
|
||||
|
||||
withTimeout(time.Second*10, func() {
|
||||
conformance.RunTestSuite(t, newStorage)
|
||||
conformance.RunTests(t, newStorage)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -72,19 +72,24 @@ func TestPostgres(t *testing.T) {
|
|||
},
|
||||
ConnectionTimeout: 5,
|
||||
}
|
||||
conn, err := p.open()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
// t.Fatal has a bad habbit of not actually printing the error
|
||||
fatal := func(i interface{}) {
|
||||
fmt.Fprintln(os.Stdout, i)
|
||||
t.Fatal(i)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
newStorage := func() storage.Storage {
|
||||
conn, err := p.open()
|
||||
if err != nil {
|
||||
fatal(err)
|
||||
}
|
||||
if err := cleanDB(conn); err != nil {
|
||||
t.Fatal(err)
|
||||
fatal(err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
withTimeout(time.Minute*1, func() {
|
||||
conformance.RunTestSuite(t, newStorage)
|
||||
conformance.RunTests(t, newStorage)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
)
|
||||
|
@ -83,6 +84,25 @@ type scanner interface {
|
|||
Scan(dest ...interface{}) error
|
||||
}
|
||||
|
||||
func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
|
||||
r, err := c.Exec(`delete from auth_request where expiry < $1`, now)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("gc auth_request: %v", err)
|
||||
}
|
||||
if n, err := r.RowsAffected(); err == nil {
|
||||
result.AuthRequests = n
|
||||
}
|
||||
|
||||
r, err = c.Exec(`delete from auth_code where expiry < $1`, now)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("gc auth_code: %v", err)
|
||||
}
|
||||
if n, err := r.RowsAffected(); err == nil {
|
||||
result.AuthCodes = n
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
||||
_, err := c.Exec(`
|
||||
insert into auth_request (
|
||||
|
|
|
@ -1,53 +0,0 @@
|
|||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
)
|
||||
|
||||
type gc struct {
|
||||
now func() time.Time
|
||||
conn *conn
|
||||
}
|
||||
|
||||
func (gc gc) run() error {
|
||||
for _, table := range []string{"auth_request", "auth_code"} {
|
||||
_, err := gc.conn.Exec(`delete from `+table+` where expiry < $1`, gc.now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("gc %s: %v", table, err)
|
||||
}
|
||||
// TODO(ericchiang): when we have levelled logging print how many rows were gc'd
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type withCancel struct {
|
||||
storage.Storage
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func (w withCancel) Close() error {
|
||||
w.cancel()
|
||||
return w.Storage.Close()
|
||||
}
|
||||
|
||||
func withGC(conn *conn, now func() time.Time) storage.Storage {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
run := (gc{now, conn}).run
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Second * 30):
|
||||
if err := run(); err != nil {
|
||||
log.Printf("gc failed: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
}()
|
||||
return withCancel{conn, cancel}
|
||||
}
|
|
@ -1,53 +0,0 @@
|
|||
package sql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
)
|
||||
|
||||
func TestGC(t *testing.T) {
|
||||
// TODO(ericchiang): Add a GarbageCollect method to the storage interface so
|
||||
// we can write conformance tests instead of directly testing each implementation.
|
||||
s := &SQLite3{":memory:"}
|
||||
conn, err := s.open()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
clock := time.Now()
|
||||
now := func() time.Time { return clock }
|
||||
|
||||
runGC := (gc{now, conn}).run
|
||||
|
||||
a := storage.AuthRequest{
|
||||
ID: storage.NewID(),
|
||||
Expiry: now().Add(time.Second),
|
||||
}
|
||||
|
||||
if err := conn.CreateAuthRequest(a); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := runGC(); err != nil {
|
||||
t.Errorf("gc failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := conn.GetAuthRequest(a.ID); err != nil {
|
||||
t.Errorf("failed to get auth request after gc: %v", err)
|
||||
}
|
||||
|
||||
clock = clock.Add(time.Minute)
|
||||
|
||||
if err := runGC(); err != nil {
|
||||
t.Errorf("gc failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := conn.GetAuthRequest(a.ID); err == nil {
|
||||
t.Errorf("expected error after gc'ing auth request: %v", err)
|
||||
} else if err != storage.ErrNotFound {
|
||||
t.Errorf("expected error storage.NotFound got: %v", err)
|
||||
}
|
||||
}
|
|
@ -38,6 +38,12 @@ func NewID() string {
|
|||
return strings.TrimRight(encoding.EncodeToString(buff), "=")
|
||||
}
|
||||
|
||||
// GCResult returns the number of objects deleted by garbage collection.
|
||||
type GCResult struct {
|
||||
AuthRequests int64
|
||||
AuthCodes int64
|
||||
}
|
||||
|
||||
// Storage is the storage interface used by the server. Implementations, at minimum
|
||||
// require compare-and-swap atomic actions.
|
||||
//
|
||||
|
@ -80,8 +86,8 @@ type Storage interface {
|
|||
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
|
||||
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
||||
|
||||
// TODO(ericchiang): Add a GarbageCollect(now time.Time) method so conformance tests
|
||||
// can test implementations.
|
||||
// GarbageCollect deletes all expired AuthCodes and AuthRequests.
|
||||
GarbageCollect(now time.Time) (GCResult, error)
|
||||
}
|
||||
|
||||
// Client represents an OAuth2 client.
|
||||
|
|
Loading…
Reference in a new issue