diff --git a/connector/gitlab/gitlab.go b/connector/gitlab/gitlab.go index 7d8e8337..f35ac357 100644 --- a/connector/gitlab/gitlab.go +++ b/connector/gitlab/gitlab.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "strconv" + "time" "golang.org/x/oauth2" @@ -61,8 +62,9 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) } type connectorData struct { - // GitLab's OAuth2 tokens never expire. We don't need a refresh token. - AccessToken string `json:"accessToken"` + // Support GitLab's Access Tokens and Refresh tokens. + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` } var ( @@ -135,6 +137,11 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i return identity, fmt.Errorf("gitlab: failed to get token: %v", err) } + return c.identity(ctx, s, token) +} + +func (c *gitlabConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) { + oauth2Config := c.oauth2Config(s) client := oauth2Config.Client(ctx, token) user, err := c.user(ctx, client) @@ -146,6 +153,7 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i if username == "" { username = user.Email } + identity = connector.Identity{ UserID: strconv.Itoa(user.ID), Username: username, @@ -166,10 +174,10 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i } if s.OfflineAccess { - data := connectorData{AccessToken: token.AccessToken} + data := connectorData{RefreshToken: token.RefreshToken, AccessToken: token.AccessToken} connData, err := json.Marshal(data) if err != nil { - return identity, fmt.Errorf("marshal connector data: %v", err) + return identity, fmt.Errorf("gitlab: marshal connector data: %v", err) } identity.ConnectorData = connData } @@ -178,37 +186,39 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i } func (c *gitlabConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) { - if len(ident.ConnectorData) == 0 { - return ident, errors.New("no upstream access token found") - } - var data connectorData if err := json.Unmarshal(ident.ConnectorData, &data); err != nil { - return ident, fmt.Errorf("gitlab: unmarshal access token: %v", err) + return ident, fmt.Errorf("gitlab: unmarshal connector data: %v", err) + } + oauth2Config := c.oauth2Config(s) + + if c.httpClient != nil { + ctx = context.WithValue(ctx, oauth2.HTTPClient, c.httpClient) } - client := c.oauth2Config(s).Client(ctx, &oauth2.Token{AccessToken: data.AccessToken}) - user, err := c.user(ctx, client) - if err != nil { - return ident, fmt.Errorf("gitlab: get user: %v", err) - } - - username := user.Name - if username == "" { - username = user.Email - } - ident.Username = username - ident.PreferredUsername = user.Username - ident.Email = user.Email - - if c.groupsRequired(s.Groups) { - groups, err := c.getGroups(ctx, client, s.Groups, user.Username) - if err != nil { - return ident, fmt.Errorf("gitlab: get groups: %v", err) + switch { + case data.RefreshToken != "": + { + t := &oauth2.Token{ + RefreshToken: data.RefreshToken, + Expiry: time.Now().Add(-time.Hour), + } + token, err := oauth2Config.TokenSource(ctx, t).Token() + if err != nil { + return ident, fmt.Errorf("gitlab: failed to get refresh token: %v", err) + } + return c.identity(ctx, s, token) } - ident.Groups = groups + case data.AccessToken != "": + { + token := &oauth2.Token{ + AccessToken: data.AccessToken, + } + return c.identity(ctx, s, token) + } + default: + return ident, errors.New("no refresh or access token found") } - return ident, nil } func (c *gitlabConnector) groupsRequired(groupScope bool) bool { diff --git a/connector/gitlab/gitlab_test.go b/connector/gitlab/gitlab_test.go index 23cf9aac..d828b8bd 100644 --- a/connector/gitlab/gitlab_test.go +++ b/connector/gitlab/gitlab_test.go @@ -180,6 +180,75 @@ func TestLoginWithTeamNonWhitelisted(t *testing.T) { expectEquals(t, err.Error(), "gitlab: get groups: gitlab: user \"joebloggs\" is not in any of the required groups") } +func TestRefresh(t *testing.T) { + s := newTestServer(map[string]interface{}{ + "/api/v4/user": gitlabUser{Email: "some@email.com", ID: 12345678}, + "/oauth/token": map[string]interface{}{ + "access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9", + "refresh_token": "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC", + "expires_in": "30", + }, + "/oauth/userinfo": userInfo{ + Groups: []string{"team-1"}, + }, + }) + defer s.Close() + + hostURL, err := url.Parse(s.URL) + expectNil(t, err) + + req, err := http.NewRequest("GET", hostURL.String(), nil) + expectNil(t, err) + + c := gitlabConnector{baseURL: s.URL, httpClient: newClient()} + + expectedConnectorData, err := json.Marshal(connectorData{ + RefreshToken: "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC", + AccessToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9", + }) + expectNil(t, err) + + identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, req) + expectNil(t, err) + expectEquals(t, identity.Username, "some@email.com") + expectEquals(t, identity.UserID, "12345678") + expectEquals(t, identity.ConnectorData, expectedConnectorData) + + identity, err = c.Refresh(context.Background(), connector.Scopes{OfflineAccess: true}, identity) + expectNil(t, err) + expectEquals(t, identity.Username, "some@email.com") + expectEquals(t, identity.UserID, "12345678") + expectEquals(t, identity.ConnectorData, expectedConnectorData) +} + +func TestRefreshWithEmptyConnectorData(t *testing.T) { + s := newTestServer(map[string]interface{}{ + "/api/v4/user": gitlabUser{Email: "some@email.com", ID: 12345678}, + "/oauth/token": map[string]interface{}{ + "access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9", + "refresh_token": "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC", + "expires_in": "30", + }, + "/oauth/userinfo": userInfo{ + Groups: []string{"team-1"}, + }, + }) + defer s.Close() + + emptyConnectorData, err := json.Marshal(connectorData{ + RefreshToken: "", + AccessToken: "", + }) + expectNil(t, err) + + c := gitlabConnector{baseURL: s.URL, httpClient: newClient()} + emptyIdentity := connector.Identity{ConnectorData: emptyConnectorData} + + identity, err := c.Refresh(context.Background(), connector.Scopes{OfflineAccess: true}, emptyIdentity) + expectNotNil(t, err, "Refresh error") + expectEquals(t, emptyIdentity, identity) +} + func newTestServer(responses map[string]interface{}) *httptest.Server { return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response := responses[r.RequestURI]