{cmd,server}: move garbage collection logic to server

This commit is contained in:
Eric Chiang 2016-10-12 18:51:32 -07:00
parent 3e20a080fe
commit 4296604f11
5 changed files with 64 additions and 43 deletions

View file

@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/net/context"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
@ -124,7 +125,7 @@ func serve(cmd *cobra.Command, args []string) error {
EnablePasswordDB: c.EnablePasswordDB, EnablePasswordDB: c.EnablePasswordDB,
} }
serv, err := server.NewServer(serverConfig) serv, err := server.NewServer(context.Background(), serverConfig)
if err != nil { if err != nil {
return fmt.Errorf("initializing server: %v", err) return fmt.Errorf("initializing server: %v", err)
} }

View file

@ -4,10 +4,15 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
"golang.org/x/net/context"
) )
func TestHandleHealth(t *testing.T) { func TestHandleHealth(t *testing.T) {
httpServer, server := newTestServer(t, nil) ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, server := newTestServer(t, ctx, nil)
defer httpServer.Close() defer httpServer.Close()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()

View file

@ -56,40 +56,34 @@ type keyRotater struct {
storage.Storage storage.Storage
strategy rotationStrategy strategy rotationStrategy
cancel context.CancelFunc now func() time.Time
now func() time.Time
} }
func storageWithKeyRotation(s storage.Storage, strategy rotationStrategy, now func() time.Time) storage.Storage { // startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled.
if now == nil { //
now = time.Now // The method blocks until after the first attempt to rotate keys has completed. That way
} // healthy storages will return from this call with valid keys.
ctx, cancel := context.WithCancel(context.Background()) func startKeyRotation(ctx context.Context, s storage.Storage, strategy rotationStrategy, now func() time.Time) {
rotater := keyRotater{s, strategy, cancel, now} rotater := keyRotater{s, strategy, now}
// Try to rotate immediately so properly configured storages will return a // Try to rotate immediately so properly configured storages will have keys.
// storage with keys.
if err := rotater.rotate(); err != nil { if err := rotater.rotate(); err != nil {
log.Printf("failed to rotate keys: %v", err) log.Printf("failed to rotate keys: %v", err)
} }
go func() { go func() {
select { for {
case <-ctx.Done(): select {
return case <-ctx.Done():
case <-time.After(time.Second * 30): return
if err := rotater.rotate(); err != nil { case <-time.After(strategy.period):
log.Printf("failed to rotate keys: %v", err) if err := rotater.rotate(); err != nil {
log.Printf("failed to rotate keys: %v", err)
}
} }
} }
}() }()
return rotater return
}
func (k keyRotater) Close() error {
k.cancel()
return k.Storage.Close()
} }
func (k keyRotater) rotate() error { func (k keyRotater) rotate() error {

View file

@ -11,6 +11,7 @@ import (
"time" "time"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.org/x/net/context"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -48,6 +49,8 @@ type Config struct {
RotateKeysAfter time.Duration // Defaults to 6 hours. RotateKeysAfter time.Duration // Defaults to 6 hours.
IDTokensValidFor time.Duration // Defaults to 24 hours IDTokensValidFor time.Duration // Defaults to 24 hours
GCFrequency time.Duration // Defaults to 5 minutes
// If specified, the server will use this function for determining time. // If specified, the server will use this function for determining time.
Now func() time.Time Now func() time.Time
@ -87,14 +90,14 @@ type Server struct {
} }
// NewServer constructs a server from the provided config. // NewServer constructs a server from the provided config.
func NewServer(c Config) (*Server, error) { func NewServer(ctx context.Context, c Config) (*Server, error) {
return newServer(c, defaultRotationStrategy( return newServer(ctx, c, defaultRotationStrategy(
value(c.RotateKeysAfter, 6*time.Hour), value(c.RotateKeysAfter, 6*time.Hour),
value(c.IDTokensValidFor, 24*time.Hour), value(c.IDTokensValidFor, 24*time.Hour),
)) ))
} }
func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) (*Server, error) {
issuerURL, err := url.Parse(c.Issuer) issuerURL, err := url.Parse(c.Issuer)
if err != nil { if err != nil {
return nil, fmt.Errorf("server: can't parse issuer URL") return nil, fmt.Errorf("server: can't parse issuer URL")
@ -138,14 +141,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
} }
s := &Server{ s := &Server{
issuerURL: *issuerURL, issuerURL: *issuerURL,
connectors: make(map[string]Connector), connectors: make(map[string]Connector),
storage: newKeyCacher( storage: newKeyCacher(c.Storage, now),
storageWithKeyRotation(
c.Storage, rotationStrategy, now,
),
now,
),
supportedResponseTypes: supported, supportedResponseTypes: supported,
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
skipApproval: c.SkipApprovalScreen, skipApproval: c.SkipApprovalScreen,
@ -179,6 +177,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
handleFunc("/healthz", s.handleHealth) handleFunc("/healthz", s.handleHealth)
s.mux = r s.mux = r
startKeyRotation(ctx, c.Storage, rotationStrategy, now)
startGarbageCollection(ctx, c.Storage, value(c.GCFrequency, 5*time.Minute), now)
return s, nil return s, nil
} }
@ -262,3 +263,21 @@ func (k *keyCacher) GetKeys() (storage.Keys, error) {
} }
return storageKeys, nil return storageKeys, nil
} }
func startGarbageCollection(ctx context.Context, s storage.Storage, frequency time.Duration, now func() time.Time) {
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(frequency):
if r, err := s.GarbageCollect(now()); err != nil {
log.Printf("garbage collection failed: %v", err)
} else {
log.Printf("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes)
}
}
}
}()
return
}

View file

@ -69,7 +69,7 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo= Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
-----END RSA PRIVATE KEY-----`) -----END RSA PRIVATE KEY-----`)
func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) { func newTestServer(t *testing.T, ctx context.Context, updateConfig func(c *Config)) (*httptest.Server, *Server) {
var server *Server var server *Server
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.ServeHTTP(w, r) server.ServeHTTP(w, r)
@ -91,7 +91,7 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server
s.URL = config.Issuer s.URL = config.Issuer
var err error var err error
if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil { if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
@ -99,14 +99,16 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server
} }
func TestNewTestServer(t *testing.T) { func TestNewTestServer(t *testing.T) {
newTestServer(t, nil) ctx, cancel := context.WithCancel(context.Background())
defer cancel()
newTestServer(t, ctx, nil)
} }
func TestDiscovery(t *testing.T) { func TestDiscovery(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
httpServer, _ := newTestServer(t, func(c *Config) { httpServer, _ := newTestServer(t, ctx, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path" c.Issuer = c.Issuer + "/non-root-path"
}) })
defer httpServer.Close() defer httpServer.Close()
@ -227,7 +229,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
httpServer, s := newTestServer(t, func(c *Config) { httpServer, s := newTestServer(t, ctx, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path" c.Issuer = c.Issuer + "/non-root-path"
}) })
defer httpServer.Close() defer httpServer.Close()
@ -340,7 +342,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
httpServer, s := newTestServer(t, func(c *Config) { httpServer, s := newTestServer(t, ctx, func(c *Config) {
// Enable support for the implicit flow. // Enable support for the implicit flow.
c.SupportedResponseTypes = []string{"code", "token"} c.SupportedResponseTypes = []string{"code", "token"}
}) })
@ -470,7 +472,7 @@ func TestCrossClientScopes(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
httpServer, s := newTestServer(t, func(c *Config) { httpServer, s := newTestServer(t, ctx, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path" c.Issuer = c.Issuer + "/non-root-path"
}) })
defer httpServer.Close() defer httpServer.Close()