This repository has been archived on 2022-08-17. You can view files and clone it, but cannot push or open issues or pull requests.
dex/cmd/dex/serve.go
Steven Danna 59f8b02d47
Set minimum TLS protocol version to TLSv1.2, set PreferServerCipherSuites
Some environments are subject to strict rules about the permitted TLS
protocol verion and available ciphers. Setting TLSv1.2 as the minimum
version ensures we do not use weaker protocols. We've opted against
making this configurable given the age of TLSv1.2 and the increasing
push to deprecate TLSv1.1 and older.

The PreferServerCipherSuites setting is also commonly flagged by SSL
quality scanning tools. Since Go provides a relatively modern set of
default ciphers by default, defaulting this to true is unlikely to
make much practical difference.

Signed-off-by: Steven Danna <steve@chef.io>
2019-01-29 11:18:55 +00:00

356 lines
11 KiB
Go

package main
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"strings"
"time"
"github.com/ghodss/yaml"
grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"github.com/dexidp/dex/api"
"github.com/dexidp/dex/server"
"github.com/dexidp/dex/storage"
)
func commandServe() *cobra.Command {
return &cobra.Command{
Use: "serve [ config file ]",
Short: "Connect to the storage and begin serving requests.",
Long: ``,
Example: "dex serve config.yaml",
Run: func(cmd *cobra.Command, args []string) {
if err := serve(cmd, args); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(2)
}
},
}
}
func serve(cmd *cobra.Command, args []string) error {
switch len(args) {
default:
return errors.New("surplus arguments")
case 0:
// TODO(ericchiang): Consider having a default config file location.
return errors.New("no arguments provided")
case 1:
}
configFile := args[0]
configData, err := ioutil.ReadFile(configFile)
if err != nil {
return fmt.Errorf("failed to read config file %s: %v", configFile, err)
}
var c Config
if err := yaml.Unmarshal(configData, &c); err != nil {
return fmt.Errorf("error parse config file %s: %v", configFile, err)
}
logger, err := newLogger(c.Logger.Level, c.Logger.Format)
if err != nil {
return fmt.Errorf("invalid config: %v", err)
}
if c.Logger.Level != "" {
logger.Infof("config using log level: %s", c.Logger.Level)
}
// Fast checks. Perform these first for a more responsive CLI.
checks := []struct {
bad bool
errMsg string
}{
{c.Issuer == "", "no issuer specified in config file"},
{!c.EnablePasswordDB && len(c.StaticPasswords) != 0, "cannot specify static passwords without enabling password db"},
{c.Storage.Config == nil, "no storage supplied in config file"},
{c.Web.HTTP == "" && c.Web.HTTPS == "", "must supply a HTTP/HTTPS address to listen on"},
{c.Web.HTTPS != "" && c.Web.TLSCert == "", "no cert specified for HTTPS"},
{c.Web.HTTPS != "" && c.Web.TLSKey == "", "no private key specified for HTTPS"},
{c.GRPC.TLSCert != "" && c.GRPC.Addr == "", "no address specified for gRPC"},
{c.GRPC.TLSKey != "" && c.GRPC.Addr == "", "no address specified for gRPC"},
{(c.GRPC.TLSCert == "") != (c.GRPC.TLSKey == ""), "must specific both a gRPC TLS cert and key"},
{c.GRPC.TLSCert == "" && c.GRPC.TLSClientCA != "", "cannot specify gRPC TLS client CA without a gRPC TLS cert"},
}
for _, check := range checks {
if check.bad {
return fmt.Errorf("invalid config: %s", check.errMsg)
}
}
logger.Infof("config issuer: %s", c.Issuer)
prometheusRegistry := prometheus.NewRegistry()
err = prometheusRegistry.Register(prometheus.NewGoCollector())
if err != nil {
return fmt.Errorf("failed to register Go runtime metrics: %v", err)
}
err = prometheusRegistry.Register(prometheus.NewProcessCollector(os.Getpid(), ""))
if err != nil {
return fmt.Errorf("failed to register process metrics: %v", err)
}
grpcMetrics := grpcprometheus.NewServerMetrics()
err = prometheusRegistry.Register(grpcMetrics)
if err != nil {
return fmt.Errorf("failed to register gRPC server metrics: %v", err)
}
var grpcOptions []grpc.ServerOption
if c.GRPC.TLSCert != "" {
// Parse certificates from certificate file and key file for server.
cert, err := tls.LoadX509KeyPair(c.GRPC.TLSCert, c.GRPC.TLSKey)
if err != nil {
return fmt.Errorf("invalid config: error parsing gRPC certificate file: %v", err)
}
tlsConfig := tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
PreferServerCipherSuites: true,
}
if c.GRPC.TLSClientCA != "" {
// Parse certificates from client CA file to a new CertPool.
cPool := x509.NewCertPool()
clientCert, err := ioutil.ReadFile(c.GRPC.TLSClientCA)
if err != nil {
return fmt.Errorf("invalid config: reading from client CA file: %v", err)
}
if cPool.AppendCertsFromPEM(clientCert) != true {
return errors.New("invalid config: failed to parse client CA")
}
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = cPool
// Only add metrics if client auth is enabled
grpcOptions = append(grpcOptions,
grpc.StreamInterceptor(grpcMetrics.StreamServerInterceptor()),
grpc.UnaryInterceptor(grpcMetrics.UnaryServerInterceptor()),
)
}
grpcOptions = append(grpcOptions, grpc.Creds(credentials.NewTLS(&tlsConfig)))
}
s, err := c.Storage.Config.Open(logger)
if err != nil {
return fmt.Errorf("failed to initialize storage: %v", err)
}
logger.Infof("config storage: %s", c.Storage.Type)
if len(c.StaticClients) > 0 {
for _, client := range c.StaticClients {
logger.Infof("config static client: %s", client.ID)
}
s = storage.WithStaticClients(s, c.StaticClients)
}
if len(c.StaticPasswords) > 0 {
passwords := make([]storage.Password, len(c.StaticPasswords))
for i, p := range c.StaticPasswords {
passwords[i] = storage.Password(p)
}
s = storage.WithStaticPasswords(s, passwords, logger)
}
storageConnectors := make([]storage.Connector, len(c.StaticConnectors))
for i, c := range c.StaticConnectors {
if c.ID == "" || c.Name == "" || c.Type == "" {
return fmt.Errorf("invalid config: ID, Type and Name fields are required for a connector")
}
if c.Config == nil {
return fmt.Errorf("invalid config: no config field for connector %q", c.ID)
}
logger.Infof("config connector: %s", c.ID)
// convert to a storage connector object
conn, err := ToStorageConnector(c)
if err != nil {
return fmt.Errorf("failed to initialize storage connectors: %v", err)
}
storageConnectors[i] = conn
}
if c.EnablePasswordDB {
storageConnectors = append(storageConnectors, storage.Connector{
ID: server.LocalConnector,
Name: "Email",
Type: server.LocalConnector,
})
logger.Infof("config connector: local passwords enabled")
}
s = storage.WithStaticConnectors(s, storageConnectors)
if len(c.OAuth2.ResponseTypes) > 0 {
logger.Infof("config response types accepted: %s", c.OAuth2.ResponseTypes)
}
if c.OAuth2.SkipApprovalScreen {
logger.Infof("config skipping approval screen")
}
if len(c.Web.AllowedOrigins) > 0 {
logger.Infof("config allowed origins: %s", c.Web.AllowedOrigins)
}
// explicitly convert to UTC.
now := func() time.Time { return time.Now().UTC() }
serverConfig := server.Config{
SupportedResponseTypes: c.OAuth2.ResponseTypes,
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
AllowedOrigins: c.Web.AllowedOrigins,
Issuer: c.Issuer,
Storage: s,
Web: c.Frontend,
Logger: logger,
Now: now,
PrometheusRegistry: prometheusRegistry,
}
if c.Expiry.SigningKeys != "" {
signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys)
if err != nil {
return fmt.Errorf("invalid config value %q for signing keys expiry: %v", c.Expiry.SigningKeys, err)
}
logger.Infof("config signing keys expire after: %v", signingKeys)
serverConfig.RotateKeysAfter = signingKeys
}
if c.Expiry.IDTokens != "" {
idTokens, err := time.ParseDuration(c.Expiry.IDTokens)
if err != nil {
return fmt.Errorf("invalid config value %q for id token expiry: %v", c.Expiry.IDTokens, err)
}
logger.Infof("config id tokens valid for: %v", idTokens)
serverConfig.IDTokensValidFor = idTokens
}
if c.Expiry.AuthRequests != "" {
authRequests, err := time.ParseDuration(c.Expiry.AuthRequests)
if err != nil {
return fmt.Errorf("invalid config value %q for auth request expiry: %v", c.Expiry.AuthRequests, err)
}
logger.Infof("config auth requests valid for: %v", authRequests)
serverConfig.AuthRequestsValidFor = authRequests
}
serv, err := server.NewServer(context.Background(), serverConfig)
if err != nil {
return fmt.Errorf("failed to initialize server: %v", err)
}
telemetryServ := http.NewServeMux()
telemetryServ.Handle("/metrics", promhttp.HandlerFor(prometheusRegistry, promhttp.HandlerOpts{}))
errc := make(chan error, 3)
if c.Telemetry.HTTP != "" {
logger.Infof("listening (http/telemetry) on %s", c.Telemetry.HTTP)
go func() {
err := http.ListenAndServe(c.Telemetry.HTTP, telemetryServ)
errc <- fmt.Errorf("listening on %s failed: %v", c.Telemetry.HTTP, err)
}()
}
if c.Web.HTTP != "" {
logger.Infof("listening (http) on %s", c.Web.HTTP)
go func() {
err := http.ListenAndServe(c.Web.HTTP, serv)
errc <- fmt.Errorf("listening on %s failed: %v", c.Web.HTTP, err)
}()
}
if c.Web.HTTPS != "" {
httpsSrv := &http.Server{
Addr: c.Web.HTTPS,
Handler: serv,
TLSConfig: &tls.Config{
PreferServerCipherSuites: true,
MinVersion: tls.VersionTLS12,
},
}
logger.Infof("listening (https) on %s", c.Web.HTTPS)
go func() {
err = httpsSrv.ListenAndServeTLS(c.Web.TLSCert, c.Web.TLSKey)
errc <- fmt.Errorf("listening on %s failed: %v", c.Web.HTTPS, err)
}()
}
if c.GRPC.Addr != "" {
logger.Infof("listening (grpc) on %s", c.GRPC.Addr)
go func() {
errc <- func() error {
list, err := net.Listen("tcp", c.GRPC.Addr)
if err != nil {
return fmt.Errorf("listening on %s failed: %v", c.GRPC.Addr, err)
}
s := grpc.NewServer(grpcOptions...)
api.RegisterDexServer(s, server.NewAPI(serverConfig.Storage, logger))
grpcMetrics.InitializeMetrics(s)
err = s.Serve(list)
return fmt.Errorf("listening on %s failed: %v", c.GRPC.Addr, err)
}()
}()
}
return <-errc
}
var (
logLevels = []string{"debug", "info", "error"}
logFormats = []string{"json", "text"}
)
type utcFormatter struct {
f logrus.Formatter
}
func (f *utcFormatter) Format(e *logrus.Entry) ([]byte, error) {
e.Time = e.Time.UTC()
return f.f.Format(e)
}
func newLogger(level string, format string) (logrus.FieldLogger, error) {
var logLevel logrus.Level
switch strings.ToLower(level) {
case "debug":
logLevel = logrus.DebugLevel
case "", "info":
logLevel = logrus.InfoLevel
case "error":
logLevel = logrus.ErrorLevel
default:
return nil, fmt.Errorf("log level is not one of the supported values (%s): %s", strings.Join(logLevels, ", "), level)
}
var formatter utcFormatter
switch strings.ToLower(format) {
case "", "text":
formatter.f = &logrus.TextFormatter{DisableColors: true}
case "json":
formatter.f = &logrus.JSONFormatter{}
default:
return nil, fmt.Errorf("log format is not one of the supported values (%s): %s", strings.Join(logFormats, ", "), format)
}
return &logrus.Logger{
Out: os.Stderr,
Formatter: &formatter,
Level: logLevel,
}, nil
}