Use GitLab's refresh_token during Refresh. (#2352)

Signed-off-by: Daniel Haus <dhaus@redhat.com>
This commit is contained in:
dhaus67 2022-07-20 11:16:12 +02:00 committed by GitHub
parent d564cc7200
commit 100246328b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 108 additions and 29 deletions

View File

@ -9,6 +9,7 @@ import (
"io"
"net/http"
"strconv"
"time"
"golang.org/x/oauth2"
@ -61,8 +62,9 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
}
type connectorData struct {
// GitLab's OAuth2 tokens never expire. We don't need a refresh token.
// Support GitLab's Access Tokens and Refresh tokens.
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
}
var (
@ -135,6 +137,11 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i
return identity, fmt.Errorf("gitlab: failed to get token: %v", err)
}
return c.identity(ctx, s, token)
}
func (c *gitlabConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) {
oauth2Config := c.oauth2Config(s)
client := oauth2Config.Client(ctx, token)
user, err := c.user(ctx, client)
@ -146,6 +153,7 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i
if username == "" {
username = user.Email
}
identity = connector.Identity{
UserID: strconv.Itoa(user.ID),
Username: username,
@ -166,10 +174,10 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i
}
if s.OfflineAccess {
data := connectorData{AccessToken: token.AccessToken}
data := connectorData{RefreshToken: token.RefreshToken, AccessToken: token.AccessToken}
connData, err := json.Marshal(data)
if err != nil {
return identity, fmt.Errorf("marshal connector data: %v", err)
return identity, fmt.Errorf("gitlab: marshal connector data: %v", err)
}
identity.ConnectorData = connData
}
@ -178,37 +186,39 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i
}
func (c *gitlabConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
if len(ident.ConnectorData) == 0 {
return ident, errors.New("no upstream access token found")
}
var data connectorData
if err := json.Unmarshal(ident.ConnectorData, &data); err != nil {
return ident, fmt.Errorf("gitlab: unmarshal access token: %v", err)
return ident, fmt.Errorf("gitlab: unmarshal connector data: %v", err)
}
oauth2Config := c.oauth2Config(s)
if c.httpClient != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
}
client := c.oauth2Config(s).Client(ctx, &oauth2.Token{AccessToken: data.AccessToken})
user, err := c.user(ctx, client)
switch {
case data.RefreshToken != "":
{
t := &oauth2.Token{
RefreshToken: data.RefreshToken,
Expiry: time.Now().Add(-time.Hour),
}
token, err := oauth2Config.TokenSource(ctx, t).Token()
if err != nil {
return ident, fmt.Errorf("gitlab: get user: %v", err)
return ident, fmt.Errorf("gitlab: failed to get refresh token: %v", err)
}
username := user.Name
if username == "" {
username = user.Email
return c.identity(ctx, s, token)
}
ident.Username = username
ident.PreferredUsername = user.Username
ident.Email = user.Email
if c.groupsRequired(s.Groups) {
groups, err := c.getGroups(ctx, client, s.Groups, user.Username)
if err != nil {
return ident, fmt.Errorf("gitlab: get groups: %v", err)
case data.AccessToken != "":
{
token := &oauth2.Token{
AccessToken: data.AccessToken,
}
ident.Groups = groups
return c.identity(ctx, s, token)
}
default:
return ident, errors.New("no refresh or access token found")
}
return ident, nil
}
func (c *gitlabConnector) groupsRequired(groupScope bool) bool {

View File

@ -180,6 +180,75 @@ func TestLoginWithTeamNonWhitelisted(t *testing.T) {
expectEquals(t, err.Error(), "gitlab: get groups: gitlab: user \"joebloggs\" is not in any of the required groups")
}
func TestRefresh(t *testing.T) {
s := newTestServer(map[string]interface{}{
"/api/v4/user": gitlabUser{Email: "some@email.com", ID: 12345678},
"/oauth/token": map[string]interface{}{
"access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9",
"refresh_token": "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC",
"expires_in": "30",
},
"/oauth/userinfo": userInfo{
Groups: []string{"team-1"},
},
})
defer s.Close()
hostURL, err := url.Parse(s.URL)
expectNil(t, err)
req, err := http.NewRequest("GET", hostURL.String(), nil)
expectNil(t, err)
c := gitlabConnector{baseURL: s.URL, httpClient: newClient()}
expectedConnectorData, err := json.Marshal(connectorData{
RefreshToken: "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC",
AccessToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9",
})
expectNil(t, err)
identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, req)
expectNil(t, err)
expectEquals(t, identity.Username, "some@email.com")
expectEquals(t, identity.UserID, "12345678")
expectEquals(t, identity.ConnectorData, expectedConnectorData)
identity, err = c.Refresh(context.Background(), connector.Scopes{OfflineAccess: true}, identity)
expectNil(t, err)
expectEquals(t, identity.Username, "some@email.com")
expectEquals(t, identity.UserID, "12345678")
expectEquals(t, identity.ConnectorData, expectedConnectorData)
}
func TestRefreshWithEmptyConnectorData(t *testing.T) {
s := newTestServer(map[string]interface{}{
"/api/v4/user": gitlabUser{Email: "some@email.com", ID: 12345678},
"/oauth/token": map[string]interface{}{
"access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9",
"refresh_token": "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC",
"expires_in": "30",
},
"/oauth/userinfo": userInfo{
Groups: []string{"team-1"},
},
})
defer s.Close()
emptyConnectorData, err := json.Marshal(connectorData{
RefreshToken: "",
AccessToken: "",
})
expectNil(t, err)
c := gitlabConnector{baseURL: s.URL, httpClient: newClient()}
emptyIdentity := connector.Identity{ConnectorData: emptyConnectorData}
identity, err := c.Refresh(context.Background(), connector.Scopes{OfflineAccess: true}, emptyIdentity)
expectNotNil(t, err, "Refresh error")
expectEquals(t, emptyIdentity, identity)
}
func newTestServer(responses map[string]interface{}) *httptest.Server {
return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := responses[r.RequestURI]