forked from mystiq/dex
9fad0602ec
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
294 lines
6.3 KiB
Go
294 lines
6.3 KiB
Go
package kubernetes
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/stretchr/testify/suite"
|
|
|
|
"github.com/dexidp/dex/storage"
|
|
"github.com/dexidp/dex/storage/conformance"
|
|
)
|
|
|
|
const kubeconfigPathVariableName = "DEX_KUBERNETES_CONFIG_PATH"
|
|
|
|
func TestStorage(t *testing.T) {
|
|
if os.Getenv(kubeconfigPathVariableName) == "" {
|
|
t.Skip(fmt.Sprintf("variable %q not set, skipping kubernetes storage tests\n", kubeconfigPathVariableName))
|
|
}
|
|
|
|
suite.Run(t, new(StorageTestSuite))
|
|
}
|
|
|
|
type StorageTestSuite struct {
|
|
suite.Suite
|
|
client *client
|
|
}
|
|
|
|
func (s *StorageTestSuite) expandDir(dir string) string {
|
|
dir = strings.Trim(dir, `"`)
|
|
if strings.HasPrefix(dir, "~/") {
|
|
homedir, err := os.UserHomeDir()
|
|
s.Require().NoError(err)
|
|
|
|
dir = filepath.Join(homedir, strings.TrimPrefix(dir, "~/"))
|
|
}
|
|
return dir
|
|
}
|
|
|
|
func (s *StorageTestSuite) SetupTest() {
|
|
kubeconfigPath := s.expandDir(os.Getenv(kubeconfigPathVariableName))
|
|
|
|
config := Config{
|
|
KubeConfigFile: kubeconfigPath,
|
|
}
|
|
|
|
logger := &logrus.Logger{
|
|
Out: os.Stderr,
|
|
Formatter: &logrus.TextFormatter{DisableColors: true},
|
|
Level: logrus.DebugLevel,
|
|
}
|
|
|
|
kubeClient, err := config.open(logger, true)
|
|
s.Require().NoError(err)
|
|
|
|
s.client = kubeClient
|
|
}
|
|
|
|
func (s *StorageTestSuite) TestStorage() {
|
|
newStorage := func() storage.Storage {
|
|
for _, resource := range []string{
|
|
resourceAuthCode,
|
|
resourceAuthRequest,
|
|
resourceDeviceRequest,
|
|
resourceDeviceToken,
|
|
resourceClient,
|
|
resourceRefreshToken,
|
|
resourceKeys,
|
|
resourcePassword,
|
|
} {
|
|
if err := s.client.deleteAll(resource); err != nil {
|
|
s.T().Fatalf("delete all %q failed: %v", resource, err)
|
|
}
|
|
}
|
|
return s.client
|
|
}
|
|
|
|
conformance.RunTests(s.T(), newStorage)
|
|
conformance.RunTransactionTests(s.T(), newStorage)
|
|
}
|
|
|
|
func TestURLFor(t *testing.T) {
|
|
tests := []struct {
|
|
apiVersion, namespace, resource, name string
|
|
|
|
baseURL string
|
|
want string
|
|
}{
|
|
{
|
|
"v1", "default", "pods", "a",
|
|
"https://k8s.example.com",
|
|
"https://k8s.example.com/api/v1/namespaces/default/pods/a",
|
|
},
|
|
{
|
|
"foo/v1", "default", "bar", "a",
|
|
"https://k8s.example.com",
|
|
"https://k8s.example.com/apis/foo/v1/namespaces/default/bar/a",
|
|
},
|
|
{
|
|
"foo/v1", "default", "bar", "a",
|
|
"https://k8s.example.com/",
|
|
"https://k8s.example.com/apis/foo/v1/namespaces/default/bar/a",
|
|
},
|
|
{
|
|
"foo/v1", "default", "bar", "a",
|
|
"https://k8s.example.com/",
|
|
"https://k8s.example.com/apis/foo/v1/namespaces/default/bar/a",
|
|
},
|
|
{
|
|
// no namespace
|
|
"foo/v1", "", "bar", "a",
|
|
"https://k8s.example.com",
|
|
"https://k8s.example.com/apis/foo/v1/bar/a",
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
c := &client{baseURL: test.baseURL}
|
|
got := c.urlFor(test.apiVersion, test.namespace, test.resource, test.name)
|
|
if got != test.want {
|
|
t.Errorf("(&client{baseURL:%q}).urlFor(%q, %q, %q, %q): expected %q got %q",
|
|
test.baseURL,
|
|
test.apiVersion, test.namespace, test.resource, test.name,
|
|
test.want, got,
|
|
)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestUpdateKeys(t *testing.T) {
|
|
fakeUpdater := func(old storage.Keys) (storage.Keys, error) { return storage.Keys{}, nil }
|
|
|
|
tests := []struct {
|
|
name string
|
|
updater func(old storage.Keys) (storage.Keys, error)
|
|
getResponseCode int
|
|
actionResponseCode int
|
|
wantErr bool
|
|
exactErr error
|
|
}{
|
|
{
|
|
"Create OK test",
|
|
fakeUpdater,
|
|
404,
|
|
201,
|
|
false,
|
|
nil,
|
|
},
|
|
{
|
|
"Update should be OK",
|
|
fakeUpdater,
|
|
200,
|
|
200,
|
|
false,
|
|
nil,
|
|
},
|
|
{
|
|
"Create conflict should be OK",
|
|
fakeUpdater,
|
|
404,
|
|
409,
|
|
true,
|
|
errors.New("keys already created by another server instance"),
|
|
},
|
|
{
|
|
"Update conflict should be OK",
|
|
fakeUpdater,
|
|
200,
|
|
409,
|
|
true,
|
|
errors.New("keys already rotated by another server instance"),
|
|
},
|
|
{
|
|
"Client error is error",
|
|
fakeUpdater,
|
|
404,
|
|
500,
|
|
true,
|
|
nil,
|
|
},
|
|
{
|
|
"Client error during update is error",
|
|
fakeUpdater,
|
|
200,
|
|
500,
|
|
true,
|
|
nil,
|
|
},
|
|
{
|
|
"Get error is error",
|
|
fakeUpdater,
|
|
500,
|
|
200,
|
|
true,
|
|
nil,
|
|
},
|
|
{
|
|
"Updater error is error",
|
|
func(old storage.Keys) (storage.Keys, error) { return storage.Keys{}, fmt.Errorf("test") },
|
|
200,
|
|
201,
|
|
true,
|
|
nil,
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
client := newStatusCodesResponseTestClient(test.getResponseCode, test.actionResponseCode)
|
|
|
|
err := client.UpdateKeys(test.updater)
|
|
if err != nil {
|
|
if !test.wantErr {
|
|
t.Fatalf("Test %q: %v", test.name, err)
|
|
}
|
|
|
|
if test.exactErr != nil && test.exactErr.Error() != err.Error() {
|
|
t.Fatalf("Test %q: %v, wanted: %v", test.name, err, test.exactErr)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func newStatusCodesResponseTestClient(getResponseCode, actionResponseCode int) *client {
|
|
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodGet {
|
|
w.WriteHeader(getResponseCode)
|
|
} else {
|
|
w.WriteHeader(actionResponseCode)
|
|
}
|
|
w.Write([]byte(`{}`)) // Empty json is enough, we will test only response codes here
|
|
}))
|
|
|
|
tr := &http.Transport{
|
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
|
}
|
|
return &client{
|
|
client: &http.Client{Transport: tr},
|
|
baseURL: s.URL,
|
|
logger: &logrus.Logger{
|
|
Out: os.Stderr,
|
|
Formatter: &logrus.TextFormatter{DisableColors: true},
|
|
Level: logrus.DebugLevel,
|
|
},
|
|
}
|
|
}
|
|
|
|
func TestRetryOnConflict(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
action func() error
|
|
exactErr string
|
|
}{
|
|
{
|
|
"Timeout reached",
|
|
func() error { err := httpErr{status: 409}; return error(&err) },
|
|
"maximum timeout reached while retrying a conflicted request: Conflict: response from server \"\"",
|
|
},
|
|
{
|
|
"HTTP Error",
|
|
func() error { err := httpErr{status: 500}; return error(&err) },
|
|
" Internal Server Error: response from server \"\"",
|
|
},
|
|
{
|
|
"Error",
|
|
func() error { return errors.New("test") },
|
|
"test",
|
|
},
|
|
{
|
|
"OK",
|
|
func() error { return nil },
|
|
"",
|
|
},
|
|
}
|
|
|
|
for _, testCase := range tests {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
err := retryOnConflict(context.TODO(), testCase.action)
|
|
if testCase.exactErr != "" {
|
|
require.EqualError(t, err, testCase.exactErr)
|
|
} else {
|
|
require.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|