package connector

import (
	"net/http"
	"net/url"
	"strings"

	"github.com/coreos/dex/pkg/log"
	chttp "github.com/coreos/go-oidc/http"
	"github.com/coreos/go-oidc/oauth2"
	"github.com/coreos/go-oidc/oidc"
)

type oauth2Connector interface {
	Client() *oauth2.Client

	// Identity uses a HTTP client authenticated as the end user to construct
	// an OIDC identity for that user.
	Identity(cli chttp.Client) (oidc.Identity, error)

	// Healthy it should attempt to determine if the connector's credientials
	// are valid.
	Healthy() error

	TrustedEmailProvider() bool
}

type OAuth2Connector struct {
	id        string
	loginFunc oidc.LoginFunc
	cbURL     url.URL
	conn      oauth2Connector
}

func (c *OAuth2Connector) ID() string {
	return c.id
}

func (c *OAuth2Connector) Healthy() error {
	return c.conn.Healthy()
}

func (c *OAuth2Connector) Sync() chan struct{} {
	stop := make(chan struct{}, 1)
	return stop
}

func (c *OAuth2Connector) TrustedEmailProvider() bool {
	return c.conn.TrustedEmailProvider()
}

func (c *OAuth2Connector) LoginURL(sessionKey, prompt string) (string, error) {
	return c.conn.Client().AuthCodeURL(sessionKey, oauth2.GrantTypeAuthCode, prompt), nil
}

func (c *OAuth2Connector) Register(mux *http.ServeMux, errorURL url.URL) {
	mux.Handle(c.cbURL.Path, c.handleCallbackFunc(c.loginFunc, errorURL))
}

func (c *OAuth2Connector) handleCallbackFunc(lf oidc.LoginFunc, errorURL url.URL) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		q := r.URL.Query()

		e := q.Get("error")
		if e != "" {
			redirectError(w, errorURL, q)
			return
		}

		code := q.Get("code")
		if code == "" {
			q.Set("error", oauth2.ErrorInvalidRequest)
			q.Set("error_description", "code query param must be set")
			redirectError(w, errorURL, q)
			return
		}
		sessionKey := q.Get("state")

		token, err := c.conn.Client().RequestToken(oauth2.GrantTypeAuthCode, code)
		if err != nil {
			log.Errorf("Unable to verify auth code with issuer: %v", err)
			q.Set("error", oauth2.ErrorUnsupportedResponseType)
			q.Set("error_description", "unable to verify auth code with issuer")
			redirectError(w, errorURL, q)
			return
		}
		ident, err := c.conn.Identity(newAuthenticatedClient(token, http.DefaultClient))
		if err != nil {
			log.Errorf("Unable to retrieve identity: %v", err)
			q.Set("error", oauth2.ErrorUnsupportedResponseType)
			q.Set("error_description", "unable to retrieve identity from issuer")
			redirectError(w, errorURL, q)
			return
		}
		redirectURL, err := lf(ident, sessionKey)
		if err != nil {
			log.Errorf("Unable to log in %#v: %v", ident, err)
			q.Set("error", oauth2.ErrorAccessDenied)
			q.Set("error_description", "login failed")
			redirectError(w, errorURL, q)
			return
		}
		w.Header().Set("Location", redirectURL)
		w.WriteHeader(http.StatusFound)
		return
	}
}

// authedClient authenticates all requests as the end user.
type authedClient struct {
	token oauth2.TokenResponse
	cli   chttp.Client
}

func newAuthenticatedClient(token oauth2.TokenResponse, cli chttp.Client) chttp.Client {
	return &authedClient{token, cli}
}

func (c *authedClient) Do(req *http.Request) (*http.Response, error) {
	req.Header.Set("Authorization", tokenType(c.token)+" "+c.token.AccessToken)
	return c.cli.Do(req)
}

// Return the canonical name of the token type if non-empty, else "Bearer".
// Take from golang.org/x/oauth2
func tokenType(token oauth2.TokenResponse) string {
	if strings.EqualFold(token.TokenType, "bearer") {
		return "Bearer"
	}
	if strings.EqualFold(token.TokenType, "mac") {
		return "MAC"
	}
	if strings.EqualFold(token.TokenType, "basic") {
		return "Basic"
	}
	if token.TokenType != "" {
		return token.TokenType
	}
	return "Bearer"
}