diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 728bdf6a..0310266c 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -6,6 +6,9 @@ import ( "errors" "fmt" "net/http" + "net/url" + "strings" + "sync" "github.com/Sirupsen/logrus" "github.com/coreos/go-oidc" @@ -21,7 +24,50 @@ type Config struct { ClientSecret string `json:"clientSecret"` RedirectURI string `json:"redirectURI"` + // Causes client_secret to be passed as POST parameters instead of basic + // auth. This is specifically "NOT RECOMMENDED" by the OAuth2 RFC, but some + // providers require it. + // + // https://tools.ietf.org/html/rfc6749#section-2.3.1 + BasicAuthUnsupported *bool `json:"basicAuthUnsupported"` + Scopes []string `json:"scopes"` // defaults to "profile" and "email" + +} + +// Domains that don't support basic auth. golang.org/x/oauth2 has an internal +// list, but it only matches specific URLs, not top level domains. +var brokenAuthHeaderDomains = []string{ + // See: https://github.com/coreos/dex/issues/859 + "okta.com", + "oktapreview.com", +} + +// Detect auth header provider issues for known providers. This lets users +// avoid having to explicitly set "basicAuthUnsupported" in their config. +// +// Setting the config field always overrides values returned by this function. +func knownBrokenAuthHeaderProvider(issuerURL string) bool { + if u, err := url.Parse(issuerURL); err == nil { + for _, host := range brokenAuthHeaderDomains { + if u.Host == host || strings.HasSuffix(u.Host, "."+host) { + return true + } + } + } + return false +} + +// golang.org/x/oauth2 doesn't do internal locking. Need to do it in this +// package ourselves and hope that other packages aren't calling it at the +// same time. +var registerMu = new(sync.Mutex) + +func registerBrokenAuthHeaderProvider(url string) { + registerMu.Lock() + defer registerMu.Unlock() + + oauth2.RegisterBrokenAuthHeaderProvider(url) } // Open returns a connector which can be used to login users through an upstream @@ -35,6 +81,15 @@ func (c *Config) Open(logger logrus.FieldLogger) (conn connector.Connector, err return nil, fmt.Errorf("failed to get provider: %v", err) } + if c.BasicAuthUnsupported != nil { + // Setting "basicAuthUnsupported" always overrides our detection. + if *c.BasicAuthUnsupported { + registerBrokenAuthHeaderProvider(provider.Endpoint().TokenURL) + } + } else if knownBrokenAuthHeaderProvider(c.Issuer) { + registerBrokenAuthHeaderProvider(provider.Endpoint().TokenURL) + } + scopes := []string{oidc.ScopeOpenID} if len(c.Scopes) > 0 { scopes = append(scopes, c.Scopes...) diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go new file mode 100644 index 00000000..b3f609d1 --- /dev/null +++ b/connector/oidc/oidc_test.go @@ -0,0 +1,23 @@ +package oidc + +import "testing" + +func TestKnownBrokenAuthHeaderProvider(t *testing.T) { + tests := []struct { + issuerURL string + expect bool + }{ + {"https://dev.oktapreview.com", true}, + {"https://dev.okta.com", true}, + {"https://okta.com", true}, + {"https://dev.oktaaccounts.com", false}, + {"https://accounts.google.com", false}, + } + + for _, tc := range tests { + got := knownBrokenAuthHeaderProvider(tc.issuerURL) + if got != tc.expect { + t.Errorf("knownBrokenAuthHeaderProvider(%q), want=%t, got=%t", tc.issuerURL, tc.expect, got) + } + } +}