package server

import (
	"encoding/base64"
	"fmt"
	"net/http"
	"net/http/httptest"
	"net/url"
	"sort"
	"strings"
	"testing"

	"github.com/coreos/go-oidc/oidc"
	"github.com/kylelemons/godebug/pretty"

	"github.com/coreos/dex/client"
	"github.com/coreos/dex/connector"
	"github.com/coreos/dex/scope"
)

func makeCrossClientTestFixtures() (*testFixtures, error) {
	xClients := []client.LoadableClient{}
	for _, cliData := range []struct {
		id           string
		trustedPeers []string
	}{
		{
			id: "client_a",
		}, {
			id:           "client_b",
			trustedPeers: []string{"client_a"},
		}, {
			id:           "client_c",
			trustedPeers: []string{"client_a", "client_b"},
		},
	} {
		u := url.URL{
			Scheme: "https://",
			Path:   cliData.id,
			Host:   cliData.id,
		}
		xClients = append(xClients, client.LoadableClient{
			Client: client.Client{
				Credentials: oidc.ClientCredentials{
					ID: cliData.id,
					Secret: base64.URLEncoding.EncodeToString(
						[]byte(cliData.id + "_secret")),
				},
				Metadata: oidc.ClientMetadata{
					RedirectURIs: []url.URL{u},
				},
			},
			TrustedPeers: cliData.trustedPeers,
		})
	}

	xClients = append(xClients, testClients...)
	f, err := makeTestFixturesWithOptions(testFixtureOptions{
		clients: xClients,
	})
	if err != nil {
		return nil, fmt.Errorf("couldn't make test fixtures: %v", err)
	}
	return f, nil
}

func TestServerCrossClientAuthAllowed(t *testing.T) {
	f, err := makeCrossClientTestFixtures()
	if err != nil {
		t.Fatalf("couldn't make test fixtures: %v", err)
	}

	tests := []struct {
		reqClient      string
		authClient     string
		wantAuthorized bool
		wantErr        bool
	}{
		{
			reqClient:      "client_b",
			authClient:     "client_a",
			wantAuthorized: false,
			wantErr:        false,
		},
		{
			reqClient:      "client_a",
			authClient:     "client_b",
			wantAuthorized: true,
			wantErr:        false,
		},
		{
			reqClient:      "client_a",
			authClient:     "client_c",
			wantAuthorized: true,
			wantErr:        false,
		},
		{
			reqClient:      "client_c",
			authClient:     "client_b",
			wantAuthorized: false,
			wantErr:        false,
		},
		{
			reqClient:  "client_c",
			authClient: "nope",
			wantErr:    false,
		},
	}
	for i, tt := range tests {
		got, err := f.srv.CrossClientAuthAllowed(tt.reqClient, tt.authClient)
		if tt.wantErr {
			if err == nil {
				t.Errorf("case %d: want non-nil err", i)
			}
			continue
		}
		if err != nil {
			t.Errorf("case %d: unexpected err %v: ", i, err)
		}

		if got != tt.wantAuthorized {
			t.Errorf("case %d: want=%v, got=%v", i, tt.wantAuthorized, got)
		}
	}
}

func TestHandleAuthCrossClient(t *testing.T) {
	f, err := makeCrossClientTestFixtures()
	if err != nil {
		t.Fatalf("couldn't make test fixtures: %v", err)
	}

	tests := []struct {
		scopes   []string
		clientID string
		wantCode int
	}{
		{
			scopes:   []string{scope.ScopeGoogleCrossClient + "client_a"},
			clientID: "client_b",
			wantCode: http.StatusBadRequest,
		},
		{
			scopes:   []string{scope.ScopeGoogleCrossClient + "client_b"},
			clientID: "client_a",
			wantCode: http.StatusFound,
		},
		{
			scopes:   []string{scope.ScopeGoogleCrossClient + "client_b"},
			clientID: "client_a",
			wantCode: http.StatusFound,
		},
		{
			scopes:   []string{scope.ScopeGoogleCrossClient + "client_c"},
			clientID: "client_a",
			wantCode: http.StatusFound,
		},
		{
			// Two clients that client_a is authorized to mint tokens for.
			scopes: []string{
				scope.ScopeGoogleCrossClient + "client_c",
				scope.ScopeGoogleCrossClient + "client_b",
			},
			clientID: "client_a",
			wantCode: http.StatusFound,
		},
		{
			// Two clients that client_a is authorized to mint tokens for.
			scopes: []string{
				scope.ScopeGoogleCrossClient + "client_c",
				scope.ScopeGoogleCrossClient + "client_a",
			},
			clientID: "client_b",
			wantCode: http.StatusBadRequest,
		},
	}

	idpcs := []connector.Connector{
		&fakeConnector{loginURL: "http://fake.example.com"},
	}

	for i, tt := range tests {
		hdlr := handleAuthFunc(f.srv, idpcs, nil, true)
		w := httptest.NewRecorder()

		query := url.Values{
			"response_type": []string{"code"},
			"client_id":     []string{tt.clientID},
			"connector_id":  []string{"fake"},
			"scope":         []string{strings.Join(append([]string{"openid"}, tt.scopes...), " ")},
		}
		u := fmt.Sprintf("http://server.example.com?%s", query.Encode())
		req, err := http.NewRequest("GET", u, nil)
		if err != nil {
			t.Errorf("case %d: unable to form HTTP request: %v", i, err)
			continue
		}

		hdlr.ServeHTTP(w, req)
		if tt.wantCode != w.Code {
			t.Errorf("case %d: HTTP code mismatch: want=%d got=%d", i, tt.wantCode, w.Code)
			continue
		}
	}

}

