connector/oidc: replace deprecated oauth2.RegisterBrokenAuthHeaderProvider with oauth2.Endpoint.AuthStyle

This commit is contained in:
Lars Lehtonen 2019-11-15 16:31:22 -08:00
parent e0f927c7a9
commit 8e0ae82034
No known key found for this signature in database
GPG key ID: 8137D474EBCB04F2
2 changed files with 7 additions and 16 deletions

View file

@ -9,7 +9,6 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"sync"
"time" "time"
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
@ -85,18 +84,6 @@ func knownBrokenAuthHeaderProvider(issuerURL string) bool {
return false return false
} }
// golang.org/x/oauth2 doesn't do internal locking. Need to do it in this
// package ourselves and hope that other packages aren't calling it at the
// same time.
var registerMu = new(sync.Mutex)
func registerBrokenAuthHeaderProvider(url string) {
registerMu.Lock()
defer registerMu.Unlock()
oauth2.RegisterBrokenAuthHeaderProvider(url)
}
// Open returns a connector which can be used to login users through an upstream // Open returns a connector which can be used to login users through an upstream
// OpenID Connect provider. // OpenID Connect provider.
func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) {
@ -108,13 +95,15 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
return nil, fmt.Errorf("failed to get provider: %v", err) return nil, fmt.Errorf("failed to get provider: %v", err)
} }
endpoint := provider.Endpoint()
if c.BasicAuthUnsupported != nil { if c.BasicAuthUnsupported != nil {
// Setting "basicAuthUnsupported" always overrides our detection. // Setting "basicAuthUnsupported" always overrides our detection.
if *c.BasicAuthUnsupported { if *c.BasicAuthUnsupported {
registerBrokenAuthHeaderProvider(provider.Endpoint().TokenURL) endpoint.AuthStyle = oauth2.AuthStyleInParams
} }
} else if knownBrokenAuthHeaderProvider(c.Issuer) { } else if knownBrokenAuthHeaderProvider(c.Issuer) {
registerBrokenAuthHeaderProvider(provider.Endpoint().TokenURL) endpoint.AuthStyle = oauth2.AuthStyleInParams
} }
scopes := []string{oidc.ScopeOpenID} scopes := []string{oidc.ScopeOpenID}
@ -131,7 +120,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
oauth2Config: &oauth2.Config{ oauth2Config: &oauth2.Config{
ClientID: clientID, ClientID: clientID,
ClientSecret: c.ClientSecret, ClientSecret: c.ClientSecret,
Endpoint: provider.Endpoint(), Endpoint: endpoint,
Scopes: scopes, Scopes: scopes,
RedirectURL: c.RedirectURI, RedirectURL: c.RedirectURI,
}, },

View file

@ -111,6 +111,7 @@ func TestHandleCallback(t *testing.T) {
} }
defer testServer.Close() defer testServer.Close()
serverURL := testServer.URL serverURL := testServer.URL
basicAuth := true
config := Config{ config := Config{
Issuer: serverURL, Issuer: serverURL,
ClientID: "clientID", ClientID: "clientID",
@ -120,6 +121,7 @@ func TestHandleCallback(t *testing.T) {
UserIDKey: tc.userIDKey, UserIDKey: tc.userIDKey,
UserNameKey: tc.userNameKey, UserNameKey: tc.userNameKey,
InsecureSkipEmailVerified: tc.insecureSkipEmailVerified, InsecureSkipEmailVerified: tc.insecureSkipEmailVerified,
BasicAuthUnsupported: &basicAuth,
} }
conn, err := newConnector(config) conn, err := newConnector(config)