Remove external setting, enable injection of HTTP client to config.

Signed-off-by: Daniel Haus <dhaus@redhat.com>
This commit is contained in:
Daniel Haus 2022-04-12 17:38:01 +02:00
parent 2b262ff5d6
commit 4088d4f897
No known key found for this signature in database
GPG key ID: 262B7643F39EB8A9
2 changed files with 57 additions and 57 deletions

View file

@ -35,7 +35,6 @@ type Config struct {
Groups []string `json:"groups"` Groups []string `json:"groups"`
InsecureCA bool `json:"insecureCA"` InsecureCA bool `json:"insecureCA"`
RootCA string `json:"rootCA"` RootCA string `json:"rootCA"`
IncludeSystemRootCAs bool `json:"includeSystemRootCAs"`
} }
var ( var (
@ -54,7 +53,6 @@ type openshiftConnector struct {
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
insecureCA bool insecureCA bool
rootCA string rootCA string
includeSystemRootCAs bool
groups []string groups []string
} }
@ -69,6 +67,18 @@ type user struct {
// Open returns a connector which can be used to login users through an upstream // Open returns a connector which can be used to login users through an upstream
// OpenShift OAuth2 provider. // OpenShift OAuth2 provider.
func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) {
httpClient, err := newHTTPClient(c.InsecureCA, c.RootCA)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP client: %w", err)
}
return c.OpenWithHTTPClient(id, logger, httpClient)
}
// OpenWithHTTPClient returns a connector which can be used to login users through an upstream
// OpenShift OAuth2 provider. It provides the ability to inject a http.Client.
func (c *Config) OpenWithHTTPClient(id string, logger log.Logger,
httpClient *http.Client) (conn connector.Connector, err error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + wellKnownURLPath wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + wellKnownURLPath
@ -83,13 +93,8 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
logger: logger, logger: logger,
redirectURI: c.RedirectURI, redirectURI: c.RedirectURI,
rootCA: c.RootCA, rootCA: c.RootCA,
includeSystemRootCAs: c.IncludeSystemRootCAs,
groups: c.Groups, groups: c.Groups,
} httpClient: httpClient,
if openshiftConnector.httpClient, err = newHTTPClient(c.InsecureCA, c.RootCA, c.IncludeSystemRootCAs); err != nil {
cancel()
return nil, fmt.Errorf("failed to create HTTP client: %v", err)
} }
var metadata struct { var metadata struct {
@ -100,14 +105,14 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
resp, err := openshiftConnector.httpClient.Do(req.WithContext(ctx)) resp, err := openshiftConnector.httpClient.Do(req.WithContext(ctx))
if err != nil { if err != nil {
cancel() cancel()
return nil, fmt.Errorf("failed to query OpenShift endpoint %v", err) return nil, fmt.Errorf("failed to query OpenShift endpoint %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
cancel() cancel()
return nil, fmt.Errorf("discovery through endpoint %s failed to decode body: %v", return nil, fmt.Errorf("discovery through endpoint %s failed to decode body: %w",
wellKnownURL, err) wellKnownURL, err)
} }
@ -131,7 +136,8 @@ func (c *openshiftConnector) Close() error {
// LoginURL returns the URL to redirect the user to login with. // LoginURL returns the URL to redirect the user to login with.
func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
if c.redirectURI != callbackURL { if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q",
callbackURL, c.redirectURI)
} }
return c.oauth2Config.AuthCodeURL(state), nil return c.oauth2Config.AuthCodeURL(state), nil
} }
@ -149,7 +155,8 @@ func (e *oauth2Error) Error() string {
} }
// HandleCallback parses the request and returns the user's identity // HandleCallback parses the request and returns the user's identity
func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { func (c *openshiftConnector) HandleCallback(s connector.Scopes,
r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query() q := r.URL.Query()
if errType := q.Get("error"); errType != "" { if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")} return identity, &oauth2Error{errType, q.Get("error_description")}
@ -168,7 +175,8 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request)
return c.identity(ctx, s, token) return c.identity(ctx, s, token)
} }
func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, oldID connector.Identity) (connector.Identity, error) { func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes,
oldID connector.Identity) (connector.Identity, error) {
var token oauth2.Token var token oauth2.Token
err := json.Unmarshal(oldID.ConnectorData, &token) err := json.Unmarshal(oldID.ConnectorData, &token)
if err != nil { if err != nil {
@ -180,7 +188,8 @@ func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, ol
return c.identity(ctx, s, &token) return c.identity(ctx, s, &token)
} }
func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) { func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes,
token *oauth2.Token) (identity connector.Identity, err error) {
client := c.oauth2Config.Client(ctx, token) client := c.oauth2Config.Client(ctx, token)
user, err := c.user(ctx, client) user, err := c.user(ctx, client)
if err != nil { if err != nil {
@ -251,21 +260,12 @@ func validateAllowedGroups(userGroups, allowedGroups []string) bool {
} }
// newHTTPClient returns a new HTTP client // newHTTPClient returns a new HTTP client
func newHTTPClient(insecureCA bool, rootCA string, includeSystemRootCAs bool) (*http.Client, error) { func newHTTPClient(insecureCA bool, rootCA string) (*http.Client, error) {
tlsConfig := tls.Config{} tlsConfig := tls.Config{}
if insecureCA { if insecureCA {
tlsConfig = tls.Config{InsecureSkipVerify: true} tlsConfig = tls.Config{InsecureSkipVerify: true}
} else if rootCA != "" { } else if rootCA != "" {
if !includeSystemRootCAs {
tlsConfig = tls.Config{RootCAs: x509.NewCertPool()} tlsConfig = tls.Config{RootCAs: x509.NewCertPool()}
} else {
systemCAs, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("failed to read host CA: %w", err)
}
tlsConfig = tls.Config{RootCAs: systemCAs}
}
rootCABytes, err := os.ReadFile(rootCA) rootCABytes, err := os.ReadFile(rootCA)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read root-ca: %w", err) return nil, fmt.Errorf("failed to read root-ca: %w", err)

View file

@ -70,7 +70,7 @@ func TestGetUser(t *testing.T) {
_, err = http.NewRequest("GET", hostURL.String(), nil) _, err = http.NewRequest("GET", hostURL.String(), nil)
expectNil(t, err) expectNil(t, err)
h, err := newHTTPClient(true, "", false) h, err := newHTTPClient(true, "")
expectNil(t, err) expectNil(t, err)
@ -128,7 +128,7 @@ func TestVerifyGroup(t *testing.T) {
_, err = http.NewRequest("GET", hostURL.String(), nil) _, err = http.NewRequest("GET", hostURL.String(), nil)
expectNil(t, err) expectNil(t, err)
h, err := newHTTPClient(true, "", false) h, err := newHTTPClient(true, "")
expectNil(t, err) expectNil(t, err)
@ -164,7 +164,7 @@ func TestCallbackIdentity(t *testing.T) {
req, err := http.NewRequest("GET", hostURL.String(), nil) req, err := http.NewRequest("GET", hostURL.String(), nil)
expectNil(t, err) expectNil(t, err)
h, err := newHTTPClient(true, "", false) h, err := newHTTPClient(true, "")
expectNil(t, err) expectNil(t, err)
@ -198,7 +198,7 @@ func TestRefreshIdentity(t *testing.T) {
}) })
defer s.Close() defer s.Close()
h, err := newHTTPClient(true, "", false) h, err := newHTTPClient(true, "")
expectNil(t, err) expectNil(t, err)
oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{ oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{
@ -237,7 +237,7 @@ func TestRefreshIdentityFailure(t *testing.T) {
}) })
defer s.Close() defer s.Close()
h, err := newHTTPClient(true, "", false) h, err := newHTTPClient(true, "")
expectNil(t, err) expectNil(t, err)
oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{ oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{