diff --git a/cmd/dex-worker/main.go b/cmd/dex-worker/main.go index e9a50ea2..8a7166bf 100644 --- a/cmd/dex-worker/main.go +++ b/cmd/dex-worker/main.go @@ -27,8 +27,11 @@ func init() { func main() { fs := flag.NewFlagSet("dex-worker", flag.ExitOnError) - listen := fs.String("listen", "http://0.0.0.0:5556", "") - issuer := fs.String("issuer", "http://127.0.0.1:5556", "") + listen := fs.String("listen", "http://127.0.0.1:5556", "the address that the server will listen on") + issuer := fs.String("issuer", "http://127.0.0.1:5556", "the issuer's location") + certFile := fs.String("tls-cert-file", "", "the server's certificate file for TLS connection") + keyFile := fs.String("tls-key-file", "", "the server's private key file for TLS connection") + templates := fs.String("html-assets", "./static/html", "directory of html template files") emailTemplateDirs := flagutil.StringSliceFlag{"./static/email"} @@ -75,13 +78,30 @@ func main() { log.EnableTimestamps() } + // Validate listen address. lu, err := url.Parse(*listen) if err != nil { - log.Fatalf("Unable to use --listen flag: %v", err) + log.Fatalf("Invalid listen address %q: %v", *listen, err) } - if lu.Scheme != "http" { - log.Fatalf("Unable to listen using scheme %s", lu.Scheme) + switch lu.Scheme { + case "http": + case "https": + if *certFile == "" || *keyFile == "" { + log.Fatalf("Must provide certificate file and private key file") + } + default: + log.Fatalf("Only 'http' and 'https' schemes are supported") + } + + // Validate issuer address. + iu, err := url.Parse(*issuer) + if err != nil { + log.Fatalf("Invalid issuer URL %q: %v", *issuer, err) + } + + if iu.Scheme != "http" && iu.Scheme != "https" { + log.Fatalf("Only 'http' and 'https' schemes are supported") } scfg := server.ServerConfig{ @@ -145,7 +165,11 @@ func main() { log.Infof("Binding to %s...", httpsrv.Addr) go func() { - log.Fatal(httpsrv.ListenAndServe()) + if lu.Scheme == "http" { + log.Fatal(httpsrv.ListenAndServe()) + } else { + log.Fatal(httpsrv.ListenAndServeTLS(*certFile, *keyFile)) + } }() <-srv.Run() diff --git a/examples/app/main.go b/examples/app/main.go index 68bb12e5..a8d1ed58 100644 --- a/examples/app/main.go +++ b/examples/app/main.go @@ -2,9 +2,12 @@ package main import ( "bytes" + "crypto/tls" + "crypto/x509" "encoding/json" "flag" "fmt" + "io/ioutil" "net" "net/http" "net/url" @@ -27,6 +30,8 @@ func main() { redirectURL := fs.String("redirect-url", "http://127.0.0.1:5555/callback", "") clientID := fs.String("client-id", "", "") clientSecret := fs.String("client-secret", "", "") + caFile := fs.String("ca-file", "", "the TLS CA file, if empty then the host's root CA will be used") + discovery := fs.String("discovery", "https://accounts.google.com", "") logDebug := fs.Bool("log-debug", false, "log debug-level information") logTimestamps := fs.Bool("log-timestamps", false, "prefix log lines with timestamps") @@ -71,9 +76,22 @@ func main() { Secret: *clientSecret, } + var tlsConfig tls.Config + if *caFile != "" { + roots := x509.NewCertPool() + pemBlock, err := ioutil.ReadFile(*caFile) + if err != nil { + log.Fatalf("Unable to read ca file: %v", err) + } + roots.AppendCertsFromPEM(pemBlock) + tlsConfig.RootCAs = roots + } + + httpClient := &http.Client{Transport: &http.Transport{TLSClientConfig: &tlsConfig}} + var cfg oidc.ProviderConfig for { - cfg, err = oidc.FetchProviderConfig(http.DefaultClient, *discovery) + cfg, err = oidc.FetchProviderConfig(httpClient, *discovery) if err == nil { break } @@ -86,6 +104,7 @@ func main() { log.Infof("Fetched provider config from %s: %#v", *discovery, cfg) ccfg := oidc.ClientConfig{ + HTTPClient: httpClient, ProviderConfig: cfg, Credentials: cc, RedirectURL: *redirectURL,