func TestServerCodeTokenCrossClient(t *testing.T) {
	f, err := makeCrossClientTestFixtures()
	if err != nil {
		t.Fatalf("Error creating test fixtures: %v", err)
	}
	sm := f.sessionManager

	tests := []struct {
		clientID     string
		offline      bool
		refreshToken string
		crossClients []string

		wantErr bool
		wantAUD []string
		wantAZP string
	}{
		// First test the non-cross-client cases, make sure they're undisturbed:
		{
			// No 'offline_access' in scope, should get empty refresh token.
			clientID:     testClientID,
			refreshToken: "",

			wantAUD: []string{testClientID},
		},
		{
			// Have 'offline_access' in scope, should get non-empty refresh token.
			clientID:     testClientID,
			offline:      true,
			refreshToken: fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),

			wantAUD: []string{testClientID},
		},
		// Now test cross-client cases:
		{
			clientID:     "client_a",
			crossClients: []string{"client_b"},

			wantAUD: []string{"client_b"},
			wantAZP: "client_a",
		},
		{
			clientID:     "client_a",
			crossClients: []string{"client_b", "client_a"},

			wantAUD: []string{"client_a", "client_b"},
			wantAZP: "client_a",
		},
	}

	for i, tt := range tests {
		scopes := []string{"openid"}
		if tt.offline {
			scopes = append(scopes, "offline_access")
		}
		for _, client := range tt.crossClients {
			scopes = append(scopes, scope.ScopeGoogleCrossClient+client)
		}

		sessionID, err := sm.NewSession("bogus_idpc", tt.clientID, "bogus", url.URL{}, "", false, scopes)
		if err != nil {
			t.Fatalf("case %d: unexpected error: %v", i, err)
		}
		_, err = sm.AttachRemoteIdentity(sessionID, oidc.Identity{})
		if err != nil {
			t.Fatalf("case %d: unexpected error: %v", i, err)
		}

		_, err = sm.AttachUser(sessionID, "ID-1")
		if err != nil {
			t.Fatalf("case %d: unexpected error: %v", i, err)
		}

		key, err := sm.NewSessionKey(sessionID)
		if err != nil {
			t.Fatalf("case %d: unexpected error: %v", i, err)
		}

		jwt, token, err := f.srv.CodeToken(f.clientCreds[tt.clientID], key)
		if err != nil {
			t.Fatalf("case %d: unexpected error: %v", i, err)
		}
		if jwt == nil {
			t.Fatalf("case %d: expect non-nil jwt", i)
		}
		if token != tt.refreshToken {
			t.Errorf("case %d: expect refresh token %q, got %q", i, tt.refreshToken, token)
		}

		claims, err := jwt.Claims()
		if err != nil {
			t.Fatalf("case %d: unexpected error getting claims: %v", i, err)
		}

		var gotAUD []string
		if len(tt.wantAUD) < 2 {
			aud, _, err := claims.StringClaim("aud")
			if err != nil {
				t.Fatalf("case %d: unexpected error getting 'aud': %q: raw: %v", i, err, claims["aud"])
			}
			gotAUD = []string{aud}
		} else {
			gotAUD, _, err = claims.StringsClaim("aud")
			if err != nil {
				t.Fatalf("case %d: unexpected error getting 'aud': %v", i, err)
			}
		}

		sort.Strings(gotAUD)
		if diff := pretty.Compare(tt.wantAUD, gotAUD); diff != "" {
			t.Fatalf("case %d: pretty.Compare(tt.wantAUD, gotAUD): %v", i, diff)
		}

		gotAZP, _, err := claims.StringClaim("azp")
		if err != nil {
			if err != nil {
				t.Fatalf("case %d: unexpected error getting 'aud': %v", i, err)
			}
		}

		if gotAZP != tt.wantAZP {
			t.Errorf("case %d: wantAZP=%v, gotAZP=%v", i, tt.wantAZP, gotAZP)
		}
	}
}