diff --git a/cmd/dex-overlord/main.go b/cmd/dex-overlord/main.go index 419b6f2c..be56db47 100644 --- a/cmd/dex-overlord/main.go +++ b/cmd/dex-overlord/main.go @@ -28,7 +28,10 @@ func init() { func main() { fs := flag.NewFlagSet("dex-overlord", flag.ExitOnError) - secret := fs.String("key-secret", "", "symmetric key used to encrypt/decrypt signing key data in DB") + + keySecrets := pflag.NewBase64List(32) + fs.Var(keySecrets, "key-secrets", "A comma-separated list of base64 encoded 32 byte strings used as symmetric keys used to encrypt/decrypt signing key data in DB. The first key is considered the active key and used for encryption, while the others are used to decrypt.") + dbURL := fs.String("db-url", "", "DSN-formatted database connection string") dbMigrate := fs.Bool("db-migrate", true, "perform database migrations when starting up overlord. This includes the initial DB objects creation.") @@ -59,10 +62,6 @@ func main() { log.EnableTimestamps() } - if len(*secret) == 0 { - log.Fatalf("--key-secret unset") - } - adminURL, err := url.Parse(*adminListen) if err != nil { log.Fatalf("Unable to use --admin-listen flag: %v", err) @@ -96,11 +95,32 @@ func main() { userManager := user.NewManager(userRepo, pwiRepo, db.TransactionFactory(dbc), user.ManagerOptions{}) adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID) - kRepo, err := db.NewPrivateKeySetRepo(dbc, *secret) + kRepo, err := db.NewPrivateKeySetRepo(dbc, keySecrets.BytesSlice()...) if err != nil { log.Fatalf(err.Error()) } + var sleep time.Duration + for { + var done bool + _, err := kRepo.Get() + switch err { + case nil: + done = true + case key.ErrorNoKeys: + done = true + case db.ErrorCannotDecryptKeys: + log.Fatalf("Cannot decrypt keys using any of the given key secrets. The key secrets must be changed to include one that can decrypt the existing keys, or the existing keys must be deleted.") + } + + if done { + break + } + sleep = ptime.ExpBackoff(sleep, time.Minute) + log.Errorf("Unable to get keys from repository, retrying in %v: %v", sleep, err) + time.Sleep(sleep) + } + krot := key.NewPrivateKeyRotator(kRepo, *keyPeriod) s := server.NewAdminServer(adminAPI, krot) h := s.HTTPHandler() diff --git a/cmd/dex-worker/main.go b/cmd/dex-worker/main.go index 2a93b481..e9a50ea2 100644 --- a/cmd/dex-worker/main.go +++ b/cmd/dex-worker/main.go @@ -41,7 +41,10 @@ func main() { // ignored if --no-db is set dbURL := fs.String("db-url", "", "DSN-formatted database connection string") - keySecret := fs.String("key-secret", "", "symmetric key used to encrypt/decrypt signing key data in DB") + + keySecrets := pflag.NewBase64List(32) + fs.Var(keySecrets, "key-secrets", "A comma-separated list of base64 encoded 32 byte strings used as symmetric keys used to encrypt/decrypt signing key data in DB. The first key is considered the active key and used for encryption, while the others are used to decrypt.") + dbMaxIdleConns := fs.Int("db-max-idle-conns", 0, "maximum number of connections in the idle connection pool") dbMaxOpenConns := fs.Int("db-max-open-conns", 0, "maximum number of open connections to the database") @@ -109,7 +112,7 @@ func main() { MaxOpenConnections: *dbMaxOpenConns, } scfg.StateConfig = &server.MultiServerConfig{ - KeySecret: *keySecret, + KeySecrets: keySecrets.BytesSlice(), DatabaseConfig: dbCfg, } } diff --git a/db/key.go b/db/key.go index 1858f9c2..16759136 100644 --- a/db/key.go +++ b/db/key.go @@ -18,6 +18,10 @@ const ( keyTableName = "key" ) +var ( + ErrorCannotDecryptKeys = errors.New("Cannot Decrypt Keys") +) + func init() { register(table{ name: keyTableName, @@ -85,23 +89,24 @@ type privateKeySetBlob struct { Value []byte `db:"value"` } -func NewPrivateKeySetRepo(dbm *gorp.DbMap, secret string) (*PrivateKeySetRepo, error) { - bsecret := []byte(secret) - if len(bsecret) != 32 { - return nil, errors.New("expected 32-byte secret") +func NewPrivateKeySetRepo(dbm *gorp.DbMap, secrets ...[]byte) (*PrivateKeySetRepo, error) { + for i, secret := range secrets { + if len(secret) != 32 { + return nil, fmt.Errorf("key secret %d: expected 32-byte secret", i) + } } r := &PrivateKeySetRepo{ - dbMap: dbm, - secret: []byte(secret), + dbMap: dbm, + secrets: secrets, } return r, nil } type PrivateKeySetRepo struct { - dbMap *gorp.DbMap - secret []byte + dbMap *gorp.DbMap + secrets [][]byte } func (r *PrivateKeySetRepo) Set(ks key.KeySet) error { @@ -126,7 +131,7 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error { return err } - v, err := pcrypto.AESEncrypt(j, r.secret) + v, err := pcrypto.AESEncrypt(j, r.active()) if err != nil { return err } @@ -151,20 +156,32 @@ func (r *PrivateKeySetRepo) Get() (key.KeySet, error) { return nil, errors.New("unable to cast to KeySet") } - j, err := pcrypto.AESDecrypt(b.Value, r.secret) + var pks *key.PrivateKeySet + for _, secret := range r.secrets { + var j []byte + j, err = pcrypto.AESDecrypt(b.Value, secret) + if err != nil { + continue + } + + var m privateKeySetModel + if err = json.Unmarshal(j, &m); err != nil { + continue + } + + pks, err = m.PrivateKeySet() + if err != nil { + continue + } + break + } + if err != nil { - return nil, errors.New("unable to decrypt key set") + return nil, ErrorCannotDecryptKeys } - - var m privateKeySetModel - if err := json.Unmarshal(j, &m); err != nil { - return nil, err - } - - pks, err := m.PrivateKeySet() - if err != nil { - return nil, err - } - return key.KeySet(pks), nil } + +func (r *PrivateKeySetRepo) active() []byte { + return r.secrets[0] +} diff --git a/db/key_test.go b/db/key_test.go index 634c9bf3..fe2c2570 100644 --- a/db/key_test.go +++ b/db/key_test.go @@ -5,7 +5,7 @@ import ( ) func TestNewPrivateKeySetRepoInvalidKey(t *testing.T) { - _, err := NewPrivateKeySetRepo(nil, "sharks") + _, err := NewPrivateKeySetRepo(nil, []byte("sharks")) if err == nil { t.Fatalf("Expected non-nil error") } diff --git a/functional/db_test.go b/functional/db_test.go index c308526f..1177fa4e 100644 --- a/functional/db_test.go +++ b/functional/db_test.go @@ -114,33 +114,75 @@ func TestDBSessionRepoCreateUpdate(t *testing.T) { } func TestDBPrivateKeySetRepoSetGet(t *testing.T) { - r, err := db.NewPrivateKeySetRepo(connect(t), "roflroflroflroflroflroflroflrofl") - if err != nil { - t.Fatalf(err.Error()) + s1 := []byte("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") + s2 := []byte("oooooooooooooooooooooooooooooooo") + s3 := []byte("wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww") + + keys := []*key.PrivateKey{} + for i := 0; i < 2; i++ { + k, err := key.GeneratePrivateKey() + if err != nil { + t.Fatalf("Unable to generate RSA key: %v", err) + } + keys = append(keys, k) } - k1, err := key.GeneratePrivateKey() - if err != nil { - t.Fatalf("Unable to generate RSA key: %v", err) + ks := key.NewPrivateKeySet( + []*key.PrivateKey{keys[0], keys[1]}, time.Now().Add(time.Minute)) + + tests := []struct { + setSecrets [][]byte + getSecrets [][]byte + wantErr bool + }{ + { + // same secrets used to encrypt, decrypt + setSecrets: [][]byte{s1, s2}, + getSecrets: [][]byte{s1, s2}, + }, + { + // setSecrets got rotated, but getSecrets didn't yet. + setSecrets: [][]byte{s2, s3}, + getSecrets: [][]byte{s1, s2}, + }, + { + // getSecrets doesn't have s3 + setSecrets: [][]byte{s3}, + getSecrets: [][]byte{s1, s2}, + wantErr: true, + }, } - k2, err := key.GeneratePrivateKey() - if err != nil { - t.Fatalf("Unable to generate RSA key: %v", err) - } + for i, tt := range tests { + setRepo, err := db.NewPrivateKeySetRepo(connect(t), tt.setSecrets...) + if err != nil { + t.Fatalf(err.Error()) + } - ks := key.NewPrivateKeySet([]*key.PrivateKey{k1, k2}, time.Now().Add(time.Minute)) - if err := r.Set(ks); err != nil { - t.Fatalf("Unexpected error: %v", err) - } + getRepo, err := db.NewPrivateKeySetRepo(connect(t), tt.getSecrets...) + if err != nil { + t.Fatalf(err.Error()) + } - got, err := r.Get() - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + if err := setRepo.Set(ks); err != nil { + t.Fatalf("case %d: Unexpected error: %v", i, err) + } + + got, err := getRepo.Get() + if tt.wantErr { + if err == nil { + t.Errorf("case %d: want err, got nil", i) + } + continue + } + if err != nil { + t.Fatalf("case %d: Unexpected error: %v", i, err) + } + + if diff := pretty.Compare(ks, got); diff != "" { + t.Fatalf("case %d:Retrieved incorrect KeySet: Compare(want,got): %v", i, diff) + } - if diff := pretty.Compare(ks, got); diff != "" { - t.Fatalf("Retrieved incorrect KeySet: Compare(want,got): %v", diff) } } diff --git a/pkg/crypto/aes.go b/pkg/crypto/aes.go index 82ffad5b..d5dad127 100644 --- a/pkg/crypto/aes.go +++ b/pkg/crypto/aes.go @@ -84,11 +84,13 @@ func AESDecrypt(ciphertext, key []byte) ([]byte, error) { } mode := cipher.NewCBCDecrypter(block, iv) - mode.CryptBlocks(ciphertext, ciphertext) - if len(ciphertext)%aes.BlockSize != 0 { + plaintext := make([]byte, len(ciphertext)) + mode.CryptBlocks(plaintext, ciphertext) + + if len(plaintext)%aes.BlockSize != 0 { return nil, errors.New("ciphertext is not a multiple of the block size") } - return unpad(ciphertext) + return unpad(plaintext) } diff --git a/pkg/flag/base64.go b/pkg/flag/base64.go new file mode 100644 index 00000000..f60267cd --- /dev/null +++ b/pkg/flag/base64.go @@ -0,0 +1,86 @@ +package flag + +import ( + "encoding/base64" + "fmt" + "strings" +) + +// Base64 implements flag.Value, and is used to populate []byte values from baes64 encoded strings. +type Base64 struct { + val []byte + len int +} + +// NewBase64 returns a Base64 which accepts values which decode to len byte strings. +func NewBase64(len int) *Base64 { + return &Base64{ + len: len, + } +} + +func (f *Base64) String() string { + return base64.StdEncoding.EncodeToString(f.val) +} + +// Set will set the []byte value of the Base64 to the base64 decoded values of the string, returning an error if it cannot be decoded or is of the wrong length. +func (f *Base64) Set(s string) error { + b, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return err + } + + if len(b) != f.len { + return fmt.Errorf("expected %d-byte secret", f.len) + } + + f.val = b + return nil +} + +// Bytes returns the set []byte value. +// If no value has been set, a nil []byte is returned. +func (f *Base64) Bytes() []byte { + return f.val +} + +// NewBase64List returns a Base64List which accepts a comma-separated list of strings which must decode to len byte strings. +func NewBase64List(len int) *Base64List { + return &Base64List{ + len: len, + } +} + +// Base64List implements flag.Value and is used to populate [][]byte values from a comma-separated list of base64 encoded strings. +type Base64List struct { + val [][]byte + len int +} + +// Set will set the [][]byte value of the Base64List to the base64 decoded values of the comma-separated strings, returning an error on the first error it encounters. +func (f *Base64List) Set(ss string) error { + if ss == "" { + return nil + } + for i, s := range strings.Split(ss, ",") { + b64 := NewBase64(f.len) + err := b64.Set(s) + if err != nil { + return fmt.Errorf("error decoding string %d: %q", i, err) + } + f.val = append(f.val, b64.Bytes()) + } + return nil +} + +func (f *Base64List) String() string { + ss := []string{} + for _, b := range f.val { + ss = append(ss, base64.StdEncoding.EncodeToString(b)) + } + return strings.Join(ss, ",") +} + +func (f *Base64List) BytesSlice() [][]byte { + return f.val +} diff --git a/pkg/flag/base64_test.go b/pkg/flag/base64_test.go new file mode 100644 index 00000000..f2aa25e8 --- /dev/null +++ b/pkg/flag/base64_test.go @@ -0,0 +1,134 @@ +package flag + +import ( + "encoding/base64" + "strings" + "testing" + + "github.com/kylelemons/godebug/pretty" +) + +func TestBase64(t *testing.T) { + toB64 := func(b []byte) string { + return base64.StdEncoding.EncodeToString(b) + } + + tests := []struct { + s string + l int + b []byte + wantError bool + }{ + { + s: toB64([]byte("123456")), + l: 6, + b: []byte("123456"), + }, + { + s: toB64([]byte("123456")), + l: 5, + wantError: true, + }, + { + s: "not base64", + l: 5, + wantError: true, + }, + } + + for i, tt := range tests { + b64 := NewBase64(tt.l) + err := b64.Set(tt.s) + if tt.wantError { + if err == nil { + t.Errorf("case %d: want err, got nil", i) + } + continue + } + + if err != nil { + t.Errorf("case %d: unexpected error %q", i, err) + } + + if diff := pretty.Compare(tt.b, b64.Bytes()); diff != "" { + t.Errorf("case %d: Compare(want, got) = %v", i, + diff) + } + + if b64.String() != tt.s { + t.Errorf("case %d: want=%q, got=%q", i, b64.String(), tt.s) + } + } +} + +func TestBase64List(t *testing.T) { + // toCSB64 == to comma separated base 64 + toCSB64 := func(bb ...[]byte) string { + ss := []string{} + for _, b := range bb { + ss = append(ss, base64.StdEncoding.EncodeToString(b)) + } + return strings.Join(ss, ",") + } + + b123 := []byte("123456") + b567 := []byte("567890") + bShort := []byte("1234") + + tests := []struct { + s string + l int + bb [][]byte + wantError bool + }{ + { + s: toCSB64(b123, b567), + l: 6, + bb: [][]byte{b123, b567}, + }, + { + s: toCSB64(b123), + l: 6, + bb: [][]byte{b123}, + }, + { + s: "", + l: 6, + bb: [][]byte{}, + }, + { + s: toCSB64(b123, bShort), + l: 6, + wantError: true, + }, + { + s: toCSB64(bShort, b123), + l: 6, + wantError: true, + }, + } + + for i, tt := range tests { + b64 := NewBase64List(tt.l) + err := b64.Set(tt.s) + if tt.wantError { + if err == nil { + t.Errorf("case %d: want err, got nil", i) + } + continue + } + + if err != nil { + t.Errorf("case %d: unexpected error %q", i, err) + } + + if diff := pretty.Compare(tt.bb, b64.BytesSlice()); diff != "" { + t.Errorf("case %d: Compare(want, got) = %v", i, + diff) + } + + if b64.String() != tt.s { + t.Errorf("case %d: want=%q, got=%q", i, b64.String(), tt.s) + } + } +} diff --git a/server/config.go b/server/config.go index 17171d52..b44ad702 100644 --- a/server/config.go +++ b/server/config.go @@ -44,7 +44,7 @@ type SingleServerConfig struct { } type MultiServerConfig struct { - KeySecret string + KeySecrets [][]byte DatabaseConfig db.Config } @@ -141,7 +141,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error { } func (cfg *MultiServerConfig) Configure(srv *Server) error { - if cfg.KeySecret == "" { + if len(cfg.KeySecrets) == 0 { return errors.New("missing key secret") } @@ -154,7 +154,7 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error { return fmt.Errorf("unable to initialize database connection: %v", err) } - kRepo, err := db.NewPrivateKeySetRepo(dbc, cfg.KeySecret) + kRepo, err := db.NewPrivateKeySetRepo(dbc, cfg.KeySecrets...) if err != nil { return fmt.Errorf("unable to create PrivateKeySetRepo: %v", err) }