Compare commits
10 commits
Author | SHA1 | Date | |
---|---|---|---|
11d4d78d6b | |||
5586a3c1fd | |||
29af9b4c75 | |||
5c6ddbb6dc | |||
1bb4edb25c | |||
0d894e3186 | |||
565805b716 | |||
2cfcdfb80f | |||
f8aec4c1c5 | |||
7968f283f2 |
|
@ -22,8 +22,9 @@ type rotationStrategy struct {
|
||||||
// Time between rotations.
|
// Time between rotations.
|
||||||
rotationFrequency time.Duration
|
rotationFrequency time.Duration
|
||||||
|
|
||||||
// After being rotated how long can a key validate signatues?
|
// After being rotated how long should the key be kept around for validating
|
||||||
verifyFor time.Duration
|
// signatues?
|
||||||
|
idTokenValidFor time.Duration
|
||||||
|
|
||||||
// Keys are always RSA keys. Though cryptopasta recommends ECDSA keys, not every
|
// Keys are always RSA keys. Though cryptopasta recommends ECDSA keys, not every
|
||||||
// client may support these (e.g. github.com/coreos/go-oidc/oidc).
|
// client may support these (e.g. github.com/coreos/go-oidc/oidc).
|
||||||
|
@ -35,17 +36,17 @@ func staticRotationStrategy(key *rsa.PrivateKey) rotationStrategy {
|
||||||
return rotationStrategy{
|
return rotationStrategy{
|
||||||
// Setting these values to 100 years is easier than having a flag indicating no rotation.
|
// Setting these values to 100 years is easier than having a flag indicating no rotation.
|
||||||
rotationFrequency: time.Hour * 8760 * 100,
|
rotationFrequency: time.Hour * 8760 * 100,
|
||||||
verifyFor: time.Hour * 8760 * 100,
|
idTokenValidFor: time.Hour * 8760 * 100,
|
||||||
key: func() (*rsa.PrivateKey, error) { return key, nil },
|
key: func() (*rsa.PrivateKey, error) { return key, nil },
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultRotationStrategy returns a strategy which rotates keys every provided period,
|
// defaultRotationStrategy returns a strategy which rotates keys every provided period,
|
||||||
// holding onto the public parts for some specified amount of time.
|
// holding onto the public parts for some specified amount of time.
|
||||||
func defaultRotationStrategy(rotationFrequency, verifyFor time.Duration) rotationStrategy {
|
func defaultRotationStrategy(rotationFrequency, idTokenValidFor time.Duration) rotationStrategy {
|
||||||
return rotationStrategy{
|
return rotationStrategy{
|
||||||
rotationFrequency: rotationFrequency,
|
rotationFrequency: rotationFrequency,
|
||||||
verifyFor: verifyFor,
|
idTokenValidFor: idTokenValidFor,
|
||||||
key: func() (*rsa.PrivateKey, error) {
|
key: func() (*rsa.PrivateKey, error) {
|
||||||
return rsa.GenerateKey(rand.Reader, 2048)
|
return rsa.GenerateKey(rand.Reader, 2048)
|
||||||
},
|
},
|
||||||
|
@ -128,11 +129,14 @@ func (k keyRotater) rotate() error {
|
||||||
return storage.Keys{}, errors.New("keys already rotated")
|
return storage.Keys{}, errors.New("keys already rotated")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove expired verification keys.
|
expired := func(key storage.VerificationKey) bool {
|
||||||
i := 0
|
return tNow.After(key.Expiry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove any verification keys that have expired.
|
||||||
|
i := 0
|
||||||
for _, key := range keys.VerificationKeys {
|
for _, key := range keys.VerificationKeys {
|
||||||
if !key.Expiry.After(tNow) {
|
if !expired(key) {
|
||||||
keys.VerificationKeys[i] = key
|
keys.VerificationKeys[i] = key
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
|
@ -140,10 +144,15 @@ func (k keyRotater) rotate() error {
|
||||||
keys.VerificationKeys = keys.VerificationKeys[:i]
|
keys.VerificationKeys = keys.VerificationKeys[:i]
|
||||||
|
|
||||||
if keys.SigningKeyPub != nil {
|
if keys.SigningKeyPub != nil {
|
||||||
// Move current signing key to a verification only key.
|
// Move current signing key to a verification only key, throwing
|
||||||
|
// away the private part.
|
||||||
verificationKey := storage.VerificationKey{
|
verificationKey := storage.VerificationKey{
|
||||||
PublicKey: keys.SigningKeyPub,
|
PublicKey: keys.SigningKeyPub,
|
||||||
Expiry: tNow.Add(k.strategy.verifyFor),
|
// After demoting the signing key, keep the token around for at least
|
||||||
|
// the amount of time an ID Token is valid for. This ensures the
|
||||||
|
// verification key won't expire until all ID Tokens it's signed
|
||||||
|
// expired as well.
|
||||||
|
Expiry: tNow.Add(k.strategy.idTokenValidFor),
|
||||||
}
|
}
|
||||||
keys.VerificationKeys = append(keys.VerificationKeys, verificationKey)
|
keys.VerificationKeys = append(keys.VerificationKeys, verificationKey)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1 +1,101 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Sirupsen/logrus"
|
||||||
|
"github.com/coreos/dex/storage"
|
||||||
|
"github.com/coreos/dex/storage/memory"
|
||||||
|
)
|
||||||
|
|
||||||
|
func signingKeyID(t *testing.T, s storage.Storage) string {
|
||||||
|
keys, err := s.GetKeys()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
return keys.SigningKey.KeyID
|
||||||
|
}
|
||||||
|
|
||||||
|
func verificationKeyIDs(t *testing.T, s storage.Storage) (ids []string) {
|
||||||
|
keys, err := s.GetKeys()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, key := range keys.VerificationKeys {
|
||||||
|
ids = append(ids, key.PublicKey.KeyID)
|
||||||
|
}
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
// slicesEq compare two string slices without modifying the ordering
|
||||||
|
// of the slices.
|
||||||
|
func slicesEq(s1, s2 []string) bool {
|
||||||
|
if len(s1) != len(s2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
cp := func(s []string) []string {
|
||||||
|
c := make([]string, len(s))
|
||||||
|
copy(c, s)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
cp1 := cp(s1)
|
||||||
|
cp2 := cp(s2)
|
||||||
|
sort.Strings(cp1)
|
||||||
|
sort.Strings(cp2)
|
||||||
|
|
||||||
|
for i, el := range cp1 {
|
||||||
|
if el != cp2[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyRotater(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
delta := time.Millisecond
|
||||||
|
rotationFrequency := time.Second * 5
|
||||||
|
validFor := time.Second * 21
|
||||||
|
|
||||||
|
// Only the last 5 verification keys are expected to be kept around.
|
||||||
|
maxVerificationKeys := 5
|
||||||
|
|
||||||
|
l := &logrus.Logger{
|
||||||
|
Out: os.Stderr,
|
||||||
|
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||||
|
Level: logrus.DebugLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
r := &keyRotater{
|
||||||
|
Storage: memory.New(l),
|
||||||
|
strategy: defaultRotationStrategy(rotationFrequency, validFor),
|
||||||
|
now: func() time.Time { return now },
|
||||||
|
logger: l,
|
||||||
|
}
|
||||||
|
|
||||||
|
var expVerificationKeys []string
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
now = now.Add(rotationFrequency + delta)
|
||||||
|
if err := r.rotate(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := verificationKeyIDs(t, r.Storage)
|
||||||
|
|
||||||
|
if !slicesEq(expVerificationKeys, got) {
|
||||||
|
t.Errorf("after %d rotation, expected varification keys %q, got %q", i+1, expVerificationKeys, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
expVerificationKeys = append(expVerificationKeys, signingKeyID(t, r.Storage))
|
||||||
|
if n := len(expVerificationKeys); n > maxVerificationKeys {
|
||||||
|
expVerificationKeys = expVerificationKeys[n-maxVerificationKeys:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -269,6 +269,32 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
|
||||||
|
|
||||||
getAndCompare(id, refresh)
|
getAndCompare(id, refresh)
|
||||||
|
|
||||||
|
id2 := storage.NewID()
|
||||||
|
refresh2 := storage.RefreshToken{
|
||||||
|
ID: id2,
|
||||||
|
Token: "bar_2",
|
||||||
|
Nonce: "foo_2",
|
||||||
|
ClientID: "client_id_2",
|
||||||
|
ConnectorID: "client_secret",
|
||||||
|
Scopes: []string{"openid", "email", "profile"},
|
||||||
|
CreatedAt: time.Now().UTC().Round(time.Millisecond),
|
||||||
|
LastUsed: time.Now().UTC().Round(time.Millisecond),
|
||||||
|
Claims: storage.Claims{
|
||||||
|
UserID: "2",
|
||||||
|
Username: "john",
|
||||||
|
Email: "john.doe@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
Groups: []string{"a", "b"},
|
||||||
|
},
|
||||||
|
ConnectorData: []byte(`{"some":"data"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.CreateRefresh(refresh2); err != nil {
|
||||||
|
t.Fatalf("create second refresh token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
getAndCompare(id2, refresh2)
|
||||||
|
|
||||||
updatedAt := time.Now().UTC().Round(time.Millisecond)
|
updatedAt := time.Now().UTC().Round(time.Millisecond)
|
||||||
|
|
||||||
updater := func(r storage.RefreshToken) (storage.RefreshToken, error) {
|
updater := func(r storage.RefreshToken) (storage.RefreshToken, error) {
|
||||||
|
@ -283,6 +309,9 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
|
||||||
refresh.LastUsed = updatedAt
|
refresh.LastUsed = updatedAt
|
||||||
getAndCompare(id, refresh)
|
getAndCompare(id, refresh)
|
||||||
|
|
||||||
|
// Ensure that updating the first token doesn't impact the second. Issue #847.
|
||||||
|
getAndCompare(id2, refresh2)
|
||||||
|
|
||||||
if err := s.DeleteRefresh(id); err != nil {
|
if err := s.DeleteRefresh(id); err != nil {
|
||||||
t.Fatalf("failed to delete refresh request: %v", err)
|
t.Fatalf("failed to delete refresh request: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,9 +72,10 @@ func idToName(s string, h func() hash.Hash) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
func offlineTokenName(userID string, connID string, h func() hash.Hash) string {
|
func offlineTokenName(userID string, connID string, h func() hash.Hash) string {
|
||||||
h().Write([]byte(userID))
|
hash := h()
|
||||||
h().Write([]byte(connID))
|
hash.Write([]byte(userID))
|
||||||
return strings.TrimRight(encoding.EncodeToString(h().Sum(nil)), "=")
|
hash.Write([]byte(connID))
|
||||||
|
return strings.TrimRight(encoding.EncodeToString(hash.Sum(nil)), "=")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) urlFor(apiVersion, namespace, resource, name string) string {
|
func (c *client) urlFor(apiVersion, namespace, resource, name string) string {
|
||||||
|
@ -135,6 +136,9 @@ func checkHTTPErr(r *http.Response, validStatusCodes ...int) error {
|
||||||
if r.StatusCode == http.StatusNotFound {
|
if r.StatusCode == http.StatusNotFound {
|
||||||
return storage.ErrNotFound
|
return storage.ErrNotFound
|
||||||
}
|
}
|
||||||
|
if r.Request.Method == "POST" && r.StatusCode == http.StatusConflict {
|
||||||
|
return storage.ErrAlreadyExists
|
||||||
|
}
|
||||||
|
|
||||||
var url, method string
|
var url, method string
|
||||||
if r.Request != nil {
|
if r.Request != nil {
|
||||||
|
|
|
@ -29,6 +29,19 @@ func TestIDToName(t *testing.T) {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOfflineTokenName(t *testing.T) {
|
||||||
|
h := func() hash.Hash { return fnv.New64() }
|
||||||
|
|
||||||
|
userID1 := "john"
|
||||||
|
userID2 := "jane"
|
||||||
|
|
||||||
|
id1 := offlineTokenName(userID1, "local", h)
|
||||||
|
id2 := offlineTokenName(userID2, "local", h)
|
||||||
|
if id1 == id2 {
|
||||||
|
t.Errorf("expected offlineTokenName to produce different hashes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNamespaceFromServiceAccountJWT(t *testing.T) {
|
func TestNamespaceFromServiceAccountJWT(t *testing.T) {
|
||||||
namespace, err := namespaceFromServiceAccountJWT(serviceAccountToken)
|
namespace, err := namespaceFromServiceAccountJWT(serviceAccountToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -3,7 +3,6 @@ package kubernetes
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -31,7 +30,7 @@ const (
|
||||||
resourceRefreshToken = "refreshtokens"
|
resourceRefreshToken = "refreshtokens"
|
||||||
resourceKeys = "signingkeies" // Kubernetes attempts to pluralize.
|
resourceKeys = "signingkeies" // Kubernetes attempts to pluralize.
|
||||||
resourcePassword = "passwords"
|
resourcePassword = "passwords"
|
||||||
resourceOfflineSessions = "offlinesessions"
|
resourceOfflineSessions = "offlinesessionses" // Again attempts to pluralize.
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config values for the Kubernetes storage type.
|
// Config values for the Kubernetes storage type.
|
||||||
|
@ -42,15 +41,19 @@ type Config struct {
|
||||||
|
|
||||||
// Open returns a storage using Kubernetes third party resource.
|
// Open returns a storage using Kubernetes third party resource.
|
||||||
func (c *Config) Open(logger logrus.FieldLogger) (storage.Storage, error) {
|
func (c *Config) Open(logger logrus.FieldLogger) (storage.Storage, error) {
|
||||||
cli, err := c.open(logger)
|
cli, err := c.open(logger, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return cli, nil
|
return cli, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// open returns a client with no garbage collection.
|
// open returns a kubernetes client, initializing the third party resources used
|
||||||
func (c *Config) open(logger logrus.FieldLogger) (*client, error) {
|
// by dex.
|
||||||
|
//
|
||||||
|
// errOnTPRs controls if errors creating the resources cause this method to return
|
||||||
|
// immediately (used during testing), or if the client will asynchronously retry.
|
||||||
|
func (c *Config) open(logger logrus.FieldLogger, errOnTPRs bool) (*client, error) {
|
||||||
if c.InCluster && (c.KubeConfigFile != "") {
|
if c.InCluster && (c.KubeConfigFile != "") {
|
||||||
return nil, errors.New("cannot specify both 'inCluster' and 'kubeConfigFile'")
|
return nil, errors.New("cannot specify both 'inCluster' and 'kubeConfigFile'")
|
||||||
}
|
}
|
||||||
|
@ -80,16 +83,18 @@ func (c *Config) open(logger logrus.FieldLogger) (*client, error) {
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
// Try to synchronously create the third party resources once. This doesn't mean
|
if !cli.createThirdPartyResources() {
|
||||||
// they'll immediately be available, but ensures that the client will actually try
|
if errOnTPRs {
|
||||||
// once.
|
return nil, fmt.Errorf("failed creating third party resources")
|
||||||
if err := cli.createThirdPartyResources(); err != nil {
|
}
|
||||||
|
|
||||||
|
// Try to synchronously create the third party resources once. This doesn't mean
|
||||||
|
// they'll immediately be available, but ensures that the client will actually try
|
||||||
|
// once.
|
||||||
logger.Errorf("failed creating third party resources: %v", err)
|
logger.Errorf("failed creating third party resources: %v", err)
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
if err := cli.createThirdPartyResources(); err != nil {
|
if cli.createThirdPartyResources() {
|
||||||
logger.Errorf("failed creating third party resources: %v", err)
|
|
||||||
} else {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,27 +113,33 @@ func (c *Config) open(logger logrus.FieldLogger) (*client, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// createThirdPartyResources attempts to create the third party resources dex
|
// createThirdPartyResources attempts to create the third party resources dex
|
||||||
// requires or identifies that they're already enabled.
|
// requires or identifies that they're already enabled. It logs all errors,
|
||||||
|
// returning true if the third party resources were created successfully.
|
||||||
//
|
//
|
||||||
// Creating a third party resource does not mean that they'll be immediately available.
|
// Creating a third party resource does not mean that they'll be immediately available.
|
||||||
//
|
//
|
||||||
// TODO(ericchiang): Provide an option to wait for the third party resources
|
// TODO(ericchiang): Provide an option to wait for the third party resources
|
||||||
// to actually be available.
|
// to actually be available.
|
||||||
func (cli *client) createThirdPartyResources() error {
|
func (cli *client) createThirdPartyResources() (ok bool) {
|
||||||
|
ok = true
|
||||||
for _, r := range thirdPartyResources {
|
for _, r := range thirdPartyResources {
|
||||||
err := cli.postResource("extensions/v1beta1", "", "thirdpartyresources", r)
|
err := cli.postResource("extensions/v1beta1", "", "thirdpartyresources", r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if e, ok := err.(httpError); ok {
|
switch err {
|
||||||
if e.StatusCode() == http.StatusConflict {
|
case storage.ErrAlreadyExists:
|
||||||
cli.logger.Errorf("third party resource already created %q", r.ObjectMeta.Name)
|
cli.logger.Errorf("third party resource already created %s", r.ObjectMeta.Name)
|
||||||
continue
|
case storage.ErrNotFound:
|
||||||
}
|
cli.logger.Errorf("third party resources not found, please enable API group extensions/v1beta1")
|
||||||
|
ok = false
|
||||||
|
default:
|
||||||
|
cli.logger.Errorf("creating third party resource %s: %v", r.ObjectMeta.Name, err)
|
||||||
|
ok = false
|
||||||
}
|
}
|
||||||
return err
|
continue
|
||||||
}
|
}
|
||||||
cli.logger.Errorf("create third party resource %q", r.ObjectMeta.Name)
|
cli.logger.Errorf("create third party resource %s", r.ObjectMeta.Name)
|
||||||
}
|
}
|
||||||
return nil
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cli *client) Close() error {
|
func (cli *client) Close() error {
|
||||||
|
|
|
@ -28,7 +28,7 @@ func loadClient(t *testing.T) *client {
|
||||||
Formatter: &logrus.TextFormatter{DisableColors: true},
|
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||||
Level: logrus.DebugLevel,
|
Level: logrus.DebugLevel,
|
||||||
}
|
}
|
||||||
s, err := config.open(logger)
|
s, err := config.open(logger, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -502,9 +502,14 @@ func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) Offline
|
||||||
}
|
}
|
||||||
|
|
||||||
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
|
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
|
||||||
return storage.OfflineSessions{
|
s := storage.OfflineSessions{
|
||||||
UserID: o.UserID,
|
UserID: o.UserID,
|
||||||
ConnID: o.ConnID,
|
ConnID: o.ConnID,
|
||||||
Refresh: o.Refresh,
|
Refresh: o.Refresh,
|
||||||
}
|
}
|
||||||
|
if s.Refresh == nil {
|
||||||
|
// Server code assumes this will be non-nil.
|
||||||
|
s.Refresh = make(map[string]*storage.RefreshTokenRef)
|
||||||
|
}
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
|
@ -299,12 +299,14 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
|
||||||
token = $11,
|
token = $11,
|
||||||
created_at = $12,
|
created_at = $12,
|
||||||
last_used = $13
|
last_used = $13
|
||||||
|
where
|
||||||
|
id = $14
|
||||||
`,
|
`,
|
||||||
r.ClientID, encoder(r.Scopes), r.Nonce,
|
r.ClientID, encoder(r.Scopes), r.Nonce,
|
||||||
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
|
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
|
||||||
encoder(r.Claims.Groups),
|
encoder(r.Claims.Groups),
|
||||||
r.ConnectorID, r.ConnectorData,
|
r.ConnectorID, r.ConnectorData,
|
||||||
r.Token, r.CreatedAt, r.LastUsed,
|
r.Token, r.CreatedAt, r.LastUsed, id,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("update refresh token: %v", err)
|
return fmt.Errorf("update refresh token: %v", err)
|
||||||
|
|
Reference in a new issue