forked from mystiq/dex
When constructing the host address string, the address is not wrapped in square brackets. This does not work in IPv6 Kubernetes deployments. This commit adds square brackets around the address. IPv4 was also tested to ensure it works with wrapped address. Signed-off-by: Jerry Sun <>
486 lines
13 KiB
486 lines
13 KiB
package kubernetes
import (
type client struct {
client *http.Client
baseURL string
namespace string
logger log.Logger
// Hash function to map IDs (which could span a large range) to Kubernetes names.
// While this is not currently upgradable, it could be in the future.
// The default hash is a non-cryptographic hash, because cryptographic hashes
// always produce sums too long to fit into a Kubernetes name. Because of this,
// gets, updates, and deletes are _always_ checked for collisions.
hash func() hash.Hash
// API version of the oidc resources. For example "". This is
// currently not configurable, but could be in the future.
apiVersion string
// This is called once the client's Close method is called to signal goroutines,
// such as the one creating third party resources, to stop.
cancel context.CancelFunc
// idToName maps an arbitrary ID, such as an email or client ID to a Kubernetes object name.
func (cli *client) idToName(s string) string {
return idToName(s, cli.hash)
// offlineTokenName maps two arbitrary IDs, to a single Kubernetes object name.
// This is used when more than one field is used to uniquely identify the object.
func (cli *client) offlineTokenName(userID string, connID string) string {
return offlineTokenName(userID, connID, cli.hash)
// Kubernetes names must match the regexp '[a-z0-9]([-a-z0-9]*[a-z0-9])?'.
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
func idToName(s string, h func() hash.Hash) string {
return strings.TrimRight(encoding.EncodeToString(h().Sum([]byte(s))), "=")
func offlineTokenName(userID string, connID string, h func() hash.Hash) string {
hash := h()
return strings.TrimRight(encoding.EncodeToString(hash.Sum(nil)), "=")
func (cli *client) urlFor(apiVersion, namespace, resource, name string) string {
basePath := "apis/"
if apiVersion == "v1" {
basePath = "api/"
var p string
if namespace != "" {
p = path.Join(basePath, apiVersion, "namespaces", namespace, resource, name)
} else {
p = path.Join(basePath, apiVersion, resource, name)
if strings.HasSuffix(cli.baseURL, "/") {
return cli.baseURL + p
return cli.baseURL + "/" + p
// Define an error interface so we can get at the underlying status code if it's
// absolutely necessary. For instance when we need to see if an error indicates
// a resource already exists.
type httpError interface {
StatusCode() int
var _ httpError = (*httpErr)(nil)
type httpErr struct {
method string
url string
status int
body []byte
func (e *httpErr) StatusCode() int {
return e.status
func (e *httpErr) Error() string {
return fmt.Sprintf("%s %s %s: response from server \"%s\"", e.method, e.url, http.StatusText(e.status), bytes.TrimSpace(e.body))
func checkHTTPErr(r *http.Response, validStatusCodes error {
for _, status := range validStatusCodes {
if r.StatusCode == status {
return nil
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 2<<15)) // 64 KiB
if err != nil {
return fmt.Errorf("read response body: %v", err)
// Check this case after we read the body so the connection can be reused.
if r.StatusCode == http.StatusNotFound {
return storage.ErrNotFound
if r.Request.Method == http.MethodPost && r.StatusCode == http.StatusConflict {
return storage.ErrAlreadyExists
var url, method string
if r.Request != nil {
method = r.Request.Method
url = r.Request.URL.String()
return &httpErr{method, url, r.StatusCode, body}
// Close the response body. The initial request is drained so the connection can
// be reused.
func closeResp(r *http.Response) {
io.Copy(ioutil.Discard, r.Body)
func (cli *client) get(resource, name string, v interface{}) error {
return cli.getResource(cli.apiVersion, cli.namespace, resource, name, v)
func (cli *client) getResource(apiVersion, namespace, resource, name string, v interface{}) error {
url := cli.urlFor(apiVersion, namespace, resource, name)
resp, err := cli.client.Get(url)
if err != nil {
return err
defer closeResp(resp)
if err := checkHTTPErr(resp, http.StatusOK); err != nil {
return err
return json.NewDecoder(resp.Body).Decode(v)
func (cli *client) list(resource string, v interface{}) error {
return cli.get(resource, "", v)
func (cli *client) post(resource string, v interface{}) error {
return cli.postResource(cli.apiVersion, cli.namespace, resource, v)
func (cli *client) postResource(apiVersion, namespace, resource string, v interface{}) error {
body, err := json.Marshal(v)
if err != nil {
return fmt.Errorf("marshal object: %v", err)
url := cli.urlFor(apiVersion, namespace, resource, "")
resp, err := cli.client.Post(url, "application/json", bytes.NewReader(body))
if err != nil {
return err
defer closeResp(resp)
return checkHTTPErr(resp, http.StatusCreated)
func (cli *client) delete(resource, name string) error {
url := cli.urlFor(cli.apiVersion, cli.namespace, resource, name)
req, err := http.NewRequest("DELETE", url, nil)
if err != nil {
return fmt.Errorf("create delete request: %v", err)
resp, err := cli.client.Do(req)
if err != nil {
return fmt.Errorf("delete request: %v", err)
defer closeResp(resp)
return checkHTTPErr(resp, http.StatusOK)
func (cli *client) deleteAll(resource string) error {
var list struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ListMeta `json:"metadata,omitempty"`
Items []struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
} `json:"items"`
if err := cli.list(resource, &list); err != nil {
return err
for _, item := range list.Items {
if err := cli.delete(resource, item.Name); err != nil {
return err
return nil
func (cli *client) put(resource, name string, v interface{}) error {
body, err := json.Marshal(v)
if err != nil {
return fmt.Errorf("marshal object: %v", err)
url := cli.urlFor(cli.apiVersion, cli.namespace, resource, name)
req, err := http.NewRequest("PUT", url, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("create patch request: %v", err)
req.Header.Set("Content-Length", strconv.Itoa(len(body)))
resp, err := cli.client.Do(req)
if err != nil {
return fmt.Errorf("patch request: %v", err)
defer closeResp(resp)
return checkHTTPErr(resp, http.StatusOK)
func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, logger log.Logger) (*client, error) {
tlsConfig := cryptopasta.DefaultTLSConfig()
data := func(b string, file string) ([]byte, error) {
if b != "" {
return base64.StdEncoding.DecodeString(b)
if file == "" {
return nil, nil
return ioutil.ReadFile(file)
if caData, err := data(cluster.CertificateAuthorityData, cluster.CertificateAuthority); err != nil {
return nil, err
} else if caData != nil {
tlsConfig.RootCAs = x509.NewCertPool()
if !tlsConfig.RootCAs.AppendCertsFromPEM(caData) {
return nil, fmt.Errorf("no certificate data found: %v", err)
clientCert, err := data(user.ClientCertificateData, user.ClientCertificate)
if err != nil {
return nil, err
clientKey, err := data(user.ClientKeyData, user.ClientKey)
if err != nil {
return nil, err
if clientCert != nil && clientKey != nil {
cert, err := tls.X509KeyPair(clientCert, clientKey)
if err != nil {
return nil, fmt.Errorf("failed to load client cert: %v", err)
tlsConfig.Certificates = []tls.Certificate{cert}
var t http.RoundTripper
httpTransport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
TLSClientConfig: tlsConfig,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
// Since we set a custom TLS client config we have to explicitly
// enable HTTP/2.
if err := http2.ConfigureTransport(httpTransport); err != nil {
return nil, err
t = httpTransport
if user.Token != "" {
t = transport{
updateReq: func(r *http.Request) {
r.Header.Set("Authorization", "Bearer "+user.Token)
base: t,
if user.Username != "" && user.Password != "" {
t = transport{
updateReq: func(r *http.Request) {
r.SetBasicAuth(user.Username, user.Password)
base: t,
apiVersion := ""
logger.Infof("kubernetes client apiVersion = %s", apiVersion)
return &client{
client: &http.Client{
Transport: t,
Timeout: 15 * time.Second,
baseURL: cluster.Server,
hash: func() hash.Hash { return fnv.New64() },
namespace: namespace,
apiVersion: apiVersion,
logger: logger,
}, nil
type transport struct {
updateReq func(r *http.Request)
base http.RoundTripper
func (t transport) RoundTrip(r *http.Request) (*http.Response, error) {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
return t.base.RoundTrip(r2)
func loadKubeConfig(kubeConfigPath string) (cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, err error) {
data, err := ioutil.ReadFile(kubeConfigPath)
if err != nil {
err = fmt.Errorf("read %s: %v", kubeConfigPath, err)
var c k8sapi.Config
if err = yaml.Unmarshal(data, &c); err != nil {
err = fmt.Errorf("unmarshal %s: %v", kubeConfigPath, err)
cluster, user, namespace, err = currentContext(&c)
if namespace == "" {
namespace = "default"
func namespaceFromServiceAccountJWT(s string) (string, error) {
// The service account token is just a JWT. Parse it as such.
parts := strings.Split(s, ".")
if len(parts) < 2 {
// It's extremely important we don't log the actual service account token.
return "", fmt.Errorf("malformed service account token: expected 3 parts got %d", len(parts))
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return "", fmt.Errorf("malformed service account token: %v", err)
var data struct {
// The claim Kubernetes uses to identify which namespace a service account belongs to.
// See:
Namespace string `json:""`
if err := json.Unmarshal(payload, &data); err != nil {
return "", fmt.Errorf("malformed service account token: %v", err)
if data.Namespace == "" {
return "", errors.New(`jwt claim "" not found`)
return data.Namespace, nil
func inClusterConfig() (cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, err error) {
host, port := os.Getenv("KUBERNETES_SERVICE_HOST"), os.Getenv("KUBERNETES_SERVICE_PORT")
if len(host) == 0 || len(port) == 0 {
err = fmt.Errorf("unable to load in-cluster configuration, KUBERNETES_SERVICE_HOST and KUBERNETES_SERVICE_PORT must be defined")
// we need to wrap IPv6 addresses in square brackets
// IPv4 also works with square brackets
host = "[" + host + "]"
cluster = k8sapi.Cluster{
Server: "https://" + host + ":" + port,
CertificateAuthority: "/var/run/secrets/",
token, err := ioutil.ReadFile("/var/run/secrets/")
if err != nil {
user = k8sapi.AuthInfo{Token: string(token)}
if namespace = os.Getenv("KUBERNETES_POD_NAMESPACE"); namespace == "" {
namespace, err = namespaceFromServiceAccountJWT(user.Token)
if err != nil {
err = fmt.Errorf("failed to inspect service account token: %v", err)
func currentContext(config *k8sapi.Config) (cluster k8sapi.Cluster, user k8sapi.AuthInfo, ns string, err error) {
if config.CurrentContext == "" {
if len(config.Contexts) == 1 {
config.CurrentContext = config.Contexts[0].Name
} else {
return cluster, user, "", errors.New("kubeconfig has no current context")
context, ok := func() (k8sapi.Context, bool) {
for _, namedContext := range config.Contexts {
if namedContext.Name == config.CurrentContext {
return namedContext.Context, true
return k8sapi.Context{}, false
if !ok {
return cluster, user, "", fmt.Errorf("no context named %q found", config.CurrentContext)
cluster, ok = func() (k8sapi.Cluster, bool) {
for _, namedCluster := range config.Clusters {
if namedCluster.Name == context.Cluster {
return namedCluster.Cluster, true
return k8sapi.Cluster{}, false
if !ok {
return cluster, user, "", fmt.Errorf("no cluster named %q found", context.Cluster)
user, ok = func() (k8sapi.AuthInfo, bool) {
for _, namedAuthInfo := range config.AuthInfos {
if namedAuthInfo.Name == context.AuthInfo {
return namedAuthInfo.AuthInfo, true
return k8sapi.AuthInfo{}, false
if !ok {
return cluster, user, "", fmt.Errorf("no user named %q found", context.AuthInfo)
return cluster, user, context.Namespace, nil