forked from mystiq/dex
dex-worker: add TLS support.
Add two new flags '--cert-file' and '--key-file'. If scheme == 'https', then we will use the two new flags to get the cert/key pair for TLS connection. Also add '--ca-file' to the example app to allow TLS connection to the dex-worker using a specified ca file.
This commit is contained in:
parent
5abc7633fb
commit
3da456efa8
2 changed files with 50 additions and 7 deletions
|
@ -27,8 +27,11 @@ func init() {
|
|||
|
||||
func main() {
|
||||
fs := flag.NewFlagSet("dex-worker", flag.ExitOnError)
|
||||
listen := fs.String("listen", "http://0.0.0.0:5556", "")
|
||||
issuer := fs.String("issuer", "http://127.0.0.1:5556", "")
|
||||
listen := fs.String("listen", "http://127.0.0.1:5556", "the address that the server will listen on")
|
||||
issuer := fs.String("issuer", "http://127.0.0.1:5556", "the issuer's location")
|
||||
certFile := fs.String("tls-cert-file", "", "the server's certificate file for TLS connection")
|
||||
keyFile := fs.String("tls-key-file", "", "the server's private key file for TLS connection")
|
||||
|
||||
templates := fs.String("html-assets", "./static/html", "directory of html template files")
|
||||
|
||||
emailTemplateDirs := flagutil.StringSliceFlag{"./static/email"}
|
||||
|
@ -75,13 +78,30 @@ func main() {
|
|||
log.EnableTimestamps()
|
||||
}
|
||||
|
||||
// Validate listen address.
|
||||
lu, err := url.Parse(*listen)
|
||||
if err != nil {
|
||||
log.Fatalf("Unable to use --listen flag: %v", err)
|
||||
log.Fatalf("Invalid listen address %q: %v", *listen, err)
|
||||
}
|
||||
|
||||
if lu.Scheme != "http" {
|
||||
log.Fatalf("Unable to listen using scheme %s", lu.Scheme)
|
||||
switch lu.Scheme {
|
||||
case "http":
|
||||
case "https":
|
||||
if *certFile == "" || *keyFile == "" {
|
||||
log.Fatalf("Must provide certificate file and private key file")
|
||||
}
|
||||
default:
|
||||
log.Fatalf("Only 'http' and 'https' schemes are supported")
|
||||
}
|
||||
|
||||
// Validate issuer address.
|
||||
iu, err := url.Parse(*issuer)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid issuer URL %q: %v", *issuer, err)
|
||||
}
|
||||
|
||||
if iu.Scheme != "http" && iu.Scheme != "https" {
|
||||
log.Fatalf("Only 'http' and 'https' schemes are supported")
|
||||
}
|
||||
|
||||
scfg := server.ServerConfig{
|
||||
|
@ -145,7 +165,11 @@ func main() {
|
|||
|
||||
log.Infof("Binding to %s...", httpsrv.Addr)
|
||||
go func() {
|
||||
log.Fatal(httpsrv.ListenAndServe())
|
||||
if lu.Scheme == "http" {
|
||||
log.Fatal(httpsrv.ListenAndServe())
|
||||
} else {
|
||||
log.Fatal(httpsrv.ListenAndServeTLS(*certFile, *keyFile))
|
||||
}
|
||||
}()
|
||||
|
||||
<-srv.Run()
|
||||
|
|
|
@ -2,9 +2,12 @@ package main
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -27,6 +30,8 @@ func main() {
|
|||
redirectURL := fs.String("redirect-url", "http://127.0.0.1:5555/callback", "")
|
||||
clientID := fs.String("client-id", "", "")
|
||||
clientSecret := fs.String("client-secret", "", "")
|
||||
caFile := fs.String("ca-file", "", "the TLS CA file, if empty then the host's root CA will be used")
|
||||
|
||||
discovery := fs.String("discovery", "https://accounts.google.com", "")
|
||||
logDebug := fs.Bool("log-debug", false, "log debug-level information")
|
||||
logTimestamps := fs.Bool("log-timestamps", false, "prefix log lines with timestamps")
|
||||
|
@ -71,9 +76,22 @@ func main() {
|
|||
Secret: *clientSecret,
|
||||
}
|
||||
|
||||
var tlsConfig tls.Config
|
||||
if *caFile != "" {
|
||||
roots := x509.NewCertPool()
|
||||
pemBlock, err := ioutil.ReadFile(*caFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Unable to read ca file: %v", err)
|
||||
}
|
||||
roots.AppendCertsFromPEM(pemBlock)
|
||||
tlsConfig.RootCAs = roots
|
||||
}
|
||||
|
||||
httpClient := &http.Client{Transport: &http.Transport{TLSClientConfig: &tlsConfig}}
|
||||
|
||||
var cfg oidc.ProviderConfig
|
||||
for {
|
||||
cfg, err = oidc.FetchProviderConfig(http.DefaultClient, *discovery)
|
||||
cfg, err = oidc.FetchProviderConfig(httpClient, *discovery)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
@ -86,6 +104,7 @@ func main() {
|
|||
log.Infof("Fetched provider config from %s: %#v", *discovery, cfg)
|
||||
|
||||
ccfg := oidc.ClientConfig{
|
||||
HTTPClient: httpClient,
|
||||
ProviderConfig: cfg,
|
||||
Credentials: cc,
|
||||
RedirectURL: *redirectURL,
|
||||
|
|
Loading…
Reference in a new issue