From 6d55fe1c800819951eeece196bbfc67fdf037ae0 Mon Sep 17 00:00:00 2001 From: Daniel Haus Date: Tue, 23 Nov 2021 19:39:23 +0100 Subject: [PATCH] Add support for refresh tokens for openshift connector. Signed-off-by: Daniel Haus --- connector/openshift/openshift.go | 35 +++++++++++- connector/openshift/openshift_test.go | 79 +++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/connector/openshift/openshift.go b/connector/openshift/openshift.go index 44a03edf..d25018b4 100644 --- a/connector/openshift/openshift.go +++ b/connector/openshift/openshift.go @@ -21,6 +21,11 @@ import ( "github.com/dexidp/dex/storage/kubernetes/k8sapi" ) +const ( + wellKnownURLPath = "/.well-known/oauth-authorization-server" + usersURLPath = "/apis/user.openshift.io/v1/users/~" +) + // Config holds configuration options for OpenShift login type Config struct { Issuer string `json:"issuer"` @@ -33,6 +38,7 @@ type Config struct { } var _ connector.CallbackConnector = (*openshiftConnector)(nil) +var _ connector.RefreshConnector = (*openshiftConnector)(nil) type openshiftConnector struct { apiURL string @@ -61,7 +67,7 @@ type user struct { func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { ctx, cancel := context.WithCancel(context.Background()) - wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + "/.well-known/oauth-authorization-server" + wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + wellKnownURLPath req, err := http.NewRequest(http.MethodGet, wellKnownURL, nil) openshiftConnector := openshiftConnector{ @@ -154,8 +160,23 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request) return identity, fmt.Errorf("oidc: failed to get token: %v", err) } - client := c.oauth2Config.Client(ctx, token) + return c.identity(ctx, s, token) +} +func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, oldID connector.Identity) (connector.Identity, error) { + var token oauth2.Token + err := json.Unmarshal(oldID.ConnectorData, &token) + if err != nil { + return connector.Identity{}, fmt.Errorf("parsing token: %w", err) + } + if c.httpClient != nil { + ctx = context.WithValue(ctx, oauth2.HTTPClient, c.httpClient) + } + return c.identity(ctx, s, &token) +} + +func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) { + client := c.oauth2Config.Client(ctx, token) user, err := c.user(ctx, client) if err != nil { return identity, fmt.Errorf("openshift: get user: %v", err) @@ -177,12 +198,20 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request) Groups: user.Groups, } + if s.OfflineAccess { + connData, err := json.Marshal(token) + if err != nil { + return identity, fmt.Errorf("marshal connector data: %v", err) + } + identity.ConnectorData = connData + } + return identity, nil } // user function returns the OpenShift user associated with the authenticated user func (c *openshiftConnector) user(ctx context.Context, client *http.Client) (u user, err error) { - url := c.apiURL + "/apis/user.openshift.io/v1/users/~" + url := c.apiURL + usersURLPath req, err := http.NewRequest("GET", url, nil) if err != nil { diff --git a/connector/openshift/openshift_test.go b/connector/openshift/openshift_test.go index 90f1686c..ee25668c 100644 --- a/connector/openshift/openshift_test.go +++ b/connector/openshift/openshift_test.go @@ -9,6 +9,7 @@ import ( "net/url" "reflect" "testing" + "time" "github.com/sirupsen/logrus" "golang.org/x/oauth2" @@ -184,6 +185,78 @@ func TestCallbackIdentity(t *testing.T) { expectEquals(t, identity.Groups[0], "users") } +func TestRefreshIdentity(t *testing.T) { + s := newTestServer(map[string]interface{}{ + usersURLPath: user{ + ObjectMeta: k8sapi.ObjectMeta{ + Name: "jdoe", + UID: "12345", + }, + FullName: "John Doe", + Groups: []string{"users"}, + }, + }) + defer s.Close() + + h, err := newHTTPClient(true, "") + expectNil(t, err) + + oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + AuthURL: fmt.Sprintf("%s/oauth/authorize", s.URL), + TokenURL: fmt.Sprintf("%s/oauth/token", s.URL), + }, + }} + + data, err := json.Marshal(oauth2.Token{AccessToken: "fFAGRNJru1FTz70BzhT3Zg"}) + expectNil(t, err) + + oldID := connector.Identity{ConnectorData: data} + + identity, err := oc.Refresh(context.Background(), connector.Scopes{Groups: true}, oldID) + + expectNil(t, err) + expectEquals(t, identity.UserID, "12345") + expectEquals(t, identity.Username, "jdoe") + expectEquals(t, identity.PreferredUsername, "jdoe") + expectEquals(t, identity.Email, "jdoe") + expectEquals(t, len(identity.Groups), 1) + expectEquals(t, identity.Groups[0], "users") +} + +func TestRefreshIdentityFailure(t *testing.T) { + s := newTestServer(map[string]interface{}{ + usersURLPath: user{ + ObjectMeta: k8sapi.ObjectMeta{ + Name: "jdoe", + UID: "12345", + }, + FullName: "John Doe", + Groups: []string{"users"}, + }, + }) + defer s.Close() + + h, err := newHTTPClient(true, "") + expectNil(t, err) + + oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{ + Endpoint: oauth2.Endpoint{ + AuthURL: fmt.Sprintf("%s/oauth/authorize", s.URL), + TokenURL: fmt.Sprintf("%s/oauth/token", s.URL), + }, + }} + + data, err := json.Marshal(oauth2.Token{AccessToken: "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC", Expiry: time.Now().Add(-time.Hour)}) + expectNil(t, err) + + oldID := connector.Identity{ConnectorData: data} + + identity, err := oc.Refresh(context.Background(), connector.Scopes{Groups: true}, oldID) + expectNotNil(t, err) + expectEquals(t, connector.Identity{}, identity) +} + func newTestServer(responses map[string]interface{}) *httptest.Server { var s *httptest.Server s = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -216,3 +289,9 @@ func expectEquals(t *testing.T, a interface{}, b interface{}) { t.Errorf("Expected %+v to equal %+v", a, b) } } + +func expectNotNil(t *testing.T, a interface{}) { + if a == nil { + t.Errorf("Expected %+v to not equal nil", a) + } +}