{cmd,server}: move garbage collection logic to server

This commit is contained in:
Eric Chiang 2016-10-12 18:51:32 -07:00
parent 3e20a080fe
commit 4296604f11
5 changed files with 64 additions and 43 deletions

View file

@ -9,6 +9,7 @@ import (
"net/http"
"github.com/spf13/cobra"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
yaml "gopkg.in/yaml.v2"
@ -124,7 +125,7 @@ func serve(cmd *cobra.Command, args []string) error {
EnablePasswordDB: c.EnablePasswordDB,
}
serv, err := server.NewServer(serverConfig)
serv, err := server.NewServer(context.Background(), serverConfig)
if err != nil {
return fmt.Errorf("initializing server: %v", err)
}

View file

@ -4,10 +4,15 @@ import (
"net/http"
"net/http/httptest"
"testing"
"golang.org/x/net/context"
)
func TestHandleHealth(t *testing.T) {
httpServer, server := newTestServer(t, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, server := newTestServer(t, ctx, nil)
defer httpServer.Close()
rr := httptest.NewRecorder()

View file

@ -56,40 +56,34 @@ type keyRotater struct {
storage.Storage
strategy rotationStrategy
cancel context.CancelFunc
now func() time.Time
}
func storageWithKeyRotation(s storage.Storage, strategy rotationStrategy, now func() time.Time) storage.Storage {
if now == nil {
now = time.Now
}
ctx, cancel := context.WithCancel(context.Background())
rotater := keyRotater{s, strategy, cancel, now}
// startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled.
//
// The method blocks until after the first attempt to rotate keys has completed. That way
// healthy storages will return from this call with valid keys.
func startKeyRotation(ctx context.Context, s storage.Storage, strategy rotationStrategy, now func() time.Time) {
rotater := keyRotater{s, strategy, now}
// Try to rotate immediately so properly configured storages will return a
// storage with keys.
// Try to rotate immediately so properly configured storages will have keys.
if err := rotater.rotate(); err != nil {
log.Printf("failed to rotate keys: %v", err)
}
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(time.Second * 30):
case <-time.After(strategy.period):
if err := rotater.rotate(); err != nil {
log.Printf("failed to rotate keys: %v", err)
}
}
}
}()
return rotater
}
func (k keyRotater) Close() error {
k.cancel()
return k.Storage.Close()
return
}
func (k keyRotater) rotate() error {

View file

@ -11,6 +11,7 @@ import (
"time"
"golang.org/x/crypto/bcrypt"
"golang.org/x/net/context"
"github.com/gorilla/mux"
@ -48,6 +49,8 @@ type Config struct {
RotateKeysAfter time.Duration // Defaults to 6 hours.
IDTokensValidFor time.Duration // Defaults to 24 hours
GCFrequency time.Duration // Defaults to 5 minutes
// If specified, the server will use this function for determining time.
Now func() time.Time
@ -87,14 +90,14 @@ type Server struct {
}
// NewServer constructs a server from the provided config.
func NewServer(c Config) (*Server, error) {
return newServer(c, defaultRotationStrategy(
func NewServer(ctx context.Context, c Config) (*Server, error) {
return newServer(ctx, c, defaultRotationStrategy(
value(c.RotateKeysAfter, 6*time.Hour),
value(c.IDTokensValidFor, 24*time.Hour),
))
}
func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) (*Server, error) {
issuerURL, err := url.Parse(c.Issuer)
if err != nil {
return nil, fmt.Errorf("server: can't parse issuer URL")
@ -140,12 +143,7 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
s := &Server{
issuerURL: *issuerURL,
connectors: make(map[string]Connector),
storage: newKeyCacher(
storageWithKeyRotation(
c.Storage, rotationStrategy, now,
),
now,
),
storage: newKeyCacher(c.Storage, now),
supportedResponseTypes: supported,
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
skipApproval: c.SkipApprovalScreen,
@ -179,6 +177,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
handleFunc("/healthz", s.handleHealth)
s.mux = r
startKeyRotation(ctx, c.Storage, rotationStrategy, now)
startGarbageCollection(ctx, c.Storage, value(c.GCFrequency, 5*time.Minute), now)
return s, nil
}
@ -262,3 +263,21 @@ func (k *keyCacher) GetKeys() (storage.Keys, error) {
}
return storageKeys, nil
}
func startGarbageCollection(ctx context.Context, s storage.Storage, frequency time.Duration, now func() time.Time) {
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(frequency):
if r, err := s.GarbageCollect(now()); err != nil {
log.Printf("garbage collection failed: %v", err)
} else {
log.Printf("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes)
}
}
}
}()
return
}

View file

@ -69,7 +69,7 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
-----END RSA PRIVATE KEY-----`)
func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
func newTestServer(t *testing.T, ctx context.Context, updateConfig func(c *Config)) (*httptest.Server, *Server) {
var server *Server
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.ServeHTTP(w, r)
@ -91,7 +91,7 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server
s.URL = config.Issuer
var err error
if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil {
if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil {
t.Fatal(err)
}
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
@ -99,14 +99,16 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server
}
func TestNewTestServer(t *testing.T) {
newTestServer(t, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
newTestServer(t, ctx, nil)
}
func TestDiscovery(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, _ := newTestServer(t, func(c *Config) {
httpServer, _ := newTestServer(t, ctx, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path"
})
defer httpServer.Close()
@ -227,7 +229,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, s := newTestServer(t, func(c *Config) {
httpServer, s := newTestServer(t, ctx, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path"
})
defer httpServer.Close()
@ -340,7 +342,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, s := newTestServer(t, func(c *Config) {
httpServer, s := newTestServer(t, ctx, func(c *Config) {
// Enable support for the implicit flow.
c.SupportedResponseTypes = []string{"code", "token"}
})
@ -470,7 +472,7 @@ func TestCrossClientScopes(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, s := newTestServer(t, func(c *Config) {
httpServer, s := newTestServer(t, ctx, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path"
})
defer httpServer.Close()