forked from mystiq/dex
{cmd,server}: move garbage collection logic to server
This commit is contained in:
parent
3e20a080fe
commit
4296604f11
5 changed files with 64 additions and 43 deletions
|
@ -9,6 +9,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/net/context"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
yaml "gopkg.in/yaml.v2"
|
yaml "gopkg.in/yaml.v2"
|
||||||
|
@ -124,7 +125,7 @@ func serve(cmd *cobra.Command, args []string) error {
|
||||||
EnablePasswordDB: c.EnablePasswordDB,
|
EnablePasswordDB: c.EnablePasswordDB,
|
||||||
}
|
}
|
||||||
|
|
||||||
serv, err := server.NewServer(serverConfig)
|
serv, err := server.NewServer(context.Background(), serverConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("initializing server: %v", err)
|
return fmt.Errorf("initializing server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,10 +4,15 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHandleHealth(t *testing.T) {
|
func TestHandleHealth(t *testing.T) {
|
||||||
httpServer, server := newTestServer(t, nil)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
httpServer, server := newTestServer(t, ctx, nil)
|
||||||
defer httpServer.Close()
|
defer httpServer.Close()
|
||||||
|
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
|
@ -56,40 +56,34 @@ type keyRotater struct {
|
||||||
storage.Storage
|
storage.Storage
|
||||||
|
|
||||||
strategy rotationStrategy
|
strategy rotationStrategy
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
now func() time.Time
|
now func() time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func storageWithKeyRotation(s storage.Storage, strategy rotationStrategy, now func() time.Time) storage.Storage {
|
// startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled.
|
||||||
if now == nil {
|
//
|
||||||
now = time.Now
|
// The method blocks until after the first attempt to rotate keys has completed. That way
|
||||||
}
|
// healthy storages will return from this call with valid keys.
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
func startKeyRotation(ctx context.Context, s storage.Storage, strategy rotationStrategy, now func() time.Time) {
|
||||||
rotater := keyRotater{s, strategy, cancel, now}
|
rotater := keyRotater{s, strategy, now}
|
||||||
|
|
||||||
// Try to rotate immediately so properly configured storages will return a
|
// Try to rotate immediately so properly configured storages will have keys.
|
||||||
// storage with keys.
|
|
||||||
if err := rotater.rotate(); err != nil {
|
if err := rotater.rotate(); err != nil {
|
||||||
log.Printf("failed to rotate keys: %v", err)
|
log.Printf("failed to rotate keys: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-time.After(time.Second * 30):
|
case <-time.After(strategy.period):
|
||||||
if err := rotater.rotate(); err != nil {
|
if err := rotater.rotate(); err != nil {
|
||||||
log.Printf("failed to rotate keys: %v", err)
|
log.Printf("failed to rotate keys: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
return rotater
|
return
|
||||||
}
|
|
||||||
|
|
||||||
func (k keyRotater) Close() error {
|
|
||||||
k.cancel()
|
|
||||||
return k.Storage.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k keyRotater) rotate() error {
|
func (k keyRotater) rotate() error {
|
||||||
|
|
|
@ -11,6 +11,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
@ -48,6 +49,8 @@ type Config struct {
|
||||||
RotateKeysAfter time.Duration // Defaults to 6 hours.
|
RotateKeysAfter time.Duration // Defaults to 6 hours.
|
||||||
IDTokensValidFor time.Duration // Defaults to 24 hours
|
IDTokensValidFor time.Duration // Defaults to 24 hours
|
||||||
|
|
||||||
|
GCFrequency time.Duration // Defaults to 5 minutes
|
||||||
|
|
||||||
// If specified, the server will use this function for determining time.
|
// If specified, the server will use this function for determining time.
|
||||||
Now func() time.Time
|
Now func() time.Time
|
||||||
|
|
||||||
|
@ -87,14 +90,14 @@ type Server struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer constructs a server from the provided config.
|
// NewServer constructs a server from the provided config.
|
||||||
func NewServer(c Config) (*Server, error) {
|
func NewServer(ctx context.Context, c Config) (*Server, error) {
|
||||||
return newServer(c, defaultRotationStrategy(
|
return newServer(ctx, c, defaultRotationStrategy(
|
||||||
value(c.RotateKeysAfter, 6*time.Hour),
|
value(c.RotateKeysAfter, 6*time.Hour),
|
||||||
value(c.IDTokensValidFor, 24*time.Hour),
|
value(c.IDTokensValidFor, 24*time.Hour),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
||||||
issuerURL, err := url.Parse(c.Issuer)
|
issuerURL, err := url.Parse(c.Issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("server: can't parse issuer URL")
|
return nil, fmt.Errorf("server: can't parse issuer URL")
|
||||||
|
@ -140,12 +143,7 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
||||||
s := &Server{
|
s := &Server{
|
||||||
issuerURL: *issuerURL,
|
issuerURL: *issuerURL,
|
||||||
connectors: make(map[string]Connector),
|
connectors: make(map[string]Connector),
|
||||||
storage: newKeyCacher(
|
storage: newKeyCacher(c.Storage, now),
|
||||||
storageWithKeyRotation(
|
|
||||||
c.Storage, rotationStrategy, now,
|
|
||||||
),
|
|
||||||
now,
|
|
||||||
),
|
|
||||||
supportedResponseTypes: supported,
|
supportedResponseTypes: supported,
|
||||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||||
skipApproval: c.SkipApprovalScreen,
|
skipApproval: c.SkipApprovalScreen,
|
||||||
|
@ -179,6 +177,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
||||||
handleFunc("/healthz", s.handleHealth)
|
handleFunc("/healthz", s.handleHealth)
|
||||||
s.mux = r
|
s.mux = r
|
||||||
|
|
||||||
|
startKeyRotation(ctx, c.Storage, rotationStrategy, now)
|
||||||
|
startGarbageCollection(ctx, c.Storage, value(c.GCFrequency, 5*time.Minute), now)
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -262,3 +263,21 @@ func (k *keyCacher) GetKeys() (storage.Keys, error) {
|
||||||
}
|
}
|
||||||
return storageKeys, nil
|
return storageKeys, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func startGarbageCollection(ctx context.Context, s storage.Storage, frequency time.Duration, now func() time.Time) {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(frequency):
|
||||||
|
if r, err := s.GarbageCollect(now()); err != nil {
|
||||||
|
log.Printf("garbage collection failed: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Printf("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -69,7 +69,7 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ
|
||||||
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
|
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
|
||||||
-----END RSA PRIVATE KEY-----`)
|
-----END RSA PRIVATE KEY-----`)
|
||||||
|
|
||||||
func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
|
func newTestServer(t *testing.T, ctx context.Context, updateConfig func(c *Config)) (*httptest.Server, *Server) {
|
||||||
var server *Server
|
var server *Server
|
||||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
server.ServeHTTP(w, r)
|
server.ServeHTTP(w, r)
|
||||||
|
@ -91,7 +91,7 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server
|
||||||
s.URL = config.Issuer
|
s.URL = config.Issuer
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil {
|
if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
|
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
|
||||||
|
@ -99,14 +99,16 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewTestServer(t *testing.T) {
|
func TestNewTestServer(t *testing.T) {
|
||||||
newTestServer(t, nil)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
newTestServer(t, ctx, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDiscovery(t *testing.T) {
|
func TestDiscovery(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
httpServer, _ := newTestServer(t, func(c *Config) {
|
httpServer, _ := newTestServer(t, ctx, func(c *Config) {
|
||||||
c.Issuer = c.Issuer + "/non-root-path"
|
c.Issuer = c.Issuer + "/non-root-path"
|
||||||
})
|
})
|
||||||
defer httpServer.Close()
|
defer httpServer.Close()
|
||||||
|
@ -227,7 +229,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
httpServer, s := newTestServer(t, func(c *Config) {
|
httpServer, s := newTestServer(t, ctx, func(c *Config) {
|
||||||
c.Issuer = c.Issuer + "/non-root-path"
|
c.Issuer = c.Issuer + "/non-root-path"
|
||||||
})
|
})
|
||||||
defer httpServer.Close()
|
defer httpServer.Close()
|
||||||
|
@ -340,7 +342,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
httpServer, s := newTestServer(t, func(c *Config) {
|
httpServer, s := newTestServer(t, ctx, func(c *Config) {
|
||||||
// Enable support for the implicit flow.
|
// Enable support for the implicit flow.
|
||||||
c.SupportedResponseTypes = []string{"code", "token"}
|
c.SupportedResponseTypes = []string{"code", "token"}
|
||||||
})
|
})
|
||||||
|
@ -470,7 +472,7 @@ func TestCrossClientScopes(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
httpServer, s := newTestServer(t, func(c *Config) {
|
httpServer, s := newTestServer(t, ctx, func(c *Config) {
|
||||||
c.Issuer = c.Issuer + "/non-root-path"
|
c.Issuer = c.Issuer + "/non-root-path"
|
||||||
})
|
})
|
||||||
defer httpServer.Close()
|
defer httpServer.Close()
|
||||||
|
|
Loading…
Reference in a new issue