forked from mystiq/dex
Use GitLab's refresh_token during Refresh. (#2352)
Signed-off-by: Daniel Haus <dhaus@redhat.com>
This commit is contained in:
parent
d564cc7200
commit
100246328b
2 changed files with 108 additions and 29 deletions
|
@ -9,6 +9,7 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
|
@ -61,8 +62,9 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
|
|||
}
|
||||
|
||||
type connectorData struct {
|
||||
// GitLab's OAuth2 tokens never expire. We don't need a refresh token.
|
||||
// Support GitLab's Access Tokens and Refresh tokens.
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
}
|
||||
|
||||
var (
|
||||
|
@ -135,6 +137,11 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i
|
|||
return identity, fmt.Errorf("gitlab: failed to get token: %v", err)
|
||||
}
|
||||
|
||||
return c.identity(ctx, s, token)
|
||||
}
|
||||
|
||||
func (c *gitlabConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) {
|
||||
oauth2Config := c.oauth2Config(s)
|
||||
client := oauth2Config.Client(ctx, token)
|
||||
|
||||
user, err := c.user(ctx, client)
|
||||
|
@ -146,6 +153,7 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i
|
|||
if username == "" {
|
||||
username = user.Email
|
||||
}
|
||||
|
||||
identity = connector.Identity{
|
||||
UserID: strconv.Itoa(user.ID),
|
||||
Username: username,
|
||||
|
@ -166,10 +174,10 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i
|
|||
}
|
||||
|
||||
if s.OfflineAccess {
|
||||
data := connectorData{AccessToken: token.AccessToken}
|
||||
data := connectorData{RefreshToken: token.RefreshToken, AccessToken: token.AccessToken}
|
||||
connData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return identity, fmt.Errorf("marshal connector data: %v", err)
|
||||
return identity, fmt.Errorf("gitlab: marshal connector data: %v", err)
|
||||
}
|
||||
identity.ConnectorData = connData
|
||||
}
|
||||
|
@ -178,37 +186,39 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i
|
|||
}
|
||||
|
||||
func (c *gitlabConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
|
||||
if len(ident.ConnectorData) == 0 {
|
||||
return ident, errors.New("no upstream access token found")
|
||||
}
|
||||
|
||||
var data connectorData
|
||||
if err := json.Unmarshal(ident.ConnectorData, &data); err != nil {
|
||||
return ident, fmt.Errorf("gitlab: unmarshal access token: %v", err)
|
||||
return ident, fmt.Errorf("gitlab: unmarshal connector data: %v", err)
|
||||
}
|
||||
oauth2Config := c.oauth2Config(s)
|
||||
|
||||
if c.httpClient != nil {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
|
||||
}
|
||||
|
||||
client := c.oauth2Config(s).Client(ctx, &oauth2.Token{AccessToken: data.AccessToken})
|
||||
user, err := c.user(ctx, client)
|
||||
switch {
|
||||
case data.RefreshToken != "":
|
||||
{
|
||||
t := &oauth2.Token{
|
||||
RefreshToken: data.RefreshToken,
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
}
|
||||
token, err := oauth2Config.TokenSource(ctx, t).Token()
|
||||
if err != nil {
|
||||
return ident, fmt.Errorf("gitlab: get user: %v", err)
|
||||
return ident, fmt.Errorf("gitlab: failed to get refresh token: %v", err)
|
||||
}
|
||||
|
||||
username := user.Name
|
||||
if username == "" {
|
||||
username = user.Email
|
||||
return c.identity(ctx, s, token)
|
||||
}
|
||||
ident.Username = username
|
||||
ident.PreferredUsername = user.Username
|
||||
ident.Email = user.Email
|
||||
|
||||
if c.groupsRequired(s.Groups) {
|
||||
groups, err := c.getGroups(ctx, client, s.Groups, user.Username)
|
||||
if err != nil {
|
||||
return ident, fmt.Errorf("gitlab: get groups: %v", err)
|
||||
case data.AccessToken != "":
|
||||
{
|
||||
token := &oauth2.Token{
|
||||
AccessToken: data.AccessToken,
|
||||
}
|
||||
ident.Groups = groups
|
||||
return c.identity(ctx, s, token)
|
||||
}
|
||||
default:
|
||||
return ident, errors.New("no refresh or access token found")
|
||||
}
|
||||
return ident, nil
|
||||
}
|
||||
|
||||
func (c *gitlabConnector) groupsRequired(groupScope bool) bool {
|
||||
|
|
|
@ -180,6 +180,75 @@ func TestLoginWithTeamNonWhitelisted(t *testing.T) {
|
|||
expectEquals(t, err.Error(), "gitlab: get groups: gitlab: user \"joebloggs\" is not in any of the required groups")
|
||||
}
|
||||
|
||||
func TestRefresh(t *testing.T) {
|
||||
s := newTestServer(map[string]interface{}{
|
||||
"/api/v4/user": gitlabUser{Email: "some@email.com", ID: 12345678},
|
||||
"/oauth/token": map[string]interface{}{
|
||||
"access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9",
|
||||
"refresh_token": "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC",
|
||||
"expires_in": "30",
|
||||
},
|
||||
"/oauth/userinfo": userInfo{
|
||||
Groups: []string{"team-1"},
|
||||
},
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
hostURL, err := url.Parse(s.URL)
|
||||
expectNil(t, err)
|
||||
|
||||
req, err := http.NewRequest("GET", hostURL.String(), nil)
|
||||
expectNil(t, err)
|
||||
|
||||
c := gitlabConnector{baseURL: s.URL, httpClient: newClient()}
|
||||
|
||||
expectedConnectorData, err := json.Marshal(connectorData{
|
||||
RefreshToken: "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC",
|
||||
AccessToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9",
|
||||
})
|
||||
expectNil(t, err)
|
||||
|
||||
identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, req)
|
||||
expectNil(t, err)
|
||||
expectEquals(t, identity.Username, "some@email.com")
|
||||
expectEquals(t, identity.UserID, "12345678")
|
||||
expectEquals(t, identity.ConnectorData, expectedConnectorData)
|
||||
|
||||
identity, err = c.Refresh(context.Background(), connector.Scopes{OfflineAccess: true}, identity)
|
||||
expectNil(t, err)
|
||||
expectEquals(t, identity.Username, "some@email.com")
|
||||
expectEquals(t, identity.UserID, "12345678")
|
||||
expectEquals(t, identity.ConnectorData, expectedConnectorData)
|
||||
}
|
||||
|
||||
func TestRefreshWithEmptyConnectorData(t *testing.T) {
|
||||
s := newTestServer(map[string]interface{}{
|
||||
"/api/v4/user": gitlabUser{Email: "some@email.com", ID: 12345678},
|
||||
"/oauth/token": map[string]interface{}{
|
||||
"access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9",
|
||||
"refresh_token": "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC",
|
||||
"expires_in": "30",
|
||||
},
|
||||
"/oauth/userinfo": userInfo{
|
||||
Groups: []string{"team-1"},
|
||||
},
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
emptyConnectorData, err := json.Marshal(connectorData{
|
||||
RefreshToken: "",
|
||||
AccessToken: "",
|
||||
})
|
||||
expectNil(t, err)
|
||||
|
||||
c := gitlabConnector{baseURL: s.URL, httpClient: newClient()}
|
||||
emptyIdentity := connector.Identity{ConnectorData: emptyConnectorData}
|
||||
|
||||
identity, err := c.Refresh(context.Background(), connector.Scopes{OfflineAccess: true}, emptyIdentity)
|
||||
expectNotNil(t, err, "Refresh error")
|
||||
expectEquals(t, emptyIdentity, identity)
|
||||
}
|
||||
|
||||
func newTestServer(responses map[string]interface{}) *httptest.Server {
|
||||
return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
response := responses[r.RequestURI]
|
||||
|
|
Loading…
Reference in a new issue