diff --git a/connector/openshift/openshift.go b/connector/openshift/openshift.go index 469dbd96..05919973 100644 --- a/connector/openshift/openshift.go +++ b/connector/openshift/openshift.go @@ -67,6 +67,18 @@ type user struct { // Open returns a connector which can be used to login users through an upstream // OpenShift OAuth2 provider. func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { + httpClient, err := newHTTPClient(c.InsecureCA, c.RootCA) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client: %w", err) + } + + return c.OpenWithHTTPClient(id, logger, httpClient) +} + +// OpenWithHTTPClient returns a connector which can be used to login users through an upstream +// OpenShift OAuth2 provider. It provides the ability to inject a http.Client. +func (c *Config) OpenWithHTTPClient(id string, logger log.Logger, + httpClient *http.Client) (conn connector.Connector, err error) { ctx, cancel := context.WithCancel(context.Background()) wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + wellKnownURLPath @@ -82,11 +94,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e redirectURI: c.RedirectURI, rootCA: c.RootCA, groups: c.Groups, - } - - if openshiftConnector.httpClient, err = newHTTPClient(c.InsecureCA, c.RootCA); err != nil { - cancel() - return nil, fmt.Errorf("failed to create HTTP client: %v", err) + httpClient: httpClient, } var metadata struct { @@ -97,14 +105,14 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e resp, err := openshiftConnector.httpClient.Do(req.WithContext(ctx)) if err != nil { cancel() - return nil, fmt.Errorf("failed to query OpenShift endpoint %v", err) + return nil, fmt.Errorf("failed to query OpenShift endpoint %w", err) } defer resp.Body.Close() if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { cancel() - return nil, fmt.Errorf("discovery through endpoint %s failed to decode body: %v", + return nil, fmt.Errorf("discovery through endpoint %s failed to decode body: %w", wellKnownURL, err) } @@ -128,7 +136,8 @@ func (c *openshiftConnector) Close() error { // LoginURL returns the URL to redirect the user to login with. func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { if c.redirectURI != callbackURL { - return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) + return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", + callbackURL, c.redirectURI) } return c.oauth2Config.AuthCodeURL(state), nil } @@ -146,7 +155,8 @@ func (e *oauth2Error) Error() string { } // HandleCallback parses the request and returns the user's identity -func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { +func (c *openshiftConnector) HandleCallback(s connector.Scopes, + r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} @@ -165,7 +175,8 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request) return c.identity(ctx, s, token) } -func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, oldID connector.Identity) (connector.Identity, error) { +func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, + oldID connector.Identity) (connector.Identity, error) { var token oauth2.Token err := json.Unmarshal(oldID.ConnectorData, &token) if err != nil { @@ -177,7 +188,8 @@ func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, ol return c.identity(ctx, s, &token) } -func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) { +func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes, + token *oauth2.Token) (identity connector.Identity, err error) { client := c.oauth2Config.Client(ctx, token) user, err := c.user(ctx, client) if err != nil { @@ -250,14 +262,13 @@ func validateAllowedGroups(userGroups, allowedGroups []string) bool { // newHTTPClient returns a new HTTP client func newHTTPClient(insecureCA bool, rootCA string) (*http.Client, error) { tlsConfig := tls.Config{} - if insecureCA { tlsConfig = tls.Config{InsecureSkipVerify: true} } else if rootCA != "" { tlsConfig = tls.Config{RootCAs: x509.NewCertPool()} rootCABytes, err := os.ReadFile(rootCA) if err != nil { - return nil, fmt.Errorf("failed to read root-ca: %v", err) + return nil, fmt.Errorf("failed to read root-ca: %w", err) } if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) { return nil, fmt.Errorf("no certs found in root CA file %q", rootCA)