package server import ( "encoding/json" "errors" "fmt" "math/big" "net/http" "net/http/httptest" "net/url" "reflect" "strings" "testing" "time" "github.com/jonboulle/clockwork" "github.com/coreos/dex/client" "github.com/coreos/dex/connector" "github.com/coreos/dex/session" "github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/oauth2" "github.com/coreos/go-oidc/oidc" ) type fakeConnector struct { loginURL string } func (f *fakeConnector) ID() string { return "fake" } func (f *fakeConnector) Healthy() error { return nil } func (f *fakeConnector) LoginURL(sessionKey, prompt string) (string, error) { return f.loginURL, nil } func (f *fakeConnector) Register(mux *http.ServeMux, errorURL url.URL) {} func (f *fakeConnector) Sync() chan struct{} { return nil } func (c *fakeConnector) TrustedEmailProvider() bool { return false } func TestHandleAuthFuncMethodNotAllowed(t *testing.T) { for _, m := range []string{"POST", "PUT", "DELETE"} { hdlr := handleAuthFunc(nil, nil, nil) req, err := http.NewRequest(m, "http://example.com", nil) if err != nil { t.Errorf("case %s: unable to create HTTP request: %v", m, err) continue } w := httptest.NewRecorder() hdlr.ServeHTTP(w, req) want := http.StatusMethodNotAllowed got := w.Code if want != got { t.Errorf("case %s: expected HTTP %d, got %d", m, want, got) } } } func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) { idpcs := []connector.Connector{ &fakeConnector{loginURL: "http://fake.example.com"}, } srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()), ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{ oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", Secret: "secrete", }, Metadata: oidc.ClientMetadata{ RedirectURLs: []url.URL{ url.URL{Scheme: "http", Host: "client.example.com", Path: "/callback"}, }, }, }, }), } tests := []struct { query url.Values wantCode int wantLocation string }{ // no redirect_uri provided, but client only has one, so it's usable { query: url.Values{ "response_type": []string{"code"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://fake.example.com", }, // provided redirect_uri matches client { query: url.Values{ "response_type": []string{"code"}, "redirect_uri": []string{"http://client.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://fake.example.com", }, // provided redirect_uri does not match client { query: url.Values{ "response_type": []string{"code"}, "redirect_uri": []string{"http://unrecognized.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, }, wantCode: http.StatusBadRequest, }, // nonexistant client_id { query: url.Values{ "response_type": []string{"code"}, "redirect_uri": []string{"http://client.example.com/callback"}, "client_id": []string{"YYY"}, "connector_id": []string{"fake"}, }, wantCode: http.StatusBadRequest, }, // unsupported response type, redirects back to client { query: url.Values{ "response_type": []string{"token"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://client.example.com/callback?error=unsupported_response_type&state=", }, } for i, tt := range tests { hdlr := handleAuthFunc(srv, idpcs, nil) w := httptest.NewRecorder() u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode()) req, err := http.NewRequest("GET", u, nil) if err != nil { t.Errorf("case %d: unable to form HTTP request: %v", i, err) continue } hdlr.ServeHTTP(w, req) if tt.wantCode != w.Code { t.Errorf("case %d: HTTP code mismatch: want=%d got=%d", i, tt.wantCode, w.Code) continue } gotLocation := w.Header().Get("Location") if tt.wantLocation != gotLocation { t.Errorf("case %d: HTTP Location header mismatch: want=%s got=%s", i, tt.wantLocation, gotLocation) } } } func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) { idpcs := []connector.Connector{ &fakeConnector{loginURL: "http://fake.example.com"}, } srv := &Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()), ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{ oidc.ClientIdentity{ Credentials: oidc.ClientCredentials{ ID: "XXX", Secret: "secrete", }, Metadata: oidc.ClientMetadata{ RedirectURLs: []url.URL{ url.URL{Scheme: "http", Host: "foo.example.com", Path: "/callback"}, url.URL{Scheme: "http", Host: "bar.example.com", Path: "/callback"}, }, }, }, }), } tests := []struct { query url.Values wantCode int wantLocation string }{ // provided redirect_uri matches client's first { query: url.Values{ "response_type": []string{"code"}, "redirect_uri": []string{"http://foo.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://fake.example.com", }, // provided redirect_uri matches client's second { query: url.Values{ "response_type": []string{"code"}, "redirect_uri": []string{"http://bar.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, }, wantCode: http.StatusTemporaryRedirect, wantLocation: "http://fake.example.com", }, // provided redirect_uri does not match either of client's { query: url.Values{ "response_type": []string{"code"}, "redirect_uri": []string{"http://unrecognized.example.com/callback"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, }, wantCode: http.StatusBadRequest, }, // no redirect_uri provided { query: url.Values{ "response_type": []string{"code"}, "client_id": []string{"XXX"}, "connector_id": []string{"fake"}, }, wantCode: http.StatusBadRequest, }, } for i, tt := range tests { hdlr := handleAuthFunc(srv, idpcs, nil) w := httptest.NewRecorder() u := fmt.Sprintf("http://server.example.com?%s", tt.query.Encode()) req, err := http.NewRequest("GET", u, nil) if err != nil { t.Errorf("case %d: unable to form HTTP request: %v", i, err) continue } hdlr.ServeHTTP(w, req) if tt.wantCode != w.Code { t.Errorf("case %d: HTTP code mismatch: want=%d got=%d", i, tt.wantCode, w.Code) t.Errorf("case %d: BODY: %v", i, w.Body.String()) t.Errorf("case %d: LOCO: %v", i, w.HeaderMap.Get("Location")) continue } gotLocation := w.Header().Get("Location") if tt.wantLocation != gotLocation { t.Errorf("case %d: HTTP Location header mismatch: want=%s got=%s", i, tt.wantLocation, gotLocation) } } } func TestHandleTokenFuncMethodNotAllowed(t *testing.T) { for _, m := range []string{"GET", "PUT", "DELETE"} { hdlr := handleTokenFunc(nil) req, err := http.NewRequest(m, "http://example.com", nil) if err != nil { t.Errorf("case %s: unable to create HTTP request: %v", m, err) continue } w := httptest.NewRecorder() hdlr.ServeHTTP(w, req) want := http.StatusMethodNotAllowed got := w.Code if want != got { t.Errorf("case %s: expected HTTP %d, got %d", m, want, got) } } } func TestHandleTokenFuncState(t *testing.T) { want := "test-state" v := url.Values{ "state": {want}, } hdlr := handleTokenFunc(nil) req, err := http.NewRequest("POST", "http://example.com", strings.NewReader(v.Encode())) if err != nil { t.Errorf("unable to create HTTP request, error=%v", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") w := httptest.NewRecorder() hdlr.ServeHTTP(w, req) // should have errored and returned state in the response body var resp map[string]string if err = json.Unmarshal(w.Body.Bytes(), &resp); err != nil { t.Errorf("error unmarshaling response, error=%v", err) } got := resp["state"] if want != got { t.Errorf("unexpected state, want=%v, got=%v", want, got) } } func TestHandleDiscoveryFuncMethodNotAllowed(t *testing.T) { for _, m := range []string{"POST", "PUT", "DELETE"} { hdlr := handleDiscoveryFunc(oidc.ProviderConfig{}) req, err := http.NewRequest(m, "http://example.com", nil) if err != nil { t.Errorf("case %s: unable to create HTTP request: %v", m, err) continue } w := httptest.NewRecorder() hdlr.ServeHTTP(w, req) want := http.StatusMethodNotAllowed got := w.Code if want != got { t.Errorf("case %s: expected HTTP %d, got %d", m, want, got) } } } func TestHandleDiscoveryFunc(t *testing.T) { u := "http://server.example.com" cfg := oidc.ProviderConfig{ Issuer: u, AuthEndpoint: u + httpPathAuth, TokenEndpoint: u + httpPathToken, KeysEndpoint: u + httpPathKeys, GrantTypesSupported: []string{oauth2.GrantTypeAuthCode}, ResponseTypesSupported: []string{"code"}, SubjectTypesSupported: []string{"public"}, IDTokenAlgValuesSupported: []string{"RS256"}, TokenEndpointAuthMethodsSupported: []string{"client_secret_basic"}, } req, err := http.NewRequest("GET", "http://server.example.com", nil) if err != nil { t.Fatalf("Failed creating HTTP request: err=%v", err) } w := httptest.NewRecorder() hdlr := handleDiscoveryFunc(cfg) hdlr.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("Incorrect status code: want=200 got=%d", w.Code) } h := w.Header() if ct := h.Get("Content-Type"); ct != "application/json" { t.Fatalf("Incorrect Content-Type: want=application/json, got %s", ct) } gotCC := h.Get("Cache-Control") wantCC := "public, max-age=86400" if wantCC != gotCC { t.Fatalf("Incorrect Cache-Control header: want=%q, got=%q", wantCC, gotCC) } wantBody := `{"issuer":"http://server.example.com","authorization_endpoint":"http://server.example.com/auth","token_endpoint":"http://server.example.com/token","jwks_uri":"http://server.example.com/keys","response_types_supported":["code"],"grant_types_supported":["authorization_code"],"subject_types_supported":["public"],"id_token_alg_values_supported":["RS256"],"token_endpoint_auth_methods_supported":["client_secret_basic"]}` gotBody := w.Body.String() if wantBody != gotBody { t.Fatalf("Incorrect body: want=%s got=%s", wantBody, gotBody) } } func TestHandleKeysFuncMethodNotAllowed(t *testing.T) { for _, m := range []string{"POST", "PUT", "DELETE"} { hdlr := handleKeysFunc(nil, clockwork.NewRealClock()) req, err := http.NewRequest(m, "http://example.com", nil) if err != nil { t.Errorf("case %s: unable to create HTTP request: %v", m, err) continue } w := httptest.NewRecorder() hdlr.ServeHTTP(w, req) want := http.StatusMethodNotAllowed got := w.Code if want != got { t.Errorf("case %s: expected HTTP %d, got %d", m, want, got) } } } func TestHandleKeysFunc(t *testing.T) { fc := clockwork.NewFakeClock() exp := fc.Now().Add(13 * time.Second) km := &StaticKeyManager{ expiresAt: exp, keys: []jose.JWK{ jose.JWK{ ID: "1234", Type: "RSA", Alg: "RS256", Use: "sig", Exponent: 65537, Modulus: big.NewInt(int64(5716758339926702)), }, jose.JWK{ ID: "5678", Type: "RSA", Alg: "RS256", Use: "sig", Exponent: 65537, Modulus: big.NewInt(int64(1234294715519622)), }, }, } req, err := http.NewRequest("GET", "http://server.example.com", nil) if err != nil { t.Fatalf("Failed creating HTTP request: err=%v", err) } w := httptest.NewRecorder() hdlr := handleKeysFunc(km, fc) hdlr.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("Incorrect status code: want=200 got=%d", w.Code) } wantHeader := http.Header{ "Content-Type": []string{"application/json"}, "Cache-Control": []string{"public, max-age=13"}, "Expires": []string{exp.Format(time.RFC1123)}, } gotHeader := w.Header() if !reflect.DeepEqual(wantHeader, gotHeader) { t.Fatalf("Incorrect headers: want=%#v got=%#v", wantHeader, gotHeader) } wantBody := `{"keys":[{"kid":"1234","kty":"RSA","alg":"RS256","use":"sig","e":"AQAB","n":"FE9chh46rg=="},{"kid":"5678","kty":"RSA","alg":"RS256","use":"sig","e":"AQAB","n":"BGKVohEShg=="}]}` gotBody := w.Body.String() if wantBody != gotBody { t.Fatalf("Incorrect body: want=%s got=%s", wantBody, gotBody) } } func TestShouldReprompt(t *testing.T) { tests := []struct { c *http.Cookie v bool }{ // No cookie { c: nil, v: false, }, // different cookie { c: &http.Cookie{ Name: "rando-cookie", }, v: false, }, // actual cookie we care about { c: &http.Cookie{ Name: "LastSeen", }, v: true, }, } for i, tt := range tests { r := &http.Request{Header: make(http.Header)} if tt.c != nil { r.AddCookie(tt.c) } want := tt.v got := shouldReprompt(r) if want != got { t.Errorf("case %d: want=%t, got=%t", i, want, got) } } } type checkable struct { healthy bool } func (c checkable) Healthy() (err error) { if !c.healthy { err = errors.New("im unhealthy") } return }