package session

import (
	"crypto/rand"
	"encoding/base64"
	"errors"
	"fmt"
	"net/url"
	"time"

	"github.com/jonboulle/clockwork"

	"github.com/coreos/go-oidc/oidc"
)

type GenerateCodeFunc func() (string, error)

func DefaultGenerateCode() (string, error) {
	b := make([]byte, 8)
	n, err := rand.Read(b)
	if err != nil {
		return "", err
	}
	if n != 8 {
		return "", errors.New("unable to read enough random bytes")
	}
	return base64.URLEncoding.EncodeToString(b), nil
}

func NewSessionManager(sRepo SessionRepo, skRepo SessionKeyRepo) *SessionManager {
	return &SessionManager{
		GenerateCode:   DefaultGenerateCode,
		Clock:          clockwork.NewRealClock(),
		ValidityWindow: DefaultSessionValidityWindow,
		sessions:       sRepo,
		keys:           skRepo,
	}
}

type SessionManager struct {
	GenerateCode   GenerateCodeFunc
	Clock          clockwork.Clock
	ValidityWindow time.Duration
	sessions       SessionRepo
	keys           SessionKeyRepo
}

func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
	sID, err := m.GenerateCode()
	if err != nil {
		return "", err
	}

	now := m.Clock.Now()
	s := Session{
		ConnectorID: connectorID,
		ID:          sID,
		State:       SessionStateNew,
		CreatedAt:   now,
		ExpiresAt:   now.Add(m.ValidityWindow),
		ClientID:    clientID,
		ClientState: clientState,
		RedirectURL: redirectURL,
		Register:    register,
		Nonce:       nonce,
		Scope:       scope,
	}

	err = m.sessions.Create(s)
	if err != nil {
		return "", err
	}

	return sID, nil
}

func (m *SessionManager) NewSessionKey(sessionID string) (string, error) {
	key, err := m.GenerateCode()
	if err != nil {
		return "", err
	}

	k := SessionKey{
		Key:       key,
		SessionID: sessionID,
	}

	err = m.keys.Push(k, sessionKeyValidityWindow)
	if err != nil {
		return "", err
	}

	return k.Key, nil
}

func (m *SessionManager) ExchangeKey(key string) (string, error) {
	return m.keys.Pop(key)
}

func (m *SessionManager) getSessionInState(sessionID string, state SessionState) (*Session, error) {
	s, err := m.sessions.Get(sessionID)
	if err != nil {
		return nil, err
	}

	if s.State != state {
		return nil, fmt.Errorf("session state %s, expect %s", s.State, state)
	}

	return s, nil
}

func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*Session, error) {
	s, err := m.getSessionInState(sessionID, SessionStateNew)
	if err != nil {
		return nil, err
	}

	s.Identity = ident
	s.State = SessionStateRemoteAttached

	if err = m.sessions.Update(*s); err != nil {
		return nil, err
	}

	return s, nil
}

func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session, error) {
	s, err := m.getSessionInState(sessionID, SessionStateRemoteAttached)
	if err != nil {
		return nil, err
	}

	s.UserID = userID
	s.State = SessionStateIdentified

	if err = m.sessions.Update(*s); err != nil {
		return nil, err
	}

	return s, nil
}

func (m *SessionManager) Kill(sessionID string) (*Session, error) {
	s, err := m.sessions.Get(sessionID)
	if err != nil {
		return nil, err
	}

	s.State = SessionStateDead

	if err = m.sessions.Update(*s); err != nil {
		return nil, err
	}

	return s, nil
}

func (m *SessionManager) Get(sessionID string) (*Session, error) {
	return m.sessions.Get(sessionID)
}