forked from mystiq/dex
225 lines
4.5 KiB
Go
225 lines
4.5 KiB
Go
package key
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"math/big"
|
|
"reflect"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/jonboulle/clockwork"
|
|
|
|
"github.com/coreos/go-oidc/jose"
|
|
)
|
|
|
|
var (
|
|
jwk1 jose.JWK
|
|
jwk2 jose.JWK
|
|
jwk3 jose.JWK
|
|
)
|
|
|
|
func init() {
|
|
jwk1 = jose.JWK{
|
|
ID: "1",
|
|
Type: "RSA",
|
|
Alg: "RS256",
|
|
Use: "sig",
|
|
Modulus: big.NewInt(1),
|
|
Exponent: 65537,
|
|
}
|
|
|
|
jwk2 = jose.JWK{
|
|
ID: "2",
|
|
Type: "RSA",
|
|
Alg: "RS256",
|
|
Use: "sig",
|
|
Modulus: big.NewInt(2),
|
|
Exponent: 65537,
|
|
}
|
|
|
|
jwk3 = jose.JWK{
|
|
ID: "3",
|
|
Type: "RSA",
|
|
Alg: "RS256",
|
|
Use: "sig",
|
|
Modulus: big.NewInt(3),
|
|
Exponent: 65537,
|
|
}
|
|
}
|
|
|
|
func generatePrivateKeyStatic(t *testing.T, idAndN int) *PrivateKey {
|
|
n := big.NewInt(int64(idAndN))
|
|
if n == nil {
|
|
t.Fatalf("Call to NewInt(%d) failed", idAndN)
|
|
}
|
|
|
|
pk := &rsa.PrivateKey{
|
|
PublicKey: rsa.PublicKey{N: n, E: 65537},
|
|
}
|
|
|
|
return &PrivateKey{
|
|
KeyID: strconv.Itoa(idAndN),
|
|
PrivateKey: pk,
|
|
}
|
|
}
|
|
|
|
func TestPrivateKeyManagerJWKsRotate(t *testing.T) {
|
|
k1 := generatePrivateKeyStatic(t, 1)
|
|
k2 := generatePrivateKeyStatic(t, 2)
|
|
k3 := generatePrivateKeyStatic(t, 3)
|
|
km := NewPrivateKeyManager()
|
|
err := km.Set(&PrivateKeySet{
|
|
keys: []*PrivateKey{k1, k2, k3},
|
|
ActiveKeyID: k1.KeyID,
|
|
expiresAt: time.Now().Add(time.Minute),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
want := []jose.JWK{jwk1, jwk2, jwk3}
|
|
got, err := km.JWKs()
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
if !reflect.DeepEqual(want, got) {
|
|
t.Fatalf("JWK mismatch: want=%#v got=%#v", want, got)
|
|
}
|
|
}
|
|
|
|
func TestPrivateKeyManagerSigner(t *testing.T) {
|
|
k := generatePrivateKeyStatic(t, 13)
|
|
|
|
km := NewPrivateKeyManager()
|
|
err := km.Set(&PrivateKeySet{
|
|
keys: []*PrivateKey{k},
|
|
ActiveKeyID: k.KeyID,
|
|
expiresAt: time.Now().Add(time.Minute),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
signer, err := km.Signer()
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
wantID := "13"
|
|
gotID := signer.ID()
|
|
if wantID != gotID {
|
|
t.Fatalf("Signer has incorrect ID: want=%s got=%s", wantID, gotID)
|
|
}
|
|
}
|
|
|
|
func TestPrivateKeyManagerHealthyFail(t *testing.T) {
|
|
keyFixture := generatePrivateKeyStatic(t, 1)
|
|
tests := []*privateKeyManager{
|
|
// keySet nil
|
|
&privateKeyManager{
|
|
keySet: nil,
|
|
clock: clockwork.NewRealClock(),
|
|
},
|
|
// zero keys
|
|
&privateKeyManager{
|
|
keySet: &PrivateKeySet{
|
|
keys: []*PrivateKey{},
|
|
expiresAt: time.Now().Add(time.Minute),
|
|
},
|
|
clock: clockwork.NewRealClock(),
|
|
},
|
|
// key set expired
|
|
&privateKeyManager{
|
|
keySet: &PrivateKeySet{
|
|
keys: []*PrivateKey{keyFixture},
|
|
expiresAt: time.Now().Add(-1 * time.Minute),
|
|
},
|
|
clock: clockwork.NewRealClock(),
|
|
},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
if err := tt.Healthy(); err == nil {
|
|
t.Errorf("case %d: nil error", i)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestPrivateKeyManagerHealthyFailsOtherMethods(t *testing.T) {
|
|
km := NewPrivateKeyManager()
|
|
if _, err := km.JWKs(); err == nil {
|
|
t.Fatalf("Expected non-nil error")
|
|
}
|
|
if _, err := km.Signer(); err == nil {
|
|
t.Fatalf("Expected non-nil error")
|
|
}
|
|
}
|
|
|
|
func TestPrivateKeyManagerExpiresAt(t *testing.T) {
|
|
fc := clockwork.NewFakeClock()
|
|
now := fc.Now().UTC()
|
|
|
|
k := generatePrivateKeyStatic(t, 17)
|
|
km := &privateKeyManager{
|
|
clock: fc,
|
|
}
|
|
|
|
want := fc.Now().UTC()
|
|
got := km.ExpiresAt()
|
|
if want != got {
|
|
t.Fatalf("Incorrect expiration time: want=%v got=%v", want, got)
|
|
}
|
|
|
|
err := km.Set(&PrivateKeySet{
|
|
keys: []*PrivateKey{k},
|
|
ActiveKeyID: k.KeyID,
|
|
expiresAt: now.Add(2 * time.Minute),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
want = fc.Now().UTC().Add(2 * time.Minute)
|
|
got = km.ExpiresAt()
|
|
if want != got {
|
|
t.Fatalf("Incorrect expiration time: want=%v got=%v", want, got)
|
|
}
|
|
}
|
|
|
|
func TestPublicKeys(t *testing.T) {
|
|
km := NewPrivateKeyManager()
|
|
k1 := generatePrivateKeyStatic(t, 1)
|
|
k2 := generatePrivateKeyStatic(t, 2)
|
|
k3 := generatePrivateKeyStatic(t, 3)
|
|
|
|
tests := [][]*PrivateKey{
|
|
[]*PrivateKey{k1},
|
|
[]*PrivateKey{k1, k2},
|
|
[]*PrivateKey{k1, k2, k3},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
ks := &PrivateKeySet{
|
|
keys: tt,
|
|
expiresAt: time.Now().Add(time.Hour),
|
|
}
|
|
km.Set(ks)
|
|
|
|
jwks, err := km.JWKs()
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
pks := NewPublicKeySet(jwks, time.Now().Add(time.Hour))
|
|
want := pks.Keys()
|
|
got, err := km.PublicKeys()
|
|
if err != nil {
|
|
t.Fatalf("Unexpected error: %v", err)
|
|
}
|
|
|
|
if !reflect.DeepEqual(want, got) {
|
|
t.Errorf("case %d: Invalid public keys: want=%v got=%v", i, want, got)
|
|
}
|
|
}
|
|
}
|