cmd/example-app: fix custom CA behavior

This commit is contained in:
Eric Chiang 2017-03-24 11:53:28 -07:00
parent 5d49e18478
commit 9b0e9ab2ca

View file

@ -37,8 +37,7 @@ type app struct {
// or does it use "access_type=offline" (e.g. Google)? // or does it use "access_type=offline" (e.g. Google)?
offlineAsScope bool offlineAsScope bool
ctx context.Context client *http.Client
cancel context.CancelFunc
} }
// return an HTTP client which trusts the provided root CAs. // return an HTTP client which trusts the provided root CAs.
@ -118,31 +117,31 @@ func cmd() *cobra.Command {
return fmt.Errorf("parse listen address: %v", err) return fmt.Errorf("parse listen address: %v", err)
} }
a.ctx, a.cancel = context.WithCancel(context.Background())
if rootCAs != "" { if rootCAs != "" {
client, err := httpClientForRootCAs(rootCAs) client, err := httpClientForRootCAs(rootCAs)
if err != nil { if err != nil {
return err return err
} }
a.client = client
// This sets the OAuth2 client and oidc client.
a.ctx = context.WithValue(a.ctx, oauth2.HTTPClient, client)
} }
if debug { if debug {
client, ok := a.ctx.Value(oauth2.HTTPClient).(*http.Client) if a.client == nil {
if ok { a.client = &http.Client{
client.Transport = debugTransport{client.Transport}
} else {
a.ctx = context.WithValue(a.ctx, oauth2.HTTPClient, &http.Client{
Transport: debugTransport{http.DefaultTransport}, Transport: debugTransport{http.DefaultTransport},
}) }
} else {
a.client.Transport = debugTransport{a.client.Transport}
} }
} }
if a.client == nil {
a.client = http.DefaultClient
}
// TODO(ericchiang): Retry with backoff // TODO(ericchiang): Retry with backoff
provider, err := oidc.NewProvider(a.ctx, issuerURL) ctx := oidc.ClientContext(context.Background(), a.client)
provider, err := oidc.NewProvider(ctx, issuerURL)
if err != nil { if err != nil {
return fmt.Errorf("Failed to query provider %q: %v", issuerURL, err) return fmt.Errorf("Failed to query provider %q: %v", issuerURL, err)
} }
@ -258,6 +257,8 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
err error err error
token *oauth2.Token token *oauth2.Token
) )
ctx := oidc.ClientContext(r.Context(), a.client)
oauth2Config := a.oauth2Config(nil) oauth2Config := a.oauth2Config(nil)
switch r.Method { switch r.Method {
case "GET": case "GET":
@ -275,7 +276,7 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest) http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest)
return return
} }
token, err = oauth2Config.Exchange(a.ctx, code) token, err = oauth2Config.Exchange(ctx, code)
case "POST": case "POST":
// Form request from frontend to refresh a token. // Form request from frontend to refresh a token.
refresh := r.FormValue("refresh_token") refresh := r.FormValue("refresh_token")
@ -287,7 +288,7 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
RefreshToken: refresh, RefreshToken: refresh,
Expiry: time.Now().Add(-time.Hour), Expiry: time.Now().Add(-time.Hour),
} }
token, err = oauth2Config.TokenSource(r.Context(), t).Token() token, err = oauth2Config.TokenSource(ctx, t).Token()
default: default:
http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest) http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
return return