diff --git a/Dockerfile b/Dockerfile index 7a57db3e..7e980b2b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -29,8 +29,8 @@ ARG TARGETVARIANT ENV GOMPLATE_VERSION=v3.9.0 RUN wget -O /usr/local/bin/gomplate \ - "https://github.com/hairyhenderson/gomplate/releases/download/${GOMPLATE_VERSION}/gomplate_${TARGETOS:-linux}-${TARGETARCH:-amd64}${TARGETVARIANT}" \ - && chmod +x /usr/local/bin/gomplate + "https://github.com/hairyhenderson/gomplate/releases/download/${GOMPLATE_VERSION}/gomplate_${TARGETOS:-linux}-${TARGETARCH:-amd64}${TARGETVARIANT}" \ + && chmod +x /usr/local/bin/gomplate FROM alpine:3.15.0 diff --git a/README.md b/README.md index 6ae2a7b0..b69bdc0a 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,7 @@ Dex implements the following connectors: | [SAML 2.0](https://dexidp.io/docs/connectors/saml/) | no | yes | no | stable | WARNING: Unmaintained and likely vulnerable to auth bypasses ([#1884](https://github.com/dexidp/dex/discussions/1884)) | | [GitLab](https://dexidp.io/docs/connectors/gitlab/) | yes | yes | yes | beta | | | [OpenID Connect](https://dexidp.io/docs/connectors/oidc/) | yes | yes | yes | beta | Includes Salesforce, Azure, etc. | +| [OAuth 2.0](https://dexidp.io/docs/connectors/oauth/) | no | yes | yes | alpha | | | [Google](https://dexidp.io/docs/connectors/google/) | yes | yes | yes | alpha | | | [LinkedIn](https://dexidp.io/docs/connectors/linkedin/) | yes | no | no | beta | | | [Microsoft](https://dexidp.io/docs/connectors/microsoft/) | yes | yes | no | beta | | diff --git a/connector/oauth/oauth.go b/connector/oauth/oauth.go new file mode 100644 index 00000000..e37932ad --- /dev/null +++ b/connector/oauth/oauth.go @@ -0,0 +1,292 @@ +package oauth + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "os" + "strings" + "time" + + "golang.org/x/oauth2" + + "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/pkg/log" +) + +type oauthConnector struct { + clientID string + clientSecret string + redirectURI string + tokenURL string + authorizationURL string + userInfoURL string + scopes []string + userIDKey string + userNameKey string + preferredUsernameKey string + emailKey string + emailVerifiedKey string + groupsKey string + httpClient *http.Client + logger log.Logger +} + +type connectorData struct { + AccessToken string +} + +type Config struct { + ClientID string `json:"clientID"` + ClientSecret string `json:"clientSecret"` + RedirectURI string `json:"redirectURI"` + TokenURL string `json:"tokenURL"` + AuthorizationURL string `json:"authorizationURL"` + UserInfoURL string `json:"userInfoURL"` + Scopes []string `json:"scopes"` + RootCAs []string `json:"rootCAs"` + InsecureSkipVerify bool `json:"insecureSkipVerify"` + UserIDKey string `json:"userIDKey"` // defaults to "id" + ClaimMapping struct { + UserNameKey string `json:"userNameKey"` // defaults to "user_name" + PreferredUsernameKey string `json:"preferredUsernameKey"` // defaults to "preferred_username" + GroupsKey string `json:"groupsKey"` // defaults to "groups" + EmailKey string `json:"emailKey"` // defaults to "email" + EmailVerifiedKey string `json:"emailVerifiedKey"` // defaults to "email_verified" + } `json:"claimMapping"` +} + +func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) { + var err error + + userIDKey := c.UserIDKey + if userIDKey == "" { + userIDKey = "id" + } + + userNameKey := c.ClaimMapping.UserNameKey + if userNameKey == "" { + userNameKey = "user_name" + } + + preferredUsernameKey := c.ClaimMapping.PreferredUsernameKey + if preferredUsernameKey == "" { + preferredUsernameKey = "preferred_username" + } + + groupsKey := c.ClaimMapping.GroupsKey + if groupsKey == "" { + groupsKey = "groups" + } + + emailKey := c.ClaimMapping.EmailKey + if emailKey == "" { + emailKey = "email" + } + + emailVerifiedKey := c.ClaimMapping.EmailVerifiedKey + if emailVerifiedKey == "" { + emailVerifiedKey = "email_verified" + } + + oauthConn := &oauthConnector{ + clientID: c.ClientID, + clientSecret: c.ClientSecret, + tokenURL: c.TokenURL, + authorizationURL: c.AuthorizationURL, + userInfoURL: c.UserInfoURL, + scopes: c.Scopes, + redirectURI: c.RedirectURI, + logger: logger, + userIDKey: userIDKey, + userNameKey: userNameKey, + preferredUsernameKey: preferredUsernameKey, + groupsKey: groupsKey, + emailKey: emailKey, + emailVerifiedKey: emailVerifiedKey, + } + + oauthConn.httpClient, err = newHTTPClient(c.RootCAs, c.InsecureSkipVerify) + if err != nil { + return nil, err + } + + return oauthConn, err +} + +func newHTTPClient(rootCAs []string, insecureSkipVerify bool) (*http.Client, error) { + pool, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + + tlsConfig := tls.Config{RootCAs: pool, InsecureSkipVerify: insecureSkipVerify} + for _, rootCA := range rootCAs { + rootCABytes, err := os.ReadFile(rootCA) + if err != nil { + return nil, fmt.Errorf("failed to read root-ca: %v", err) + } + if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) { + return nil, fmt.Errorf("no certs found in root CA file %q", rootCA) + } + } + + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tlsConfig, + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + }, nil +} + +func (c *oauthConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { + if c.redirectURI != callbackURL { + return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + } + + oauth2Config := &oauth2.Config{ + ClientID: c.clientID, + ClientSecret: c.clientSecret, + Endpoint: oauth2.Endpoint{TokenURL: c.tokenURL, AuthURL: c.authorizationURL}, + RedirectURL: c.redirectURI, + Scopes: c.scopes, + } + + return oauth2Config.AuthCodeURL(state), nil +} + +func (c *oauthConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { + q := r.URL.Query() + if errType := q.Get("error"); errType != "" { + return identity, errors.New(q.Get("error_description")) + } + + oauth2Config := &oauth2.Config{ + ClientID: c.clientID, + ClientSecret: c.clientSecret, + Endpoint: oauth2.Endpoint{TokenURL: c.tokenURL, AuthURL: c.authorizationURL}, + RedirectURL: c.redirectURI, + Scopes: c.scopes, + } + + ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient) + + token, err := oauth2Config.Exchange(ctx, q.Get("code")) + if err != nil { + return identity, fmt.Errorf("OAuth connector: failed to get token: %v", err) + } + + client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)) + + userInfoResp, err := client.Get(c.userInfoURL) + if err != nil { + return identity, fmt.Errorf("OAuth Connector: failed to execute request to userinfo: %v", err) + } + defer userInfoResp.Body.Close() + + if userInfoResp.StatusCode != http.StatusOK { + return identity, fmt.Errorf("OAuth Connector: failed to execute request to userinfo: status %d", userInfoResp.StatusCode) + } + + var userInfoResult map[string]interface{} + err = json.NewDecoder(userInfoResp.Body).Decode(&userInfoResult) + if err != nil { + return identity, fmt.Errorf("OAuth Connector: failed to parse userinfo: %v", err) + } + + userID, found := userInfoResult[c.userIDKey].(string) + if !found { + return identity, fmt.Errorf("OAuth Connector: not found %v claim", c.userIDKey) + } + + identity.UserID = userID + identity.Username, _ = userInfoResult[c.userNameKey].(string) + identity.PreferredUsername, _ = userInfoResult[c.preferredUsernameKey].(string) + identity.Email, _ = userInfoResult[c.emailKey].(string) + identity.EmailVerified, _ = userInfoResult[c.emailVerifiedKey].(bool) + + if s.Groups { + groups := map[string]struct{}{} + + c.addGroupsFromMap(groups, userInfoResult) + c.addGroupsFromToken(groups, token.AccessToken) + + for groupName := range groups { + identity.Groups = append(identity.Groups, groupName) + } + } + + if s.OfflineAccess { + data := connectorData{AccessToken: token.AccessToken} + connData, err := json.Marshal(data) + if err != nil { + return identity, fmt.Errorf("OAuth Connector: failed to parse connector data for offline access: %v", err) + } + identity.ConnectorData = connData + } + + return identity, nil +} + +func (c *oauthConnector) addGroupsFromMap(groups map[string]struct{}, result map[string]interface{}) error { + groupsClaim, ok := result[c.groupsKey].([]interface{}) + if !ok { + return errors.New("cannot convert to slice") + } + + for _, group := range groupsClaim { + if groupString, ok := group.(string); ok { + groups[groupString] = struct{}{} + } + if groupMap, ok := group.(map[string]interface{}); ok { + if groupName, ok := groupMap["name"].(string); ok { + groups[groupName] = struct{}{} + } + } + } + + return nil +} + +func (c *oauthConnector) addGroupsFromToken(groups map[string]struct{}, token string) error { + parts := strings.Split(token, ".") + if len(parts) < 2 { + return errors.New("invalid token") + } + + decoded, err := decode(parts[1]) + if err != nil { + return err + } + + var claimsMap map[string]interface{} + err = json.Unmarshal(decoded, &claimsMap) + if err != nil { + return err + } + + return c.addGroupsFromMap(groups, claimsMap) +} + +func decode(seg string) ([]byte, error) { + if l := len(seg) % 4; l > 0 { + seg += strings.Repeat("=", 4-l) + } + + return base64.URLEncoding.DecodeString(seg) +} diff --git a/connector/oauth/oauth_test.go b/connector/oauth/oauth_test.go new file mode 100644 index 00000000..082a3aa5 --- /dev/null +++ b/connector/oauth/oauth_test.go @@ -0,0 +1,271 @@ +package oauth + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "sort" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + jose "gopkg.in/square/go-jose.v2" + + "github.com/dexidp/dex/connector" +) + +func TestOpen(t *testing.T) { + tokenClaims := map[string]interface{}{} + userInfoClaims := map[string]interface{}{} + + testServer := testSetup(t, tokenClaims, userInfoClaims) + defer testServer.Close() + + conn := newConnector(t, testServer.URL) + + sort.Strings(conn.scopes) + + assert.Equal(t, conn.clientID, "testClient") + assert.Equal(t, conn.clientSecret, "testSecret") + assert.Equal(t, conn.redirectURI, testServer.URL+"/callback") + assert.Equal(t, conn.tokenURL, testServer.URL+"/token") + assert.Equal(t, conn.authorizationURL, testServer.URL+"/authorize") + assert.Equal(t, conn.userInfoURL, testServer.URL+"/userinfo") + assert.Equal(t, len(conn.scopes), 2) + assert.Equal(t, conn.scopes[0], "groups") + assert.Equal(t, conn.scopes[1], "openid") +} + +func TestLoginURL(t *testing.T) { + tokenClaims := map[string]interface{}{} + userInfoClaims := map[string]interface{}{} + + testServer := testSetup(t, tokenClaims, userInfoClaims) + defer testServer.Close() + + conn := newConnector(t, testServer.URL) + + loginURL, err := conn.LoginURL(connector.Scopes{}, conn.redirectURI, "some-state") + assert.Equal(t, err, nil) + + expectedURL, err := url.Parse(testServer.URL + "/authorize") + assert.Equal(t, err, nil) + + values := url.Values{} + values.Add("client_id", "testClient") + values.Add("redirect_uri", conn.redirectURI) + values.Add("response_type", "code") + values.Add("scope", "openid groups") + values.Add("state", "some-state") + expectedURL.RawQuery = values.Encode() + + assert.Equal(t, loginURL, expectedURL.String()) +} + +func TestHandleCallBackForGroupsInUserInfo(t *testing.T) { + tokenClaims := map[string]interface{}{} + + userInfoClaims := map[string]interface{}{ + "name": "test-name", + "user_id_key": "test-user-id", + "user_name_key": "test-username", + "preferred_username": "test-preferred-username", + "mail": "mod_mail", + "has_verified_email": false, + "groups_key": []string{"admin-group", "user-group"}, + } + + testServer := testSetup(t, tokenClaims, userInfoClaims) + defer testServer.Close() + + conn := newConnector(t, testServer.URL) + req := newRequestWithAuthCode(t, testServer.URL, "some-code") + + identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req) + assert.Equal(t, err, nil) + + sort.Strings(identity.Groups) + assert.Equal(t, len(identity.Groups), 2) + assert.Equal(t, identity.Groups[0], "admin-group") + assert.Equal(t, identity.Groups[1], "user-group") + assert.Equal(t, identity.UserID, "test-user-id") + assert.Equal(t, identity.Username, "test-username") + assert.Equal(t, identity.PreferredUsername, "test-preferred-username") + assert.Equal(t, identity.Email, "mod_mail") + assert.Equal(t, identity.EmailVerified, false) +} + +func TestHandleCallBackForGroupMapsInUserInfo(t *testing.T) { + tokenClaims := map[string]interface{}{} + + userInfoClaims := map[string]interface{}{ + "name": "test-name", + "user_id_key": "test-user-id", + "user_name_key": "test-username", + "preferred_username": "test-preferred-username", + "mail": "mod_mail", + "has_verified_email": false, + "groups_key": []interface{}{ + map[string]string{"name": "admin-group", "id": "111"}, + map[string]string{"name": "user-group", "id": "222"}, + }, + } + + testServer := testSetup(t, tokenClaims, userInfoClaims) + defer testServer.Close() + + conn := newConnector(t, testServer.URL) + req := newRequestWithAuthCode(t, testServer.URL, "some-code") + + identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req) + assert.Equal(t, err, nil) + + sort.Strings(identity.Groups) + assert.Equal(t, len(identity.Groups), 2) + assert.Equal(t, identity.Groups[0], "admin-group") + assert.Equal(t, identity.Groups[1], "user-group") + assert.Equal(t, identity.UserID, "test-user-id") + assert.Equal(t, identity.Username, "test-username") + assert.Equal(t, identity.PreferredUsername, "test-preferred-username") + assert.Equal(t, identity.Email, "mod_mail") + assert.Equal(t, identity.EmailVerified, false) +} + +func TestHandleCallBackForGroupsInToken(t *testing.T) { + tokenClaims := map[string]interface{}{ + "groups_key": []string{"test-group"}, + } + + userInfoClaims := map[string]interface{}{ + "name": "test-name", + "user_id_key": "test-user-id", + "user_name_key": "test-username", + "preferred_username": "test-preferred-username", + "email": "test-email", + "email_verified": true, + } + + testServer := testSetup(t, tokenClaims, userInfoClaims) + defer testServer.Close() + + conn := newConnector(t, testServer.URL) + req := newRequestWithAuthCode(t, testServer.URL, "some-code") + + identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req) + assert.Equal(t, err, nil) + + assert.Equal(t, len(identity.Groups), 1) + assert.Equal(t, identity.Groups[0], "test-group") + assert.Equal(t, identity.PreferredUsername, "test-preferred-username") + assert.Equal(t, identity.UserID, "test-user-id") + assert.Equal(t, identity.Username, "test-username") + assert.Equal(t, identity.Email, "") + assert.Equal(t, identity.EmailVerified, false) +} + +func testSetup(t *testing.T, tokenClaims map[string]interface{}, userInfoClaims map[string]interface{}) *httptest.Server { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal("Failed to generate rsa key", err) + } + + jwk := jose.JSONWebKey{ + Key: key, + KeyID: "some-key", + Algorithm: "RSA", + } + + mux := http.NewServeMux() + + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + token, err := newToken(&jwk, tokenClaims) + if err != nil { + t.Fatal("unable to generate token", err) + } + + w.Header().Add("Content-Type", "application/json") + json.NewEncoder(w).Encode(&map[string]string{ + "access_token": token, + "id_token": token, + "token_type": "Bearer", + }) + }) + + mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/json") + json.NewEncoder(w).Encode(userInfoClaims) + }) + + return httptest.NewServer(mux) +} + +func newToken(key *jose.JSONWebKey, claims map[string]interface{}) (string, error) { + signingKey := jose.SigningKey{Key: key, Algorithm: jose.RS256} + + signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{}) + if err != nil { + return "", fmt.Errorf("new signer: %v", err) + } + + payload, err := json.Marshal(claims) + if err != nil { + return "", fmt.Errorf("marshaling claims: %v", err) + } + + signature, err := signer.Sign(payload) + if err != nil { + return "", fmt.Errorf("signing payload: %v", err) + } + + return signature.CompactSerialize() +} + +func newConnector(t *testing.T, serverURL string) *oauthConnector { + testConfig := Config{ + ClientID: "testClient", + ClientSecret: "testSecret", + RedirectURI: serverURL + "/callback", + TokenURL: serverURL + "/token", + AuthorizationURL: serverURL + "/authorize", + UserInfoURL: serverURL + "/userinfo", + Scopes: []string{"openid", "groups"}, + UserIDKey: "user_id_key", + } + + testConfig.ClaimMapping.UserNameKey = "user_name_key" + testConfig.ClaimMapping.GroupsKey = "groups_key" + testConfig.ClaimMapping.EmailKey = "mail" + testConfig.ClaimMapping.EmailVerifiedKey = "has_verified_email" + + log := logrus.New() + + conn, err := testConfig.Open("id", log) + if err != nil { + t.Fatal(err) + } + + oauthConn, ok := conn.(*oauthConnector) + if !ok { + t.Fatal(errors.New("failed to convert to oauthConnector")) + } + + return oauthConn +} + +func newRequestWithAuthCode(t *testing.T, serverURL string, code string) *http.Request { + req, err := http.NewRequest("GET", serverURL, nil) + if err != nil { + t.Fatal("failed to create request", err) + } + + values := req.URL.Query() + values.Add("code", code) + req.URL.RawQuery = values.Encode() + + return req +} diff --git a/server/server.go b/server/server.go index ecd6c935..6b653fdb 100755 --- a/server/server.go +++ b/server/server.go @@ -38,6 +38,7 @@ import ( "github.com/dexidp/dex/connector/linkedin" "github.com/dexidp/dex/connector/microsoft" "github.com/dexidp/dex/connector/mock" + "github.com/dexidp/dex/connector/oauth" "github.com/dexidp/dex/connector/oidc" "github.com/dexidp/dex/connector/openshift" "github.com/dexidp/dex/connector/saml" @@ -538,6 +539,7 @@ var ConnectorsConfig = map[string]func() ConnectorConfig{ "gitlab": func() ConnectorConfig { return new(gitlab.Config) }, "google": func() ConnectorConfig { return new(google.Config) }, "oidc": func() ConnectorConfig { return new(oidc.Config) }, + "oauth": func() ConnectorConfig { return new(oauth.Config) }, "saml": func() ConnectorConfig { return new(saml.Config) }, "authproxy": func() ConnectorConfig { return new(authproxy.Config) }, "linkedin": func() ConnectorConfig { return new(linkedin.Config) },