forked from mystiq/dex
Merge pull request #875 from ericchiang/fix-example-app-custom-ca
cmd/example-app: fix custom CA behavior
This commit is contained in:
commit
6e50c18458
1 changed files with 17 additions and 16 deletions
|
@ -37,8 +37,7 @@ type app struct {
|
|||
// or does it use "access_type=offline" (e.g. Google)?
|
||||
offlineAsScope bool
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// return an HTTP client which trusts the provided root CAs.
|
||||
|
@ -118,31 +117,31 @@ func cmd() *cobra.Command {
|
|||
return fmt.Errorf("parse listen address: %v", err)
|
||||
}
|
||||
|
||||
a.ctx, a.cancel = context.WithCancel(context.Background())
|
||||
|
||||
if rootCAs != "" {
|
||||
client, err := httpClientForRootCAs(rootCAs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// This sets the OAuth2 client and oidc client.
|
||||
a.ctx = context.WithValue(a.ctx, oauth2.HTTPClient, client)
|
||||
a.client = client
|
||||
}
|
||||
|
||||
if debug {
|
||||
client, ok := a.ctx.Value(oauth2.HTTPClient).(*http.Client)
|
||||
if ok {
|
||||
client.Transport = debugTransport{client.Transport}
|
||||
} else {
|
||||
a.ctx = context.WithValue(a.ctx, oauth2.HTTPClient, &http.Client{
|
||||
if a.client == nil {
|
||||
a.client = &http.Client{
|
||||
Transport: debugTransport{http.DefaultTransport},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
a.client.Transport = debugTransport{a.client.Transport}
|
||||
}
|
||||
}
|
||||
|
||||
if a.client == nil {
|
||||
a.client = http.DefaultClient
|
||||
}
|
||||
|
||||
// TODO(ericchiang): Retry with backoff
|
||||
provider, err := oidc.NewProvider(a.ctx, issuerURL)
|
||||
ctx := oidc.ClientContext(context.Background(), a.client)
|
||||
provider, err := oidc.NewProvider(ctx, issuerURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to query provider %q: %v", issuerURL, err)
|
||||
}
|
||||
|
@ -258,6 +257,8 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
|
|||
err error
|
||||
token *oauth2.Token
|
||||
)
|
||||
|
||||
ctx := oidc.ClientContext(r.Context(), a.client)
|
||||
oauth2Config := a.oauth2Config(nil)
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
|
@ -275,7 +276,7 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
|
|||
http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
token, err = oauth2Config.Exchange(a.ctx, code)
|
||||
token, err = oauth2Config.Exchange(ctx, code)
|
||||
case "POST":
|
||||
// Form request from frontend to refresh a token.
|
||||
refresh := r.FormValue("refresh_token")
|
||||
|
@ -287,7 +288,7 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
|
|||
RefreshToken: refresh,
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
}
|
||||
token, err = oauth2Config.TokenSource(r.Context(), t).Token()
|
||||
token, err = oauth2Config.TokenSource(ctx, t).Token()
|
||||
default:
|
||||
http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
|
||||
return
|
||||
|
|
Loading…
Reference in a new issue