forked from mystiq/dex
dc828825e6
Instead of cryptic message with nowhere to, give them the choice to login with that account or register.
557 lines
15 KiB
Go
557 lines
15 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"html/template"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc/jose"
|
|
"github.com/coreos/go-oidc/key"
|
|
"github.com/coreos/go-oidc/oauth2"
|
|
"github.com/coreos/go-oidc/oidc"
|
|
"github.com/coreos/pkg/health"
|
|
"github.com/jonboulle/clockwork"
|
|
|
|
"github.com/coreos/dex/client"
|
|
"github.com/coreos/dex/connector"
|
|
phttp "github.com/coreos/dex/pkg/http"
|
|
"github.com/coreos/dex/pkg/log"
|
|
)
|
|
|
|
const (
|
|
lastSeenMaxAge = time.Minute * 5
|
|
discoveryMaxAge = time.Hour * 24
|
|
)
|
|
|
|
var (
|
|
httpPathDiscovery = "/.well-known/openid-configuration"
|
|
httpPathToken = "/token"
|
|
httpPathKeys = "/keys"
|
|
httpPathAuth = "/auth"
|
|
httpPathHealth = "/health"
|
|
httpPathAPI = "/api"
|
|
httpPathRegister = "/register"
|
|
httpPathEmailVerify = "/verify-email"
|
|
httpPathVerifyEmailResend = "/resend-verify-email"
|
|
httpPathSendResetPassword = "/send-reset-password"
|
|
httpPathResetPassword = "/reset-password"
|
|
httpPathAcceptInvitation = "/accept-invitation"
|
|
httpPathDebugVars = "/debug/vars"
|
|
|
|
cookieLastSeen = "LastSeen"
|
|
cookieShowEmailVerifiedMessage = "ShowEmailVerifiedMessage"
|
|
)
|
|
|
|
func handleDiscoveryFunc(cfg oidc.ProviderConfig) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "GET" {
|
|
w.Header().Set("Allow", "GET")
|
|
phttp.WriteError(w, http.StatusMethodNotAllowed, "GET only acceptable method")
|
|
return
|
|
}
|
|
|
|
b, err := json.Marshal(cfg)
|
|
if err != nil {
|
|
log.Errorf("Unable to marshal %#v to JSON: %v", cfg, err)
|
|
}
|
|
|
|
w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(discoveryMaxAge.Seconds())))
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write(b)
|
|
}
|
|
}
|
|
|
|
func handleKeysFunc(km key.PrivateKeyManager, clock clockwork.Clock) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "GET" {
|
|
w.Header().Set("Allow", "GET")
|
|
phttp.WriteError(w, http.StatusMethodNotAllowed, "GET only acceptable method")
|
|
return
|
|
}
|
|
|
|
jwks, err := km.JWKs()
|
|
if err != nil {
|
|
log.Errorf("Failed to get JWKs while serving HTTP request: %v", err)
|
|
phttp.WriteError(w, http.StatusInternalServerError, "")
|
|
return
|
|
}
|
|
|
|
keys := struct {
|
|
Keys []jose.JWK `json:"keys"`
|
|
}{
|
|
Keys: jwks,
|
|
}
|
|
|
|
b, err := json.Marshal(keys)
|
|
if err != nil {
|
|
log.Errorf("Unable to marshal signing key to JSON: %v", err)
|
|
}
|
|
|
|
exp := km.ExpiresAt()
|
|
w.Header().Set("Expires", exp.Format(time.RFC1123))
|
|
|
|
ttl := int(exp.Sub(clock.Now()).Seconds())
|
|
w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", ttl))
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(b)
|
|
}
|
|
}
|
|
|
|
type Link struct {
|
|
URL string
|
|
ID string
|
|
DisplayName string
|
|
}
|
|
|
|
type templateData struct {
|
|
Error bool
|
|
Message string
|
|
Instruction string
|
|
Detail string
|
|
Register bool
|
|
RegisterOrLoginURL string
|
|
MsgCode string
|
|
ShowEmailVerifiedMessage bool
|
|
Links []Link
|
|
}
|
|
|
|
// TODO(sym3tri): store this with the connector config
|
|
var connectorDisplayNameMap = map[string]string{
|
|
"google": "Google",
|
|
"local": "Email",
|
|
"github": "GitHub",
|
|
"bitbucket": "Bitbucket",
|
|
}
|
|
|
|
type Template interface {
|
|
Execute(io.Writer, interface{}) error
|
|
}
|
|
|
|
func execTemplate(w http.ResponseWriter, tpl Template, data interface{}) {
|
|
execTemplateWithStatus(w, tpl, data, http.StatusOK)
|
|
}
|
|
|
|
func execTemplateWithStatus(w http.ResponseWriter, tpl Template, data interface{}, status int) {
|
|
w.WriteHeader(status)
|
|
if err := tpl.Execute(w, data); err != nil {
|
|
log.Errorf("Error loading page: %q", err)
|
|
phttp.WriteError(w, http.StatusInternalServerError, "error loading page")
|
|
return
|
|
}
|
|
}
|
|
|
|
func renderLoginPage(w http.ResponseWriter, r *http.Request, srv OIDCServer, idpcs []connector.Connector, register bool, tpl *template.Template) {
|
|
if tpl == nil {
|
|
phttp.WriteError(w, http.StatusInternalServerError, "error loading login page")
|
|
return
|
|
}
|
|
|
|
td := templateData{
|
|
Message: "Error",
|
|
Instruction: "Please try again or contact the system administrator",
|
|
Register: register,
|
|
ShowEmailVerifiedMessage: consumeShowEmailVerifiedCookie(r, w),
|
|
}
|
|
|
|
// Render error if remote IdP connector errored and redirected here.
|
|
q := r.URL.Query()
|
|
e := q.Get("error")
|
|
connectorID := q.Get("connector_id")
|
|
if e != "" {
|
|
td.Error = true
|
|
td.Message = "Authentication Error"
|
|
remoteMsg := q.Get("error_description")
|
|
if remoteMsg == "" {
|
|
remoteMsg = q.Get("error")
|
|
}
|
|
if connectorID == "" {
|
|
td.Detail = remoteMsg
|
|
} else {
|
|
td.Detail = fmt.Sprintf("Error from %s: %s.", connectorID, remoteMsg)
|
|
}
|
|
execTemplate(w, tpl, td)
|
|
return
|
|
}
|
|
|
|
if q.Get("msg_code") != "" {
|
|
td.MsgCode = q.Get("msg_code")
|
|
}
|
|
|
|
// Render error message if client id is invalid.
|
|
clientID := q.Get("client_id")
|
|
cm, err := srv.ClientMetadata(clientID)
|
|
if err != nil {
|
|
log.Errorf("Failed fetching client %q from repo: %v", clientID, err)
|
|
td.Error = true
|
|
td.Message = "Server Error"
|
|
execTemplate(w, tpl, td)
|
|
return
|
|
}
|
|
if cm == nil {
|
|
td.Error = true
|
|
td.Message = "Authentication Error"
|
|
td.Detail = "Invalid client ID"
|
|
execTemplate(w, tpl, td)
|
|
return
|
|
}
|
|
|
|
if len(idpcs) == 0 {
|
|
td.Error = true
|
|
td.Message = "Server Error"
|
|
td.Instruction = "Unable to authenticate users at this time"
|
|
td.Detail = "Authentication service may be misconfigured"
|
|
execTemplate(w, tpl, td)
|
|
return
|
|
}
|
|
|
|
link := *r.URL
|
|
linkParams := link.Query()
|
|
if !register {
|
|
linkParams.Set("register", "1")
|
|
} else {
|
|
linkParams.Del("register")
|
|
}
|
|
linkParams.Del("msg_code")
|
|
linkParams.Del("show_connectors")
|
|
link.RawQuery = linkParams.Encode()
|
|
td.RegisterOrLoginURL = link.String()
|
|
|
|
var showConnectors map[string]struct{}
|
|
|
|
// Only show the following connectors, if param is present
|
|
if q.Get("show_connectors") != "" {
|
|
conns := strings.Split(q.Get("show_connectors"), ",")
|
|
if len(conns) != 0 {
|
|
showConnectors = make(map[string]struct{})
|
|
for _, connID := range conns {
|
|
showConnectors[connID] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
|
|
for _, idpc := range idpcs {
|
|
id := idpc.ID()
|
|
if showConnectors != nil {
|
|
if _, ok := showConnectors[id]; !ok {
|
|
continue
|
|
}
|
|
}
|
|
var link Link
|
|
link.ID = id
|
|
|
|
displayName, ok := connectorDisplayNameMap[id]
|
|
if !ok {
|
|
displayName = id
|
|
}
|
|
link.DisplayName = displayName
|
|
|
|
v := r.URL.Query()
|
|
v.Set("connector_id", idpc.ID())
|
|
v.Set("response_type", "code")
|
|
link.URL = httpPathAuth + "?" + v.Encode()
|
|
td.Links = append(td.Links, link)
|
|
}
|
|
|
|
execTemplate(w, tpl, td)
|
|
}
|
|
|
|
func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.Template, registrationEnabled bool) http.HandlerFunc {
|
|
idx := makeConnectorMap(idpcs)
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "GET" {
|
|
w.Header().Set("Allow", "GET")
|
|
phttp.WriteError(w, http.StatusMethodNotAllowed, "GET only acceptable method")
|
|
return
|
|
}
|
|
|
|
q := r.URL.Query()
|
|
register := q.Get("register") == "1" && registrationEnabled
|
|
e := q.Get("error")
|
|
if e != "" {
|
|
sessionKey := q.Get("state")
|
|
if err := srv.KillSession(sessionKey); err != nil {
|
|
log.Errorf("Failed killing sessionKey %q: %v", sessionKey, err)
|
|
}
|
|
renderLoginPage(w, r, srv, idpcs, register, tpl)
|
|
return
|
|
}
|
|
|
|
connectorID := q.Get("connector_id")
|
|
idpc, ok := idx[connectorID]
|
|
if !ok {
|
|
renderLoginPage(w, r, srv, idpcs, register, tpl)
|
|
return
|
|
}
|
|
|
|
acr, err := oauth2.ParseAuthCodeRequest(q)
|
|
if err != nil {
|
|
log.Errorf("Invalid auth request: %v", err)
|
|
writeAuthError(w, err, acr.State)
|
|
return
|
|
}
|
|
|
|
cm, err := srv.ClientMetadata(acr.ClientID)
|
|
if err != nil {
|
|
log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err)
|
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
|
|
return
|
|
}
|
|
if cm == nil {
|
|
log.Errorf("Client %q not found", acr.ClientID)
|
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
|
|
return
|
|
}
|
|
|
|
if len(cm.RedirectURLs) == 0 {
|
|
log.Errorf("Client %q has no redirect URLs", acr.ClientID)
|
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
|
|
return
|
|
}
|
|
|
|
redirectURL, err := client.ValidRedirectURL(acr.RedirectURL, cm.RedirectURLs)
|
|
if err != nil {
|
|
switch err {
|
|
case (client.ErrorCantChooseRedirectURL):
|
|
log.Errorf("Request must provide redirect URL as client %q has registered many", acr.ClientID)
|
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
|
|
return
|
|
case (client.ErrorInvalidRedirectURL):
|
|
log.Errorf("Request provided unregistered redirect URL: %s", acr.RedirectURL)
|
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
|
|
return
|
|
case (client.ErrorNoValidRedirectURLs):
|
|
log.Errorf("There are no registered URLs for the requested client: %s", acr.RedirectURL)
|
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
|
|
return
|
|
}
|
|
}
|
|
|
|
if acr.ResponseType != oauth2.ResponseTypeCode {
|
|
log.Errorf("unexpected ResponseType: %v: ", acr.ResponseType)
|
|
redirectAuthError(w, oauth2.NewError(oauth2.ErrorUnsupportedResponseType), acr.State, redirectURL)
|
|
return
|
|
}
|
|
|
|
// Check scopes.
|
|
var scopes []string
|
|
foundOpenIDScope := false
|
|
for _, scope := range acr.Scope {
|
|
switch scope {
|
|
case "openid":
|
|
foundOpenIDScope = true
|
|
scopes = append(scopes, scope)
|
|
case "offline_access":
|
|
// According to the spec, for offline_access scope, the client must
|
|
// use a response_type value that would result in an Authorization Code.
|
|
// Currently oauth2.ResponseTypeCode is the only supported response type,
|
|
// and it's been checked above, so we don't need to check it again here.
|
|
//
|
|
// TODO(yifan): Verify that 'consent' should be in 'prompt'.
|
|
scopes = append(scopes, scope)
|
|
default:
|
|
// Pass all other scopes.
|
|
scopes = append(scopes, scope)
|
|
}
|
|
}
|
|
|
|
if !foundOpenIDScope {
|
|
log.Errorf("Invalid auth request: missing 'openid' in 'scope'")
|
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
|
|
return
|
|
}
|
|
|
|
nonce := q.Get("nonce")
|
|
|
|
key, err := srv.NewSession(connectorID, acr.ClientID, acr.State, redirectURL, nonce, register, acr.Scope)
|
|
if err != nil {
|
|
log.Errorf("Error creating new session: %v: ", err)
|
|
redirectAuthError(w, err, acr.State, redirectURL)
|
|
return
|
|
}
|
|
|
|
if register {
|
|
_, ok := idpc.(*connector.LocalConnector)
|
|
if ok {
|
|
q := url.Values{}
|
|
q.Set("code", key)
|
|
ru := httpPathRegister + "?" + q.Encode()
|
|
w.Header().Set("Location", ru)
|
|
w.WriteHeader(http.StatusTemporaryRedirect)
|
|
return
|
|
}
|
|
}
|
|
|
|
var p string
|
|
if register {
|
|
p = "select_account consent"
|
|
}
|
|
if shouldReprompt(r) || register {
|
|
p = "select_account"
|
|
}
|
|
lu, err := idpc.LoginURL(key, p)
|
|
if err != nil {
|
|
log.Errorf("Connector.LoginURL failed: %v", err)
|
|
redirectAuthError(w, err, acr.State, redirectURL)
|
|
return
|
|
}
|
|
|
|
http.SetCookie(w, createLastSeenCookie())
|
|
w.Header().Set("Location", lu)
|
|
w.WriteHeader(http.StatusTemporaryRedirect)
|
|
return
|
|
}
|
|
}
|
|
|
|
func handleTokenFunc(srv OIDCServer) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
w.Header().Set("Allow", "POST")
|
|
phttp.WriteError(w, http.StatusMethodNotAllowed, fmt.Sprintf("POST only acceptable method"))
|
|
return
|
|
}
|
|
|
|
err := r.ParseForm()
|
|
if err != nil {
|
|
log.Errorf("error parsing request: %v", err)
|
|
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), "")
|
|
return
|
|
}
|
|
|
|
state := r.PostForm.Get("state")
|
|
|
|
user, password, ok := r.BasicAuth()
|
|
if !ok {
|
|
log.Errorf("error parsing basic auth")
|
|
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidClient), state)
|
|
return
|
|
}
|
|
|
|
creds := oidc.ClientCredentials{ID: user, Secret: password}
|
|
|
|
var jwt *jose.JWT
|
|
var refreshToken string
|
|
grantType := r.PostForm.Get("grant_type")
|
|
|
|
switch grantType {
|
|
case oauth2.GrantTypeAuthCode:
|
|
code := r.PostForm.Get("code")
|
|
if code == "" {
|
|
log.Errorf("missing code param")
|
|
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
|
return
|
|
}
|
|
jwt, refreshToken, err = srv.CodeToken(creds, code)
|
|
if err != nil {
|
|
log.Errorf("couldn't exchange code for token: %v", err)
|
|
writeTokenError(w, err, state)
|
|
return
|
|
}
|
|
case oauth2.GrantTypeClientCreds:
|
|
jwt, err = srv.ClientCredsToken(creds)
|
|
if err != nil {
|
|
log.Errorf("couldn't creds for token: %v", err)
|
|
writeTokenError(w, err, state)
|
|
return
|
|
}
|
|
case oauth2.GrantTypeRefreshToken:
|
|
token := r.PostForm.Get("refresh_token")
|
|
if token == "" {
|
|
writeTokenError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
|
return
|
|
}
|
|
jwt, err = srv.RefreshToken(creds, token)
|
|
if err != nil {
|
|
writeTokenError(w, err, state)
|
|
return
|
|
}
|
|
default:
|
|
log.Errorf("unsupported grant: %v", grantType)
|
|
writeTokenError(w, oauth2.NewError(oauth2.ErrorUnsupportedGrantType), state)
|
|
return
|
|
}
|
|
|
|
t := oAuth2Token{
|
|
AccessToken: jwt.Encode(),
|
|
IDToken: jwt.Encode(),
|
|
TokenType: "bearer",
|
|
RefreshToken: refreshToken,
|
|
}
|
|
|
|
b, err := json.Marshal(t)
|
|
if err != nil {
|
|
log.Errorf("Failed marshaling %#v to JSON: %v", t, err)
|
|
writeTokenError(w, oauth2.NewError(oauth2.ErrorServerError), state)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write(b)
|
|
}
|
|
}
|
|
|
|
func makeHealthHandler(checks []health.Checkable) http.Handler {
|
|
return health.Checker{
|
|
Checks: checks,
|
|
}
|
|
}
|
|
|
|
type oAuth2Token struct {
|
|
AccessToken string `json:"access_token"`
|
|
IDToken string `json:"id_token"`
|
|
TokenType string `json:"token_type"`
|
|
RefreshToken string `json:"refresh_token,omitempty"`
|
|
}
|
|
|
|
func createLastSeenCookie() *http.Cookie {
|
|
now := time.Now()
|
|
return &http.Cookie{
|
|
HttpOnly: true,
|
|
Name: cookieLastSeen,
|
|
MaxAge: int(lastSeenMaxAge.Seconds()),
|
|
// For old IE, ignored by most browsers.
|
|
Expires: now.Add(lastSeenMaxAge),
|
|
}
|
|
}
|
|
|
|
// shouldReprompt determines if user should be re-prompted for login based on existence of a cookie.
|
|
func shouldReprompt(r *http.Request) bool {
|
|
_, err := r.Cookie(cookieLastSeen)
|
|
if err == nil {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func consumeShowEmailVerifiedCookie(r *http.Request, w http.ResponseWriter) bool {
|
|
_, err := r.Cookie(cookieShowEmailVerifiedMessage)
|
|
if err == nil {
|
|
deleteCookie(w, cookieShowEmailVerifiedMessage)
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func deleteCookie(w http.ResponseWriter, name string) {
|
|
now := time.Now()
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: name,
|
|
MaxAge: -100,
|
|
Expires: now.Add(time.Second * -100),
|
|
})
|
|
}
|
|
|
|
func makeConnectorMap(idpcs []connector.Connector) map[string]connector.Connector {
|
|
idx := make(map[string]connector.Connector, len(idpcs))
|
|
for _, idpc := range idpcs {
|
|
idx[idpc.ID()] = idpc
|
|
}
|
|
return idx
|
|
}
|