server: cache signing keys
This commit is contained in:
parent
d313e5d493
commit
4cbe9bbc82
2 changed files with 112 additions and 3 deletions
|
@ -6,6 +6,7 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
@ -93,9 +94,14 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
|||
}
|
||||
|
||||
s := &Server{
|
||||
issuerURL: *issuerURL,
|
||||
connectors: make(map[string]Connector),
|
||||
storage: storageWithKeyRotation(c.Storage, rotationStrategy, now),
|
||||
issuerURL: *issuerURL,
|
||||
connectors: make(map[string]Connector),
|
||||
storage: newKeyCacher(
|
||||
storageWithKeyRotation(
|
||||
c.Storage, rotationStrategy, now,
|
||||
),
|
||||
now,
|
||||
),
|
||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||
now: now,
|
||||
}
|
||||
|
@ -139,3 +145,35 @@ func (s *Server) absURL(pathItems ...string) string {
|
|||
u.Path = s.absPath(pathItems...)
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// newKeyCacher returns a storage which caches keys so long as the next
|
||||
func newKeyCacher(s storage.Storage, now func() time.Time) storage.Storage {
|
||||
if now == nil {
|
||||
now = time.Now
|
||||
}
|
||||
return &keyCacher{Storage: s, now: now}
|
||||
}
|
||||
|
||||
type keyCacher struct {
|
||||
storage.Storage
|
||||
|
||||
now func() time.Time
|
||||
keys atomic.Value // Always holds nil or type *storage.Keys.
|
||||
}
|
||||
|
||||
func (k *keyCacher) GetKeys() (storage.Keys, error) {
|
||||
keys, ok := k.keys.Load().(*storage.Keys)
|
||||
if ok && keys != nil && k.now().Before(keys.NextRotation) {
|
||||
return *keys, nil
|
||||
}
|
||||
|
||||
storageKeys, err := k.Storage.GetKeys()
|
||||
if err != nil {
|
||||
return storageKeys, err
|
||||
}
|
||||
|
||||
if k.now().Before(storageKeys.NextRotation) {
|
||||
k.keys.Store(&storageKeys)
|
||||
}
|
||||
return storageKeys, nil
|
||||
}
|
||||
|
|
|
@ -219,3 +219,74 @@ func TestOAuth2Flow(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
type storageWithKeysTrigger struct {
|
||||
storage.Storage
|
||||
f func()
|
||||
}
|
||||
|
||||
func (s storageWithKeysTrigger) GetKeys() (storage.Keys, error) {
|
||||
s.f()
|
||||
return s.Storage.GetKeys()
|
||||
}
|
||||
|
||||
func TestKeyCacher(t *testing.T) {
|
||||
tNow := time.Now()
|
||||
now := func() time.Time { return tNow }
|
||||
|
||||
s := memory.New()
|
||||
|
||||
tests := []struct {
|
||||
before func()
|
||||
wantCallToStorage bool
|
||||
}{
|
||||
{
|
||||
before: func() {},
|
||||
wantCallToStorage: true,
|
||||
},
|
||||
{
|
||||
before: func() {
|
||||
s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
|
||||
old.NextRotation = tNow.Add(time.Minute)
|
||||
return old, nil
|
||||
})
|
||||
},
|
||||
wantCallToStorage: true,
|
||||
},
|
||||
{
|
||||
before: func() {},
|
||||
wantCallToStorage: false,
|
||||
},
|
||||
{
|
||||
before: func() {
|
||||
tNow = tNow.Add(time.Hour)
|
||||
},
|
||||
wantCallToStorage: true,
|
||||
},
|
||||
{
|
||||
before: func() {
|
||||
tNow = tNow.Add(time.Hour)
|
||||
s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
|
||||
old.NextRotation = tNow.Add(time.Minute)
|
||||
return old, nil
|
||||
})
|
||||
},
|
||||
wantCallToStorage: true,
|
||||
},
|
||||
{
|
||||
before: func() {},
|
||||
wantCallToStorage: false,
|
||||
},
|
||||
}
|
||||
|
||||
gotCall := false
|
||||
s = newKeyCacher(storageWithKeysTrigger{s, func() { gotCall = true }}, now)
|
||||
for i, tc := range tests {
|
||||
gotCall = false
|
||||
tc.before()
|
||||
s.GetKeys()
|
||||
if gotCall != tc.wantCallToStorage {
|
||||
t.Errorf("case %d: expected call to storage=%t got call to storage=%t", i, tc.wantCallToStorage, gotCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Reference in a new issue