dex/cmd/dex/serve.go
m.nabokikh 06c8ab5aa7 Fixes of naming and code style
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
2021-02-10 23:37:57 +04:00

470 lines
14 KiB
Go

package main
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"strings"
"syscall"
"time"
"github.com/ghodss/yaml"
grpcprometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
"github.com/oklog/run"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/reflection"
"github.com/dexidp/dex/api/v2"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/server"
"github.com/dexidp/dex/storage"
)
type serveOptions struct {
// Config file path
config string
// Flags
webHTTPAddr string
webHTTPSAddr string
telemetryAddr string
grpcAddr string
}
func commandServe() *cobra.Command {
options := serveOptions{}
cmd := &cobra.Command{
Use: "serve [flags] [config file]",
Short: "Launch Dex",
Example: "dex serve config.yaml",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
cmd.SilenceUsage = true
cmd.SilenceErrors = true
options.config = args[0]
return runServe(options)
},
}
flags := cmd.Flags()
flags.StringVar(&options.webHTTPAddr, "web-http-addr", "", "Web HTTP address")
flags.StringVar(&options.webHTTPSAddr, "web-https-addr", "", "Web HTTPS address")
flags.StringVar(&options.telemetryAddr, "telemetry-addr", "", "Telemetry address")
flags.StringVar(&options.grpcAddr, "grpc-addr", "", "gRPC API address")
return cmd
}
func listenAndShutdownGracefully(logger log.Logger, gr *run.Group, srv *http.Server, name string) error {
l, err := net.Listen("tcp", srv.Addr)
if err != nil {
return fmt.Errorf("listening (%s) on %s: %v", name, srv.Addr, err)
}
gr.Add(func() error {
logger.Infof("listening (%s) on %s", name, srv.Addr)
return srv.Serve(l)
}, func(err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
logger.Debugf("starting graceful shutdown (%s)", name)
if err := srv.Shutdown(ctx); err != nil {
logger.Errorf("graceful shutdown (%s): %v", name, err)
}
})
return nil
}
func runServe(options serveOptions) error {
configFile := options.config
configData, err := ioutil.ReadFile(configFile)
if err != nil {
return fmt.Errorf("failed to read config file %s: %v", configFile, err)
}
var c Config
if err := yaml.Unmarshal(configData, &c); err != nil {
return fmt.Errorf("error parse config file %s: %v", configFile, err)
}
applyConfigOverrides(options, &c)
logger, err := newLogger(c.Logger.Level, c.Logger.Format)
if err != nil {
return fmt.Errorf("invalid config: %v", err)
}
if c.Logger.Level != "" {
logger.Infof("config using log level: %s", c.Logger.Level)
}
if err := c.Validate(); err != nil {
return err
}
logger.Infof("config issuer: %s", c.Issuer)
prometheusRegistry := prometheus.NewRegistry()
err = prometheusRegistry.Register(prometheus.NewGoCollector())
if err != nil {
return fmt.Errorf("failed to register Go runtime metrics: %v", err)
}
err = prometheusRegistry.Register(prometheus.NewProcessCollector(prometheus.ProcessCollectorOpts{}))
if err != nil {
return fmt.Errorf("failed to register process metrics: %v", err)
}
grpcMetrics := grpcprometheus.NewServerMetrics()
err = prometheusRegistry.Register(grpcMetrics)
if err != nil {
return fmt.Errorf("failed to register gRPC server metrics: %v", err)
}
var grpcOptions []grpc.ServerOption
allowedTLSCiphers := []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
}
if c.GRPC.TLSCert != "" {
// Parse certificates from certificate file and key file for server.
cert, err := tls.LoadX509KeyPair(c.GRPC.TLSCert, c.GRPC.TLSKey)
if err != nil {
return fmt.Errorf("invalid config: error parsing gRPC certificate file: %v", err)
}
tlsConfig := tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
CipherSuites: allowedTLSCiphers,
PreferServerCipherSuites: true,
}
if c.GRPC.TLSClientCA != "" {
// Parse certificates from client CA file to a new CertPool.
cPool := x509.NewCertPool()
clientCert, err := ioutil.ReadFile(c.GRPC.TLSClientCA)
if err != nil {
return fmt.Errorf("invalid config: reading from client CA file: %v", err)
}
if !cPool.AppendCertsFromPEM(clientCert) {
return errors.New("invalid config: failed to parse client CA")
}
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = cPool
// Only add metrics if client auth is enabled
grpcOptions = append(grpcOptions,
grpc.StreamInterceptor(grpcMetrics.StreamServerInterceptor()),
grpc.UnaryInterceptor(grpcMetrics.UnaryServerInterceptor()),
)
}
grpcOptions = append(grpcOptions, grpc.Creds(credentials.NewTLS(&tlsConfig)))
}
s, err := c.Storage.Config.Open(logger)
if err != nil {
return fmt.Errorf("failed to initialize storage: %v", err)
}
logger.Infof("config storage: %s", c.Storage.Type)
if len(c.StaticClients) > 0 {
for i, client := range c.StaticClients {
if client.Name == "" {
return fmt.Errorf("invalid config: Name field is required for a client")
}
if client.ID == "" && client.IDEnv == "" {
return fmt.Errorf("invalid config: ID or IDEnv field is required for a client")
}
if client.IDEnv != "" {
if client.ID != "" {
return fmt.Errorf("invalid config: ID and IDEnv fields are exclusive for client %q", client.ID)
}
c.StaticClients[i].ID = os.Getenv(client.IDEnv)
}
if client.Secret == "" && client.SecretEnv == "" && !client.Public {
return fmt.Errorf("invalid config: Secret or SecretEnv field is required for client %q", client.ID)
}
if client.SecretEnv != "" {
if client.Secret != "" {
return fmt.Errorf("invalid config: Secret and SecretEnv fields are exclusive for client %q", client.ID)
}
c.StaticClients[i].Secret = os.Getenv(client.SecretEnv)
}
logger.Infof("config static client: %s", client.Name)
}
s = storage.WithStaticClients(s, c.StaticClients)
}
if len(c.StaticPasswords) > 0 {
passwords := make([]storage.Password, len(c.StaticPasswords))
for i, p := range c.StaticPasswords {
passwords[i] = storage.Password(p)
}
s = storage.WithStaticPasswords(s, passwords, logger)
}
storageConnectors := make([]storage.Connector, len(c.StaticConnectors))
for i, c := range c.StaticConnectors {
if c.ID == "" || c.Name == "" || c.Type == "" {
return fmt.Errorf("invalid config: ID, Type and Name fields are required for a connector")
}
if c.Config == nil {
return fmt.Errorf("invalid config: no config field for connector %q", c.ID)
}
logger.Infof("config connector: %s", c.ID)
// convert to a storage connector object
conn, err := ToStorageConnector(c)
if err != nil {
return fmt.Errorf("failed to initialize storage connectors: %v", err)
}
storageConnectors[i] = conn
}
if c.EnablePasswordDB {
storageConnectors = append(storageConnectors, storage.Connector{
ID: server.LocalConnector,
Name: "Email",
Type: server.LocalConnector,
})
logger.Infof("config connector: local passwords enabled")
}
s = storage.WithStaticConnectors(s, storageConnectors)
if len(c.OAuth2.ResponseTypes) > 0 {
logger.Infof("config response types accepted: %s", c.OAuth2.ResponseTypes)
}
if c.OAuth2.SkipApprovalScreen {
logger.Infof("config skipping approval screen")
}
if c.OAuth2.PasswordConnector != "" {
logger.Infof("config using password grant connector: %s", c.OAuth2.PasswordConnector)
}
if len(c.Web.AllowedOrigins) > 0 {
logger.Infof("config allowed origins: %s", c.Web.AllowedOrigins)
}
// explicitly convert to UTC.
now := func() time.Time { return time.Now().UTC() }
serverConfig := server.Config{
SupportedResponseTypes: c.OAuth2.ResponseTypes,
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
AlwaysShowLoginScreen: c.OAuth2.AlwaysShowLoginScreen,
PasswordConnector: c.OAuth2.PasswordConnector,
AllowedOrigins: c.Web.AllowedOrigins,
Issuer: c.Issuer,
Storage: s,
Web: c.Frontend,
Logger: logger,
Now: now,
PrometheusRegistry: prometheusRegistry,
}
if c.Expiry.SigningKeys != "" {
signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys)
if err != nil {
return fmt.Errorf("invalid config value %q for signing keys expiry: %v", c.Expiry.SigningKeys, err)
}
logger.Infof("config signing keys expire after: %v", signingKeys)
serverConfig.RotateKeysAfter = signingKeys
}
if c.Expiry.IDTokens != "" {
idTokens, err := time.ParseDuration(c.Expiry.IDTokens)
if err != nil {
return fmt.Errorf("invalid config value %q for id token expiry: %v", c.Expiry.IDTokens, err)
}
logger.Infof("config id tokens valid for: %v", idTokens)
serverConfig.IDTokensValidFor = idTokens
}
if c.Expiry.AuthRequests != "" {
authRequests, err := time.ParseDuration(c.Expiry.AuthRequests)
if err != nil {
return fmt.Errorf("invalid config value %q for auth request expiry: %v", c.Expiry.AuthRequests, err)
}
logger.Infof("config auth requests valid for: %v", authRequests)
serverConfig.AuthRequestsValidFor = authRequests
}
if c.Expiry.DeviceRequests != "" {
deviceRequests, err := time.ParseDuration(c.Expiry.DeviceRequests)
if err != nil {
return fmt.Errorf("invalid config value %q for device request expiry: %v", c.Expiry.AuthRequests, err)
}
logger.Infof("config device requests valid for: %v", deviceRequests)
serverConfig.DeviceRequestsValidFor = deviceRequests
}
refreshTokenPolicy, err := server.NewRefreshTokenPolicy(
logger,
c.Expiry.RefreshTokens.DisableRotation,
c.Expiry.RefreshTokens.ValidIfNotUsedFor,
c.Expiry.RefreshTokens.AbsoluteLifetime,
c.Expiry.RefreshTokens.ReuseInterval,
)
if err != nil {
return fmt.Errorf("invalid refresh token expiration policy config: %v", err)
}
serverConfig.RefreshTokenPolicy = refreshTokenPolicy
serv, err := server.NewServer(context.Background(), serverConfig)
if err != nil {
return fmt.Errorf("failed to initialize server: %v", err)
}
telemetryServ := http.NewServeMux()
telemetryServ.Handle("/metrics", promhttp.HandlerFor(prometheusRegistry, promhttp.HandlerOpts{}))
var gr run.Group
if c.Telemetry.HTTP != "" {
telemetrySrv := &http.Server{Addr: c.Telemetry.HTTP, Handler: telemetryServ}
defer telemetrySrv.Close()
if err := listenAndShutdownGracefully(logger, &gr, telemetrySrv, "http/telemetry"); err != nil {
return err
}
}
if c.Web.HTTP != "" {
httpSrv := &http.Server{Addr: c.Web.HTTP, Handler: serv}
defer httpSrv.Close()
if err := listenAndShutdownGracefully(logger, &gr, httpSrv, "http"); err != nil {
return err
}
}
if c.Web.HTTPS != "" {
httpsSrv := &http.Server{
Addr: c.Web.HTTPS,
Handler: serv,
TLSConfig: &tls.Config{
CipherSuites: allowedTLSCiphers,
PreferServerCipherSuites: true,
MinVersion: tls.VersionTLS12,
},
}
defer httpsSrv.Close()
if err := listenAndShutdownGracefully(logger, &gr, httpsSrv, "https"); err != nil {
return err
}
}
if c.GRPC.Addr != "" {
grpcListener, err := net.Listen("tcp", c.GRPC.Addr)
if err != nil {
return fmt.Errorf("listening (grcp) on %s: %w", c.GRPC.Addr, err)
}
grpcSrv := grpc.NewServer(grpcOptions...)
api.RegisterDexServer(grpcSrv, server.NewAPI(serverConfig.Storage, logger))
grpcMetrics.InitializeMetrics(grpcSrv)
if c.GRPC.Reflection {
logger.Info("enabling reflection in grpc service")
reflection.Register(grpcSrv)
}
gr.Add(func() error {
logger.Infof("listening (grpc) on %s", c.GRPC.Addr)
return grpcSrv.Serve(grpcListener)
}, func(err error) {
logger.Debugf("starting graceful shutdown (grpc)")
grpcSrv.GracefulStop()
})
}
gr.Add(run.SignalHandler(context.Background(), os.Interrupt, syscall.SIGTERM))
if err := gr.Run(); err != nil {
if _, ok := err.(run.SignalError); !ok {
return fmt.Errorf("run groups: %w", err)
}
logger.Infof("%v, shutdown now", err)
}
return nil
}
var (
logLevels = []string{"debug", "info", "error"}
logFormats = []string{"json", "text"}
)
type utcFormatter struct {
f logrus.Formatter
}
func (f *utcFormatter) Format(e *logrus.Entry) ([]byte, error) {
e.Time = e.Time.UTC()
return f.f.Format(e)
}
func newLogger(level string, format string) (log.Logger, error) {
var logLevel logrus.Level
switch strings.ToLower(level) {
case "debug":
logLevel = logrus.DebugLevel
case "", "info":
logLevel = logrus.InfoLevel
case "error":
logLevel = logrus.ErrorLevel
default:
return nil, fmt.Errorf("log level is not one of the supported values (%s): %s", strings.Join(logLevels, ", "), level)
}
var formatter utcFormatter
switch strings.ToLower(format) {
case "", "text":
formatter.f = &logrus.TextFormatter{DisableColors: true}
case "json":
formatter.f = &logrus.JSONFormatter{}
default:
return nil, fmt.Errorf("log format is not one of the supported values (%s): %s", strings.Join(logFormats, ", "), format)
}
return &logrus.Logger{
Out: os.Stderr,
Formatter: &formatter,
Level: logLevel,
}, nil
}
func applyConfigOverrides(options serveOptions, config *Config) {
if options.webHTTPAddr != "" {
config.Web.HTTP = options.webHTTPAddr
}
if options.webHTTPSAddr != "" {
config.Web.HTTPS = options.webHTTPSAddr
}
if options.telemetryAddr != "" {
config.Telemetry.HTTP = options.telemetryAddr
}
if options.grpcAddr != "" {
config.GRPC.Addr = options.grpcAddr
}
}