forked from mystiq/dex
connector/oidc: expose oauth2.RegisterBrokenAuthHeaderProvider
This commit is contained in:
parent
d31bb1c8d5
commit
ac032e99f0
2 changed files with 78 additions and 0 deletions
|
@ -6,6 +6,9 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/coreos/go-oidc"
|
||||
|
@ -21,7 +24,50 @@ type Config struct {
|
|||
ClientSecret string `json:"clientSecret"`
|
||||
RedirectURI string `json:"redirectURI"`
|
||||
|
||||
// Causes client_secret to be passed as POST parameters instead of basic
|
||||
// auth. This is specifically "NOT RECOMMENDED" by the OAuth2 RFC, but some
|
||||
// providers require it.
|
||||
//
|
||||
// https://tools.ietf.org/html/rfc6749#section-2.3.1
|
||||
BasicAuthUnsupported *bool `json:"basicAuthUnsupported"`
|
||||
|
||||
Scopes []string `json:"scopes"` // defaults to "profile" and "email"
|
||||
|
||||
}
|
||||
|
||||
// Domains that don't support basic auth. golang.org/x/oauth2 has an internal
|
||||
// list, but it only matches specific URLs, not top level domains.
|
||||
var brokenAuthHeaderDomains = []string{
|
||||
// See: https://github.com/coreos/dex/issues/859
|
||||
"okta.com",
|
||||
"oktapreview.com",
|
||||
}
|
||||
|
||||
// Detect auth header provider issues for known providers. This lets users
|
||||
// avoid having to explicitly set "basicAuthUnsupported" in their config.
|
||||
//
|
||||
// Setting the config field always overrides values returned by this function.
|
||||
func knownBrokenAuthHeaderProvider(issuerURL string) bool {
|
||||
if u, err := url.Parse(issuerURL); err == nil {
|
||||
for _, host := range brokenAuthHeaderDomains {
|
||||
if u.Host == host || strings.HasSuffix(u.Host, "."+host) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// golang.org/x/oauth2 doesn't do internal locking. Need to do it in this
|
||||
// package ourselves and hope that other packages aren't calling it at the
|
||||
// same time.
|
||||
var registerMu = new(sync.Mutex)
|
||||
|
||||
func registerBrokenAuthHeaderProvider(url string) {
|
||||
registerMu.Lock()
|
||||
defer registerMu.Unlock()
|
||||
|
||||
oauth2.RegisterBrokenAuthHeaderProvider(url)
|
||||
}
|
||||
|
||||
// Open returns a connector which can be used to login users through an upstream
|
||||
|
@ -35,6 +81,15 @@ func (c *Config) Open(logger logrus.FieldLogger) (conn connector.Connector, err
|
|||
return nil, fmt.Errorf("failed to get provider: %v", err)
|
||||
}
|
||||
|
||||
if c.BasicAuthUnsupported != nil {
|
||||
// Setting "basicAuthUnsupported" always overrides our detection.
|
||||
if *c.BasicAuthUnsupported {
|
||||
registerBrokenAuthHeaderProvider(provider.Endpoint().TokenURL)
|
||||
}
|
||||
} else if knownBrokenAuthHeaderProvider(c.Issuer) {
|
||||
registerBrokenAuthHeaderProvider(provider.Endpoint().TokenURL)
|
||||
}
|
||||
|
||||
scopes := []string{oidc.ScopeOpenID}
|
||||
if len(c.Scopes) > 0 {
|
||||
scopes = append(scopes, c.Scopes...)
|
||||
|
|
23
connector/oidc/oidc_test.go
Normal file
23
connector/oidc/oidc_test.go
Normal file
|
@ -0,0 +1,23 @@
|
|||
package oidc
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestKnownBrokenAuthHeaderProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
issuerURL string
|
||||
expect bool
|
||||
}{
|
||||
{"https://dev.oktapreview.com", true},
|
||||
{"https://dev.okta.com", true},
|
||||
{"https://okta.com", true},
|
||||
{"https://dev.oktaaccounts.com", false},
|
||||
{"https://accounts.google.com", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
got := knownBrokenAuthHeaderProvider(tc.issuerURL)
|
||||
if got != tc.expect {
|
||||
t.Errorf("knownBrokenAuthHeaderProvider(%q), want=%t, got=%t", tc.issuerURL, tc.expect, got)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue