fix: Handle kubernetes API conflicts properly for signing keys

Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
m.nabokikh 2020-10-11 23:43:10 +03:00
parent 3f41b26fb9
commit 4801b2c975
2 changed files with 146 additions and 2 deletions

View file

@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
@ -512,6 +513,7 @@ func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro
}
firstUpdate = true
}
var oldKeys storage.Keys
if !firstUpdate {
oldKeys = toStorageKeys(keys)
@ -521,12 +523,32 @@ func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro
if err != nil {
return err
}
newKeys := cli.fromStorageKeys(updated)
if firstUpdate {
return cli.post(resourceKeys, newKeys)
err = cli.post(resourceKeys, newKeys)
if err != nil && errors.Is(err, storage.ErrAlreadyExists) {
// We need to tolerate conflicts here in case of HA mode.
cli.logger.Debugf("Keys creation failed: %v. It is possible that keys have already been created by another dex instance.", err)
return errors.New("keys already created by another server instance")
}
return err
}
newKeys.ObjectMeta = keys.ObjectMeta
return cli.put(resourceKeys, keysName, newKeys)
err = cli.put(resourceKeys, keysName, newKeys)
if httpErr, ok := err.(httpError); ok {
// We need to tolerate conflicts here in case of HA mode.
// Dex instances run keys rotation at the same time because they use SigningKey.nextRotation CR field as a trigger.
if httpErr.StatusCode() == http.StatusConflict {
cli.logger.Debugf("Keys rotation failed: %v. It is possible that keys have already been rotated by another dex instance.", err)
return errors.New("keys already rotated by another server instance")
}
}
return err
}
func (cli *client) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {

View file

@ -1,7 +1,12 @@
package kubernetes
import (
"crypto/tls"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
@ -150,3 +155,120 @@ func TestURLFor(t *testing.T) {
}
}
}
func TestUpdateKeys(t *testing.T) {
fakeUpdater := func(old storage.Keys) (storage.Keys, error) { return storage.Keys{}, nil }
tests := []struct {
name string
updater func(old storage.Keys) (storage.Keys, error)
getResponseCode int
actionResponseCode int
wantErr bool
exactErr error
}{
{
"Create OK test",
fakeUpdater,
404,
201,
false,
nil,
},
{
"Update should be OK",
fakeUpdater,
200,
200,
false,
nil,
},
{
"Create conflict should be OK",
fakeUpdater,
404,
409,
true,
errors.New("keys already created by another server instance"),
},
{
"Update conflict should be OK",
fakeUpdater,
200,
409,
true,
errors.New("keys already rotated by another server instance"),
},
{
"Client error is error",
fakeUpdater,
404,
500,
true,
nil,
},
{
"Client error during update is error",
fakeUpdater,
200,
500,
true,
nil,
},
{
"Get error is error",
fakeUpdater,
500,
200,
true,
nil,
},
{
"Updater error is error",
func(old storage.Keys) (storage.Keys, error) { return storage.Keys{}, fmt.Errorf("test") },
200,
201,
true,
nil,
},
}
for _, test := range tests {
client := newStatusCodesResponseTestClient(test.getResponseCode, test.actionResponseCode)
err := client.UpdateKeys(test.updater)
if err != nil {
if !test.wantErr {
t.Fatalf("Test %q: %v", test.name, err)
}
if test.exactErr != nil && test.exactErr.Error() != err.Error() {
t.Fatalf("Test %q: %v, wanted: %v", test.name, err, test.exactErr)
}
}
}
}
func newStatusCodesResponseTestClient(getResponseCode, actionResponseCode int) *client {
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet {
w.WriteHeader(getResponseCode)
} else {
w.WriteHeader(actionResponseCode)
}
w.Write([]byte(`{}`)) // Empty json is enough, we will test only response codes here
}))
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
return &client{
client: &http.Client{Transport: tr},
baseURL: s.URL,
logger: &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
},
}
}