From 3d5a3befb40124b0fd939a7dab6589c64c2dc7a4 Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Mon, 11 Apr 2022 14:49:47 +0400 Subject: [PATCH] fix: prevent cross-site scripting for the device flow Signed-off-by: m.nabokikh --- server/deviceflowhandlers.go | 6 +++++- server/deviceflowhandlers_test.go | 33 ++++++++++++++++++++++++------- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/server/deviceflowhandlers.go b/server/deviceflowhandlers.go index fb73f257..7f401874 100644 --- a/server/deviceflowhandlers.go +++ b/server/deviceflowhandlers.go @@ -11,6 +11,8 @@ import ( "strings" "time" + "golang.org/x/net/html" + "github.com/dexidp/dex/pkg/log" "github.com/dexidp/dex/storage" ) @@ -247,7 +249,9 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { // Authorization redirect callback from OAuth2 auth flow. if errMsg := r.FormValue("error"); errMsg != "" { - http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest) + // escape the message to prevent cross-site scripting + msg := html.EscapeString(errMsg + ": " + r.FormValue("error_description")) + http.Error(w, msg, http.StatusBadRequest) return } diff --git a/server/deviceflowhandlers_test.go b/server/deviceflowhandlers_test.go index c387af43..a1db1231 100644 --- a/server/deviceflowhandlers_test.go +++ b/server/deviceflowhandlers_test.go @@ -160,12 +160,13 @@ func TestDeviceCallback(t *testing.T) { } tests := []struct { - testName string - expectedResponseCode int - values formValues - testAuthCode storage.AuthCode - testDeviceRequest storage.DeviceRequest - testDeviceToken storage.DeviceToken + testName string + expectedResponseCode int + expectedServerResponse string + values formValues + testAuthCode storage.AuthCode + testDeviceRequest storage.DeviceRequest + testDeviceToken storage.DeviceToken }{ { testName: "Missing State", @@ -192,7 +193,8 @@ func TestDeviceCallback(t *testing.T) { code: "somecode", error: "Error Condition", }, - expectedResponseCode: http.StatusBadRequest, + expectedResponseCode: http.StatusBadRequest, + expectedServerResponse: "Error Condition: \n", }, { testName: "Expired Auth Code", @@ -314,6 +316,16 @@ func TestDeviceCallback(t *testing.T) { testDeviceToken: baseDeviceToken, expectedResponseCode: http.StatusOK, }, + { + testName: "Prevent cross-site scripting", + values: formValues{ + state: "XXXX-XXXX", + code: "somecode", + error: "", + }, + expectedResponseCode: http.StatusBadRequest, + expectedServerResponse: "<script>console.log(window);</script>: \n", + }, } for _, tc := range tests { t.Run(tc.testName, func(t *testing.T) { @@ -366,6 +378,13 @@ func TestDeviceCallback(t *testing.T) { if rr.Code != tc.expectedResponseCode { t.Errorf("%s: Unexpected Response Type. Expected %v got %v", tc.testName, tc.expectedResponseCode, rr.Code) } + + if len(tc.expectedServerResponse) > 0 { + result, _ := io.ReadAll(rr.Body) + if string(result) != tc.expectedServerResponse { + t.Errorf("%s: Unexpected Response. Expected %q got %q", tc.testName, tc.expectedServerResponse, result) + } + } }) } }