diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index baf1d567..c0d6eb91 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net/http" "strings" "time" @@ -512,6 +513,7 @@ func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro } firstUpdate = true } + var oldKeys storage.Keys if !firstUpdate { oldKeys = toStorageKeys(keys) @@ -521,12 +523,32 @@ func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro if err != nil { return err } + newKeys := cli.fromStorageKeys(updated) if firstUpdate { - return cli.post(resourceKeys, newKeys) + err = cli.post(resourceKeys, newKeys) + if err != nil && errors.Is(err, storage.ErrAlreadyExists) { + // We need to tolerate conflicts here in case of HA mode. + cli.logger.Debugf("Keys creation failed: %v. It is possible that keys have already been created by another dex instance.", err) + return errors.New("keys already created by another server instance") + } + + return err } + newKeys.ObjectMeta = keys.ObjectMeta - return cli.put(resourceKeys, keysName, newKeys) + + err = cli.put(resourceKeys, keysName, newKeys) + if httpErr, ok := err.(httpError); ok { + // We need to tolerate conflicts here in case of HA mode. + // Dex instances run keys rotation at the same time because they use SigningKey.nextRotation CR field as a trigger. + if httpErr.StatusCode() == http.StatusConflict { + cli.logger.Debugf("Keys rotation failed: %v. It is possible that keys have already been rotated by another dex instance.", err) + return errors.New("keys already rotated by another server instance") + } + } + + return err } func (cli *client) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { diff --git a/storage/kubernetes/storage_test.go b/storage/kubernetes/storage_test.go index 2c9deeb2..e2c77a62 100644 --- a/storage/kubernetes/storage_test.go +++ b/storage/kubernetes/storage_test.go @@ -1,7 +1,12 @@ package kubernetes import ( + "crypto/tls" + "errors" + "fmt" "io/ioutil" + "net/http" + "net/http/httptest" "os" "strings" "testing" @@ -150,3 +155,120 @@ func TestURLFor(t *testing.T) { } } } + +func TestUpdateKeys(t *testing.T) { + fakeUpdater := func(old storage.Keys) (storage.Keys, error) { return storage.Keys{}, nil } + + tests := []struct { + name string + updater func(old storage.Keys) (storage.Keys, error) + getResponseCode int + actionResponseCode int + wantErr bool + exactErr error + }{ + { + "Create OK test", + fakeUpdater, + 404, + 201, + false, + nil, + }, + { + "Update should be OK", + fakeUpdater, + 200, + 200, + false, + nil, + }, + { + "Create conflict should be OK", + fakeUpdater, + 404, + 409, + true, + errors.New("keys already created by another server instance"), + }, + { + "Update conflict should be OK", + fakeUpdater, + 200, + 409, + true, + errors.New("keys already rotated by another server instance"), + }, + { + "Client error is error", + fakeUpdater, + 404, + 500, + true, + nil, + }, + { + "Client error during update is error", + fakeUpdater, + 200, + 500, + true, + nil, + }, + { + "Get error is error", + fakeUpdater, + 500, + 200, + true, + nil, + }, + { + "Updater error is error", + func(old storage.Keys) (storage.Keys, error) { return storage.Keys{}, fmt.Errorf("test") }, + 200, + 201, + true, + nil, + }, + } + + for _, test := range tests { + client := newStatusCodesResponseTestClient(test.getResponseCode, test.actionResponseCode) + + err := client.UpdateKeys(test.updater) + if err != nil { + if !test.wantErr { + t.Fatalf("Test %q: %v", test.name, err) + } + + if test.exactErr != nil && test.exactErr.Error() != err.Error() { + t.Fatalf("Test %q: %v, wanted: %v", test.name, err, test.exactErr) + } + } + } +} + +func newStatusCodesResponseTestClient(getResponseCode, actionResponseCode int) *client { + s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + w.WriteHeader(getResponseCode) + } else { + w.WriteHeader(actionResponseCode) + } + w.Write([]byte(`{}`)) // Empty json is enough, we will test only response codes here + })) + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + return &client{ + client: &http.Client{Transport: tr}, + baseURL: s.URL, + logger: &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + }, + } +}