forked from mystiq/dex
Use GitLab's refresh_token during Refresh. (#2352)
Signed-off-by: Daniel Haus <dhaus@redhat.com>
This commit is contained in:
parent
d564cc7200
commit
100246328b
2 changed files with 108 additions and 29 deletions
|
@ -9,6 +9,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
@ -61,8 +62,9 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type connectorData struct {
|
type connectorData struct {
|
||||||
// GitLab's OAuth2 tokens never expire. We don't need a refresh token.
|
// Support GitLab's Access Tokens and Refresh tokens.
|
||||||
AccessToken string `json:"accessToken"`
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -135,6 +137,11 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i
|
||||||
return identity, fmt.Errorf("gitlab: failed to get token: %v", err)
|
return identity, fmt.Errorf("gitlab: failed to get token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return c.identity(ctx, s, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *gitlabConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) {
|
||||||
|
oauth2Config := c.oauth2Config(s)
|
||||||
client := oauth2Config.Client(ctx, token)
|
client := oauth2Config.Client(ctx, token)
|
||||||
|
|
||||||
user, err := c.user(ctx, client)
|
user, err := c.user(ctx, client)
|
||||||
|
@ -146,6 +153,7 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i
|
||||||
if username == "" {
|
if username == "" {
|
||||||
username = user.Email
|
username = user.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
identity = connector.Identity{
|
identity = connector.Identity{
|
||||||
UserID: strconv.Itoa(user.ID),
|
UserID: strconv.Itoa(user.ID),
|
||||||
Username: username,
|
Username: username,
|
||||||
|
@ -166,10 +174,10 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.OfflineAccess {
|
if s.OfflineAccess {
|
||||||
data := connectorData{AccessToken: token.AccessToken}
|
data := connectorData{RefreshToken: token.RefreshToken, AccessToken: token.AccessToken}
|
||||||
connData, err := json.Marshal(data)
|
connData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return identity, fmt.Errorf("marshal connector data: %v", err)
|
return identity, fmt.Errorf("gitlab: marshal connector data: %v", err)
|
||||||
}
|
}
|
||||||
identity.ConnectorData = connData
|
identity.ConnectorData = connData
|
||||||
}
|
}
|
||||||
|
@ -178,37 +186,39 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *gitlabConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
|
func (c *gitlabConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
|
||||||
if len(ident.ConnectorData) == 0 {
|
|
||||||
return ident, errors.New("no upstream access token found")
|
|
||||||
}
|
|
||||||
|
|
||||||
var data connectorData
|
var data connectorData
|
||||||
if err := json.Unmarshal(ident.ConnectorData, &data); err != nil {
|
if err := json.Unmarshal(ident.ConnectorData, &data); err != nil {
|
||||||
return ident, fmt.Errorf("gitlab: unmarshal access token: %v", err)
|
return ident, fmt.Errorf("gitlab: unmarshal connector data: %v", err)
|
||||||
|
}
|
||||||
|
oauth2Config := c.oauth2Config(s)
|
||||||
|
|
||||||
|
if c.httpClient != nil {
|
||||||
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
|
||||||
}
|
}
|
||||||
|
|
||||||
client := c.oauth2Config(s).Client(ctx, &oauth2.Token{AccessToken: data.AccessToken})
|
switch {
|
||||||
user, err := c.user(ctx, client)
|
case data.RefreshToken != "":
|
||||||
if err != nil {
|
{
|
||||||
return ident, fmt.Errorf("gitlab: get user: %v", err)
|
t := &oauth2.Token{
|
||||||
}
|
RefreshToken: data.RefreshToken,
|
||||||
|
Expiry: time.Now().Add(-time.Hour),
|
||||||
username := user.Name
|
}
|
||||||
if username == "" {
|
token, err := oauth2Config.TokenSource(ctx, t).Token()
|
||||||
username = user.Email
|
if err != nil {
|
||||||
}
|
return ident, fmt.Errorf("gitlab: failed to get refresh token: %v", err)
|
||||||
ident.Username = username
|
}
|
||||||
ident.PreferredUsername = user.Username
|
return c.identity(ctx, s, token)
|
||||||
ident.Email = user.Email
|
|
||||||
|
|
||||||
if c.groupsRequired(s.Groups) {
|
|
||||||
groups, err := c.getGroups(ctx, client, s.Groups, user.Username)
|
|
||||||
if err != nil {
|
|
||||||
return ident, fmt.Errorf("gitlab: get groups: %v", err)
|
|
||||||
}
|
}
|
||||||
ident.Groups = groups
|
case data.AccessToken != "":
|
||||||
|
{
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: data.AccessToken,
|
||||||
|
}
|
||||||
|
return c.identity(ctx, s, token)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return ident, errors.New("no refresh or access token found")
|
||||||
}
|
}
|
||||||
return ident, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *gitlabConnector) groupsRequired(groupScope bool) bool {
|
func (c *gitlabConnector) groupsRequired(groupScope bool) bool {
|
||||||
|
|
|
@ -180,6 +180,75 @@ func TestLoginWithTeamNonWhitelisted(t *testing.T) {
|
||||||
expectEquals(t, err.Error(), "gitlab: get groups: gitlab: user \"joebloggs\" is not in any of the required groups")
|
expectEquals(t, err.Error(), "gitlab: get groups: gitlab: user \"joebloggs\" is not in any of the required groups")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRefresh(t *testing.T) {
|
||||||
|
s := newTestServer(map[string]interface{}{
|
||||||
|
"/api/v4/user": gitlabUser{Email: "some@email.com", ID: 12345678},
|
||||||
|
"/oauth/token": map[string]interface{}{
|
||||||
|
"access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9",
|
||||||
|
"refresh_token": "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC",
|
||||||
|
"expires_in": "30",
|
||||||
|
},
|
||||||
|
"/oauth/userinfo": userInfo{
|
||||||
|
Groups: []string{"team-1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
hostURL, err := url.Parse(s.URL)
|
||||||
|
expectNil(t, err)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", hostURL.String(), nil)
|
||||||
|
expectNil(t, err)
|
||||||
|
|
||||||
|
c := gitlabConnector{baseURL: s.URL, httpClient: newClient()}
|
||||||
|
|
||||||
|
expectedConnectorData, err := json.Marshal(connectorData{
|
||||||
|
RefreshToken: "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC",
|
||||||
|
AccessToken: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9",
|
||||||
|
})
|
||||||
|
expectNil(t, err)
|
||||||
|
|
||||||
|
identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, req)
|
||||||
|
expectNil(t, err)
|
||||||
|
expectEquals(t, identity.Username, "some@email.com")
|
||||||
|
expectEquals(t, identity.UserID, "12345678")
|
||||||
|
expectEquals(t, identity.ConnectorData, expectedConnectorData)
|
||||||
|
|
||||||
|
identity, err = c.Refresh(context.Background(), connector.Scopes{OfflineAccess: true}, identity)
|
||||||
|
expectNil(t, err)
|
||||||
|
expectEquals(t, identity.Username, "some@email.com")
|
||||||
|
expectEquals(t, identity.UserID, "12345678")
|
||||||
|
expectEquals(t, identity.ConnectorData, expectedConnectorData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshWithEmptyConnectorData(t *testing.T) {
|
||||||
|
s := newTestServer(map[string]interface{}{
|
||||||
|
"/api/v4/user": gitlabUser{Email: "some@email.com", ID: 12345678},
|
||||||
|
"/oauth/token": map[string]interface{}{
|
||||||
|
"access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9",
|
||||||
|
"refresh_token": "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC",
|
||||||
|
"expires_in": "30",
|
||||||
|
},
|
||||||
|
"/oauth/userinfo": userInfo{
|
||||||
|
Groups: []string{"team-1"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
emptyConnectorData, err := json.Marshal(connectorData{
|
||||||
|
RefreshToken: "",
|
||||||
|
AccessToken: "",
|
||||||
|
})
|
||||||
|
expectNil(t, err)
|
||||||
|
|
||||||
|
c := gitlabConnector{baseURL: s.URL, httpClient: newClient()}
|
||||||
|
emptyIdentity := connector.Identity{ConnectorData: emptyConnectorData}
|
||||||
|
|
||||||
|
identity, err := c.Refresh(context.Background(), connector.Scopes{OfflineAccess: true}, emptyIdentity)
|
||||||
|
expectNotNil(t, err, "Refresh error")
|
||||||
|
expectEquals(t, emptyIdentity, identity)
|
||||||
|
}
|
||||||
|
|
||||||
func newTestServer(responses map[string]interface{}) *httptest.Server {
|
func newTestServer(responses map[string]interface{}) *httptest.Server {
|
||||||
return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
response := responses[r.RequestURI]
|
response := responses[r.RequestURI]
|
||||||
|
|
Loading…
Reference in a new issue