package main import ( "bytes" "crypto/tls" "crypto/x509" "encoding/json" "errors" "fmt" "io/ioutil" "log" "net" "net/http" "net/url" "os" "strings" "time" "github.com/ericchiang/oidc" "github.com/spf13/cobra" "golang.org/x/net/context" "golang.org/x/oauth2" ) type app struct { clientID string clientSecret string redirectURI string verifier *oidc.IDTokenVerifier provider *oidc.Provider ctx context.Context cancel context.CancelFunc } // return an HTTP client which trusts the provided root CAs. func httpClientForRootCAs(rootCAs string) (*http.Client, error) { tlsConfig := tls.Config{RootCAs: x509.NewCertPool()} rootCABytes, err := ioutil.ReadFile(rootCAs) if err != nil { return nil, fmt.Errorf("failed to read root-ca: %v", err) } if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) { return nil, fmt.Errorf("no certs found in root CA file %q", rootCAs) } return &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tlsConfig, Proxy: http.ProxyFromEnvironment, Dial: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, }).Dial, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, }, }, nil } func cmd() *cobra.Command { var ( a app issuerURL string listen string tlsCert string tlsKey string rootCAs string ) c := cobra.Command{ Use: "example-app", Short: "An example OpenID Connect client", Long: "", RunE: func(cmd *cobra.Command, args []string) error { if len(args) != 0 { return errors.New("surplus arguments provided") } u, err := url.Parse(a.redirectURI) if err != nil { return fmt.Errorf("parse redirect-uri: %v", err) } listenURL, err := url.Parse(listen) if err != nil { return fmt.Errorf("parse listen address: %v", err) } a.ctx, a.cancel = context.WithCancel(context.Background()) if rootCAs != "" { client, err := httpClientForRootCAs(rootCAs) if err != nil { return err } // This sets the OAuth2 client and oidc client. a.ctx = context.WithValue(a.ctx, oauth2.HTTPClient, &client) } // TODO(ericchiang): Retry with backoff provider, err := oidc.NewProvider(a.ctx, issuerURL) if err != nil { return fmt.Errorf("Failed to query provider %q: %v", issuerURL, err) } a.provider = provider a.verifier = provider.NewVerifier(a.ctx, oidc.VerifyAudience(a.clientID)) http.HandleFunc("/", a.handleIndex) http.HandleFunc("/login", a.handleLogin) http.HandleFunc(u.Path, a.handleCallback) switch listenURL.Scheme { case "http": log.Printf("listening on %s", listen) return http.ListenAndServe(listenURL.Host, nil) case "https": log.Printf("listening on %s", listen) return http.ListenAndServeTLS(listenURL.Host, tlsCert, tlsKey, nil) default: return fmt.Errorf("listen address %q is not using http or https", listen) } }, } c.Flags().StringVar(&a.clientID, "client-id", "example-app", "OAuth2 client ID of this application.") c.Flags().StringVar(&a.clientSecret, "client-secret", "ZXhhbXBsZS1hcHAtc2VjcmV0", "OAuth2 client secret of this application.") c.Flags().StringVar(&a.redirectURI, "redirect-uri", "http://127.0.0.1:5555/callback", "Callback URL for OAuth2 responses.") c.Flags().StringVar(&issuerURL, "issuer", "http://127.0.0.1:5556", "URL of the OpenID Connect issuer.") c.Flags().StringVar(&listen, "listen", "http://127.0.0.1:5555", "HTTP(S) address to listen at.") c.Flags().StringVar(&tlsCert, "tls-cert", "", "X509 cert file to present when serving HTTPS.") c.Flags().StringVar(&tlsKey, "tls-key", "", "Private key for the HTTPS cert.") c.Flags().StringVar(&rootCAs, "issuer-root-ca", "", "Root certificate authorities for the issuer. Defaults to host certs.") return &c } func main() { if err := cmd().Execute(); err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(2) } } func (a *app) handleIndex(w http.ResponseWriter, r *http.Request) { renderIndex(w) } func (a *app) oauth2Config(scopes []string) *oauth2.Config { return &oauth2.Config{ ClientID: a.clientID, ClientSecret: a.clientSecret, Endpoint: a.provider.Endpoint(), Scopes: scopes, RedirectURL: a.redirectURI, } } func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) { var scopes []string if extraScopes := r.FormValue("extra_scopes"); extraScopes != "" { scopes = strings.Split(extraScopes, " ") } var clients []string if crossClients := r.FormValue("cross_client"); crossClients != "" { clients = strings.Split(crossClients, " ") } for _, client := range clients { scopes = append(scopes, "audience:server:client_id:"+client) } // TODO(ericchiang): Determine if provider does not support "offline_access" or has // some other mechanism for requesting refresh tokens. scopes = append(scopes, "openid", "profile", "email", "offline_access") http.Redirect(w, r, a.oauth2Config(scopes).AuthCodeURL(""), http.StatusSeeOther) } func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { if errMsg := r.FormValue("error"); errMsg != "" { http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest) return } code := r.FormValue("code") refresh := r.FormValue("refresh_token") var ( err error token *oauth2.Token ) oauth2Config := a.oauth2Config(nil) switch { case code != "": token, err = oauth2Config.Exchange(a.ctx, code) case refresh != "": t := &oauth2.Token{ RefreshToken: refresh, Expiry: time.Now().Add(-time.Hour), } token, err = oauth2Config.TokenSource(a.ctx, t).Token() default: http.Error(w, "no code in request", http.StatusBadRequest) return } if err != nil { http.Error(w, fmt.Sprintf("failed to get token: %v", err), http.StatusInternalServerError) return } rawIDToken, ok := token.Extra("id_token").(string) if !ok { http.Error(w, "no id_token in token response", http.StatusInternalServerError) return } idToken, err := a.verifier.Verify(rawIDToken) if err != nil { http.Error(w, fmt.Sprintf("Failed to verify ID token: %v", err), http.StatusInternalServerError) return } var claims json.RawMessage idToken.Claims(&claims) buff := new(bytes.Buffer) json.Indent(buff, []byte(claims), "", " ") renderToken(w, a.redirectURI, rawIDToken, token.RefreshToken, buff.Bytes()) }