diff --git a/server/handlers_test.go b/server/handlers_test.go index 3e0b1e81..395b7e72 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -1,7 +1,9 @@ package server import ( + "bytes" "context" + "encoding/json" "errors" "net/http" "net/http/httptest" @@ -48,3 +50,73 @@ func TestHandleHealthFailure(t *testing.T) { t.Errorf("expected 500 got %d", rr.Code) } } + +type emptyStorage struct { + storage.Storage +} + +func (*emptyStorage) GetAuthRequest(string) (storage.AuthRequest, error) { + return storage.AuthRequest{}, storage.ErrNotFound +} + +func TestHandleInvalidOAuth2Callbacks(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + httpServer, server := newTestServer(ctx, t, func(c *Config) { + c.Storage = &emptyStorage{c.Storage} + }) + defer httpServer.Close() + + tests := []struct { + TargetURI string + ExpectedCode int + }{ + {"/callback", http.StatusBadRequest}, + {"/callback?code=&state=", http.StatusBadRequest}, + {"/callback?code=AAAAAAA&state=BBBBBBB", http.StatusBadRequest}, + } + + rr := httptest.NewRecorder() + + for i, r := range tests { + server.ServeHTTP(rr, httptest.NewRequest("GET", r.TargetURI, nil)) + if rr.Code != r.ExpectedCode { + t.Fatalf("test %d expected %d, got %d", i, r.ExpectedCode, rr.Code) + } + } +} + +func TestHandleInvalidSAMLCallbacks(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + httpServer, server := newTestServer(ctx, t, func(c *Config) { + c.Storage = &emptyStorage{c.Storage} + }) + defer httpServer.Close() + + type requestForm struct { + RelayState string + } + tests := []struct { + RequestForm requestForm + ExpectedCode int + }{ + {requestForm{}, http.StatusBadRequest}, + {requestForm{RelayState: "AAAAAAA"}, http.StatusBadRequest}, + } + + rr := httptest.NewRecorder() + + for i, r := range tests { + jsonValue, err := json.Marshal(r.RequestForm) + if err != nil { + t.Fatal(err.Error()) + } + server.ServeHTTP(rr, httptest.NewRequest("POST", "/callback", bytes.NewBuffer(jsonValue))) + if rr.Code != r.ExpectedCode { + t.Fatalf("test %d expected %d, got %d", i, r.ExpectedCode, rr.Code) + } + } +}