From 5c6ddbb6dc09020bd683b1a3503ea55261401be4 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Wed, 1 Mar 2017 12:03:28 -0800 Subject: [PATCH] server: fix expiry detection for verification keys --- server/rotation.go | 29 ++++++++---- server/rotation_test.go | 100 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 10 deletions(-) diff --git a/server/rotation.go b/server/rotation.go index 8516e9bd..fb790c62 100644 --- a/server/rotation.go +++ b/server/rotation.go @@ -22,8 +22,9 @@ type rotationStrategy struct { // Time between rotations. rotationFrequency time.Duration - // After being rotated how long can a key validate signatues? - verifyFor time.Duration + // After being rotated how long should the key be kept around for validating + // signatues? + idTokenValidFor time.Duration // Keys are always RSA keys. Though cryptopasta recommends ECDSA keys, not every // client may support these (e.g. github.com/coreos/go-oidc/oidc). @@ -35,17 +36,17 @@ func staticRotationStrategy(key *rsa.PrivateKey) rotationStrategy { return rotationStrategy{ // Setting these values to 100 years is easier than having a flag indicating no rotation. rotationFrequency: time.Hour * 8760 * 100, - verifyFor: time.Hour * 8760 * 100, + idTokenValidFor: time.Hour * 8760 * 100, key: func() (*rsa.PrivateKey, error) { return key, nil }, } } // defaultRotationStrategy returns a strategy which rotates keys every provided period, // holding onto the public parts for some specified amount of time. -func defaultRotationStrategy(rotationFrequency, verifyFor time.Duration) rotationStrategy { +func defaultRotationStrategy(rotationFrequency, idTokenValidFor time.Duration) rotationStrategy { return rotationStrategy{ rotationFrequency: rotationFrequency, - verifyFor: verifyFor, + idTokenValidFor: idTokenValidFor, key: func() (*rsa.PrivateKey, error) { return rsa.GenerateKey(rand.Reader, 2048) }, @@ -128,11 +129,14 @@ func (k keyRotater) rotate() error { return storage.Keys{}, errors.New("keys already rotated") } - // Remove expired verification keys. - i := 0 + expired := func(key storage.VerificationKey) bool { + return tNow.After(key.Expiry) + } + // Remove any verification keys that have expired. + i := 0 for _, key := range keys.VerificationKeys { - if !key.Expiry.After(tNow) { + if !expired(key) { keys.VerificationKeys[i] = key i++ } @@ -140,10 +144,15 @@ func (k keyRotater) rotate() error { keys.VerificationKeys = keys.VerificationKeys[:i] if keys.SigningKeyPub != nil { - // Move current signing key to a verification only key. + // Move current signing key to a verification only key, throwing + // away the private part. verificationKey := storage.VerificationKey{ PublicKey: keys.SigningKeyPub, - Expiry: tNow.Add(k.strategy.verifyFor), + // After demoting the signing key, keep the token around for at least + // the amount of time an ID Token is valid for. This ensures the + // verification key won't expire until all ID Tokens it's signed + // expired as well. + Expiry: tNow.Add(k.strategy.idTokenValidFor), } keys.VerificationKeys = append(keys.VerificationKeys, verificationKey) } diff --git a/server/rotation_test.go b/server/rotation_test.go index abb4e431..a792d7b6 100644 --- a/server/rotation_test.go +++ b/server/rotation_test.go @@ -1 +1,101 @@ package server + +import ( + "os" + "sort" + "testing" + "time" + + "github.com/Sirupsen/logrus" + "github.com/coreos/dex/storage" + "github.com/coreos/dex/storage/memory" +) + +func signingKeyID(t *testing.T, s storage.Storage) string { + keys, err := s.GetKeys() + if err != nil { + t.Fatal(err) + } + return keys.SigningKey.KeyID +} + +func verificationKeyIDs(t *testing.T, s storage.Storage) (ids []string) { + keys, err := s.GetKeys() + if err != nil { + t.Fatal(err) + } + for _, key := range keys.VerificationKeys { + ids = append(ids, key.PublicKey.KeyID) + } + return ids +} + +// slicesEq compare two string slices without modifying the ordering +// of the slices. +func slicesEq(s1, s2 []string) bool { + if len(s1) != len(s2) { + return false + } + + cp := func(s []string) []string { + c := make([]string, len(s)) + copy(c, s) + return c + } + + cp1 := cp(s1) + cp2 := cp(s2) + sort.Strings(cp1) + sort.Strings(cp2) + + for i, el := range cp1 { + if el != cp2[i] { + return false + } + } + return true +} + +func TestKeyRotater(t *testing.T) { + now := time.Now() + + delta := time.Millisecond + rotationFrequency := time.Second * 5 + validFor := time.Second * 21 + + // Only the last 5 verification keys are expected to be kept around. + maxVerificationKeys := 5 + + l := &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + } + + r := &keyRotater{ + Storage: memory.New(l), + strategy: defaultRotationStrategy(rotationFrequency, validFor), + now: func() time.Time { return now }, + logger: l, + } + + var expVerificationKeys []string + + for i := 0; i < 10; i++ { + now = now.Add(rotationFrequency + delta) + if err := r.rotate(); err != nil { + t.Fatal(err) + } + + got := verificationKeyIDs(t, r.Storage) + + if !slicesEq(expVerificationKeys, got) { + t.Errorf("after %d rotation, expected varification keys %q, got %q", i+1, expVerificationKeys, got) + } + + expVerificationKeys = append(expVerificationKeys, signingKeyID(t, r.Storage)) + if n := len(expVerificationKeys); n > maxVerificationKeys { + expVerificationKeys = expVerificationKeys[n-maxVerificationKeys:] + } + } +}