diff --git a/server/server_test.go b/server/server_test.go index b438279c..7a4eb0c1 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -413,6 +413,7 @@ func TestOAuth2CodeFlow(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // Setup a dex server. logger := &logrus.Logger{ Out: os.Stderr, Formatter: &logrus.TextFormatter{DisableColors: true}, @@ -422,7 +423,9 @@ func TestOAuth2CodeFlow(t *testing.T) { c.Issuer = c.Issuer + "/non-root-path" c.Now = now c.IDTokensValidFor = idTokensValidFor - // Create a new mock callback connector for each test case. + + // Testing connector that redirects without interaction with + // the user. conn = mock.NewCallbackConnector(logger).(*mock.Callback) c.Connectors = []Connector{ { @@ -434,14 +437,17 @@ func TestOAuth2CodeFlow(t *testing.T) { }) defer httpServer.Close() + // Query server's provider metadata. p, err := oidc.NewProvider(ctx, httpServer.URL) if err != nil { t.Fatalf("failed to get provider: %v", err) } var ( - reqDump, respDump []byte + // If the OAuth2 client didn't get a response, we need + // to print the requests the user saw. gotCode bool + reqDump, respDump []byte // Auth step, not token. state = "a_state" ) defer func() { @@ -450,46 +456,56 @@ func TestOAuth2CodeFlow(t *testing.T) { } }() + // Setup OAuth2 client. var oauth2Config *oauth2.Config - oauth2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/callback" { - q := r.URL.Query() - if errType := q.Get("error"); errType != "" { - if desc := q.Get("error_description"); desc != "" { - t.Errorf("got error from server %s: %s", errType, desc) - } else { - t.Errorf("got error from server %s", errType) - } - w.WriteHeader(http.StatusInternalServerError) - return - } - - if code := q.Get("code"); code != "" { - gotCode = true - token, err := oauth2Config.Exchange(ctx, code) - if err != nil { - t.Errorf("failed to exchange code for token: %v", err) - return - } - err = tc.handleToken(ctx, p, oauth2Config, token) - if err != nil { - t.Errorf("%s: %v", tc.name, err) - } - return - - } - if gotState := q.Get("state"); gotState != state { - t.Errorf("state did not match, want=%q got=%q", state, gotState) - } - w.WriteHeader(http.StatusOK) + oauth2Client := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/callback" { + // User is visiting app first time. Redirect to dex. + http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusSeeOther) return } - http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusSeeOther) + + // User is at '/callback' so they were just redirected _from_ dex. + q := r.URL.Query() + + // Did dex return an error? + if errType := q.Get("error"); errType != "" { + if desc := q.Get("error_description"); desc != "" { + t.Errorf("got error from server %s: %s", errType, desc) + } else { + t.Errorf("got error from server %s", errType) + } + w.WriteHeader(http.StatusInternalServerError) + return + } + + // Grab code, exchange for token. + if code := q.Get("code"); code != "" { + gotCode = true + token, err := oauth2Config.Exchange(ctx, code) + if err != nil { + t.Errorf("failed to exchange code for token: %v", err) + return + } + err = tc.handleToken(ctx, p, oauth2Config, token) + if err != nil { + t.Errorf("%s: %v", tc.name, err) + } + return + } + + // Ensure state matches. + if gotState := q.Get("state"); gotState != state { + t.Errorf("state did not match, want=%q got=%q", state, gotState) + } + w.WriteHeader(http.StatusOK) + return })) - defer oauth2Server.Close() + defer oauth2Client.Close() - redirectURL := oauth2Server.URL + "/callback" + // Regester the client above with dex. + redirectURL := oauth2Client.URL + "/callback" client := storage.Client{ ID: clientID, Secret: clientSecret, @@ -499,6 +515,7 @@ func TestOAuth2CodeFlow(t *testing.T) { t.Fatalf("failed to create client: %v", err) } + // Create the OAuth2 config. oauth2Config = &oauth2.Config{ ClientID: client.ID, ClientSecret: client.Secret, @@ -510,7 +527,14 @@ func TestOAuth2CodeFlow(t *testing.T) { oauth2Config.Scopes = tc.scopes } - resp, err := http.Get(oauth2Server.URL + "/login") + // Login! + // + // 1. First request to client, redirects to dex. + // 2. Dex "logs in" the user, redirects to client with "code". + // 3. Client exchanges "code" for "token" (id_token, refresh_token, etc.). + // 4. Test is run with OAuth2 token response. + // + resp, err := http.Get(oauth2Client.URL + "/login") if err != nil { t.Fatalf("get failed: %v", err) }