forked from mystiq/dex
*: verify "state" field before passing request to callback connectors
Let the server handle the state token instead of the connector. As a result it can throw out bad requests earlier. It can also use that token to determine which connector was used to generate the request allowing all connectors to share the same callback URL. Callbacks now all look like: https://dex.example.com/callback Instead of: https://dex.example.com/callback/(connector id) Even when multiple connectors are being used.
This commit is contained in:
parent
ba9f6c6cd6
commit
a3235d022a
6 changed files with 79 additions and 41 deletions
|
@ -33,7 +33,7 @@ type PasswordConnector interface {
|
||||||
// CallbackConnector is an optional interface for callback based connectors.
|
// CallbackConnector is an optional interface for callback based connectors.
|
||||||
type CallbackConnector interface {
|
type CallbackConnector interface {
|
||||||
LoginURL(callbackURL, state string) (string, error)
|
LoginURL(callbackURL, state string) (string, error)
|
||||||
HandleCallback(r *http.Request) (identity Identity, state string, err error)
|
HandleCallback(r *http.Request) (identity Identity, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GroupsConnector is an optional interface for connectors which can map a user to groups.
|
// GroupsConnector is an optional interface for connectors which can map a user to groups.
|
||||||
|
|
|
@ -84,28 +84,28 @@ func (e *oauth2Error) Error() string {
|
||||||
return e.error + ": " + e.errorDescription
|
return e.error + ": " + e.errorDescription
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, state string, err error) {
|
func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) {
|
||||||
q := r.URL.Query()
|
q := r.URL.Query()
|
||||||
if errType := q.Get("error"); errType != "" {
|
if errType := q.Get("error"); errType != "" {
|
||||||
return identity, "", &oauth2Error{errType, q.Get("error_description")}
|
return identity, &oauth2Error{errType, q.Get("error_description")}
|
||||||
}
|
}
|
||||||
token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code"))
|
token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return identity, "", fmt.Errorf("github: failed to get token: %v", err)
|
return identity, fmt.Errorf("github: failed to get token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user")
|
resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return identity, "", fmt.Errorf("github: get URL %v", err)
|
return identity, fmt.Errorf("github: get URL %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return identity, "", fmt.Errorf("github: read body: %v", err)
|
return identity, fmt.Errorf("github: read body: %v", err)
|
||||||
}
|
}
|
||||||
return identity, "", fmt.Errorf("%s: %s", resp.Status, body)
|
return identity, fmt.Errorf("%s: %s", resp.Status, body)
|
||||||
}
|
}
|
||||||
var user struct {
|
var user struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
@ -114,13 +114,13 @@ func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Id
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
}
|
}
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
|
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
|
||||||
return identity, "", fmt.Errorf("failed to decode response: %v", err)
|
return identity, fmt.Errorf("failed to decode response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
data := connectorData{AccessToken: token.AccessToken}
|
data := connectorData{AccessToken: token.AccessToken}
|
||||||
connData, err := json.Marshal(data)
|
connData, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return identity, "", fmt.Errorf("marshal connector data: %v", err)
|
return identity, fmt.Errorf("marshal connector data: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
username := user.Name
|
username := user.Name
|
||||||
|
@ -134,7 +134,7 @@ func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Id
|
||||||
EmailVerified: true,
|
EmailVerified: true,
|
||||||
ConnectorData: connData,
|
ConnectorData: connData,
|
||||||
}
|
}
|
||||||
return identity, q.Get("state"), nil
|
return identity, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *githubConnector) Groups(identity connector.Identity) ([]string, error) {
|
func (c *githubConnector) Groups(identity connector.Identity) ([]string, error) {
|
||||||
|
|
|
@ -41,14 +41,14 @@ func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) {
|
||||||
|
|
||||||
var connectorData = []byte("foobar")
|
var connectorData = []byte("foobar")
|
||||||
|
|
||||||
func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, string, error) {
|
func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, error) {
|
||||||
return connector.Identity{
|
return connector.Identity{
|
||||||
UserID: "0-385-28089-0",
|
UserID: "0-385-28089-0",
|
||||||
Username: "Kilgore Trout",
|
Username: "Kilgore Trout",
|
||||||
Email: "kilgore@kilgore.trout",
|
Email: "kilgore@kilgore.trout",
|
||||||
EmailVerified: true,
|
EmailVerified: true,
|
||||||
ConnectorData: connectorData,
|
ConnectorData: connectorData,
|
||||||
}, r.URL.Query().Get("state"), nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m callbackConnector) Groups(identity connector.Identity) ([]string, error) {
|
func (m callbackConnector) Groups(identity connector.Identity) ([]string, error) {
|
||||||
|
|
|
@ -95,23 +95,23 @@ func (e *oauth2Error) Error() string {
|
||||||
return e.error + ": " + e.errorDescription
|
return e.error + ": " + e.errorDescription
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, state string, err error) {
|
func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) {
|
||||||
q := r.URL.Query()
|
q := r.URL.Query()
|
||||||
if errType := q.Get("error"); errType != "" {
|
if errType := q.Get("error"); errType != "" {
|
||||||
return identity, "", &oauth2Error{errType, q.Get("error_description")}
|
return identity, &oauth2Error{errType, q.Get("error_description")}
|
||||||
}
|
}
|
||||||
token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code"))
|
token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return identity, "", fmt.Errorf("oidc: failed to get token: %v", err)
|
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rawIDToken, ok := token.Extra("id_token").(string)
|
rawIDToken, ok := token.Extra("id_token").(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return identity, "", errors.New("oidc: no id_token in token response")
|
return identity, errors.New("oidc: no id_token in token response")
|
||||||
}
|
}
|
||||||
idToken, err := c.verifier.Verify(rawIDToken)
|
idToken, err := c.verifier.Verify(rawIDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return identity, "", fmt.Errorf("oidc: failed to verify ID Token: %v", err)
|
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var claims struct {
|
var claims struct {
|
||||||
|
@ -120,7 +120,7 @@ func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Iden
|
||||||
EmailVerified bool `json:"email_verified"`
|
EmailVerified bool `json:"email_verified"`
|
||||||
}
|
}
|
||||||
if err := idToken.Claims(&claims); err != nil {
|
if err := idToken.Claims(&claims); err != nil {
|
||||||
return identity, "", fmt.Errorf("oidc: failed to decode claims: %v", err)
|
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
identity = connector.Identity{
|
identity = connector.Identity{
|
||||||
|
@ -129,5 +129,5 @@ func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Iden
|
||||||
Email: claims.Email,
|
Email: claims.Email,
|
||||||
EmailVerified: claims.EmailVerified,
|
EmailVerified: claims.EmailVerified,
|
||||||
}
|
}
|
||||||
return identity, q.Get("state"), nil
|
return identity, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,6 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -180,14 +179,26 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
authReqID := r.FormValue("state")
|
||||||
|
|
||||||
// TODO(ericchiang): cache user identity.
|
// TODO(ericchiang): cache user identity.
|
||||||
|
|
||||||
state := r.FormValue("state")
|
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case "GET":
|
case "GET":
|
||||||
|
// Set the connector being used for the login.
|
||||||
|
updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
|
||||||
|
a.ConnectorID = connID
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil {
|
||||||
|
log.Printf("Failed to set connector ID on auth request: %v", err)
|
||||||
|
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
switch conn := conn.Connector.(type) {
|
switch conn := conn.Connector.(type) {
|
||||||
case connector.CallbackConnector:
|
case connector.CallbackConnector:
|
||||||
callbackURL, err := conn.LoginURL(s.absURL("/callback", connID), state)
|
callbackURL, err := conn.LoginURL(s.absURL("/callback"), authReqID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Connector %q returned error when creating callback: %v", connID, err)
|
log.Printf("Connector %q returned error when creating callback: %v", connID, err)
|
||||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||||
|
@ -195,7 +206,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
http.Redirect(w, r, callbackURL, http.StatusFound)
|
http.Redirect(w, r, callbackURL, http.StatusFound)
|
||||||
case connector.PasswordConnector:
|
case connector.PasswordConnector:
|
||||||
s.templates.password(w, state, r.URL.String(), "", false)
|
s.templates.password(w, authReqID, r.URL.String(), "", false)
|
||||||
default:
|
default:
|
||||||
s.notFound(w, r)
|
s.notFound(w, r)
|
||||||
}
|
}
|
||||||
|
@ -216,10 +227,16 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
s.templates.password(w, state, r.URL.String(), username, true)
|
s.templates.password(w, authReqID, r.URL.String(), username, true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
redirectURL, err := s.finalizeLogin(identity, state, connID, conn.Connector)
|
authReq, err := s.storage.GetAuthRequest(authReqID)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to get auth request: %v", err)
|
||||||
|
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to finalize login: %v", err)
|
log.Printf("Failed to finalize login: %v", err)
|
||||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||||
|
@ -233,8 +250,31 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
connID := mux.Vars(r)["connector"]
|
// SAML redirect bindings use the "RelayState" URL query field. When we support
|
||||||
conn, ok := s.connectors[connID]
|
// SAML, we'll have to check that field too and possibly let callback connectors
|
||||||
|
// indicate which field is used to determine the state.
|
||||||
|
//
|
||||||
|
// See:
|
||||||
|
// https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf
|
||||||
|
// Section: "3.4.3 RelayState"
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
if state == "" {
|
||||||
|
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "no 'state' parameter provided")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
authReq, err := s.storage.GetAuthRequest(state)
|
||||||
|
if err != nil {
|
||||||
|
if err == storage.ErrNotFound {
|
||||||
|
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "invalid 'state' parameter provided")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Failed to get auth request: %v", err)
|
||||||
|
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, ok := s.connectors[authReq.ConnectorID]
|
||||||
if !ok {
|
if !ok {
|
||||||
s.notFound(w, r)
|
s.notFound(w, r)
|
||||||
return
|
return
|
||||||
|
@ -245,14 +285,14 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
identity, state, err := callbackConnector.HandleCallback(r)
|
identity, err := callbackConnector.HandleCallback(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to authenticate: %v", err)
|
log.Printf("Failed to authenticate: %v", err)
|
||||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL, err := s.finalizeLogin(identity, state, connID, conn.Connector)
|
redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to finalize login: %v", err)
|
log.Printf("Failed to finalize login: %v", err)
|
||||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||||
|
@ -262,10 +302,11 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
|
||||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connectorID string, conn connector.Connector) (string, error) {
|
func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, error) {
|
||||||
if authReqID == "" {
|
if authReq.ConnectorID == "" {
|
||||||
return "", errors.New("no auth request ID passed")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
claims := storage.Claims{
|
claims := storage.Claims{
|
||||||
UserID: identity.UserID,
|
UserID: identity.UserID,
|
||||||
Username: identity.Username,
|
Username: identity.Username,
|
||||||
|
@ -275,10 +316,6 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connector
|
||||||
|
|
||||||
groupsConn, ok := conn.(connector.GroupsConnector)
|
groupsConn, ok := conn.(connector.GroupsConnector)
|
||||||
if ok {
|
if ok {
|
||||||
authReq, err := s.storage.GetAuthRequest(authReqID)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("get auth request: %v", err)
|
|
||||||
}
|
|
||||||
reqGroups := func() bool {
|
reqGroups := func() bool {
|
||||||
for _, scope := range authReq.Scopes {
|
for _, scope := range authReq.Scopes {
|
||||||
if scope == scopeGroups {
|
if scope == scopeGroups {
|
||||||
|
@ -288,23 +325,24 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connector
|
||||||
return false
|
return false
|
||||||
}()
|
}()
|
||||||
if reqGroups {
|
if reqGroups {
|
||||||
if claims.Groups, err = groupsConn.Groups(identity); err != nil {
|
groups, err := groupsConn.Groups(identity)
|
||||||
|
if err != nil {
|
||||||
return "", fmt.Errorf("getting groups: %v", err)
|
return "", fmt.Errorf("getting groups: %v", err)
|
||||||
}
|
}
|
||||||
|
claims.Groups = groups
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
|
updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
|
||||||
a.LoggedIn = true
|
a.LoggedIn = true
|
||||||
a.Claims = claims
|
a.Claims = claims
|
||||||
a.ConnectorID = connectorID
|
|
||||||
a.ConnectorData = identity.ConnectorData
|
a.ConnectorData = identity.ConnectorData
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil {
|
if err := s.storage.UpdateAuthRequest(authReq.ID, updater); err != nil {
|
||||||
return "", fmt.Errorf("failed to update auth request: %v", err)
|
return "", fmt.Errorf("failed to update auth request: %v", err)
|
||||||
}
|
}
|
||||||
return path.Join(s.issuerURL.Path, "/approval") + "?state=" + authReqID, nil
|
return path.Join(s.issuerURL.Path, "/approval") + "?state=" + authReq.ID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
|
@ -172,7 +172,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
||||||
handleFunc("/keys", s.handlePublicKeys)
|
handleFunc("/keys", s.handlePublicKeys)
|
||||||
handleFunc("/auth", s.handleAuthorization)
|
handleFunc("/auth", s.handleAuthorization)
|
||||||
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
||||||
handleFunc("/callback/{connector}", s.handleConnectorCallback)
|
handleFunc("/callback", s.handleConnectorCallback)
|
||||||
handleFunc("/approval", s.handleApproval)
|
handleFunc("/approval", s.handleApproval)
|
||||||
handleFunc("/healthz", s.handleHealth)
|
handleFunc("/healthz", s.handleHealth)
|
||||||
s.mux = r
|
s.mux = r
|
||||||
|
|
Loading…
Reference in a new issue