This repository has been archived on 2022-08-17. You can view files and clone it, but cannot push or open issues or pull requests.
dex/server/server.go
2016-07-19 11:23:04 -07:00

734 lines
22 KiB
Go

package server
import (
"errors"
"fmt"
"html/template"
"net/http"
"net/url"
"path"
"sort"
"time"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
"github.com/coreos/pkg/health"
"github.com/go-gorp/gorp"
"github.com/jonboulle/clockwork"
"github.com/coreos/dex/client"
clientmanager "github.com/coreos/dex/client/manager"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/scope"
"github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
usersapi "github.com/coreos/dex/user/api"
useremail "github.com/coreos/dex/user/email"
usermanager "github.com/coreos/dex/user/manager"
)
const (
LoginPageTemplateName = "login.html"
RegisterTemplateName = "register.html"
VerifyEmailTemplateName = "verify-email.html"
SendResetPasswordEmailTemplateName = "send-reset-password.html"
ResetPasswordTemplateName = "reset-password.html"
OOBTemplateName = "oob-template.html"
APIVersion = "v1"
)
type OIDCServer interface {
Client(string) (client.Client, error)
NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error)
Login(oidc.Identity, string) (string, error)
// CodeToken exchanges a code for an ID token and a refresh token string on success.
CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error)
ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error)
// RefreshToken takes a previously generated refresh token and returns a new ID token
// if the token is valid.
RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error)
KillSession(string) error
CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error)
}
type JWTVerifierFactory func(clientID string) oidc.JWTVerifier
type Server struct {
IssuerURL url.URL
Templates *template.Template
LoginTemplate *template.Template
RegisterTemplate *template.Template
VerifyEmailTemplate *template.Template
SendResetPasswordEmailTemplate *template.Template
ResetPasswordTemplate *template.Template
OOBTemplate *template.Template
HealthChecks []health.Checkable
// TODO(ericchiang): Make this a map of ID to connector.
Connectors []connector.Connector
ClientRepo client.ClientRepo
ConnectorConfigRepo connector.ConnectorConfigRepo
KeySetRepo key.PrivateKeySetRepo
RefreshTokenRepo refresh.RefreshTokenRepo
UserRepo user.UserRepo
PasswordInfoRepo user.PasswordInfoRepo
ClientManager *clientmanager.ClientManager
KeyManager key.PrivateKeyManager
SessionManager *sessionmanager.SessionManager
UserManager *usermanager.UserManager
UserEmailer *useremail.UserEmailer
EnableRegistration bool
EnableClientRegistration bool
RegisterOnFirstLogin bool
dbMap *gorp.DbMap
localConnectorID string
}
func (s *Server) Run() chan struct{} {
stop := make(chan struct{})
chans := []chan struct{}{
key.NewKeySetSyncer(s.KeySetRepo, s.KeyManager).Run(),
}
for _, idpc := range s.Connectors {
chans = append(chans, idpc.Sync())
}
go func() {
<-stop
for _, ch := range chans {
close(ch)
}
}()
return stop
}
func (s *Server) KillSession(sessionKey string) error {
sessionID, err := s.SessionManager.ExchangeKey(sessionKey)
if err != nil {
return err
}
_, err = s.SessionManager.Kill(sessionID)
return err
}
func (s *Server) ProviderConfig() oidc.ProviderConfig {
authEndpoint := s.absURL(httpPathAuth)
tokenEndpoint := s.absURL(httpPathToken)
keysEndpoint := s.absURL(httpPathKeys)
cfg := oidc.ProviderConfig{
Issuer: &s.IssuerURL,
AuthEndpoint: &authEndpoint,
TokenEndpoint: &tokenEndpoint,
KeysEndpoint: &keysEndpoint,
GrantTypesSupported: []string{oauth2.GrantTypeAuthCode, oauth2.GrantTypeClientCreds},
ResponseTypesSupported: []string{"code"},
SubjectTypesSupported: []string{"public"},
IDTokenSigningAlgValues: []string{"RS256"},
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"},
}
if s.EnableClientRegistration {
regEndpoint := s.absURL(httpPathClientRegistration)
cfg.RegistrationEndpoint = &regEndpoint
}
return cfg
}
func (s *Server) absURL(paths ...string) url.URL {
url := s.IssuerURL
paths = append([]string{url.Path}, paths...)
url.Path = path.Join(paths...)
return url
}
func (s *Server) AddConnector(cfg connector.ConnectorConfig) error {
connectorID := cfg.ConnectorID()
ns := s.IssuerURL
ns.Path = path.Join(ns.Path, httpPathAuth, connectorID)
idpc, err := cfg.Connector(ns, s.Login, s.Templates)
if err != nil {
return err
}
s.Connectors = append(s.Connectors, idpc)
sortable := sortableIDPCs(s.Connectors)
sort.Sort(sortable)
// We handle the LocalConnector specially because it needs access to the
// UserRepo and the PasswordInfoRepo; if it turns out that other connectors
// need access to these resources we'll figure out how to provide it in a
// cleaner manner.
localConn, ok := idpc.(*connector.LocalConnector)
if ok {
s.localConnectorID = connectorID
if s.UserRepo == nil {
return errors.New("UserRepo cannot be nil")
}
if s.PasswordInfoRepo == nil {
return errors.New("PasswordInfoRepo cannot be nil")
}
localConn.SetLocalIdentityProvider(&connector.LocalIdentityProvider{
UserRepo: s.UserRepo,
PasswordInfoRepo: s.PasswordInfoRepo,
})
}
log.Infof("Loaded IdP connector: id=%s type=%s", connectorID, cfg.ConnectorType())
return nil
}
func (s *Server) HTTPHandler() http.Handler {
checks := make([]health.Checkable, len(s.HealthChecks))
copy(checks, s.HealthChecks)
for _, idpc := range s.Connectors {
idpc := idpc
checks = append(checks, idpc)
}
clock := clockwork.NewRealClock()
mux := http.NewServeMux()
mux.HandleFunc(httpPathDiscovery, handleDiscoveryFunc(s.ProviderConfig()))
mux.HandleFunc(httpPathAuth, handleAuthFunc(s, s.Connectors, s.LoginTemplate, s.EnableRegistration))
mux.HandleFunc(httpPathOOB, handleOOBFunc(s, s.OOBTemplate))
mux.HandleFunc(httpPathToken, handleTokenFunc(s))
mux.HandleFunc(httpPathKeys, handleKeysFunc(s.KeyManager, clock))
mux.Handle(httpPathHealth, makeHealthHandler(checks))
if s.EnableRegistration {
mux.HandleFunc(httpPathRegister, handleRegisterFunc(s, s.RegisterTemplate))
}
mux.HandleFunc(httpPathEmailVerify, handleEmailVerifyFunc(s.VerifyEmailTemplate,
s.IssuerURL, s.KeyManager.PublicKeys, s.UserManager))
mux.Handle(httpPathVerifyEmailResend, s.NewClientTokenAuthHandler(handleVerifyEmailResendFunc(s.IssuerURL,
s.KeyManager.PublicKeys,
s.UserEmailer,
s.UserRepo,
s.ClientManager)))
mux.Handle(httpPathSendResetPassword, &SendResetPasswordEmailHandler{
tpl: s.SendResetPasswordEmailTemplate,
emailer: s.UserEmailer,
sm: s.SessionManager,
cm: s.ClientManager,
})
mux.Handle(httpPathResetPassword, &ResetPasswordHandler{
tpl: s.ResetPasswordTemplate,
issuerURL: s.IssuerURL,
um: s.UserManager,
keysFunc: s.KeyManager.PublicKeys,
})
mux.Handle(httpPathAcceptInvitation, &InvitationHandler{
passwordResetURL: s.absURL(httpPathResetPassword),
issuerURL: s.IssuerURL,
um: s.UserManager,
keysFunc: s.KeyManager.PublicKeys,
signerFunc: s.KeyManager.Signer,
redirectValidityWindow: s.SessionManager.ValidityWindow,
})
if s.EnableClientRegistration {
mux.HandleFunc(httpPathClientRegistration, s.handleClientRegistration)
}
mux.HandleFunc(httpPathDebugVars, health.ExpvarHandler)
pcfg := s.ProviderConfig()
for _, idpc := range s.Connectors {
errorURL, err := url.Parse(fmt.Sprintf("%s?connector_id=%s", pcfg.AuthEndpoint, idpc.ID()))
if err != nil {
log.Fatal(err)
}
idpc.Register(mux, *errorURL)
}
apiBasePath := path.Join(httpPathAPI, APIVersion)
registerDiscoveryResource(apiBasePath, mux)
usersAPI := usersapi.NewUsersAPI(s.UserManager, s.ClientManager, s.RefreshTokenRepo, s.UserEmailer, s.localConnectorID)
handler := NewUserMgmtServer(usersAPI, s.JWTVerifierFactory(), s.UserManager, s.ClientManager).HTTPHandler()
mux.Handle(apiBasePath+"/", handler)
return http.Handler(mux)
}
// NewClientTokenAuthHandler returns the given handler wrapped in middleware which requires a Client Bearer token.
func (s *Server) NewClientTokenAuthHandler(handler http.Handler) http.Handler {
return &clientTokenMiddleware{
issuerURL: s.IssuerURL.String(),
ciManager: s.ClientManager,
keysFunc: s.KeyManager.PublicKeys,
next: handler,
}
}
func (s *Server) Client(clientID string) (client.Client, error) {
return s.ClientManager.Get(clientID)
}
func (s *Server) NewSession(ipdcID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
sessionID, err := s.SessionManager.NewSession(ipdcID, clientID, clientState, redirectURL, nonce, register, scope)
if err != nil {
return "", err
}
log.Infof("Session %s created: clientID=%s clientState=%s", sessionID, clientID, clientState)
return s.SessionManager.NewSessionKey(sessionID)
}
func (s *Server) connector(id string) (connector.Connector, bool) {
for _, c := range s.Connectors {
if c.ID() == id {
return c, true
}
}
return nil, false
}
func (s *Server) Login(ident oidc.Identity, key string) (string, error) {
sessionID, err := s.SessionManager.ExchangeKey(key)
if err != nil {
return "", err
}
ses, err := s.SessionManager.AttachRemoteIdentity(sessionID, ident)
if err != nil {
return "", err
}
log.Infof("Session %s remote identity attached: clientID=%s identity=%#v", sessionID, ses.ClientID, ident)
// Get the connector used to log the user in.
conn, ok := s.connector(ses.ConnectorID)
if !ok {
return "", fmt.Errorf("session contained invalid connector ID (%s)", ses.ConnectorID)
}
// If the client has requested access to groups, add them here.
if ses.Scope.HasScope(scope.ScopeGroups) {
grouper, ok := conn.(connector.GroupsConnector)
if !ok {
return "", fmt.Errorf("scope %q provided but connector does not support groups", scope.ScopeGroups)
}
groups, err := grouper.Groups(ident.ID)
if err != nil {
return "", fmt.Errorf("failed to retrieve user groups for %q %v", ident.ID, err)
}
// Update the session.
if ses, err = s.SessionManager.AttachGroups(sessionID, groups); err != nil {
return "", fmt.Errorf("failed save groups")
}
}
if ses.Register {
code, err := s.SessionManager.NewSessionKey(sessionID)
if err != nil {
return "", err
}
ru := s.absURL(httpPathRegister)
q := ru.Query()
q.Set("code", code)
q.Set("state", ses.ClientState)
ru.RawQuery = q.Encode()
return ru.String(), nil
}
remoteIdentity := user.RemoteIdentity{ConnectorID: ses.ConnectorID, ID: ses.Identity.ID}
usr, err := s.UserRepo.GetByRemoteIdentity(nil, remoteIdentity)
if err == user.ErrorNotFound {
if ses.Identity.Email == "" {
// User doesn't have an existing account. Ask them to register.
u := newLoginURLFromSession(s.IssuerURL, ses, true, []string{ses.ConnectorID}, "register-maybe")
return u.String(), nil
}
// Does the user have an existing account with a different connector?
if connID, err := getConnectorForUserByEmail(s.UserRepo, ses.Identity.Email); err == nil {
// Ask user to sign in through existing account.
u := newLoginURLFromSession(s.IssuerURL, ses, false, []string{connID}, "wrong-connector")
return u.String(), nil
}
// RegisterOnFirstLogin doesn't work for the local connector
tryToRegister := s.RegisterOnFirstLogin && (ses.ConnectorID != s.localConnectorID)
if !tryToRegister {
// User doesn't have an existing account. Ask them to register.
u := newLoginURLFromSession(s.IssuerURL, ses, true, []string{ses.ConnectorID}, "register-maybe")
return u.String(), nil
}
// First time logging in through a remote connector. Attempt to register.
emailVerified := conn.TrustedEmailProvider()
usrID, err := s.UserManager.RegisterWithRemoteIdentity(ses.Identity.Email, emailVerified, remoteIdentity)
if err != nil {
return "", fmt.Errorf("failed to register user: %v", err)
}
usr, err = s.UserManager.Get(usrID)
if err != nil {
return "", fmt.Errorf("getting created user: %v", err)
}
} else if err != nil {
return "", fmt.Errorf("getting user: %v", err)
}
if usr.Disabled {
log.Errorf("user %s disabled", ses.Identity.Email)
return "", user.ErrorNotFound
}
ses, err = s.SessionManager.AttachUser(sessionID, usr.ID)
if err != nil {
return "", fmt.Errorf("attaching user to session: %v", err)
}
log.Infof("Session %s user identified: clientID=%s user=%#v", sessionID, ses.ClientID, usr)
code, err := s.SessionManager.NewSessionKey(sessionID)
if err != nil {
return "", fmt.Errorf("creating new session key: %v", err)
}
ru := ses.RedirectURL
if ru.String() == client.OOBRedirectURI {
ru = s.absURL(httpPathOOB)
}
q := ru.Query()
q.Set("code", code)
q.Set("state", ses.ClientState)
ru.RawQuery = q.Encode()
return ru.String(), nil
}
func (s *Server) ClientCredsToken(creds oidc.ClientCredentials) (*jose.JWT, error) {
cli, err := s.Client(creds.ID)
if err != nil {
return nil, err
}
if cli.Public {
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
}
ok, err := s.ClientManager.Authenticate(creds)
if err != nil {
log.Errorf("Failed fetching client %s from manager: %v", creds.ID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
if !ok {
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
}
signer, err := s.KeyManager.Signer()
if err != nil {
log.Errorf("Failed to generate ID token: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
now := time.Now()
exp := now.Add(s.SessionManager.ValidityWindow)
claims := oidc.NewClaims(s.IssuerURL.String(), creds.ID, creds.ID, now, exp)
claims.Add("name", creds.ID)
jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil {
log.Errorf("Failed to generate ID token: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
log.Infof("Client token sent: clientID=%s", creds.ID)
return jwt, nil
}
func (s *Server) CodeToken(creds oidc.ClientCredentials, sessionKey string) (*jose.JWT, string, error) {
ok, err := s.ClientManager.Authenticate(creds)
if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
if !ok {
log.Errorf("Failed to Authenticate client %s", creds.ID)
return nil, "", oauth2.NewError(oauth2.ErrorInvalidClient)
}
sessionID, err := s.SessionManager.ExchangeKey(sessionKey)
if err != nil {
return nil, "", oauth2.NewError(oauth2.ErrorInvalidGrant)
}
ses, err := s.SessionManager.Kill(sessionID)
if err != nil {
return nil, "", oauth2.NewError(oauth2.ErrorInvalidRequest)
}
if ses.ClientID != creds.ID {
return nil, "", oauth2.NewError(oauth2.ErrorInvalidGrant)
}
signer, err := s.KeyManager.Signer()
if err != nil {
log.Errorf("Failed to generate ID token: %v", err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
user, err := s.UserRepo.Get(nil, ses.UserID)
if err != nil {
log.Errorf("Failed to fetch user %q from repo: %v: ", ses.UserID, err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
claims := ses.Claims(s.IssuerURL.String())
user.AddToClaims(claims)
s.addClaimsFromScope(claims, ses.Scope, ses.ClientID)
jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil {
log.Errorf("Failed to generate ID token: %v", err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
// Generate refresh token when 'scope' contains 'offline_access'.
var refreshToken string
for _, scope := range ses.Scope {
if scope == "offline_access" {
log.Infof("Session %s requests offline access, will generate refresh token", sessionID)
refreshToken, err = s.RefreshTokenRepo.Create(ses.UserID, creds.ID, ses.ConnectorID, ses.Scope)
switch err {
case nil:
break
default:
log.Errorf("Failed to generate refresh token: %v", err)
return nil, "", oauth2.NewError(oauth2.ErrorServerError)
}
break
}
}
log.Infof("Session %s token sent: clientID=%s", sessionID, creds.ID)
return jwt, refreshToken, nil
}
func (s *Server) RefreshToken(creds oidc.ClientCredentials, scopes scope.Scopes, token string) (*jose.JWT, error) {
ok, err := s.ClientManager.Authenticate(creds)
if err != nil {
log.Errorf("Failed fetching client %s from repo: %v", creds.ID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
if !ok {
log.Errorf("Failed to Authenticate client %s", creds.ID)
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
}
userID, connectorID, rtScopes, err := s.RefreshTokenRepo.Verify(creds.ID, token)
switch err {
case nil:
break
case refresh.ErrorInvalidToken:
return nil, oauth2.NewError(oauth2.ErrorInvalidRequest)
case refresh.ErrorInvalidClientID:
return nil, oauth2.NewError(oauth2.ErrorInvalidClient)
default:
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
if len(scopes) == 0 {
scopes = rtScopes
} else {
if !rtScopes.Contains(scopes) {
return nil, oauth2.NewError(oauth2.ErrorInvalidRequest)
}
}
usr, err := s.UserRepo.Get(nil, userID)
if err != nil {
// The error can be user.ErrorNotFound, but we are not deleting
// user at this moment, so this shouldn't happen.
log.Errorf("Failed to fetch user %q from repo: %v: ", userID, err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
var groups []string
if rtScopes.HasScope(scope.ScopeGroups) {
conn, ok := s.connector(connectorID)
if !ok {
log.Errorf("refresh token contained invalid connector ID (%s)", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
grouper, ok := conn.(connector.GroupsConnector)
if !ok {
log.Errorf("refresh token requested groups for connector (%s) that doesn't support groups", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
remoteIdentities, err := s.UserRepo.GetRemoteIdentities(nil, userID)
if err != nil {
log.Errorf("failed to get remote identities: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
remoteIdentity, ok := func() (user.RemoteIdentity, bool) {
for _, ri := range remoteIdentities {
if ri.ConnectorID == connectorID {
return ri, true
}
}
return user.RemoteIdentity{}, false
}()
if !ok {
log.Errorf("failed to get remote identity for connector %s", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
if groups, err = grouper.Groups(remoteIdentity.ID); err != nil {
log.Errorf("failed to get groups for refresh token: %v", connectorID)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
}
signer, err := s.KeyManager.Signer()
if err != nil {
log.Errorf("Failed to refresh ID token: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
now := time.Now()
expireAt := now.Add(session.DefaultSessionValidityWindow)
claims := oidc.NewClaims(s.IssuerURL.String(), usr.ID, creds.ID, now, expireAt)
usr.AddToClaims(claims)
if rtScopes.HasScope(scope.ScopeGroups) {
if groups == nil {
groups = []string{}
}
claims["groups"] = groups
}
s.addClaimsFromScope(claims, scope.Scopes(scopes), creds.ID)
jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil {
log.Errorf("Failed to generate ID token: %v", err)
return nil, oauth2.NewError(oauth2.ErrorServerError)
}
log.Infof("New token sent: clientID=%s", creds.ID)
return jwt, nil
}
func (s *Server) CrossClientAuthAllowed(requestingClientID, authorizingClientID string) (bool, error) {
alloweds, err := s.ClientRepo.GetTrustedPeers(nil, authorizingClientID)
if err != nil {
return false, err
}
for _, allowed := range alloweds {
if requestingClientID == allowed {
return true, nil
}
}
return false, nil
}
func (s *Server) JWTVerifierFactory() JWTVerifierFactory {
noop := func() error { return nil }
keyFunc := func() []key.PublicKey {
keys, err := s.KeyManager.PublicKeys()
if err != nil {
log.Errorf("error getting public keys from manager: %v", err)
return []key.PublicKey{}
}
return keys
}
return func(clientID string) oidc.JWTVerifier {
return oidc.NewJWTVerifier(s.IssuerURL.String(), clientID, noop, keyFunc)
}
}
// addClaimsFromScope adds claims that are based on the scopes that the client requested.
// Currently, these include cross-client claims (aud, azp).
func (s *Server) addClaimsFromScope(claims jose.Claims, scopes scope.Scopes, clientID string) error {
crossClientIDs := scopes.CrossClientIDs()
if len(crossClientIDs) > 0 {
var aud []string
for _, id := range crossClientIDs {
if clientID == id {
aud = append(aud, id)
continue
}
allowed, err := s.CrossClientAuthAllowed(clientID, id)
if err != nil {
log.Errorf("Failed to check cross client auth. reqClientID %v; authClient:ID %v; err: %v", clientID, id, err)
return oauth2.NewError(oauth2.ErrorServerError)
}
if !allowed {
err := oauth2.NewError(oauth2.ErrorInvalidRequest)
err.Description = fmt.Sprintf(
"%q is not authorized to perform cross-client requests for %q",
clientID, id)
return err
}
aud = append(aud, id)
}
if len(aud) == 1 {
claims.Add("aud", aud[0])
} else {
claims.Add("aud", aud)
}
claims.Add("azp", clientID)
}
return nil
}
type sortableIDPCs []connector.Connector
func (s sortableIDPCs) Len() int {
return len([]connector.Connector(s))
}
func (s sortableIDPCs) Less(i, j int) bool {
idpcs := []connector.Connector(s)
return idpcs[i].ID() < idpcs[j].ID()
}
func (s sortableIDPCs) Swap(i, j int) {
idpcs := []connector.Connector(s)
idpcs[i], idpcs[j] = idpcs[j], idpcs[i]
}