forked from mystiq/dex
Merge pull request #87 from bobbyrullo/keyspace
Base64 Encode secrets, and allow >1 of them
This commit is contained in:
commit
5abc7633fb
9 changed files with 361 additions and 57 deletions
|
@ -28,7 +28,10 @@ func init() {
|
|||
|
||||
func main() {
|
||||
fs := flag.NewFlagSet("dex-overlord", flag.ExitOnError)
|
||||
secret := fs.String("key-secret", "", "symmetric key used to encrypt/decrypt signing key data in DB")
|
||||
|
||||
keySecrets := pflag.NewBase64List(32)
|
||||
fs.Var(keySecrets, "key-secrets", "A comma-separated list of base64 encoded 32 byte strings used as symmetric keys used to encrypt/decrypt signing key data in DB. The first key is considered the active key and used for encryption, while the others are used to decrypt.")
|
||||
|
||||
dbURL := fs.String("db-url", "", "DSN-formatted database connection string")
|
||||
|
||||
dbMigrate := fs.Bool("db-migrate", true, "perform database migrations when starting up overlord. This includes the initial DB objects creation.")
|
||||
|
@ -59,10 +62,6 @@ func main() {
|
|||
log.EnableTimestamps()
|
||||
}
|
||||
|
||||
if len(*secret) == 0 {
|
||||
log.Fatalf("--key-secret unset")
|
||||
}
|
||||
|
||||
adminURL, err := url.Parse(*adminListen)
|
||||
if err != nil {
|
||||
log.Fatalf("Unable to use --admin-listen flag: %v", err)
|
||||
|
@ -96,11 +95,32 @@ func main() {
|
|||
userManager := user.NewManager(userRepo,
|
||||
pwiRepo, db.TransactionFactory(dbc), user.ManagerOptions{})
|
||||
adminAPI := admin.NewAdminAPI(userManager, userRepo, pwiRepo, *localConnectorID)
|
||||
kRepo, err := db.NewPrivateKeySetRepo(dbc, *secret)
|
||||
kRepo, err := db.NewPrivateKeySetRepo(dbc, keySecrets.BytesSlice()...)
|
||||
if err != nil {
|
||||
log.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
var sleep time.Duration
|
||||
for {
|
||||
var done bool
|
||||
_, err := kRepo.Get()
|
||||
switch err {
|
||||
case nil:
|
||||
done = true
|
||||
case key.ErrorNoKeys:
|
||||
done = true
|
||||
case db.ErrorCannotDecryptKeys:
|
||||
log.Fatalf("Cannot decrypt keys using any of the given key secrets. The key secrets must be changed to include one that can decrypt the existing keys, or the existing keys must be deleted.")
|
||||
}
|
||||
|
||||
if done {
|
||||
break
|
||||
}
|
||||
sleep = ptime.ExpBackoff(sleep, time.Minute)
|
||||
log.Errorf("Unable to get keys from repository, retrying in %v: %v", sleep, err)
|
||||
time.Sleep(sleep)
|
||||
}
|
||||
|
||||
krot := key.NewPrivateKeyRotator(kRepo, *keyPeriod)
|
||||
s := server.NewAdminServer(adminAPI, krot)
|
||||
h := s.HTTPHandler()
|
||||
|
|
|
@ -41,7 +41,10 @@ func main() {
|
|||
|
||||
// ignored if --no-db is set
|
||||
dbURL := fs.String("db-url", "", "DSN-formatted database connection string")
|
||||
keySecret := fs.String("key-secret", "", "symmetric key used to encrypt/decrypt signing key data in DB")
|
||||
|
||||
keySecrets := pflag.NewBase64List(32)
|
||||
fs.Var(keySecrets, "key-secrets", "A comma-separated list of base64 encoded 32 byte strings used as symmetric keys used to encrypt/decrypt signing key data in DB. The first key is considered the active key and used for encryption, while the others are used to decrypt.")
|
||||
|
||||
dbMaxIdleConns := fs.Int("db-max-idle-conns", 0, "maximum number of connections in the idle connection pool")
|
||||
dbMaxOpenConns := fs.Int("db-max-open-conns", 0, "maximum number of open connections to the database")
|
||||
|
||||
|
@ -109,7 +112,7 @@ func main() {
|
|||
MaxOpenConnections: *dbMaxOpenConns,
|
||||
}
|
||||
scfg.StateConfig = &server.MultiServerConfig{
|
||||
KeySecret: *keySecret,
|
||||
KeySecrets: keySecrets.BytesSlice(),
|
||||
DatabaseConfig: dbCfg,
|
||||
}
|
||||
}
|
||||
|
|
61
db/key.go
61
db/key.go
|
@ -18,6 +18,10 @@ const (
|
|||
keyTableName = "key"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrorCannotDecryptKeys = errors.New("Cannot Decrypt Keys")
|
||||
)
|
||||
|
||||
func init() {
|
||||
register(table{
|
||||
name: keyTableName,
|
||||
|
@ -85,23 +89,24 @@ type privateKeySetBlob struct {
|
|||
Value []byte `db:"value"`
|
||||
}
|
||||
|
||||
func NewPrivateKeySetRepo(dbm *gorp.DbMap, secret string) (*PrivateKeySetRepo, error) {
|
||||
bsecret := []byte(secret)
|
||||
if len(bsecret) != 32 {
|
||||
return nil, errors.New("expected 32-byte secret")
|
||||
func NewPrivateKeySetRepo(dbm *gorp.DbMap, secrets ...[]byte) (*PrivateKeySetRepo, error) {
|
||||
for i, secret := range secrets {
|
||||
if len(secret) != 32 {
|
||||
return nil, fmt.Errorf("key secret %d: expected 32-byte secret", i)
|
||||
}
|
||||
}
|
||||
|
||||
r := &PrivateKeySetRepo{
|
||||
dbMap: dbm,
|
||||
secret: []byte(secret),
|
||||
dbMap: dbm,
|
||||
secrets: secrets,
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
type PrivateKeySetRepo struct {
|
||||
dbMap *gorp.DbMap
|
||||
secret []byte
|
||||
dbMap *gorp.DbMap
|
||||
secrets [][]byte
|
||||
}
|
||||
|
||||
func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
||||
|
@ -126,7 +131,7 @@ func (r *PrivateKeySetRepo) Set(ks key.KeySet) error {
|
|||
return err
|
||||
}
|
||||
|
||||
v, err := pcrypto.AESEncrypt(j, r.secret)
|
||||
v, err := pcrypto.AESEncrypt(j, r.active())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -151,20 +156,32 @@ func (r *PrivateKeySetRepo) Get() (key.KeySet, error) {
|
|||
return nil, errors.New("unable to cast to KeySet")
|
||||
}
|
||||
|
||||
j, err := pcrypto.AESDecrypt(b.Value, r.secret)
|
||||
var pks *key.PrivateKeySet
|
||||
for _, secret := range r.secrets {
|
||||
var j []byte
|
||||
j, err = pcrypto.AESDecrypt(b.Value, secret)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var m privateKeySetModel
|
||||
if err = json.Unmarshal(j, &m); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
pks, err = m.PrivateKeySet()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.New("unable to decrypt key set")
|
||||
return nil, ErrorCannotDecryptKeys
|
||||
}
|
||||
|
||||
var m privateKeySetModel
|
||||
if err := json.Unmarshal(j, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pks, err := m.PrivateKeySet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return key.KeySet(pks), nil
|
||||
}
|
||||
|
||||
func (r *PrivateKeySetRepo) active() []byte {
|
||||
return r.secrets[0]
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
)
|
||||
|
||||
func TestNewPrivateKeySetRepoInvalidKey(t *testing.T) {
|
||||
_, err := NewPrivateKeySetRepo(nil, "sharks")
|
||||
_, err := NewPrivateKeySetRepo(nil, []byte("sharks"))
|
||||
if err == nil {
|
||||
t.Fatalf("Expected non-nil error")
|
||||
}
|
||||
|
|
|
@ -114,33 +114,75 @@ func TestDBSessionRepoCreateUpdate(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDBPrivateKeySetRepoSetGet(t *testing.T) {
|
||||
r, err := db.NewPrivateKeySetRepo(connect(t), "roflroflroflroflroflroflroflrofl")
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
s1 := []byte("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
|
||||
s2 := []byte("oooooooooooooooooooooooooooooooo")
|
||||
s3 := []byte("wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww")
|
||||
|
||||
keys := []*key.PrivateKey{}
|
||||
for i := 0; i < 2; i++ {
|
||||
k, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to generate RSA key: %v", err)
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
k1, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to generate RSA key: %v", err)
|
||||
ks := key.NewPrivateKeySet(
|
||||
[]*key.PrivateKey{keys[0], keys[1]}, time.Now().Add(time.Minute))
|
||||
|
||||
tests := []struct {
|
||||
setSecrets [][]byte
|
||||
getSecrets [][]byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
// same secrets used to encrypt, decrypt
|
||||
setSecrets: [][]byte{s1, s2},
|
||||
getSecrets: [][]byte{s1, s2},
|
||||
},
|
||||
{
|
||||
// setSecrets got rotated, but getSecrets didn't yet.
|
||||
setSecrets: [][]byte{s2, s3},
|
||||
getSecrets: [][]byte{s1, s2},
|
||||
},
|
||||
{
|
||||
// getSecrets doesn't have s3
|
||||
setSecrets: [][]byte{s3},
|
||||
getSecrets: [][]byte{s1, s2},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
k2, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to generate RSA key: %v", err)
|
||||
}
|
||||
for i, tt := range tests {
|
||||
setRepo, err := db.NewPrivateKeySetRepo(connect(t), tt.setSecrets...)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
ks := key.NewPrivateKeySet([]*key.PrivateKey{k1, k2}, time.Now().Add(time.Minute))
|
||||
if err := r.Set(ks); err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
getRepo, err := db.NewPrivateKeySetRepo(connect(t), tt.getSecrets...)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
got, err := r.Get()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if err := setRepo.Set(ks); err != nil {
|
||||
t.Fatalf("case %d: Unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
got, err := getRepo.Get()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want err, got nil", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: Unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
if diff := pretty.Compare(ks, got); diff != "" {
|
||||
t.Fatalf("case %d:Retrieved incorrect KeySet: Compare(want,got): %v", i, diff)
|
||||
}
|
||||
|
||||
if diff := pretty.Compare(ks, got); diff != "" {
|
||||
t.Fatalf("Retrieved incorrect KeySet: Compare(want,got): %v", diff)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -84,11 +84,13 @@ func AESDecrypt(ciphertext, key []byte) ([]byte, error) {
|
|||
}
|
||||
|
||||
mode := cipher.NewCBCDecrypter(block, iv)
|
||||
mode.CryptBlocks(ciphertext, ciphertext)
|
||||
|
||||
if len(ciphertext)%aes.BlockSize != 0 {
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
mode.CryptBlocks(plaintext, ciphertext)
|
||||
|
||||
if len(plaintext)%aes.BlockSize != 0 {
|
||||
return nil, errors.New("ciphertext is not a multiple of the block size")
|
||||
}
|
||||
|
||||
return unpad(ciphertext)
|
||||
return unpad(plaintext)
|
||||
}
|
||||
|
|
86
pkg/flag/base64.go
Normal file
86
pkg/flag/base64.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package flag
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Base64 implements flag.Value, and is used to populate []byte values from baes64 encoded strings.
|
||||
type Base64 struct {
|
||||
val []byte
|
||||
len int
|
||||
}
|
||||
|
||||
// NewBase64 returns a Base64 which accepts values which decode to len byte strings.
|
||||
func NewBase64(len int) *Base64 {
|
||||
return &Base64{
|
||||
len: len,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Base64) String() string {
|
||||
return base64.StdEncoding.EncodeToString(f.val)
|
||||
}
|
||||
|
||||
// Set will set the []byte value of the Base64 to the base64 decoded values of the string, returning an error if it cannot be decoded or is of the wrong length.
|
||||
func (f *Base64) Set(s string) error {
|
||||
b, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(b) != f.len {
|
||||
return fmt.Errorf("expected %d-byte secret", f.len)
|
||||
}
|
||||
|
||||
f.val = b
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bytes returns the set []byte value.
|
||||
// If no value has been set, a nil []byte is returned.
|
||||
func (f *Base64) Bytes() []byte {
|
||||
return f.val
|
||||
}
|
||||
|
||||
// NewBase64List returns a Base64List which accepts a comma-separated list of strings which must decode to len byte strings.
|
||||
func NewBase64List(len int) *Base64List {
|
||||
return &Base64List{
|
||||
len: len,
|
||||
}
|
||||
}
|
||||
|
||||
// Base64List implements flag.Value and is used to populate [][]byte values from a comma-separated list of base64 encoded strings.
|
||||
type Base64List struct {
|
||||
val [][]byte
|
||||
len int
|
||||
}
|
||||
|
||||
// Set will set the [][]byte value of the Base64List to the base64 decoded values of the comma-separated strings, returning an error on the first error it encounters.
|
||||
func (f *Base64List) Set(ss string) error {
|
||||
if ss == "" {
|
||||
return nil
|
||||
}
|
||||
for i, s := range strings.Split(ss, ",") {
|
||||
b64 := NewBase64(f.len)
|
||||
err := b64.Set(s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error decoding string %d: %q", i, err)
|
||||
}
|
||||
f.val = append(f.val, b64.Bytes())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Base64List) String() string {
|
||||
ss := []string{}
|
||||
for _, b := range f.val {
|
||||
ss = append(ss, base64.StdEncoding.EncodeToString(b))
|
||||
}
|
||||
return strings.Join(ss, ",")
|
||||
}
|
||||
|
||||
func (f *Base64List) BytesSlice() [][]byte {
|
||||
return f.val
|
||||
}
|
134
pkg/flag/base64_test.go
Normal file
134
pkg/flag/base64_test.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
package flag
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
)
|
||||
|
||||
func TestBase64(t *testing.T) {
|
||||
toB64 := func(b []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
s string
|
||||
l int
|
||||
b []byte
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
s: toB64([]byte("123456")),
|
||||
l: 6,
|
||||
b: []byte("123456"),
|
||||
},
|
||||
{
|
||||
s: toB64([]byte("123456")),
|
||||
l: 5,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
s: "not base64",
|
||||
l: 5,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
b64 := NewBase64(tt.l)
|
||||
err := b64.Set(tt.s)
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want err, got nil", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error %q", i, err)
|
||||
}
|
||||
|
||||
if diff := pretty.Compare(tt.b, b64.Bytes()); diff != "" {
|
||||
t.Errorf("case %d: Compare(want, got) = %v", i,
|
||||
diff)
|
||||
}
|
||||
|
||||
if b64.String() != tt.s {
|
||||
t.Errorf("case %d: want=%q, got=%q", i, b64.String(), tt.s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBase64List(t *testing.T) {
|
||||
// toCSB64 == to comma separated base 64
|
||||
toCSB64 := func(bb ...[]byte) string {
|
||||
ss := []string{}
|
||||
for _, b := range bb {
|
||||
ss = append(ss, base64.StdEncoding.EncodeToString(b))
|
||||
}
|
||||
return strings.Join(ss, ",")
|
||||
}
|
||||
|
||||
b123 := []byte("123456")
|
||||
b567 := []byte("567890")
|
||||
bShort := []byte("1234")
|
||||
|
||||
tests := []struct {
|
||||
s string
|
||||
l int
|
||||
bb [][]byte
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
s: toCSB64(b123, b567),
|
||||
l: 6,
|
||||
bb: [][]byte{b123, b567},
|
||||
},
|
||||
{
|
||||
s: toCSB64(b123),
|
||||
l: 6,
|
||||
bb: [][]byte{b123},
|
||||
},
|
||||
{
|
||||
s: "",
|
||||
l: 6,
|
||||
bb: [][]byte{},
|
||||
},
|
||||
{
|
||||
s: toCSB64(b123, bShort),
|
||||
l: 6,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
s: toCSB64(bShort, b123),
|
||||
l: 6,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
b64 := NewBase64List(tt.l)
|
||||
err := b64.Set(tt.s)
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want err, got nil", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error %q", i, err)
|
||||
}
|
||||
|
||||
if diff := pretty.Compare(tt.bb, b64.BytesSlice()); diff != "" {
|
||||
t.Errorf("case %d: Compare(want, got) = %v", i,
|
||||
diff)
|
||||
}
|
||||
|
||||
if b64.String() != tt.s {
|
||||
t.Errorf("case %d: want=%q, got=%q", i, b64.String(), tt.s)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -44,7 +44,7 @@ type SingleServerConfig struct {
|
|||
}
|
||||
|
||||
type MultiServerConfig struct {
|
||||
KeySecret string
|
||||
KeySecrets [][]byte
|
||||
DatabaseConfig db.Config
|
||||
}
|
||||
|
||||
|
@ -141,7 +141,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
|
|||
}
|
||||
|
||||
func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
||||
if cfg.KeySecret == "" {
|
||||
if len(cfg.KeySecrets) == 0 {
|
||||
return errors.New("missing key secret")
|
||||
}
|
||||
|
||||
|
@ -154,7 +154,7 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
|
|||
return fmt.Errorf("unable to initialize database connection: %v", err)
|
||||
}
|
||||
|
||||
kRepo, err := db.NewPrivateKeySetRepo(dbc, cfg.KeySecret)
|
||||
kRepo, err := db.NewPrivateKeySetRepo(dbc, cfg.KeySecrets...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create PrivateKeySetRepo: %v", err)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue