From 9284ffb8c0d96eef422378b72479a7083e409ba6 Mon Sep 17 00:00:00 2001 From: Joshua Winters Date: Fri, 4 May 2018 12:43:09 -0400 Subject: [PATCH] Add generic oauth connector Co-authored-by: Shash Reddy Signed-off-by: Joshua Winters --- connector/oauth/oauth.go | 242 ++++++++++++++++++++++++++++++++++ connector/oauth/oauth_test.go | 234 ++++++++++++++++++++++++++++++++ server/server.go | 2 + 3 files changed, 478 insertions(+) create mode 100644 connector/oauth/oauth.go create mode 100644 connector/oauth/oauth_test.go diff --git a/connector/oauth/oauth.go b/connector/oauth/oauth.go new file mode 100644 index 00000000..7bf480cd --- /dev/null +++ b/connector/oauth/oauth.go @@ -0,0 +1,242 @@ +package oauth + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net" + "net/http" + "strings" + "time" + + "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/pkg/log" + "golang.org/x/oauth2" +) + +type oauthConnector struct { + clientID string + clientSecret string + redirectURI string + tokenURL string + authorizationURL string + userInfoURL string + scopes []string + groupsKey string + httpClient *http.Client + logger log.Logger +} + +type connectorData struct { + AccessToken string +} + +type Config struct { + ClientID string `json:"clientID"` + ClientSecret string `json:"clientSecret"` + RedirectURI string `json:"redirectURI"` + TokenURL string `json:"tokenURL"` + AuthorizationURL string `json:"authorizationURL"` + UserInfoURL string `json:"userInfoURL"` + Scopes []string `json:"scopes"` + GroupsKey string `json:"groupsKey"` + RootCAs []string `json:"rootCAs"` + InsecureSkipVerify bool `json:"insecureSkipVerify"` +} + +func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) { + var err error + + oauthConn := &oauthConnector{ + clientID: c.ClientID, + clientSecret: c.ClientSecret, + tokenURL: c.TokenURL, + authorizationURL: c.AuthorizationURL, + userInfoURL: c.UserInfoURL, + scopes: c.Scopes, + groupsKey: c.GroupsKey, + redirectURI: c.RedirectURI, + logger: logger, + } + + oauthConn.httpClient, err = newHTTPClient(c.RootCAs, c.InsecureSkipVerify) + if err != nil { + return nil, err + } + + return oauthConn, err +} + +func newHTTPClient(rootCAs []string, insecureSkipVerify bool) (*http.Client, error) { + pool, err := x509.SystemCertPool() + if err != nil { + return nil, err + } + + tlsConfig := tls.Config{RootCAs: pool, InsecureSkipVerify: insecureSkipVerify} + for _, rootCA := range rootCAs { + rootCABytes, err := ioutil.ReadFile(rootCA) + if err != nil { + return nil, fmt.Errorf("failed to read root-ca: %v", err) + } + if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) { + return nil, fmt.Errorf("no certs found in root CA file %q", rootCA) + } + } + + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tlsConfig, + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + }, nil +} + +func (c *oauthConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { + + if c.redirectURI != callbackURL { + return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + } + + oauth2Config := &oauth2.Config{ + ClientID: c.clientID, + ClientSecret: c.clientSecret, + Endpoint: oauth2.Endpoint{TokenURL: c.tokenURL, AuthURL: c.authorizationURL}, + RedirectURL: c.redirectURI, + Scopes: c.scopes, + } + + return oauth2Config.AuthCodeURL(state), nil +} + +func (c *oauthConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { + + q := r.URL.Query() + if errType := q.Get("error"); errType != "" { + return identity, errors.New(q.Get("error_description")) + } + + oauth2Config := &oauth2.Config{ + ClientID: c.clientID, + ClientSecret: c.clientSecret, + Endpoint: oauth2.Endpoint{TokenURL: c.tokenURL, AuthURL: c.authorizationURL}, + RedirectURL: c.redirectURI, + Scopes: c.scopes, + } + + ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient) + + token, err := oauth2Config.Exchange(ctx, q.Get("code")) + if err != nil { + return identity, fmt.Errorf("OAuth connector: failed to get token: %v", err) + } + + client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)) + + userInfoResp, err := client.Get(c.userInfoURL) + if err != nil { + return identity, fmt.Errorf("OAuth Connector: failed to execute request to userinfo: %v", err) + } + + if userInfoResp.StatusCode != http.StatusOK { + return identity, fmt.Errorf("OAuth Connector: failed to execute request to userinfo: status %d", userInfoResp.StatusCode) + } + + defer userInfoResp.Body.Close() + + var userInfoResult map[string]interface{} + err = json.NewDecoder(userInfoResp.Body).Decode(&userInfoResult) + + if err != nil { + return identity, fmt.Errorf("OAuth Connector: failed to parse userinfo: %v", err) + } + + identity.UserID, _ = userInfoResult["user_id"].(string) + identity.Name, _ = userInfoResult["name"].(string) + identity.Username, _ = userInfoResult["user_name"].(string) + identity.Email, _ = userInfoResult["email"].(string) + identity.EmailVerified, _ = userInfoResult["email_verified"].(bool) + + if s.Groups { + if c.groupsKey == "" { + c.groupsKey = "groups" + } + + groups := map[string]bool{} + + c.addGroupsFromMap(groups, userInfoResult) + c.addGroupsFromToken(groups, token.AccessToken) + + for groupName, _ := range groups { + identity.Groups = append(identity.Groups, groupName) + } + } + + if s.OfflineAccess { + data := connectorData{AccessToken: token.AccessToken} + connData, err := json.Marshal(data) + if err != nil { + return identity, fmt.Errorf("OAuth Connector: failed to parse connector data for offline access: %v", err) + } + identity.ConnectorData = connData + } + + return identity, nil +} + +func (c *oauthConnector) addGroupsFromMap(groups map[string]bool, result map[string]interface{}) error { + groupsClaim, ok := result[c.groupsKey].([]interface{}) + if !ok { + return errors.New("Cant convert to array") + } + + for _, group := range groupsClaim { + if groupString, ok := group.(string); ok { + groups[groupString] = true + } + } + + return nil +} + +func (c *oauthConnector) addGroupsFromToken(groups map[string]bool, token string) error { + parts := strings.Split(token, ".") + if len(parts) < 2 { + return errors.New("Invalid token") + } + + decoded, err := decode(parts[1]) + if err != nil { + return err + } + + var claimsMap map[string]interface{} + err = json.Unmarshal(decoded, &claimsMap) + if err != nil { + return err + } + + return c.addGroupsFromMap(groups, claimsMap) +} + +func decode(seg string) ([]byte, error) { + if l := len(seg) % 4; l > 0 { + seg += strings.Repeat("=", 4-l) + } + + return base64.URLEncoding.DecodeString(seg) +} diff --git a/connector/oauth/oauth_test.go b/connector/oauth/oauth_test.go new file mode 100644 index 00000000..2a43d72d --- /dev/null +++ b/connector/oauth/oauth_test.go @@ -0,0 +1,234 @@ +package oauth + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "sort" + "testing" + + "github.com/dexidp/dex/connector" + "github.com/sirupsen/logrus" + jose "gopkg.in/square/go-jose.v2" +) + +func TestOpen(t *testing.T) { + tokenClaims := map[string]interface{}{} + userInfoClaims := map[string]interface{}{} + + testServer := testSetup(t, tokenClaims, userInfoClaims) + defer testServer.Close() + + conn := newConnector(t, testServer.URL) + + sort.Strings(conn.scopes) + + expectEqual(t, conn.clientID, "testClient") + expectEqual(t, conn.clientSecret, "testSecret") + expectEqual(t, conn.redirectURI, testServer.URL+"/callback") + expectEqual(t, conn.tokenURL, testServer.URL+"/token") + expectEqual(t, conn.authorizationURL, testServer.URL+"/authorize") + expectEqual(t, conn.userInfoURL, testServer.URL+"/userinfo") + expectEqual(t, len(conn.scopes), 2) + expectEqual(t, conn.scopes[0], "groups") + expectEqual(t, conn.scopes[1], "openid") +} + +func TestLoginURL(t *testing.T) { + tokenClaims := map[string]interface{}{} + userInfoClaims := map[string]interface{}{} + + testServer := testSetup(t, tokenClaims, userInfoClaims) + defer testServer.Close() + + conn := newConnector(t, testServer.URL) + + loginURL, err := conn.LoginURL(connector.Scopes{}, conn.redirectURI, "some-state") + expectEqual(t, err, nil) + + expectedURL, err := url.Parse(testServer.URL + "/authorize") + expectEqual(t, err, nil) + + values := url.Values{} + values.Add("client_id", "testClient") + values.Add("redirect_uri", conn.redirectURI) + values.Add("response_type", "code") + values.Add("scope", "openid groups") + values.Add("state", "some-state") + expectedURL.RawQuery = values.Encode() + + expectEqual(t, loginURL, expectedURL.String()) +} + +func TestHandleCallBackForGroupsInUserInfo(t *testing.T) { + + tokenClaims := map[string]interface{}{} + + userInfoClaims := map[string]interface{}{ + "name": "test-name", + "user_name": "test-username", + "user_id": "test-user-id", + "email": "test-email", + "email_verified": true, + "groups_key": []string{"admin-group", "user-group"}, + } + + testServer := testSetup(t, tokenClaims, userInfoClaims) + defer testServer.Close() + + conn := newConnector(t, testServer.URL) + req := newRequestWithAuthCode(t, testServer.URL, "some-code") + + identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req) + expectEqual(t, err, nil) + + sort.Strings(identity.Groups) + expectEqual(t, len(identity.Groups), 2) + expectEqual(t, identity.Groups[0], "admin-group") + expectEqual(t, identity.Groups[1], "user-group") + expectEqual(t, identity.Name, "test-name") + expectEqual(t, identity.Username, "test-username") + expectEqual(t, identity.Email, "test-email") + expectEqual(t, identity.EmailVerified, true) +} + +func TestHandleCallBackForGroupsInToken(t *testing.T) { + + tokenClaims := map[string]interface{}{ + "groups_key": []string{"test-group"}, + } + + userInfoClaims := map[string]interface{}{ + "name": "test-name", + "user_name": "test-username", + "user_id": "test-user-id", + "email": "test-email", + "email_verified": true, + } + + testServer := testSetup(t, tokenClaims, userInfoClaims) + defer testServer.Close() + + conn := newConnector(t, testServer.URL) + req := newRequestWithAuthCode(t, testServer.URL, "some-code") + + identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req) + expectEqual(t, err, nil) + + expectEqual(t, len(identity.Groups), 1) + expectEqual(t, identity.Groups[0], "test-group") + expectEqual(t, identity.Name, "test-name") + expectEqual(t, identity.Username, "test-username") + expectEqual(t, identity.Email, "test-email") + expectEqual(t, identity.EmailVerified, true) +} + +func testSetup(t *testing.T, tokenClaims map[string]interface{}, userInfoClaims map[string]interface{}) *httptest.Server { + + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal("Failed to generate rsa key", err) + } + + jwk := jose.JSONWebKey{ + Key: key, + KeyID: "some-key", + Algorithm: "RSA", + } + + mux := http.NewServeMux() + + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + token, err := newToken(&jwk, tokenClaims) + if err != nil { + t.Fatal("unable to generate token", err) + } + + w.Header().Add("Content-Type", "application/json") + json.NewEncoder(w).Encode(&map[string]string{ + "access_token": token, + "id_token": token, + "token_type": "Bearer", + }) + }) + + mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "application/json") + json.NewEncoder(w).Encode(userInfoClaims) + }) + + return httptest.NewServer(mux) +} + +func newToken(key *jose.JSONWebKey, claims map[string]interface{}) (string, error) { + signingKey := jose.SigningKey{Key: key, Algorithm: jose.RS256} + + signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{}) + if err != nil { + return "", fmt.Errorf("new signer: %v", err) + } + + payload, err := json.Marshal(claims) + if err != nil { + return "", fmt.Errorf("marshaling claims: %v", err) + } + + signature, err := signer.Sign(payload) + if err != nil { + return "", fmt.Errorf("signing payload: %v", err) + } + + return signature.CompactSerialize() +} + +func newConnector(t *testing.T, serverURL string) *oauthConnector { + testConfig := Config{ + ClientID: "testClient", + ClientSecret: "testSecret", + RedirectURI: serverURL + "/callback", + TokenURL: serverURL + "/token", + AuthorizationURL: serverURL + "/authorize", + UserInfoURL: serverURL + "/userinfo", + Scopes: []string{"openid", "groups"}, + GroupsKey: "groups_key", + } + + log := logrus.New() + + conn, err := testConfig.Open("id", log) + if err != nil { + t.Fatal(err) + } + + oauthConn, ok := conn.(*oauthConnector) + if !ok { + t.Fatal(errors.New("failed to convert to oauthConnector")) + } + + return oauthConn +} + +func newRequestWithAuthCode(t *testing.T, serverURL string, code string) *http.Request { + req, err := http.NewRequest("GET", serverURL, nil) + if err != nil { + t.Fatal("failed to create request", err) + } + + values := req.URL.Query() + values.Add("code", code) + req.URL.RawQuery = values.Encode() + + return req +} + +func expectEqual(t *testing.T, a interface{}, b interface{}) { + if !reflect.DeepEqual(a, b) { + t.Fatalf("Expected %+v to equal %+v", a, b) + } +} diff --git a/server/server.go b/server/server.go index ecd6c935..6b653fdb 100755 --- a/server/server.go +++ b/server/server.go @@ -38,6 +38,7 @@ import ( "github.com/dexidp/dex/connector/linkedin" "github.com/dexidp/dex/connector/microsoft" "github.com/dexidp/dex/connector/mock" + "github.com/dexidp/dex/connector/oauth" "github.com/dexidp/dex/connector/oidc" "github.com/dexidp/dex/connector/openshift" "github.com/dexidp/dex/connector/saml" @@ -538,6 +539,7 @@ var ConnectorsConfig = map[string]func() ConnectorConfig{ "gitlab": func() ConnectorConfig { return new(gitlab.Config) }, "google": func() ConnectorConfig { return new(google.Config) }, "oidc": func() ConnectorConfig { return new(oidc.Config) }, + "oauth": func() ConnectorConfig { return new(oauth.Config) }, "saml": func() ConnectorConfig { return new(saml.Config) }, "authproxy": func() ConnectorConfig { return new(authproxy.Config) }, "linkedin": func() ConnectorConfig { return new(linkedin.Config) },