feat: Update token periodically if Dex is running in Kubernetes cluster

Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
m.nabokikh 2021-05-07 02:10:11 +04:00
parent 823484f024
commit d413870f6e
4 changed files with 197 additions and 43 deletions

View file

@ -303,7 +303,7 @@ func defaultTLSConfig() *tls.Config {
} }
} }
func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, logger log.Logger) (*client, error) { func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, logger log.Logger, inCluster bool) (*client, error) {
tlsConfig := defaultTLSConfig() tlsConfig := defaultTLSConfig()
data := func(b string, file string) ([]byte, error) { data := func(b string, file string) ([]byte, error) {
if b != "" { if b != "" {
@ -359,25 +359,7 @@ func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, l
if err := http2.ConfigureTransport(httpTransport); err != nil { if err := http2.ConfigureTransport(httpTransport); err != nil {
return nil, err return nil, err
} }
t = httpTransport t = wrapRoundTripper(httpTransport, user, inCluster)
if user.Token != "" {
t = transport{
updateReq: func(r *http.Request) {
r.Header.Set("Authorization", "Bearer "+user.Token)
},
base: t,
}
}
if user.Username != "" && user.Password != "" {
t = transport{
updateReq: func(r *http.Request) {
r.SetBasicAuth(user.Username, user.Password)
},
base: t,
}
}
apiVersion := "dex.coreos.com/v1" apiVersion := "dex.coreos.com/v1"
@ -396,24 +378,6 @@ func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, l
}, nil }, nil
} }
type transport struct {
updateReq func(r *http.Request)
base http.RoundTripper
}
func (t transport) RoundTrip(r *http.Request) (*http.Response, error) {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}
t.updateReq(r2)
return t.base.RoundTrip(r2)
}
func loadKubeConfig(kubeConfigPath string) (cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, err error) { func loadKubeConfig(kubeConfigPath string) (cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, err error) {
data, err := ioutil.ReadFile(kubeConfigPath) data, err := ioutil.ReadFile(kubeConfigPath)
if err != nil { if err != nil {

View file

@ -3,12 +3,8 @@ package kubernetes
import ( import (
"hash" "hash"
"hash/fnv" "hash/fnv"
"io/ioutil"
"os"
"sync" "sync"
"testing" "testing"
"github.com/stretchr/testify/require"
) )
// This test does not have an explicit error condition but is used // This test does not have an explicit error condition but is used
@ -46,6 +42,81 @@ func TestOfflineTokenName(t *testing.T) {
} }
} }
func TestInClusterTransport(t *testing.T) {
logger := &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
user := k8sapi.AuthInfo{Token: "abc"}
cli, err := newClient(
k8sapi.Cluster{},
user,
"test",
logger,
true,
)
require.NoError(t, err)
fpath := filepath.Join(os.TempDir(), "test.in_cluster")
defer os.RemoveAll(fpath)
err = ioutil.WriteFile(fpath, []byte("def"), 0644)
require.NoError(t, err)
tests := []struct {
name string
time func() time.Time
expected string
}{
{
name: "Stale token",
time: func() time.Time {
return time.Now().Add(-24 * time.Hour)
},
expected: "def",
},
{
name: "Normal token",
time: func() time.Time {
return time.Time{}
},
expected: "abc",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
helper := newInClusterTransportHelper(user)
helper.now = tc.time
helper.tokenLocation = fpath
cli.client.Transport = transport{
updateReq: func(r *http.Request) {
helper.UpdateToken()
r.Header.Set("Authorization", "Bearer "+helper.GetToken())
},
base: cli.client.Transport,
}
_ = cli.isCRDReady("test")
require.Equal(t, tc.expected, helper.info.Token)
})
}
}
func TestNamespaceFromServiceAccountJWT(t *testing.T) {
namespace, err := namespaceFromServiceAccountJWT(serviceAccountToken)
if err != nil {
t.Fatal(err)
}
wantNamespace := "dex-test-namespace"
if namespace != wantNamespace {
t.Errorf("expected namespace %q got %q", wantNamespace, namespace)
}
}
func TestGetClusterConfigNamespace(t *testing.T) { func TestGetClusterConfigNamespace(t *testing.T) {
const namespaceENVVariableName = "TEST_GET_CLUSTER_CONFIG_NAMESPACE" const namespaceENVVariableName = "TEST_GET_CLUSTER_CONFIG_NAMESPACE"
{ {

View file

@ -83,7 +83,7 @@ func (c *Config) open(logger log.Logger, waitForResources bool) (*client, error)
return nil, err return nil, err
} }
cli, err := newClient(cluster, user, namespace, logger) cli, err := newClient(cluster, user, namespace, logger, c.InCluster)
if err != nil { if err != nil {
return nil, fmt.Errorf("create client: %v", err) return nil, fmt.Errorf("create client: %v", err)
} }

View file

@ -0,0 +1,119 @@
package kubernetes
import (
"io/ioutil"
"net/http"
"sync"
"time"
"github.com/dexidp/dex/storage/kubernetes/k8sapi"
)
// transport is a simple http.Transport wrapper
type transport struct {
updateReq func(r *http.Request)
base http.RoundTripper
}
func (t transport) RoundTrip(r *http.Request) (*http.Response, error) {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}
t.updateReq(r2)
return t.base.RoundTrip(r2)
}
func wrapRoundTripper(base http.RoundTripper, user k8sapi.AuthInfo, inCluster bool) http.RoundTripper {
if inCluster {
inClusterTransportHelper := newInClusterTransportHelper(user)
return transport{
updateReq: func(r *http.Request) {
inClusterTransportHelper.UpdateToken()
r.Header.Set("Authorization", "Bearer "+inClusterTransportHelper.GetToken())
},
base: base,
}
}
if user.Token != "" {
return transport{
updateReq: func(r *http.Request) {
r.Header.Set("Authorization", "Bearer "+user.Token)
},
base: base,
}
}
if user.Username != "" && user.Password != "" {
return transport{
updateReq: func(r *http.Request) {
r.SetBasicAuth(user.Username, user.Password)
},
base: base,
}
}
return base
}
// renewTokenPeriod is the interval after which dex will read the token from a well-known file.
// By Kubernetes documentation, this interval should be at least one minute long.
// Kubernetes client-go v0.15+ uses 10 seconds long interval.
// Dex uses the reasonable value between these two.
const renewTokenPeriod = 30 * time.Second
// inClusterTransportHelper is capable of safely updating the user token.
// BoundServiceAccountTokenVolume feature is enabled in Kubernetes >=1.21 by default.
// With this feature, the service account token in the pod becomes periodically updated.
// Therefore, Dex needs to re-read the token from the disk after some time to be sure that it uses the valid token.
type inClusterTransportHelper struct {
mu sync.RWMutex
info k8sapi.AuthInfo
expiry time.Time
now func() time.Time
tokenLocation string
}
func newInClusterTransportHelper(info k8sapi.AuthInfo) *inClusterTransportHelper {
user := inClusterTransportHelper{
info: info,
now: time.Now,
tokenLocation: "/var/run/secrets/kubernetes.io/serviceaccount/token",
}
user.UpdateToken()
return &user
}
func (c *inClusterTransportHelper) UpdateToken() {
c.mu.RLock()
exp := c.expiry
c.mu.RUnlock()
if !c.now().After(exp) {
// Do not need to update token yet
return
}
token, err := ioutil.ReadFile(c.tokenLocation)
if err != nil {
return
}
c.mu.Lock()
defer c.mu.Unlock()
c.info.Token = string(token)
c.expiry = c.now().Add(renewTokenPeriod)
}
func (c *inClusterTransportHelper) GetToken() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.info.Token
}