dex/server/server.go

180 lines
4.1 KiB
Go
Raw Normal View History

2016-07-26 01:30:28 +05:30
package server
import (
"errors"
"fmt"
"net/http"
"net/url"
"path"
2016-08-11 09:21:58 +05:30
"sync/atomic"
2016-07-26 01:30:28 +05:30
"time"
"github.com/gorilla/mux"
2016-08-11 11:01:42 +05:30
"github.com/coreos/dex/connector"
"github.com/coreos/dex/storage"
2016-07-26 01:30:28 +05:30
)
// Connector is a connector with metadata.
type Connector struct {
ID string
DisplayName string
Connector connector.Connector
}
// Config holds the server's configuration options.
type Config struct {
Issuer string
// The backing persistence layer.
Storage storage.Storage
// Strategies for federated identity.
Connectors []Connector
// NOTE: Multiple servers using the same storage are expected to set rotation and
// validity periods to the same values.
RotateKeysAfter time.Duration // Defaults to 6 hours.
IDTokensValidFor time.Duration // Defaults to 24 hours
// If specified, the server will use this function for determining time.
Now func() time.Time
}
func value(val, defaultValue time.Duration) time.Duration {
if val == 0 {
return defaultValue
}
return val
}
// Server is the top level object.
type Server struct {
issuerURL url.URL
// Read-only map of connector IDs to connectors.
connectors map[string]Connector
storage storage.Storage
mux http.Handler
// If enabled, don't prompt user for approval after logging in through connector.
// No package level API to set this, only used in tests.
skipApproval bool
now func() time.Time
idTokensValidFor time.Duration
}
// New constructs a server from the provided config.
func New(c Config) (*Server, error) {
return newServer(c, defaultRotationStrategy(
value(c.RotateKeysAfter, 6*time.Hour),
value(c.IDTokensValidFor, 24*time.Hour),
))
}
func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
issuerURL, err := url.Parse(c.Issuer)
if err != nil {
return nil, fmt.Errorf("server: can't parse issuer URL")
}
if len(c.Connectors) == 0 {
return nil, errors.New("server: no connectors specified")
}
if c.Storage == nil {
return nil, errors.New("server: storage cannot be nil")
}
now := c.Now
if now == nil {
now = time.Now
}
s := &Server{
2016-08-11 09:21:58 +05:30
issuerURL: *issuerURL,
connectors: make(map[string]Connector),
storage: newKeyCacher(
storageWithKeyRotation(
c.Storage, rotationStrategy, now,
),
now,
),
2016-07-26 01:30:28 +05:30
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
now: now,
}
for _, conn := range c.Connectors {
s.connectors[conn.ID] = conn
}
r := mux.NewRouter()
handleFunc := func(p string, h http.HandlerFunc) {
r.HandleFunc(path.Join(issuerURL.Path, p), h)
}
r.NotFoundHandler = http.HandlerFunc(s.notFound)
// TODO(ericchiang): rate limit certain paths based on IP.
handleFunc("/.well-known/openid-configuration", s.handleDiscovery)
handleFunc("/token", s.handleToken)
handleFunc("/keys", s.handlePublicKeys)
handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin)
handleFunc("/callback/{connector}", s.handleConnectorCallback)
handleFunc("/approval", s.handleApproval)
s.mux = r
return s, nil
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.mux.ServeHTTP(w, r)
}
func (s *Server) absPath(pathItems ...string) string {
paths := make([]string, len(pathItems)+1)
paths[0] = s.issuerURL.Path
copy(paths[1:], pathItems)
return path.Join(paths...)
}
func (s *Server) absURL(pathItems ...string) string {
u := s.issuerURL
u.Path = s.absPath(pathItems...)
return u.String()
}
2016-08-11 09:21:58 +05:30
// newKeyCacher returns a storage which caches keys so long as the next
func newKeyCacher(s storage.Storage, now func() time.Time) storage.Storage {
if now == nil {
now = time.Now
}
return &keyCacher{Storage: s, now: now}
}
type keyCacher struct {
storage.Storage
now func() time.Time
keys atomic.Value // Always holds nil or type *storage.Keys.
}
func (k *keyCacher) GetKeys() (storage.Keys, error) {
keys, ok := k.keys.Load().(*storage.Keys)
if ok && keys != nil && k.now().Before(keys.NextRotation) {
return *keys, nil
}
storageKeys, err := k.Storage.GetKeys()
if err != nil {
return storageKeys, err
}
if k.now().Before(storageKeys.NextRotation) {
k.keys.Store(&storageKeys)
}
return storageKeys, nil
}