package server

import (
	"errors"
	"net/http"
	"net/http/httptest"
	"net/url"
	"reflect"
	"testing"

	"github.com/coreos/go-oidc/oauth2"
)

func TestWriteAPIError(t *testing.T) {
	tests := []struct {
		err      error
		code     int
		wantCode int
		wantBody string
	}{
		// standard
		{
			err:      newAPIError(errorInvalidRequest, "foo"),
			code:     http.StatusBadRequest,
			wantCode: http.StatusBadRequest,
			wantBody: `{"error":"invalid_request","error_description":"foo"}`,
		},
		// no description
		{
			err:      newAPIError(errorInvalidRequest, ""),
			code:     http.StatusBadRequest,
			wantCode: http.StatusBadRequest,
			wantBody: `{"error":"invalid_request"}`,
		},
		// no type
		{
			err:      newAPIError("", ""),
			code:     http.StatusBadRequest,
			wantCode: http.StatusBadRequest,
			wantBody: `{"error":"server_error"}`,
		},
		// generic error
		{
			err:      errors.New("generic failure"),
			code:     http.StatusTeapot,
			wantCode: http.StatusTeapot,
			wantBody: `{"error":"server_error"}`,
		},
		// nil error
		{
			err:      nil,
			code:     http.StatusTeapot,
			wantCode: http.StatusTeapot,
			wantBody: `{"error":"server_error"}`,
		},
		// empty code
		{
			err:      nil,
			code:     0,
			wantCode: http.StatusInternalServerError,
			wantBody: `{"error":"server_error"}`,
		},
	}

	for i, tt := range tests {
		w := httptest.NewRecorder()
		writeAPIError(w, tt.code, tt.err)

		if tt.wantCode != w.Code {
			t.Errorf("case %d: incorrect HTTP status: want=%d got=%d", i, tt.wantCode, w.Code)
		}

		gotBody := w.Body.String()
		if tt.wantBody != gotBody {
			t.Errorf("case %d: incorrect HTTP body: want=%q got=%q", i, tt.wantBody, gotBody)
		}
	}
}

func TestWriteTokenError(t *testing.T) {
	tests := []struct {
		err        error
		state      string
		wantCode   int
		wantHeader http.Header
		wantBody   string
	}{
		{
			err:      oauth2.NewError(oauth2.ErrorInvalidRequest),
			state:    "bazinga",
			wantCode: http.StatusBadRequest,
			wantHeader: http.Header{
				"Content-Type": []string{"application/json"},
			},
			wantBody: `{"error":"invalid_request","state":"bazinga"}`,
		},
		{
			err:      oauth2.NewError(oauth2.ErrorInvalidRequest),
			wantCode: http.StatusBadRequest,
			wantHeader: http.Header{
				"Content-Type": []string{"application/json"},
			},
			wantBody: `{"error":"invalid_request"}`,
		},
		{
			err:      oauth2.NewError(oauth2.ErrorInvalidGrant),
			wantCode: http.StatusBadRequest,
			wantHeader: http.Header{
				"Content-Type": []string{"application/json"},
			},
			wantBody: `{"error":"invalid_grant"}`,
		},
		{
			err:      oauth2.NewError(oauth2.ErrorInvalidClient),
			wantCode: http.StatusUnauthorized,
			wantHeader: http.Header{
				"Content-Type":     []string{"application/json"},
				"Www-Authenticate": []string{"Basic"},
			},
			wantBody: `{"error":"invalid_client"}`,
		},
		{
			err:      oauth2.NewError(oauth2.ErrorServerError),
			wantCode: http.StatusBadRequest,
			wantHeader: http.Header{
				"Content-Type": []string{"application/json"},
			},
			wantBody: `{"error":"server_error"}`,
		},
		{
			err:      oauth2.NewError(oauth2.ErrorUnsupportedGrantType),
			wantCode: http.StatusBadRequest,
			wantHeader: http.Header{
				"Content-Type": []string{"application/json"},
			},
			wantBody: `{"error":"unsupported_grant_type"}`,
		},
		{
			err:      errors.New("generic failure"),
			wantCode: http.StatusBadRequest,
			wantHeader: http.Header{
				"Content-Type": []string{"application/json"},
			},
			wantBody: `{"error":"server_error"}`,
		},
	}

	for i, tt := range tests {
		w := httptest.NewRecorder()
		writeTokenError(w, tt.err, tt.state)

		if tt.wantCode != w.Code {
			t.Errorf("case %d: incorrect HTTP status: want=%d got=%d", i, tt.wantCode, w.Code)
		}

		gotHeader := w.Header()
		if !reflect.DeepEqual(tt.wantHeader, gotHeader) {
			t.Errorf("case %d: incorrect HTTP headers: want=%#v got=%#v", i, tt.wantHeader, gotHeader)
		}

		gotBody := w.Body.String()
		if tt.wantBody != gotBody {
			t.Errorf("case %d: incorrect HTTP body: want=%q got=%q", i, tt.wantBody, gotBody)
		}
	}
}

