forked from mystiq/dex
Fix response_type missing param
This commit fix problem with response_type param, which is required according to OIDC spec, when it is missing. At now, when connector_id url query param is not set, connector view use response_type that client request instead of default "code". Fixes #370
This commit is contained in:
parent
d660dbea8a
commit
821b242c83
2 changed files with 100 additions and 45 deletions
122
server/http.go
122
server/http.go
|
@ -255,7 +255,7 @@ func renderLoginPage(w http.ResponseWriter, r *http.Request, srv OIDCServer, idp
|
||||||
|
|
||||||
v := r.URL.Query()
|
v := r.URL.Query()
|
||||||
v.Set("connector_id", idpc.ID())
|
v.Set("connector_id", idpc.ID())
|
||||||
v.Set("response_type", "code")
|
v.Set("response_type", q.Get("response_type"))
|
||||||
link.URL = httpPathAuth + "?" + v.Encode()
|
link.URL = httpPathAuth + "?" + v.Encode()
|
||||||
td.Links = append(td.Links, link)
|
td.Links = append(td.Links, link)
|
||||||
}
|
}
|
||||||
|
@ -273,77 +273,92 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
|
||||||
}
|
}
|
||||||
|
|
||||||
q := r.URL.Query()
|
q := r.URL.Query()
|
||||||
register := q.Get("register") == "1" && registrationEnabled
|
|
||||||
e := q.Get("error")
|
// Retrieve client id
|
||||||
if e != "" {
|
clientid := q.Get("client_id")
|
||||||
sessionKey := q.Get("state")
|
|
||||||
if err := srv.KillSession(sessionKey); err != nil {
|
// Retrieve state
|
||||||
log.Errorf("Failed killing sessionKey %q: %v", sessionKey, err)
|
state := q.Get("state")
|
||||||
|
|
||||||
|
// Retrieve response_type
|
||||||
|
responseType := q.Get("response_type")
|
||||||
|
|
||||||
|
// Retrieve scopes
|
||||||
|
qscope := strings.Fields(q.Get("scope"))
|
||||||
|
|
||||||
|
// Check client ID param
|
||||||
|
if clientid == "" {
|
||||||
|
log.Errorf("Invalid auth request: no client_id received")
|
||||||
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check redirect_uri param, but if it's empty we don't return any error here
|
||||||
|
qru := q.Get("redirect_uri")
|
||||||
|
var rURL *url.URL
|
||||||
|
if qru != "" {
|
||||||
|
ru, err := url.Parse(qru)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Invalid auth request: %v", err)
|
||||||
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
renderLoginPage(w, r, srv, idpcs, register, tpl)
|
rURL = ru
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
connectorID := q.Get("connector_id")
|
cm, err := srv.ClientMetadata(clientid)
|
||||||
idpc, ok := idx[connectorID]
|
|
||||||
if !ok {
|
|
||||||
renderLoginPage(w, r, srv, idpcs, register, tpl)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
acr, err := oauth2.ParseAuthCodeRequest(q)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Invalid auth request: %v", err)
|
log.Errorf("Failed fetching client %q from repo: %v", clientid, err)
|
||||||
writeAuthError(w, err, acr.State)
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), state)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
cm, err := srv.ClientMetadata(acr.ClientID)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed fetching client %q from repo: %v", acr.ClientID, err)
|
|
||||||
writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if cm == nil {
|
if cm == nil {
|
||||||
log.Errorf("Client %q not found", acr.ClientID)
|
log.Errorf("Client %q not found", clientid)
|
||||||
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(cm.RedirectURIs) == 0 {
|
if len(cm.RedirectURIs) == 0 {
|
||||||
log.Errorf("Client %q has no redirect URLs", acr.ClientID)
|
log.Errorf("Client %q has no redirect URLs", clientid)
|
||||||
writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), acr.State)
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), state)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirectURL, err := client.ValidRedirectURL(acr.RedirectURL, cm.RedirectURIs)
|
redirectURL, err := client.ValidRedirectURL(rURL, cm.RedirectURIs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch err {
|
switch err {
|
||||||
case (client.ErrorCantChooseRedirectURL):
|
case (client.ErrorCantChooseRedirectURL):
|
||||||
log.Errorf("Request must provide redirect URL as client %q has registered many", acr.ClientID)
|
log.Errorf("Request must provide redirect URL as client %q has registered many", clientid)
|
||||||
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
||||||
return
|
return
|
||||||
case (client.ErrorInvalidRedirectURL):
|
case (client.ErrorInvalidRedirectURL):
|
||||||
log.Errorf("Request provided unregistered redirect URL: %s", acr.RedirectURL)
|
log.Errorf("Request provided unregistered redirect URL: %s", rURL)
|
||||||
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
||||||
return
|
return
|
||||||
case (client.ErrorNoValidRedirectURLs):
|
case (client.ErrorNoValidRedirectURLs):
|
||||||
log.Errorf("There are no registered URLs for the requested client: %s", acr.RedirectURL)
|
log.Errorf("There are no registered URLs for the requested client: %s", rURL)
|
||||||
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
log.Errorf("Unexpected error checking redirect URL for client %q: %v", clientid, err)
|
||||||
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorServerError), state)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if acr.ResponseType != oauth2.ResponseTypeCode {
|
// Response type check
|
||||||
log.Errorf("unexpected ResponseType: %v: ", acr.ResponseType)
|
switch responseType {
|
||||||
redirectAuthError(w, oauth2.NewError(oauth2.ErrorUnsupportedResponseType), acr.State, redirectURL)
|
case "code": // Add more cases as we support more response types
|
||||||
|
default:
|
||||||
|
log.Errorf("Invalid auth request: unsupported response_type")
|
||||||
|
redirectAuthError(w, oauth2.NewError(oauth2.ErrorUnsupportedResponseType), state, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check scopes.
|
// Check scopes.
|
||||||
var scopes []string
|
var scopes []string
|
||||||
foundOpenIDScope := false
|
foundOpenIDScope := false
|
||||||
for _, scope := range acr.Scope {
|
for _, scope := range qscope {
|
||||||
switch scope {
|
switch scope {
|
||||||
case "openid":
|
case "openid":
|
||||||
foundOpenIDScope = true
|
foundOpenIDScope = true
|
||||||
|
@ -364,16 +379,33 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
|
||||||
|
|
||||||
if !foundOpenIDScope {
|
if !foundOpenIDScope {
|
||||||
log.Errorf("Invalid auth request: missing 'openid' in 'scope'")
|
log.Errorf("Invalid auth request: missing 'openid' in 'scope'")
|
||||||
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), acr.State)
|
writeAuthError(w, oauth2.NewError(oauth2.ErrorInvalidRequest), state)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
register := q.Get("register") == "1" && registrationEnabled
|
||||||
|
e := q.Get("error")
|
||||||
|
if e != "" {
|
||||||
|
if err := srv.KillSession(state); err != nil {
|
||||||
|
log.Errorf("Failed killing sessionKey %q: %v", state, err)
|
||||||
|
}
|
||||||
|
renderLoginPage(w, r, srv, idpcs, register, tpl)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
connectorID := q.Get("connector_id")
|
||||||
|
idpc, ok := idx[connectorID]
|
||||||
|
if !ok {
|
||||||
|
renderLoginPage(w, r, srv, idpcs, register, tpl)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
nonce := q.Get("nonce")
|
nonce := q.Get("nonce")
|
||||||
|
|
||||||
key, err := srv.NewSession(connectorID, acr.ClientID, acr.State, redirectURL, nonce, register, acr.Scope)
|
key, err := srv.NewSession(connectorID, clientid, state, redirectURL, nonce, register, qscope)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Error creating new session: %v: ", err)
|
log.Errorf("Error creating new session: %v: ", err)
|
||||||
redirectAuthError(w, err, acr.State, redirectURL)
|
redirectAuthError(w, err, state, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -399,7 +431,7 @@ func handleAuthFunc(srv OIDCServer, idpcs []connector.Connector, tpl *template.T
|
||||||
lu, err := idpc.LoginURL(key, p)
|
lu, err := idpc.LoginURL(key, p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Connector.LoginURL failed: %v", err)
|
log.Errorf("Connector.LoginURL failed: %v", err)
|
||||||
redirectAuthError(w, err, acr.State, redirectURL)
|
redirectAuthError(w, err, state, redirectURL)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -175,6 +175,29 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
|
||||||
},
|
},
|
||||||
wantCode: http.StatusBadRequest,
|
wantCode: http.StatusBadRequest,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
// empty response_type
|
||||||
|
{
|
||||||
|
query: url.Values{
|
||||||
|
"redirect_uri": []string{"http://client.example.com/callback"},
|
||||||
|
"client_id": []string{"XXX"},
|
||||||
|
"connector_id": []string{"fake"},
|
||||||
|
"scope": []string{"openid"},
|
||||||
|
},
|
||||||
|
wantCode: http.StatusFound,
|
||||||
|
wantLocation: "http://client.example.com/callback?error=unsupported_response_type&state=",
|
||||||
|
},
|
||||||
|
|
||||||
|
// empty client_id
|
||||||
|
{
|
||||||
|
query: url.Values{
|
||||||
|
"response_type": []string{"code"},
|
||||||
|
"redirect_uri": []string{"http://unrecognized.example.com/callback"},
|
||||||
|
"connector_id": []string{"fake"},
|
||||||
|
"scope": []string{"openid"},
|
||||||
|
},
|
||||||
|
wantCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
|
|
Loading…
Reference in a new issue