Merge pull request #2468 from flant/cwe-79-device-code

fix: prevent cross-site scripting for the device flow
This commit is contained in:
Márk Sági-Kazár 2022-06-30 22:52:33 +03:00 committed by GitHub
commit 1cc26fab2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 8 deletions

View File

@ -11,6 +11,8 @@ import (
"strings"
"time"
"golang.org/x/net/html"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage"
)
@ -251,7 +253,9 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
// Authorization redirect callback from OAuth2 auth flow.
if errMsg := r.FormValue("error"); errMsg != "" {
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
// escape the message to prevent cross-site scripting
msg := html.EscapeString(errMsg + ": " + r.FormValue("error_description"))
http.Error(w, msg, http.StatusBadRequest)
return
}

View File

@ -169,6 +169,7 @@ func TestDeviceCallback(t *testing.T) {
tests := []struct {
testName string
expectedResponseCode int
expectedServerResponse string
values formValues
testAuthCode storage.AuthCode
testDeviceRequest storage.DeviceRequest
@ -200,6 +201,7 @@ func TestDeviceCallback(t *testing.T) {
error: "Error Condition",
},
expectedResponseCode: http.StatusBadRequest,
expectedServerResponse: "Error Condition: \n",
},
{
testName: "Expired Auth Code",
@ -321,6 +323,16 @@ func TestDeviceCallback(t *testing.T) {
testDeviceToken: baseDeviceToken,
expectedResponseCode: http.StatusOK,
},
{
testName: "Prevent cross-site scripting",
values: formValues{
state: "XXXX-XXXX",
code: "somecode",
error: "<script>console.log(window);</script>",
},
expectedResponseCode: http.StatusBadRequest,
expectedServerResponse: "&lt;script&gt;console.log(window);&lt;/script&gt;: \n",
},
}
for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) {
@ -373,6 +385,13 @@ func TestDeviceCallback(t *testing.T) {
if rr.Code != tc.expectedResponseCode {
t.Errorf("%s: Unexpected Response Type. Expected %v got %v", tc.testName, tc.expectedResponseCode, rr.Code)
}
if len(tc.expectedServerResponse) > 0 {
result, _ := io.ReadAll(rr.Body)
if string(result) != tc.expectedServerResponse {
t.Errorf("%s: Unexpected Response. Expected %q got %q", tc.testName, tc.expectedServerResponse, result)
}
}
})
}
}