forked from mystiq/dex
server: cache signing keys
This commit is contained in:
parent
d313e5d493
commit
4cbe9bbc82
2 changed files with 112 additions and 3 deletions
|
@ -6,6 +6,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
"path"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
@ -93,9 +94,14 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
issuerURL: *issuerURL,
|
issuerURL: *issuerURL,
|
||||||
connectors: make(map[string]Connector),
|
connectors: make(map[string]Connector),
|
||||||
storage: storageWithKeyRotation(c.Storage, rotationStrategy, now),
|
storage: newKeyCacher(
|
||||||
|
storageWithKeyRotation(
|
||||||
|
c.Storage, rotationStrategy, now,
|
||||||
|
),
|
||||||
|
now,
|
||||||
|
),
|
||||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||||
now: now,
|
now: now,
|
||||||
}
|
}
|
||||||
|
@ -139,3 +145,35 @@ func (s *Server) absURL(pathItems ...string) string {
|
||||||
u.Path = s.absPath(pathItems...)
|
u.Path = s.absPath(pathItems...)
|
||||||
return u.String()
|
return u.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// newKeyCacher returns a storage which caches keys so long as the next
|
||||||
|
func newKeyCacher(s storage.Storage, now func() time.Time) storage.Storage {
|
||||||
|
if now == nil {
|
||||||
|
now = time.Now
|
||||||
|
}
|
||||||
|
return &keyCacher{Storage: s, now: now}
|
||||||
|
}
|
||||||
|
|
||||||
|
type keyCacher struct {
|
||||||
|
storage.Storage
|
||||||
|
|
||||||
|
now func() time.Time
|
||||||
|
keys atomic.Value // Always holds nil or type *storage.Keys.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *keyCacher) GetKeys() (storage.Keys, error) {
|
||||||
|
keys, ok := k.keys.Load().(*storage.Keys)
|
||||||
|
if ok && keys != nil && k.now().Before(keys.NextRotation) {
|
||||||
|
return *keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
storageKeys, err := k.Storage.GetKeys()
|
||||||
|
if err != nil {
|
||||||
|
return storageKeys, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if k.now().Before(storageKeys.NextRotation) {
|
||||||
|
k.keys.Store(&storageKeys)
|
||||||
|
}
|
||||||
|
return storageKeys, nil
|
||||||
|
}
|
||||||
|
|
|
@ -219,3 +219,74 @@ func TestOAuth2Flow(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type storageWithKeysTrigger struct {
|
||||||
|
storage.Storage
|
||||||
|
f func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s storageWithKeysTrigger) GetKeys() (storage.Keys, error) {
|
||||||
|
s.f()
|
||||||
|
return s.Storage.GetKeys()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyCacher(t *testing.T) {
|
||||||
|
tNow := time.Now()
|
||||||
|
now := func() time.Time { return tNow }
|
||||||
|
|
||||||
|
s := memory.New()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
before func()
|
||||||
|
wantCallToStorage bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
before: func() {},
|
||||||
|
wantCallToStorage: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
before: func() {
|
||||||
|
s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
|
||||||
|
old.NextRotation = tNow.Add(time.Minute)
|
||||||
|
return old, nil
|
||||||
|
})
|
||||||
|
},
|
||||||
|
wantCallToStorage: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
before: func() {},
|
||||||
|
wantCallToStorage: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
before: func() {
|
||||||
|
tNow = tNow.Add(time.Hour)
|
||||||
|
},
|
||||||
|
wantCallToStorage: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
before: func() {
|
||||||
|
tNow = tNow.Add(time.Hour)
|
||||||
|
s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
|
||||||
|
old.NextRotation = tNow.Add(time.Minute)
|
||||||
|
return old, nil
|
||||||
|
})
|
||||||
|
},
|
||||||
|
wantCallToStorage: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
before: func() {},
|
||||||
|
wantCallToStorage: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gotCall := false
|
||||||
|
s = newKeyCacher(storageWithKeysTrigger{s, func() { gotCall = true }}, now)
|
||||||
|
for i, tc := range tests {
|
||||||
|
gotCall = false
|
||||||
|
tc.before()
|
||||||
|
s.GetKeys()
|
||||||
|
if gotCall != tc.wantCallToStorage {
|
||||||
|
t.Errorf("case %d: expected call to storage=%t got call to storage=%t", i, tc.wantCallToStorage, gotCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue