forked from mystiq/dex
3a3a2bcc86
When constructing the host address string, the address is not wrapped in square brackets. This does not work in IPv6 Kubernetes deployments. This commit adds square brackets around the address. IPv4 was also tested to ensure it works with wrapped address. Signed-off-by: Jerry Sun <jerry.sun@windriver.com>
486 lines
13 KiB
Go
486 lines
13 KiB
Go
package kubernetes
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/base32"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"hash"
|
|
"hash/fnv"
|
|
"io"
|
|
"io/ioutil"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/ghodss/yaml"
|
|
"github.com/gtank/cryptopasta"
|
|
"golang.org/x/net/http2"
|
|
|
|
"github.com/dexidp/dex/pkg/log"
|
|
"github.com/dexidp/dex/storage"
|
|
"github.com/dexidp/dex/storage/kubernetes/k8sapi"
|
|
)
|
|
|
|
type client struct {
|
|
client *http.Client
|
|
baseURL string
|
|
namespace string
|
|
logger log.Logger
|
|
|
|
// Hash function to map IDs (which could span a large range) to Kubernetes names.
|
|
// While this is not currently upgradable, it could be in the future.
|
|
//
|
|
// The default hash is a non-cryptographic hash, because cryptographic hashes
|
|
// always produce sums too long to fit into a Kubernetes name. Because of this,
|
|
// gets, updates, and deletes are _always_ checked for collisions.
|
|
hash func() hash.Hash
|
|
|
|
// API version of the oidc resources. For example "oidc.coreos.com". This is
|
|
// currently not configurable, but could be in the future.
|
|
apiVersion string
|
|
|
|
// This is called once the client's Close method is called to signal goroutines,
|
|
// such as the one creating third party resources, to stop.
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// idToName maps an arbitrary ID, such as an email or client ID to a Kubernetes object name.
|
|
func (cli *client) idToName(s string) string {
|
|
return idToName(s, cli.hash)
|
|
}
|
|
|
|
// offlineTokenName maps two arbitrary IDs, to a single Kubernetes object name.
|
|
// This is used when more than one field is used to uniquely identify the object.
|
|
func (cli *client) offlineTokenName(userID string, connID string) string {
|
|
return offlineTokenName(userID, connID, cli.hash)
|
|
}
|
|
|
|
// Kubernetes names must match the regexp '[a-z0-9]([-a-z0-9]*[a-z0-9])?'.
|
|
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
|
|
|
|
func idToName(s string, h func() hash.Hash) string {
|
|
return strings.TrimRight(encoding.EncodeToString(h().Sum([]byte(s))), "=")
|
|
}
|
|
|
|
func offlineTokenName(userID string, connID string, h func() hash.Hash) string {
|
|
hash := h()
|
|
hash.Write([]byte(userID))
|
|
hash.Write([]byte(connID))
|
|
return strings.TrimRight(encoding.EncodeToString(hash.Sum(nil)), "=")
|
|
}
|
|
|
|
func (cli *client) urlFor(apiVersion, namespace, resource, name string) string {
|
|
basePath := "apis/"
|
|
if apiVersion == "v1" {
|
|
basePath = "api/"
|
|
}
|
|
|
|
var p string
|
|
if namespace != "" {
|
|
p = path.Join(basePath, apiVersion, "namespaces", namespace, resource, name)
|
|
} else {
|
|
p = path.Join(basePath, apiVersion, resource, name)
|
|
}
|
|
if strings.HasSuffix(cli.baseURL, "/") {
|
|
return cli.baseURL + p
|
|
}
|
|
return cli.baseURL + "/" + p
|
|
}
|
|
|
|
// Define an error interface so we can get at the underlying status code if it's
|
|
// absolutely necessary. For instance when we need to see if an error indicates
|
|
// a resource already exists.
|
|
type httpError interface {
|
|
StatusCode() int
|
|
}
|
|
|
|
var _ httpError = (*httpErr)(nil)
|
|
|
|
type httpErr struct {
|
|
method string
|
|
url string
|
|
status int
|
|
body []byte
|
|
}
|
|
|
|
func (e *httpErr) StatusCode() int {
|
|
return e.status
|
|
}
|
|
|
|
func (e *httpErr) Error() string {
|
|
return fmt.Sprintf("%s %s %s: response from server \"%s\"", e.method, e.url, http.StatusText(e.status), bytes.TrimSpace(e.body))
|
|
}
|
|
|
|
func checkHTTPErr(r *http.Response, validStatusCodes ...int) error {
|
|
for _, status := range validStatusCodes {
|
|
if r.StatusCode == status {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 2<<15)) // 64 KiB
|
|
if err != nil {
|
|
return fmt.Errorf("read response body: %v", err)
|
|
}
|
|
|
|
// Check this case after we read the body so the connection can be reused.
|
|
if r.StatusCode == http.StatusNotFound {
|
|
return storage.ErrNotFound
|
|
}
|
|
if r.Request.Method == http.MethodPost && r.StatusCode == http.StatusConflict {
|
|
return storage.ErrAlreadyExists
|
|
}
|
|
|
|
var url, method string
|
|
if r.Request != nil {
|
|
method = r.Request.Method
|
|
url = r.Request.URL.String()
|
|
}
|
|
return &httpErr{method, url, r.StatusCode, body}
|
|
}
|
|
|
|
// Close the response body. The initial request is drained so the connection can
|
|
// be reused.
|
|
func closeResp(r *http.Response) {
|
|
io.Copy(ioutil.Discard, r.Body)
|
|
r.Body.Close()
|
|
}
|
|
|
|
func (cli *client) get(resource, name string, v interface{}) error {
|
|
return cli.getResource(cli.apiVersion, cli.namespace, resource, name, v)
|
|
}
|
|
|
|
func (cli *client) getResource(apiVersion, namespace, resource, name string, v interface{}) error {
|
|
url := cli.urlFor(apiVersion, namespace, resource, name)
|
|
resp, err := cli.client.Get(url)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer closeResp(resp)
|
|
if err := checkHTTPErr(resp, http.StatusOK); err != nil {
|
|
return err
|
|
}
|
|
return json.NewDecoder(resp.Body).Decode(v)
|
|
}
|
|
|
|
func (cli *client) list(resource string, v interface{}) error {
|
|
return cli.get(resource, "", v)
|
|
}
|
|
|
|
func (cli *client) post(resource string, v interface{}) error {
|
|
return cli.postResource(cli.apiVersion, cli.namespace, resource, v)
|
|
}
|
|
|
|
func (cli *client) postResource(apiVersion, namespace, resource string, v interface{}) error {
|
|
body, err := json.Marshal(v)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal object: %v", err)
|
|
}
|
|
|
|
url := cli.urlFor(apiVersion, namespace, resource, "")
|
|
resp, err := cli.client.Post(url, "application/json", bytes.NewReader(body))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer closeResp(resp)
|
|
return checkHTTPErr(resp, http.StatusCreated)
|
|
}
|
|
|
|
func (cli *client) delete(resource, name string) error {
|
|
url := cli.urlFor(cli.apiVersion, cli.namespace, resource, name)
|
|
req, err := http.NewRequest("DELETE", url, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("create delete request: %v", err)
|
|
}
|
|
resp, err := cli.client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("delete request: %v", err)
|
|
}
|
|
defer closeResp(resp)
|
|
return checkHTTPErr(resp, http.StatusOK)
|
|
}
|
|
|
|
func (cli *client) deleteAll(resource string) error {
|
|
var list struct {
|
|
k8sapi.TypeMeta `json:",inline"`
|
|
k8sapi.ListMeta `json:"metadata,omitempty"`
|
|
Items []struct {
|
|
k8sapi.TypeMeta `json:",inline"`
|
|
k8sapi.ObjectMeta `json:"metadata,omitempty"`
|
|
} `json:"items"`
|
|
}
|
|
if err := cli.list(resource, &list); err != nil {
|
|
return err
|
|
}
|
|
for _, item := range list.Items {
|
|
if err := cli.delete(resource, item.Name); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (cli *client) put(resource, name string, v interface{}) error {
|
|
body, err := json.Marshal(v)
|
|
if err != nil {
|
|
return fmt.Errorf("marshal object: %v", err)
|
|
}
|
|
|
|
url := cli.urlFor(cli.apiVersion, cli.namespace, resource, name)
|
|
req, err := http.NewRequest("PUT", url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return fmt.Errorf("create patch request: %v", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Length", strconv.Itoa(len(body)))
|
|
|
|
resp, err := cli.client.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("patch request: %v", err)
|
|
}
|
|
defer closeResp(resp)
|
|
|
|
return checkHTTPErr(resp, http.StatusOK)
|
|
}
|
|
|
|
func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, logger log.Logger) (*client, error) {
|
|
tlsConfig := cryptopasta.DefaultTLSConfig()
|
|
data := func(b string, file string) ([]byte, error) {
|
|
if b != "" {
|
|
return base64.StdEncoding.DecodeString(b)
|
|
}
|
|
if file == "" {
|
|
return nil, nil
|
|
}
|
|
return ioutil.ReadFile(file)
|
|
}
|
|
|
|
if caData, err := data(cluster.CertificateAuthorityData, cluster.CertificateAuthority); err != nil {
|
|
return nil, err
|
|
} else if caData != nil {
|
|
tlsConfig.RootCAs = x509.NewCertPool()
|
|
if !tlsConfig.RootCAs.AppendCertsFromPEM(caData) {
|
|
return nil, fmt.Errorf("no certificate data found: %v", err)
|
|
}
|
|
}
|
|
|
|
clientCert, err := data(user.ClientCertificateData, user.ClientCertificate)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
clientKey, err := data(user.ClientKeyData, user.ClientKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if clientCert != nil && clientKey != nil {
|
|
cert, err := tls.X509KeyPair(clientCert, clientKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load client cert: %v", err)
|
|
}
|
|
tlsConfig.Certificates = []tls.Certificate{cert}
|
|
}
|
|
|
|
var t http.RoundTripper
|
|
httpTransport := &http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
Dial: (&net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}).Dial,
|
|
TLSClientConfig: tlsConfig,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
}
|
|
|
|
// Since we set a custom TLS client config we have to explicitly
|
|
// enable HTTP/2.
|
|
//
|
|
// https://github.com/golang/go/blob/go1.7.4/src/net/http/transport.go#L200-L206
|
|
if err := http2.ConfigureTransport(httpTransport); err != nil {
|
|
return nil, err
|
|
}
|
|
t = httpTransport
|
|
|
|
if user.Token != "" {
|
|
t = transport{
|
|
updateReq: func(r *http.Request) {
|
|
r.Header.Set("Authorization", "Bearer "+user.Token)
|
|
},
|
|
base: t,
|
|
}
|
|
}
|
|
|
|
if user.Username != "" && user.Password != "" {
|
|
t = transport{
|
|
updateReq: func(r *http.Request) {
|
|
r.SetBasicAuth(user.Username, user.Password)
|
|
},
|
|
base: t,
|
|
}
|
|
}
|
|
|
|
apiVersion := "dex.coreos.com/v1"
|
|
|
|
logger.Infof("kubernetes client apiVersion = %s", apiVersion)
|
|
return &client{
|
|
client: &http.Client{
|
|
Transport: t,
|
|
Timeout: 15 * time.Second,
|
|
},
|
|
baseURL: cluster.Server,
|
|
hash: func() hash.Hash { return fnv.New64() },
|
|
namespace: namespace,
|
|
apiVersion: apiVersion,
|
|
logger: logger,
|
|
}, nil
|
|
}
|
|
|
|
type transport struct {
|
|
updateReq func(r *http.Request)
|
|
base http.RoundTripper
|
|
}
|
|
|
|
func (t transport) RoundTrip(r *http.Request) (*http.Response, error) {
|
|
// shallow copy of the struct
|
|
r2 := new(http.Request)
|
|
*r2 = *r
|
|
// deep copy of the Header
|
|
r2.Header = make(http.Header, len(r.Header))
|
|
for k, s := range r.Header {
|
|
r2.Header[k] = append([]string(nil), s...)
|
|
}
|
|
t.updateReq(r2)
|
|
return t.base.RoundTrip(r2)
|
|
}
|
|
|
|
func loadKubeConfig(kubeConfigPath string) (cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, err error) {
|
|
data, err := ioutil.ReadFile(kubeConfigPath)
|
|
if err != nil {
|
|
err = fmt.Errorf("read %s: %v", kubeConfigPath, err)
|
|
return
|
|
}
|
|
|
|
var c k8sapi.Config
|
|
if err = yaml.Unmarshal(data, &c); err != nil {
|
|
err = fmt.Errorf("unmarshal %s: %v", kubeConfigPath, err)
|
|
return
|
|
}
|
|
|
|
cluster, user, namespace, err = currentContext(&c)
|
|
if namespace == "" {
|
|
namespace = "default"
|
|
}
|
|
return
|
|
}
|
|
|
|
func namespaceFromServiceAccountJWT(s string) (string, error) {
|
|
// The service account token is just a JWT. Parse it as such.
|
|
parts := strings.Split(s, ".")
|
|
if len(parts) < 2 {
|
|
// It's extremely important we don't log the actual service account token.
|
|
return "", fmt.Errorf("malformed service account token: expected 3 parts got %d", len(parts))
|
|
}
|
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
return "", fmt.Errorf("malformed service account token: %v", err)
|
|
}
|
|
var data struct {
|
|
// The claim Kubernetes uses to identify which namespace a service account belongs to.
|
|
//
|
|
// See: https://github.com/kubernetes/kubernetes/blob/v1.4.3/pkg/serviceaccount/jwt.go#L42
|
|
Namespace string `json:"kubernetes.io/serviceaccount/namespace"`
|
|
}
|
|
if err := json.Unmarshal(payload, &data); err != nil {
|
|
return "", fmt.Errorf("malformed service account token: %v", err)
|
|
}
|
|
if data.Namespace == "" {
|
|
return "", errors.New(`jwt claim "kubernetes.io/serviceaccount/namespace" not found`)
|
|
}
|
|
return data.Namespace, nil
|
|
}
|
|
|
|
func inClusterConfig() (cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, err error) {
|
|
host, port := os.Getenv("KUBERNETES_SERVICE_HOST"), os.Getenv("KUBERNETES_SERVICE_PORT")
|
|
if len(host) == 0 || len(port) == 0 {
|
|
err = fmt.Errorf("unable to load in-cluster configuration, KUBERNETES_SERVICE_HOST and KUBERNETES_SERVICE_PORT must be defined")
|
|
return
|
|
}
|
|
// we need to wrap IPv6 addresses in square brackets
|
|
// IPv4 also works with square brackets
|
|
host = "[" + host + "]"
|
|
cluster = k8sapi.Cluster{
|
|
Server: "https://" + host + ":" + port,
|
|
CertificateAuthority: "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt",
|
|
}
|
|
token, err := ioutil.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/token")
|
|
if err != nil {
|
|
return
|
|
}
|
|
user = k8sapi.AuthInfo{Token: string(token)}
|
|
|
|
if namespace = os.Getenv("KUBERNETES_POD_NAMESPACE"); namespace == "" {
|
|
namespace, err = namespaceFromServiceAccountJWT(user.Token)
|
|
if err != nil {
|
|
err = fmt.Errorf("failed to inspect service account token: %v", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func currentContext(config *k8sapi.Config) (cluster k8sapi.Cluster, user k8sapi.AuthInfo, ns string, err error) {
|
|
if config.CurrentContext == "" {
|
|
if len(config.Contexts) == 1 {
|
|
config.CurrentContext = config.Contexts[0].Name
|
|
} else {
|
|
return cluster, user, "", errors.New("kubeconfig has no current context")
|
|
}
|
|
}
|
|
context, ok := func() (k8sapi.Context, bool) {
|
|
for _, namedContext := range config.Contexts {
|
|
if namedContext.Name == config.CurrentContext {
|
|
return namedContext.Context, true
|
|
}
|
|
}
|
|
return k8sapi.Context{}, false
|
|
}()
|
|
if !ok {
|
|
return cluster, user, "", fmt.Errorf("no context named %q found", config.CurrentContext)
|
|
}
|
|
|
|
cluster, ok = func() (k8sapi.Cluster, bool) {
|
|
for _, namedCluster := range config.Clusters {
|
|
if namedCluster.Name == context.Cluster {
|
|
return namedCluster.Cluster, true
|
|
}
|
|
}
|
|
return k8sapi.Cluster{}, false
|
|
}()
|
|
if !ok {
|
|
return cluster, user, "", fmt.Errorf("no cluster named %q found", context.Cluster)
|
|
}
|
|
|
|
user, ok = func() (k8sapi.AuthInfo, bool) {
|
|
for _, namedAuthInfo := range config.AuthInfos {
|
|
if namedAuthInfo.Name == context.AuthInfo {
|
|
return namedAuthInfo.AuthInfo, true
|
|
}
|
|
}
|
|
return k8sapi.AuthInfo{}, false
|
|
}()
|
|
if !ok {
|
|
return cluster, user, "", fmt.Errorf("no user named %q found", context.AuthInfo)
|
|
}
|
|
return cluster, user, context.Namespace, nil
|
|
}
|