func TestWriteAuthError(t *testing.T) {
	wantCode := http.StatusBadRequest
	wantHeader := http.Header{"Content-Type": []string{"application/json"}}
	tests := []struct {
		err      error
		state    string
		wantBody string
	}{
		{
			err:      errors.New("foobar"),
			state:    "bazinga",
			wantBody: `{"error":"server_error","state":"bazinga"}`,
		},
		{
			err:      oauth2.NewError(oauth2.ErrorInvalidRequest),
			state:    "foo",
			wantBody: `{"error":"invalid_request","state":"foo"}`,
		},
		{
			err:      oauth2.NewError(oauth2.ErrorUnsupportedResponseType),
			state:    "bar",
			wantBody: `{"error":"unsupported_response_type","state":"bar"}`,
		},
	}

	for i, tt := range tests {
		w := httptest.NewRecorder()
		writeAuthError(w, tt.err, tt.state)

		if wantCode != w.Code {
			t.Errorf("case %d: incorrect HTTP status: want=%d got=%d", i, wantCode, w.Code)
		}

		gotHeader := w.Header()
		if !reflect.DeepEqual(wantHeader, gotHeader) {
			t.Errorf("case %d: incorrect HTTP headers: want=%#v got=%#v", i, wantHeader, gotHeader)
		}

		gotBody := w.Body.String()
		if tt.wantBody != gotBody {
			t.Errorf("case %d: incorrect HTTP body: want=%q got=%q", i, tt.wantBody, gotBody)
		}
	}
}

func TestRedirectAuthError(t *testing.T) {
	wantCode := http.StatusFound

	tests := []struct {
		err         error
		state       string
		redirectURL url.URL
		wantLoc     string
	}{
		{
			err:         errors.New("foobar"),
			state:       "bazinga",
			redirectURL: url.URL{Scheme: "http", Host: "server.example.com"},
			wantLoc:     "http://server.example.com?error=server_error&state=bazinga",
		},
		{
			err:         oauth2.NewError(oauth2.ErrorInvalidRequest),
			state:       "foo",
			redirectURL: url.URL{Scheme: "http", Host: "server.example.com"},
			wantLoc:     "http://server.example.com?error=invalid_request&state=foo",
		},
		{
			err:         oauth2.NewError(oauth2.ErrorUnsupportedResponseType),
			state:       "bar",
			redirectURL: url.URL{Scheme: "http", Host: "server.example.com"},
			wantLoc:     "http://server.example.com?error=unsupported_response_type&state=bar",
		},
	}

	for i, tt := range tests {
		w := httptest.NewRecorder()
		redirectAuthError(w, tt.err, tt.state, tt.redirectURL)

		if wantCode != w.Code {
			t.Errorf("case %d: incorrect HTTP status: want=%d got=%d", i, wantCode, w.Code)
		}

		wantHeader := http.Header{"Location": []string{tt.wantLoc}}
		gotHeader := w.Header()
		if !reflect.DeepEqual(wantHeader, gotHeader) {
			t.Errorf("case %d: incorrect HTTP headers: want=%#v got=%#v", i, wantHeader, gotHeader)
		}

		gotBody := w.Body.String()
		if gotBody != "" {
			t.Errorf("case %d: incorrect empty HTTP body, got=%q", i, gotBody)
		}
	}
}