diff --git a/server/deviceflowhandlers.go b/server/deviceflowhandlers.go index f8462902..0efe3b2b 100644 --- a/server/deviceflowhandlers.go +++ b/server/deviceflowhandlers.go @@ -11,6 +11,8 @@ import ( "strings" "time" + "golang.org/x/net/html" + "github.com/dexidp/dex/pkg/log" "github.com/dexidp/dex/storage" ) @@ -251,7 +253,9 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { // Authorization redirect callback from OAuth2 auth flow. if errMsg := r.FormValue("error"); errMsg != "" { - http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest) + // escape the message to prevent cross-site scripting + msg := html.EscapeString(errMsg + ": " + r.FormValue("error_description")) + http.Error(w, msg, http.StatusBadRequest) return } diff --git a/server/deviceflowhandlers_test.go b/server/deviceflowhandlers_test.go index 95ca46e0..225703a4 100644 --- a/server/deviceflowhandlers_test.go +++ b/server/deviceflowhandlers_test.go @@ -167,12 +167,13 @@ func TestDeviceCallback(t *testing.T) { } tests := []struct { - testName string - expectedResponseCode int - values formValues - testAuthCode storage.AuthCode - testDeviceRequest storage.DeviceRequest - testDeviceToken storage.DeviceToken + testName string + expectedResponseCode int + expectedServerResponse string + values formValues + testAuthCode storage.AuthCode + testDeviceRequest storage.DeviceRequest + testDeviceToken storage.DeviceToken }{ { testName: "Missing State", @@ -199,7 +200,8 @@ func TestDeviceCallback(t *testing.T) { code: "somecode", error: "Error Condition", }, - expectedResponseCode: http.StatusBadRequest, + expectedResponseCode: http.StatusBadRequest, + expectedServerResponse: "Error Condition: \n", }, { testName: "Expired Auth Code", @@ -321,6 +323,16 @@ func TestDeviceCallback(t *testing.T) { testDeviceToken: baseDeviceToken, expectedResponseCode: http.StatusOK, }, + { + testName: "Prevent cross-site scripting", + values: formValues{ + state: "XXXX-XXXX", + code: "somecode", + error: "", + }, + expectedResponseCode: http.StatusBadRequest, + expectedServerResponse: "<script>console.log(window);</script>: \n", + }, } for _, tc := range tests { t.Run(tc.testName, func(t *testing.T) { @@ -373,6 +385,13 @@ func TestDeviceCallback(t *testing.T) { if rr.Code != tc.expectedResponseCode { t.Errorf("%s: Unexpected Response Type. Expected %v got %v", tc.testName, tc.expectedResponseCode, rr.Code) } + + if len(tc.expectedServerResponse) > 0 { + result, _ := io.ReadAll(rr.Body) + if string(result) != tc.expectedServerResponse { + t.Errorf("%s: Unexpected Response. Expected %q got %q", tc.testName, tc.expectedServerResponse, result) + } + } }) } }