forked from mystiq/dex
Add support for refresh tokens for openshift connector.
Signed-off-by: Daniel Haus <dhaus@redhat.com>
This commit is contained in:
parent
e00e75b773
commit
6d55fe1c80
2 changed files with 111 additions and 3 deletions
|
@ -21,6 +21,11 @@ import (
|
||||||
"github.com/dexidp/dex/storage/kubernetes/k8sapi"
|
"github.com/dexidp/dex/storage/kubernetes/k8sapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
wellKnownURLPath = "/.well-known/oauth-authorization-server"
|
||||||
|
usersURLPath = "/apis/user.openshift.io/v1/users/~"
|
||||||
|
)
|
||||||
|
|
||||||
// Config holds configuration options for OpenShift login
|
// Config holds configuration options for OpenShift login
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Issuer string `json:"issuer"`
|
Issuer string `json:"issuer"`
|
||||||
|
@ -33,6 +38,7 @@ type Config struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ connector.CallbackConnector = (*openshiftConnector)(nil)
|
var _ connector.CallbackConnector = (*openshiftConnector)(nil)
|
||||||
|
var _ connector.RefreshConnector = (*openshiftConnector)(nil)
|
||||||
|
|
||||||
type openshiftConnector struct {
|
type openshiftConnector struct {
|
||||||
apiURL string
|
apiURL string
|
||||||
|
@ -61,7 +67,7 @@ type user struct {
|
||||||
func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) {
|
func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + "/.well-known/oauth-authorization-server"
|
wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + wellKnownURLPath
|
||||||
req, err := http.NewRequest(http.MethodGet, wellKnownURL, nil)
|
req, err := http.NewRequest(http.MethodGet, wellKnownURL, nil)
|
||||||
|
|
||||||
openshiftConnector := openshiftConnector{
|
openshiftConnector := openshiftConnector{
|
||||||
|
@ -154,8 +160,23 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request)
|
||||||
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
|
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
client := c.oauth2Config.Client(ctx, token)
|
return c.identity(ctx, s, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, oldID connector.Identity) (connector.Identity, error) {
|
||||||
|
var token oauth2.Token
|
||||||
|
err := json.Unmarshal(oldID.ConnectorData, &token)
|
||||||
|
if err != nil {
|
||||||
|
return connector.Identity{}, fmt.Errorf("parsing token: %w", err)
|
||||||
|
}
|
||||||
|
if c.httpClient != nil {
|
||||||
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
|
||||||
|
}
|
||||||
|
return c.identity(ctx, s, &token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) {
|
||||||
|
client := c.oauth2Config.Client(ctx, token)
|
||||||
user, err := c.user(ctx, client)
|
user, err := c.user(ctx, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return identity, fmt.Errorf("openshift: get user: %v", err)
|
return identity, fmt.Errorf("openshift: get user: %v", err)
|
||||||
|
@ -177,12 +198,20 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request)
|
||||||
Groups: user.Groups,
|
Groups: user.Groups,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.OfflineAccess {
|
||||||
|
connData, err := json.Marshal(token)
|
||||||
|
if err != nil {
|
||||||
|
return identity, fmt.Errorf("marshal connector data: %v", err)
|
||||||
|
}
|
||||||
|
identity.ConnectorData = connData
|
||||||
|
}
|
||||||
|
|
||||||
return identity, nil
|
return identity, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// user function returns the OpenShift user associated with the authenticated user
|
// user function returns the OpenShift user associated with the authenticated user
|
||||||
func (c *openshiftConnector) user(ctx context.Context, client *http.Client) (u user, err error) {
|
func (c *openshiftConnector) user(ctx context.Context, client *http.Client) (u user, err error) {
|
||||||
url := c.apiURL + "/apis/user.openshift.io/v1/users/~"
|
url := c.apiURL + usersURLPath
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", url, nil)
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
@ -184,6 +185,78 @@ func TestCallbackIdentity(t *testing.T) {
|
||||||
expectEquals(t, identity.Groups[0], "users")
|
expectEquals(t, identity.Groups[0], "users")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRefreshIdentity(t *testing.T) {
|
||||||
|
s := newTestServer(map[string]interface{}{
|
||||||
|
usersURLPath: user{
|
||||||
|
ObjectMeta: k8sapi.ObjectMeta{
|
||||||
|
Name: "jdoe",
|
||||||
|
UID: "12345",
|
||||||
|
},
|
||||||
|
FullName: "John Doe",
|
||||||
|
Groups: []string{"users"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
h, err := newHTTPClient(true, "")
|
||||||
|
expectNil(t, err)
|
||||||
|
|
||||||
|
oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: fmt.Sprintf("%s/oauth/authorize", s.URL),
|
||||||
|
TokenURL: fmt.Sprintf("%s/oauth/token", s.URL),
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
|
||||||
|
data, err := json.Marshal(oauth2.Token{AccessToken: "fFAGRNJru1FTz70BzhT3Zg"})
|
||||||
|
expectNil(t, err)
|
||||||
|
|
||||||
|
oldID := connector.Identity{ConnectorData: data}
|
||||||
|
|
||||||
|
identity, err := oc.Refresh(context.Background(), connector.Scopes{Groups: true}, oldID)
|
||||||
|
|
||||||
|
expectNil(t, err)
|
||||||
|
expectEquals(t, identity.UserID, "12345")
|
||||||
|
expectEquals(t, identity.Username, "jdoe")
|
||||||
|
expectEquals(t, identity.PreferredUsername, "jdoe")
|
||||||
|
expectEquals(t, identity.Email, "jdoe")
|
||||||
|
expectEquals(t, len(identity.Groups), 1)
|
||||||
|
expectEquals(t, identity.Groups[0], "users")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshIdentityFailure(t *testing.T) {
|
||||||
|
s := newTestServer(map[string]interface{}{
|
||||||
|
usersURLPath: user{
|
||||||
|
ObjectMeta: k8sapi.ObjectMeta{
|
||||||
|
Name: "jdoe",
|
||||||
|
UID: "12345",
|
||||||
|
},
|
||||||
|
FullName: "John Doe",
|
||||||
|
Groups: []string{"users"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
h, err := newHTTPClient(true, "")
|
||||||
|
expectNil(t, err)
|
||||||
|
|
||||||
|
oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: fmt.Sprintf("%s/oauth/authorize", s.URL),
|
||||||
|
TokenURL: fmt.Sprintf("%s/oauth/token", s.URL),
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
|
||||||
|
data, err := json.Marshal(oauth2.Token{AccessToken: "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC", Expiry: time.Now().Add(-time.Hour)})
|
||||||
|
expectNil(t, err)
|
||||||
|
|
||||||
|
oldID := connector.Identity{ConnectorData: data}
|
||||||
|
|
||||||
|
identity, err := oc.Refresh(context.Background(), connector.Scopes{Groups: true}, oldID)
|
||||||
|
expectNotNil(t, err)
|
||||||
|
expectEquals(t, connector.Identity{}, identity)
|
||||||
|
}
|
||||||
|
|
||||||
func newTestServer(responses map[string]interface{}) *httptest.Server {
|
func newTestServer(responses map[string]interface{}) *httptest.Server {
|
||||||
var s *httptest.Server
|
var s *httptest.Server
|
||||||
s = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
s = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -216,3 +289,9 @@ func expectEquals(t *testing.T, a interface{}, b interface{}) {
|
||||||
t.Errorf("Expected %+v to equal %+v", a, b)
|
t.Errorf("Expected %+v to equal %+v", a, b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func expectNotNil(t *testing.T, a interface{}) {
|
||||||
|
if a == nil {
|
||||||
|
t.Errorf("Expected %+v to not equal nil", a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue