forked from mystiq/dex
Add support for refresh tokens for openshift connector.
Signed-off-by: Daniel Haus <dhaus@redhat.com>
This commit is contained in:
parent
e00e75b773
commit
6d55fe1c80
2 changed files with 111 additions and 3 deletions
|
@ -21,6 +21,11 @@ import (
|
|||
"github.com/dexidp/dex/storage/kubernetes/k8sapi"
|
||||
)
|
||||
|
||||
const (
|
||||
wellKnownURLPath = "/.well-known/oauth-authorization-server"
|
||||
usersURLPath = "/apis/user.openshift.io/v1/users/~"
|
||||
)
|
||||
|
||||
// Config holds configuration options for OpenShift login
|
||||
type Config struct {
|
||||
Issuer string `json:"issuer"`
|
||||
|
@ -33,6 +38,7 @@ type Config struct {
|
|||
}
|
||||
|
||||
var _ connector.CallbackConnector = (*openshiftConnector)(nil)
|
||||
var _ connector.RefreshConnector = (*openshiftConnector)(nil)
|
||||
|
||||
type openshiftConnector struct {
|
||||
apiURL string
|
||||
|
@ -61,7 +67,7 @@ type user struct {
|
|||
func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + "/.well-known/oauth-authorization-server"
|
||||
wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + wellKnownURLPath
|
||||
req, err := http.NewRequest(http.MethodGet, wellKnownURL, nil)
|
||||
|
||||
openshiftConnector := openshiftConnector{
|
||||
|
@ -154,8 +160,23 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request)
|
|||
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
|
||||
}
|
||||
|
||||
client := c.oauth2Config.Client(ctx, token)
|
||||
return c.identity(ctx, s, token)
|
||||
}
|
||||
|
||||
func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, oldID connector.Identity) (connector.Identity, error) {
|
||||
var token oauth2.Token
|
||||
err := json.Unmarshal(oldID.ConnectorData, &token)
|
||||
if err != nil {
|
||||
return connector.Identity{}, fmt.Errorf("parsing token: %w", err)
|
||||
}
|
||||
if c.httpClient != nil {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
|
||||
}
|
||||
return c.identity(ctx, s, &token)
|
||||
}
|
||||
|
||||
func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) {
|
||||
client := c.oauth2Config.Client(ctx, token)
|
||||
user, err := c.user(ctx, client)
|
||||
if err != nil {
|
||||
return identity, fmt.Errorf("openshift: get user: %v", err)
|
||||
|
@ -177,12 +198,20 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request)
|
|||
Groups: user.Groups,
|
||||
}
|
||||
|
||||
if s.OfflineAccess {
|
||||
connData, err := json.Marshal(token)
|
||||
if err != nil {
|
||||
return identity, fmt.Errorf("marshal connector data: %v", err)
|
||||
}
|
||||
identity.ConnectorData = connData
|
||||
}
|
||||
|
||||
return identity, nil
|
||||
}
|
||||
|
||||
// user function returns the OpenShift user associated with the authenticated user
|
||||
func (c *openshiftConnector) user(ctx context.Context, client *http.Client) (u user, err error) {
|
||||
url := c.apiURL + "/apis/user.openshift.io/v1/users/~"
|
||||
url := c.apiURL + usersURLPath
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
@ -184,6 +185,78 @@ func TestCallbackIdentity(t *testing.T) {
|
|||
expectEquals(t, identity.Groups[0], "users")
|
||||
}
|
||||
|
||||
func TestRefreshIdentity(t *testing.T) {
|
||||
s := newTestServer(map[string]interface{}{
|
||||
usersURLPath: user{
|
||||
ObjectMeta: k8sapi.ObjectMeta{
|
||||
Name: "jdoe",
|
||||
UID: "12345",
|
||||
},
|
||||
FullName: "John Doe",
|
||||
Groups: []string{"users"},
|
||||
},
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
h, err := newHTTPClient(true, "")
|
||||
expectNil(t, err)
|
||||
|
||||
oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: fmt.Sprintf("%s/oauth/authorize", s.URL),
|
||||
TokenURL: fmt.Sprintf("%s/oauth/token", s.URL),
|
||||
},
|
||||
}}
|
||||
|
||||
data, err := json.Marshal(oauth2.Token{AccessToken: "fFAGRNJru1FTz70BzhT3Zg"})
|
||||
expectNil(t, err)
|
||||
|
||||
oldID := connector.Identity{ConnectorData: data}
|
||||
|
||||
identity, err := oc.Refresh(context.Background(), connector.Scopes{Groups: true}, oldID)
|
||||
|
||||
expectNil(t, err)
|
||||
expectEquals(t, identity.UserID, "12345")
|
||||
expectEquals(t, identity.Username, "jdoe")
|
||||
expectEquals(t, identity.PreferredUsername, "jdoe")
|
||||
expectEquals(t, identity.Email, "jdoe")
|
||||
expectEquals(t, len(identity.Groups), 1)
|
||||
expectEquals(t, identity.Groups[0], "users")
|
||||
}
|
||||
|
||||
func TestRefreshIdentityFailure(t *testing.T) {
|
||||
s := newTestServer(map[string]interface{}{
|
||||
usersURLPath: user{
|
||||
ObjectMeta: k8sapi.ObjectMeta{
|
||||
Name: "jdoe",
|
||||
UID: "12345",
|
||||
},
|
||||
FullName: "John Doe",
|
||||
Groups: []string{"users"},
|
||||
},
|
||||
})
|
||||
defer s.Close()
|
||||
|
||||
h, err := newHTTPClient(true, "")
|
||||
expectNil(t, err)
|
||||
|
||||
oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: fmt.Sprintf("%s/oauth/authorize", s.URL),
|
||||
TokenURL: fmt.Sprintf("%s/oauth/token", s.URL),
|
||||
},
|
||||
}}
|
||||
|
||||
data, err := json.Marshal(oauth2.Token{AccessToken: "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC", Expiry: time.Now().Add(-time.Hour)})
|
||||
expectNil(t, err)
|
||||
|
||||
oldID := connector.Identity{ConnectorData: data}
|
||||
|
||||
identity, err := oc.Refresh(context.Background(), connector.Scopes{Groups: true}, oldID)
|
||||
expectNotNil(t, err)
|
||||
expectEquals(t, connector.Identity{}, identity)
|
||||
}
|
||||
|
||||
func newTestServer(responses map[string]interface{}) *httptest.Server {
|
||||
var s *httptest.Server
|
||||
s = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -216,3 +289,9 @@ func expectEquals(t *testing.T, a interface{}, b interface{}) {
|
|||
t.Errorf("Expected %+v to equal %+v", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
func expectNotNil(t *testing.T, a interface{}) {
|
||||
if a == nil {
|
||||
t.Errorf("Expected %+v to not equal nil", a)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue