forked from mystiq/dex
feat: Update token periodically if Dex is running in Kubernetes cluster
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
parent
823484f024
commit
d413870f6e
4 changed files with 197 additions and 43 deletions
|
@ -303,7 +303,7 @@ func defaultTLSConfig() *tls.Config {
|
|||
}
|
||||
}
|
||||
|
||||
func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, logger log.Logger) (*client, error) {
|
||||
func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, logger log.Logger, inCluster bool) (*client, error) {
|
||||
tlsConfig := defaultTLSConfig()
|
||||
data := func(b string, file string) ([]byte, error) {
|
||||
if b != "" {
|
||||
|
@ -359,25 +359,7 @@ func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, l
|
|||
if err := http2.ConfigureTransport(httpTransport); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t = httpTransport
|
||||
|
||||
if user.Token != "" {
|
||||
t = transport{
|
||||
updateReq: func(r *http.Request) {
|
||||
r.Header.Set("Authorization", "Bearer "+user.Token)
|
||||
},
|
||||
base: t,
|
||||
}
|
||||
}
|
||||
|
||||
if user.Username != "" && user.Password != "" {
|
||||
t = transport{
|
||||
updateReq: func(r *http.Request) {
|
||||
r.SetBasicAuth(user.Username, user.Password)
|
||||
},
|
||||
base: t,
|
||||
}
|
||||
}
|
||||
t = wrapRoundTripper(httpTransport, user, inCluster)
|
||||
|
||||
apiVersion := "dex.coreos.com/v1"
|
||||
|
||||
|
@ -396,24 +378,6 @@ func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, l
|
|||
}, nil
|
||||
}
|
||||
|
||||
type transport struct {
|
||||
updateReq func(r *http.Request)
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
func (t transport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
// shallow copy of the struct
|
||||
r2 := new(http.Request)
|
||||
*r2 = *r
|
||||
// deep copy of the Header
|
||||
r2.Header = make(http.Header, len(r.Header))
|
||||
for k, s := range r.Header {
|
||||
r2.Header[k] = append([]string(nil), s...)
|
||||
}
|
||||
t.updateReq(r2)
|
||||
return t.base.RoundTrip(r2)
|
||||
}
|
||||
|
||||
func loadKubeConfig(kubeConfigPath string) (cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, err error) {
|
||||
data, err := ioutil.ReadFile(kubeConfigPath)
|
||||
if err != nil {
|
||||
|
|
|
@ -3,12 +3,8 @@ package kubernetes
|
|||
import (
|
||||
"hash"
|
||||
"hash/fnv"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// This test does not have an explicit error condition but is used
|
||||
|
@ -46,6 +42,81 @@ func TestOfflineTokenName(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestInClusterTransport(t *testing.T) {
|
||||
logger := &logrus.Logger{
|
||||
Out: os.Stderr,
|
||||
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||
Level: logrus.DebugLevel,
|
||||
}
|
||||
|
||||
user := k8sapi.AuthInfo{Token: "abc"}
|
||||
cli, err := newClient(
|
||||
k8sapi.Cluster{},
|
||||
user,
|
||||
"test",
|
||||
logger,
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
fpath := filepath.Join(os.TempDir(), "test.in_cluster")
|
||||
defer os.RemoveAll(fpath)
|
||||
|
||||
err = ioutil.WriteFile(fpath, []byte("def"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
time func() time.Time
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Stale token",
|
||||
time: func() time.Time {
|
||||
return time.Now().Add(-24 * time.Hour)
|
||||
},
|
||||
expected: "def",
|
||||
},
|
||||
{
|
||||
name: "Normal token",
|
||||
time: func() time.Time {
|
||||
return time.Time{}
|
||||
},
|
||||
expected: "abc",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
helper := newInClusterTransportHelper(user)
|
||||
helper.now = tc.time
|
||||
helper.tokenLocation = fpath
|
||||
|
||||
cli.client.Transport = transport{
|
||||
updateReq: func(r *http.Request) {
|
||||
helper.UpdateToken()
|
||||
r.Header.Set("Authorization", "Bearer "+helper.GetToken())
|
||||
},
|
||||
base: cli.client.Transport,
|
||||
}
|
||||
|
||||
_ = cli.isCRDReady("test")
|
||||
require.Equal(t, tc.expected, helper.info.Token)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamespaceFromServiceAccountJWT(t *testing.T) {
|
||||
namespace, err := namespaceFromServiceAccountJWT(serviceAccountToken)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wantNamespace := "dex-test-namespace"
|
||||
if namespace != wantNamespace {
|
||||
t.Errorf("expected namespace %q got %q", wantNamespace, namespace)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClusterConfigNamespace(t *testing.T) {
|
||||
const namespaceENVVariableName = "TEST_GET_CLUSTER_CONFIG_NAMESPACE"
|
||||
{
|
||||
|
|
|
@ -83,7 +83,7 @@ func (c *Config) open(logger log.Logger, waitForResources bool) (*client, error)
|
|||
return nil, err
|
||||
}
|
||||
|
||||
cli, err := newClient(cluster, user, namespace, logger)
|
||||
cli, err := newClient(cluster, user, namespace, logger, c.InCluster)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create client: %v", err)
|
||||
}
|
||||
|
|
119
storage/kubernetes/transport.go
Normal file
119
storage/kubernetes/transport.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
package kubernetes
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage/kubernetes/k8sapi"
|
||||
)
|
||||
|
||||
// transport is a simple http.Transport wrapper
|
||||
type transport struct {
|
||||
updateReq func(r *http.Request)
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
func (t transport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
// shallow copy of the struct
|
||||
r2 := new(http.Request)
|
||||
*r2 = *r
|
||||
// deep copy of the Header
|
||||
r2.Header = make(http.Header, len(r.Header))
|
||||
for k, s := range r.Header {
|
||||
r2.Header[k] = append([]string(nil), s...)
|
||||
}
|
||||
t.updateReq(r2)
|
||||
return t.base.RoundTrip(r2)
|
||||
}
|
||||
|
||||
func wrapRoundTripper(base http.RoundTripper, user k8sapi.AuthInfo, inCluster bool) http.RoundTripper {
|
||||
if inCluster {
|
||||
inClusterTransportHelper := newInClusterTransportHelper(user)
|
||||
return transport{
|
||||
updateReq: func(r *http.Request) {
|
||||
inClusterTransportHelper.UpdateToken()
|
||||
r.Header.Set("Authorization", "Bearer "+inClusterTransportHelper.GetToken())
|
||||
},
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
if user.Token != "" {
|
||||
return transport{
|
||||
updateReq: func(r *http.Request) {
|
||||
r.Header.Set("Authorization", "Bearer "+user.Token)
|
||||
},
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
if user.Username != "" && user.Password != "" {
|
||||
return transport{
|
||||
updateReq: func(r *http.Request) {
|
||||
r.SetBasicAuth(user.Username, user.Password)
|
||||
},
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
// renewTokenPeriod is the interval after which dex will read the token from a well-known file.
|
||||
// By Kubernetes documentation, this interval should be at least one minute long.
|
||||
// Kubernetes client-go v0.15+ uses 10 seconds long interval.
|
||||
// Dex uses the reasonable value between these two.
|
||||
const renewTokenPeriod = 30 * time.Second
|
||||
|
||||
// inClusterTransportHelper is capable of safely updating the user token.
|
||||
// BoundServiceAccountTokenVolume feature is enabled in Kubernetes >=1.21 by default.
|
||||
// With this feature, the service account token in the pod becomes periodically updated.
|
||||
// Therefore, Dex needs to re-read the token from the disk after some time to be sure that it uses the valid token.
|
||||
type inClusterTransportHelper struct {
|
||||
mu sync.RWMutex
|
||||
info k8sapi.AuthInfo
|
||||
|
||||
expiry time.Time
|
||||
now func() time.Time
|
||||
|
||||
tokenLocation string
|
||||
}
|
||||
|
||||
func newInClusterTransportHelper(info k8sapi.AuthInfo) *inClusterTransportHelper {
|
||||
user := inClusterTransportHelper{
|
||||
info: info,
|
||||
now: time.Now,
|
||||
tokenLocation: "/var/run/secrets/kubernetes.io/serviceaccount/token",
|
||||
}
|
||||
user.UpdateToken()
|
||||
return &user
|
||||
}
|
||||
|
||||
func (c *inClusterTransportHelper) UpdateToken() {
|
||||
c.mu.RLock()
|
||||
exp := c.expiry
|
||||
c.mu.RUnlock()
|
||||
|
||||
if !c.now().After(exp) {
|
||||
// Do not need to update token yet
|
||||
return
|
||||
}
|
||||
|
||||
token, err := ioutil.ReadFile(c.tokenLocation)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.info.Token = string(token)
|
||||
c.expiry = c.now().Add(renewTokenPeriod)
|
||||
}
|
||||
|
||||
func (c *inClusterTransportHelper) GetToken() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.info.Token
|
||||
}
|
Loading…
Reference in a new issue