forked from mystiq/dex
Device flow token code exchange (#2)
* Added /device/token handler with associated business logic and storage tests. Perform user code exchange, flag the device code as complete. Moved device handler code into its own file for cleanliness. Cleanup * Removed PKCE code * Rate limiting for /device/token endpoint based on ietf standards * Configurable Device expiry Signed-off-by: justin-slowik <justin.slowik@thermofisher.com>
This commit is contained in:
parent
0d1a0e4129
commit
9bbdc721d5
20 changed files with 777 additions and 274 deletions
|
@ -283,6 +283,9 @@ type Expiry struct {
|
||||||
|
|
||||||
// AuthRequests defines the duration of time for which the AuthRequests will be valid.
|
// AuthRequests defines the duration of time for which the AuthRequests will be valid.
|
||||||
AuthRequests string `json:"authRequests"`
|
AuthRequests string `json:"authRequests"`
|
||||||
|
|
||||||
|
// DeviceRequests defines the duration of time for which the DeviceRequests will be valid.
|
||||||
|
DeviceRequests string `json:"deviceRequests"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logger holds configuration required to customize logging for dex.
|
// Logger holds configuration required to customize logging for dex.
|
||||||
|
|
|
@ -119,6 +119,7 @@ expiry:
|
||||||
signingKeys: "7h"
|
signingKeys: "7h"
|
||||||
idTokens: "25h"
|
idTokens: "25h"
|
||||||
authRequests: "25h"
|
authRequests: "25h"
|
||||||
|
deviceRequests: "10m"
|
||||||
|
|
||||||
logger:
|
logger:
|
||||||
level: "debug"
|
level: "debug"
|
||||||
|
@ -200,6 +201,7 @@ logger:
|
||||||
SigningKeys: "7h",
|
SigningKeys: "7h",
|
||||||
IDTokens: "25h",
|
IDTokens: "25h",
|
||||||
AuthRequests: "25h",
|
AuthRequests: "25h",
|
||||||
|
DeviceRequests: "10m",
|
||||||
},
|
},
|
||||||
Logger: Logger{
|
Logger: Logger{
|
||||||
Level: "debug",
|
Level: "debug",
|
||||||
|
|
|
@ -269,7 +269,14 @@ func serve(cmd *cobra.Command, args []string) error {
|
||||||
logger.Infof("config auth requests valid for: %v", authRequests)
|
logger.Infof("config auth requests valid for: %v", authRequests)
|
||||||
serverConfig.AuthRequestsValidFor = authRequests
|
serverConfig.AuthRequestsValidFor = authRequests
|
||||||
}
|
}
|
||||||
|
if c.Expiry.DeviceRequests != "" {
|
||||||
|
deviceRequests, err := time.ParseDuration(c.Expiry.DeviceRequests)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid config value %q for device request expiry: %v", c.Expiry.AuthRequests, err)
|
||||||
|
}
|
||||||
|
logger.Infof("config device requests valid for: %v", deviceRequests)
|
||||||
|
serverConfig.DeviceRequestsValidFor = deviceRequests
|
||||||
|
}
|
||||||
serv, err := server.NewServer(context.Background(), serverConfig)
|
serv, err := server.NewServer(context.Background(), serverConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize server: %v", err)
|
return fmt.Errorf("failed to initialize server: %v", err)
|
||||||
|
|
|
@ -64,6 +64,7 @@ telemetry:
|
||||||
|
|
||||||
# Uncomment this block to enable configuration for the expiration time durations.
|
# Uncomment this block to enable configuration for the expiration time durations.
|
||||||
# expiry:
|
# expiry:
|
||||||
|
# deviceRequests: "5m"
|
||||||
# signingKeys: "6h"
|
# signingKeys: "6h"
|
||||||
# idTokens: "24h"
|
# idTokens: "24h"
|
||||||
|
|
||||||
|
|
359
server/deviceHandlers.go
Normal file
359
server/deviceHandlers.go
Normal file
|
@ -0,0 +1,359 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"path"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
type deviceCodeResponse struct {
|
||||||
|
//The unique device code for device authentication
|
||||||
|
DeviceCode string `json:"device_code"`
|
||||||
|
//The code the user will exchange via a browser and log in
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
//The url to verify the user code.
|
||||||
|
VerificationURI string `json:"verification_uri"`
|
||||||
|
//The verification uri with the user code appended for pre-filling form
|
||||||
|
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||||
|
//The lifetime of the device code
|
||||||
|
ExpireTime int `json:"expires_in"`
|
||||||
|
//How often the device is allowed to poll to verify that the user login occurred
|
||||||
|
PollInterval int `json:"interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getDeviceAuthURI() string {
|
||||||
|
return path.Join(s.issuerURL.Path, "/device/auth/verify_code")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
userCode := r.URL.Query().Get("user_code")
|
||||||
|
invalidAttempt, err := strconv.ParseBool(r.URL.Query().Get("invalid"))
|
||||||
|
if err != nil {
|
||||||
|
invalidAttempt = false
|
||||||
|
}
|
||||||
|
if err := s.templates.device(r, w, s.getDeviceAuthURI(), userCode, invalidAttempt); err != nil {
|
||||||
|
s.logger.Errorf("Server template error: %v", err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||||
|
pollIntervalSeconds := 5
|
||||||
|
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodPost:
|
||||||
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("Could not parse Device Request body: %v", err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//Get the client id and scopes from the post
|
||||||
|
clientID := r.Form.Get("client_id")
|
||||||
|
scopes := r.Form["scope"]
|
||||||
|
|
||||||
|
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
|
||||||
|
|
||||||
|
//Make device code
|
||||||
|
deviceCode := storage.NewDeviceCode()
|
||||||
|
|
||||||
|
//make user code
|
||||||
|
userCode, err := storage.NewUserCode()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("Error generating user code: %v", err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Generate the expire time
|
||||||
|
expireTime := time.Now().Add(s.deviceRequestsValidFor)
|
||||||
|
|
||||||
|
//Store the Device Request
|
||||||
|
deviceReq := storage.DeviceRequest{
|
||||||
|
UserCode: userCode,
|
||||||
|
DeviceCode: deviceCode,
|
||||||
|
ClientID: clientID,
|
||||||
|
Scopes: scopes,
|
||||||
|
Expiry: expireTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
|
||||||
|
s.logger.Errorf("Failed to store device request; %v", err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//Store the device token
|
||||||
|
deviceToken := storage.DeviceToken{
|
||||||
|
DeviceCode: deviceCode,
|
||||||
|
Status: deviceTokenPending,
|
||||||
|
Expiry: expireTime,
|
||||||
|
LastRequestTime: time.Now(),
|
||||||
|
PollIntervalSeconds: 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
|
||||||
|
s.logger.Errorf("Failed to store device token %v", err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(s.issuerURL.String())
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("Could not parse issuer URL %v", err)
|
||||||
|
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
u.Path = path.Join(u.Path, "device")
|
||||||
|
vURI := u.String()
|
||||||
|
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("user_code", userCode)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
vURIComplete := u.String()
|
||||||
|
|
||||||
|
code := deviceCodeResponse{
|
||||||
|
DeviceCode: deviceCode,
|
||||||
|
UserCode: userCode,
|
||||||
|
VerificationURI: vURI,
|
||||||
|
VerificationURIComplete: vURIComplete,
|
||||||
|
ExpireTime: int(s.deviceRequestsValidFor.Seconds()),
|
||||||
|
PollInterval: pollIntervalSeconds,
|
||||||
|
}
|
||||||
|
|
||||||
|
enc := json.NewEncoder(w)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
enc.Encode(code)
|
||||||
|
|
||||||
|
default:
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodPost:
|
||||||
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warnf("Could not parse Device Token Request body: %v", err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceCode := r.Form.Get("device_code")
|
||||||
|
if deviceCode == "" {
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
grantType := r.PostFormValue("grant_type")
|
||||||
|
if grantType != grantTypeDeviceCode {
|
||||||
|
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := s.now()
|
||||||
|
|
||||||
|
//Grab the device token
|
||||||
|
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
||||||
|
if err != nil || now.After(deviceToken.Expiry) {
|
||||||
|
if err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get device code: %v", err)
|
||||||
|
}
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//Rate Limiting check
|
||||||
|
pollInterval := deviceToken.PollIntervalSeconds
|
||||||
|
minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval))
|
||||||
|
if now.Before(minRequestTime) {
|
||||||
|
s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest)
|
||||||
|
//Continually increase the poll interval until the user waits the proper time
|
||||||
|
pollInterval += 5
|
||||||
|
} else {
|
||||||
|
pollInterval = 5
|
||||||
|
}
|
||||||
|
|
||||||
|
switch deviceToken.Status {
|
||||||
|
case deviceTokenPending:
|
||||||
|
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||||
|
old.PollIntervalSeconds = pollInterval
|
||||||
|
old.LastRequestTime = now
|
||||||
|
return old, nil
|
||||||
|
}
|
||||||
|
// Update device token last request time in storage
|
||||||
|
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
|
||||||
|
s.logger.Errorf("failed to update device token: %v", err)
|
||||||
|
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
||||||
|
case deviceTokenComplete:
|
||||||
|
w.Write([]byte(deviceToken.Token))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
userCode := r.FormValue("state")
|
||||||
|
code := r.FormValue("code")
|
||||||
|
|
||||||
|
if userCode == "" || code == "" {
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "Request was missing parameters")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorization redirect callback from OAuth2 auth flow.
|
||||||
|
if errMsg := r.FormValue("error"); errMsg != "" {
|
||||||
|
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
authCode, err := s.storage.GetAuthCode(code)
|
||||||
|
if err != nil || s.now().After(authCode.Expiry) {
|
||||||
|
if err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get auth code: %v", err)
|
||||||
|
}
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "Invalid or expired auth code.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//Grab the device request from storage
|
||||||
|
deviceReq, err := s.storage.GetDeviceRequest(userCode)
|
||||||
|
if err != nil || s.now().After(deviceReq.Expiry) {
|
||||||
|
if err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get device code: %v", err)
|
||||||
|
}
|
||||||
|
s.renderError(r, w, http.StatusInternalServerError, "Invalid or expired device code.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reqClient, err := s.storage.GetClient(deviceReq.ClientID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("Failed to get reqClient %q: %v", deviceReq.ClientID, err)
|
||||||
|
s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve device client.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.exchangeAuthCode(w, authCode, reqClient)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err)
|
||||||
|
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//Grab the device request from storage
|
||||||
|
old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode)
|
||||||
|
if err != nil || s.now().After(old.Expiry) {
|
||||||
|
if err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get device token: %v", err)
|
||||||
|
}
|
||||||
|
s.renderError(r, w, http.StatusInternalServerError, "Invalid or expired device code.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||||
|
if old.Status == deviceTokenComplete {
|
||||||
|
return old, errors.New("device token already complete")
|
||||||
|
}
|
||||||
|
respStr, err := json.MarshalIndent(resp, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("failed to marshal device token response: %v", err)
|
||||||
|
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||||
|
return old, err
|
||||||
|
}
|
||||||
|
|
||||||
|
old.Token = string(respStr)
|
||||||
|
old.Status = deviceTokenComplete
|
||||||
|
return old, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update refresh token in the storage, store the token and mark as complete
|
||||||
|
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
|
||||||
|
s.logger.Errorf("failed to update device token: %v", err)
|
||||||
|
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.templates.deviceSuccess(r, w, reqClient.Name); err != nil {
|
||||||
|
s.logger.Errorf("Server template error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodPost:
|
||||||
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
message := "Could not parse user code verification Request body"
|
||||||
|
s.logger.Warnf("%s : %v", message, err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userCode := r.Form.Get("user_code")
|
||||||
|
if userCode == "" {
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "No user code received")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userCode = strings.ToUpper(userCode)
|
||||||
|
|
||||||
|
//Find the user code in the available requests
|
||||||
|
deviceRequest, err := s.storage.GetDeviceRequest(userCode)
|
||||||
|
if err != nil || s.now().After(deviceRequest.Expiry) {
|
||||||
|
if err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get device request: %v", err)
|
||||||
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if err := s.templates.device(r, w, s.getDeviceAuthURI(), userCode, true); err != nil {
|
||||||
|
s.logger.Errorf("Server template error: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//Redirect to Dex Auth Endpoint
|
||||||
|
authURL := path.Join(s.issuerURL.Path, "/auth")
|
||||||
|
u, err := url.Parse(authURL)
|
||||||
|
if err != nil {
|
||||||
|
s.renderError(r, w, http.StatusInternalServerError, "Invalid auth URI.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("client_id", deviceRequest.ClientID)
|
||||||
|
q.Set("state", deviceRequest.UserCode)
|
||||||
|
q.Set("response_type", "code")
|
||||||
|
q.Set("redirect_uri", path.Join(s.issuerURL.Path, "/device/callback"))
|
||||||
|
q.Set("scope", strings.Join(deviceRequest.Scopes, " "))
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
http.Redirect(w, r, u.String(), http.StatusFound)
|
||||||
|
|
||||||
|
default:
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||||
|
}
|
||||||
|
}
|
|
@ -152,6 +152,8 @@ type discovery struct {
|
||||||
Token string `json:"token_endpoint"`
|
Token string `json:"token_endpoint"`
|
||||||
Keys string `json:"jwks_uri"`
|
Keys string `json:"jwks_uri"`
|
||||||
UserInfo string `json:"userinfo_endpoint"`
|
UserInfo string `json:"userinfo_endpoint"`
|
||||||
|
DeviceEndpoint string `json:"device_authorization_endpoint"`
|
||||||
|
GrantTypes []string `json:"grant_types_supported"'`
|
||||||
ResponseTypes []string `json:"response_types_supported"`
|
ResponseTypes []string `json:"response_types_supported"`
|
||||||
Subjects []string `json:"subject_types_supported"`
|
Subjects []string `json:"subject_types_supported"`
|
||||||
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
|
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
|
||||||
|
@ -167,7 +169,9 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
|
||||||
Token: s.absURL("/token"),
|
Token: s.absURL("/token"),
|
||||||
Keys: s.absURL("/keys"),
|
Keys: s.absURL("/keys"),
|
||||||
UserInfo: s.absURL("/userinfo"),
|
UserInfo: s.absURL("/userinfo"),
|
||||||
|
DeviceEndpoint: s.absURL("/device/code"),
|
||||||
Subjects: []string{"public"},
|
Subjects: []string{"public"},
|
||||||
|
GrantTypes: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode},
|
||||||
IDTokenAlgs: []string{string(jose.RS256)},
|
IDTokenAlgs: []string{string(jose.RS256)},
|
||||||
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
|
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
|
||||||
AuthMethods: []string{"client_secret_basic"},
|
AuthMethods: []string{"client_secret_basic"},
|
||||||
|
@ -783,24 +787,33 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tokenResponse, err := s.exchangeAuthCode(w, authCode, client)
|
||||||
|
if err != nil {
|
||||||
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.writeAccessToken(w, tokenResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenReponse, error) {
|
||||||
accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
|
accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to create new access token: %v", err)
|
s.logger.Errorf("failed to create new access token: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID)
|
idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to create ID token: %v", err)
|
s.logger.Errorf("failed to create ID token: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.storage.DeleteAuthCode(code); err != nil {
|
if err := s.storage.DeleteAuthCode(authCode.ID); err != nil {
|
||||||
s.logger.Errorf("failed to delete auth code: %v", err)
|
s.logger.Errorf("failed to delete auth code: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
reqRefresh := func() bool {
|
reqRefresh := func() bool {
|
||||||
|
@ -847,13 +860,13 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||||
if refreshToken, err = internal.Marshal(token); err != nil {
|
if refreshToken, err = internal.Marshal(token); err != nil {
|
||||||
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.storage.CreateRefresh(refresh); err != nil {
|
if err := s.storage.CreateRefresh(refresh); err != nil {
|
||||||
s.logger.Errorf("failed to create refresh token: %v", err)
|
s.logger.Errorf("failed to create refresh token: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// deleteToken determines if we need to delete the newly created refresh token
|
// deleteToken determines if we need to delete the newly created refresh token
|
||||||
|
@ -884,7 +897,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||||
s.logger.Errorf("failed to get offline session: %v", err)
|
s.logger.Errorf("failed to get offline session: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
deleteToken = true
|
deleteToken = true
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
offlineSessions := storage.OfflineSessions{
|
offlineSessions := storage.OfflineSessions{
|
||||||
UserID: refresh.Claims.UserID,
|
UserID: refresh.Claims.UserID,
|
||||||
|
@ -899,7 +912,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||||
s.logger.Errorf("failed to create offline session: %v", err)
|
s.logger.Errorf("failed to create offline session: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
deleteToken = true
|
deleteToken = true
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
|
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
|
||||||
|
@ -908,7 +921,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||||
s.logger.Errorf("failed to delete refresh token: %v", err)
|
s.logger.Errorf("failed to delete refresh token: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
deleteToken = true
|
deleteToken = true
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -920,11 +933,11 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||||
s.logger.Errorf("failed to update offline session: %v", err)
|
s.logger.Errorf("failed to update offline session: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
deleteToken = true
|
deleteToken = true
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
|
return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
|
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
|
||||||
|
@ -1120,7 +1133,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
|
resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
|
||||||
|
s.writeAccessToken(w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -1378,12 +1392,26 @@ func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, r
|
||||||
RefreshToken string `json:"refresh_token,omitempty"`
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
IDToken string `json:"id_token"`
|
IDToken string `json:"id_token"`
|
||||||
}{
|
}{
|
||||||
|
|
||||||
|
type accessTokenReponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string, expiry time.Time) *accessTokenReponse {
|
||||||
|
return &accessTokenReponse{
|
||||||
accessToken,
|
accessToken,
|
||||||
"bearer",
|
"bearer",
|
||||||
int(expiry.Sub(s.now()).Seconds()),
|
int(expiry.Sub(s.now()).Seconds()),
|
||||||
refreshToken,
|
refreshToken,
|
||||||
idToken,
|
idToken,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenReponse) {
|
||||||
data, err := json.Marshal(resp)
|
data, err := json.Marshal(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to marshal access token response: %v", err)
|
s.logger.Errorf("failed to marshal access token response: %v", err)
|
||||||
|
@ -1414,145 +1442,3 @@ func usernamePrompt(conn connector.PasswordConnector) string {
|
||||||
}
|
}
|
||||||
return "Username"
|
return "Username"
|
||||||
}
|
}
|
||||||
|
|
||||||
type deviceCodeResponse struct {
|
|
||||||
//The unique device code for device authentication
|
|
||||||
DeviceCode string `json:"device_code"`
|
|
||||||
//The code the user will exchange via a browser and log in
|
|
||||||
UserCode string `json:"user_code"`
|
|
||||||
//The url to verify the user code.
|
|
||||||
VerificationURI string `json:"verification_uri"`
|
|
||||||
//The lifetime of the device code
|
|
||||||
ExpireTime int `json:"expires_in"`
|
|
||||||
//How often the device is allowed to poll to verify that the user login occurred
|
|
||||||
PollInterval int `json:"interval"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
|
||||||
//TODO replace with configurable values
|
|
||||||
expireIntervalSeconds := 300
|
|
||||||
requestsPerMinute := 5
|
|
||||||
|
|
||||||
switch r.Method {
|
|
||||||
case http.MethodPost:
|
|
||||||
err := r.ParseForm()
|
|
||||||
if err != nil {
|
|
||||||
s.logger.Errorf("Could not parse Device Request body: %v", err)
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
//Get the client id and scopes from the post
|
|
||||||
clientID := r.Form.Get("client_id")
|
|
||||||
scopes := r.Form["scope"]
|
|
||||||
|
|
||||||
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
|
|
||||||
|
|
||||||
//Make device code
|
|
||||||
deviceCode := storage.NewDeviceCode()
|
|
||||||
|
|
||||||
//make user code
|
|
||||||
userCode, err := storage.NewUserCode()
|
|
||||||
if err != nil {
|
|
||||||
s.logger.Errorf("Error generating user code: %v", err)
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
//make a pkce verification code
|
|
||||||
pkceCode := storage.NewID()
|
|
||||||
|
|
||||||
//Generate the expire time
|
|
||||||
expireTime := time.Now().Add(time.Second * time.Duration(expireIntervalSeconds))
|
|
||||||
|
|
||||||
//Store the Device Request
|
|
||||||
deviceReq := storage.DeviceRequest{
|
|
||||||
UserCode: userCode,
|
|
||||||
DeviceCode: deviceCode,
|
|
||||||
ClientID: clientID,
|
|
||||||
Scopes: scopes,
|
|
||||||
PkceVerifier: pkceCode,
|
|
||||||
Expiry: expireTime,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
|
|
||||||
s.logger.Errorf("Failed to store device request; %v", err)
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
//Store the device token
|
|
||||||
deviceToken := storage.DeviceToken{
|
|
||||||
DeviceCode: deviceCode,
|
|
||||||
Status: deviceTokenPending,
|
|
||||||
Expiry: expireTime,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
|
|
||||||
s.logger.Errorf("Failed to store device token %v", err)
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
code := deviceCodeResponse{
|
|
||||||
DeviceCode: deviceCode,
|
|
||||||
UserCode: userCode,
|
|
||||||
VerificationURI: path.Join(s.issuerURL.String(), "/device"),
|
|
||||||
ExpireTime: expireIntervalSeconds,
|
|
||||||
PollInterval: requestsPerMinute,
|
|
||||||
}
|
|
||||||
|
|
||||||
enc := json.NewEncoder(w)
|
|
||||||
enc.SetIndent("", " ")
|
|
||||||
enc.Encode(code)
|
|
||||||
|
|
||||||
default:
|
|
||||||
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
switch r.Method {
|
|
||||||
case http.MethodPost:
|
|
||||||
err := r.ParseForm()
|
|
||||||
if err != nil {
|
|
||||||
message := "Could not parse Device Token Request body"
|
|
||||||
s.logger.Warnf("%s : %v", message, err)
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceCode := r.Form.Get("device_code")
|
|
||||||
if deviceCode == "" {
|
|
||||||
message := "No device code received"
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
grantType := r.PostFormValue("grant_type")
|
|
||||||
if grantType != grantTypeDeviceCode {
|
|
||||||
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
//Grab the device token from the db
|
|
||||||
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
|
||||||
if err != nil || s.now().After(deviceToken.Expiry) {
|
|
||||||
if err != storage.ErrNotFound {
|
|
||||||
s.logger.Errorf("failed to get device code: %v", err)
|
|
||||||
}
|
|
||||||
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch deviceToken.Status {
|
|
||||||
case deviceTokenPending:
|
|
||||||
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
|
||||||
case deviceTokenComplete:
|
|
||||||
w.Write([]byte(deviceToken.Token))
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -122,7 +122,7 @@ const (
|
||||||
grantTypeAuthorizationCode = "authorization_code"
|
grantTypeAuthorizationCode = "authorization_code"
|
||||||
grantTypeRefreshToken = "refresh_token"
|
grantTypeRefreshToken = "refresh_token"
|
||||||
grantTypePassword = "password"
|
grantTypePassword = "password"
|
||||||
grantTypeDeviceCode = "device_code"
|
grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -134,6 +134,8 @@ const (
|
||||||
const (
|
const (
|
||||||
deviceTokenPending = "authorization_pending"
|
deviceTokenPending = "authorization_pending"
|
||||||
deviceTokenComplete = "complete"
|
deviceTokenComplete = "complete"
|
||||||
|
deviceTokenSlowDown = "slow_down"
|
||||||
|
deviceTokenExpired = "expired_token"
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseScopes(scopes []string) connector.Scopes {
|
func parseScopes(scopes []string) connector.Scopes {
|
||||||
|
|
|
@ -78,6 +78,7 @@ type Config struct {
|
||||||
RotateKeysAfter time.Duration // Defaults to 6 hours.
|
RotateKeysAfter time.Duration // Defaults to 6 hours.
|
||||||
IDTokensValidFor time.Duration // Defaults to 24 hours
|
IDTokensValidFor time.Duration // Defaults to 24 hours
|
||||||
AuthRequestsValidFor time.Duration // Defaults to 24 hours
|
AuthRequestsValidFor time.Duration // Defaults to 24 hours
|
||||||
|
DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
|
||||||
// If set, the server will use this connector to handle password grants
|
// If set, the server will use this connector to handle password grants
|
||||||
PasswordConnector string
|
PasswordConnector string
|
||||||
|
|
||||||
|
@ -158,6 +159,7 @@ type Server struct {
|
||||||
|
|
||||||
idTokensValidFor time.Duration
|
idTokensValidFor time.Duration
|
||||||
authRequestsValidFor time.Duration
|
authRequestsValidFor time.Duration
|
||||||
|
deviceRequestsValidFor time.Duration
|
||||||
|
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
}
|
}
|
||||||
|
@ -219,6 +221,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
||||||
supportedResponseTypes: supported,
|
supportedResponseTypes: supported,
|
||||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||||
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
|
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
|
||||||
|
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
|
||||||
skipApproval: c.SkipApprovalScreen,
|
skipApproval: c.SkipApprovalScreen,
|
||||||
alwaysShowLogin: c.AlwaysShowLoginScreen,
|
alwaysShowLogin: c.AlwaysShowLoginScreen,
|
||||||
now: now,
|
now: now,
|
||||||
|
@ -302,8 +305,11 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
||||||
handleWithCORS("/userinfo", s.handleUserInfo)
|
handleWithCORS("/userinfo", s.handleUserInfo)
|
||||||
handleFunc("/auth", s.handleAuthorization)
|
handleFunc("/auth", s.handleAuthorization)
|
||||||
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
||||||
|
handleFunc("/device", s.handleDeviceExchange)
|
||||||
|
handleFunc("/device/auth/verify_code", s.verifyUserCode)
|
||||||
handleFunc("/device/code", s.handleDeviceCode)
|
handleFunc("/device/code", s.handleDeviceCode)
|
||||||
handleFunc("/device/token", s.handleDeviceToken)
|
handleFunc("/device/token", s.handleDeviceToken)
|
||||||
|
handleFunc("/device/callback", s.handleDeviceCallback)
|
||||||
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Strip the X-Remote-* headers to prevent security issues on
|
// Strip the X-Remote-* headers to prevent security issues on
|
||||||
// misconfigured authproxy connector setups.
|
// misconfigured authproxy connector setups.
|
||||||
|
|
|
@ -20,6 +20,8 @@ const (
|
||||||
tmplPassword = "password.html"
|
tmplPassword = "password.html"
|
||||||
tmplOOB = "oob.html"
|
tmplOOB = "oob.html"
|
||||||
tmplError = "error.html"
|
tmplError = "error.html"
|
||||||
|
tmplDevice = "device.html"
|
||||||
|
tmplDeviceSuccess = "device_success.html"
|
||||||
)
|
)
|
||||||
|
|
||||||
var requiredTmpls = []string{
|
var requiredTmpls = []string{
|
||||||
|
@ -28,6 +30,7 @@ var requiredTmpls = []string{
|
||||||
tmplPassword,
|
tmplPassword,
|
||||||
tmplOOB,
|
tmplOOB,
|
||||||
tmplError,
|
tmplError,
|
||||||
|
tmplDevice,
|
||||||
}
|
}
|
||||||
|
|
||||||
type templates struct {
|
type templates struct {
|
||||||
|
@ -36,6 +39,8 @@ type templates struct {
|
||||||
passwordTmpl *template.Template
|
passwordTmpl *template.Template
|
||||||
oobTmpl *template.Template
|
oobTmpl *template.Template
|
||||||
errorTmpl *template.Template
|
errorTmpl *template.Template
|
||||||
|
deviceTmpl *template.Template
|
||||||
|
deviceSuccessTmpl *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
type webConfig struct {
|
type webConfig struct {
|
||||||
|
@ -157,6 +162,8 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) {
|
||||||
passwordTmpl: tmpls.Lookup(tmplPassword),
|
passwordTmpl: tmpls.Lookup(tmplPassword),
|
||||||
oobTmpl: tmpls.Lookup(tmplOOB),
|
oobTmpl: tmpls.Lookup(tmplOOB),
|
||||||
errorTmpl: tmpls.Lookup(tmplError),
|
errorTmpl: tmpls.Lookup(tmplError),
|
||||||
|
deviceTmpl: tmpls.Lookup(tmplDevice),
|
||||||
|
deviceSuccessTmpl: tmpls.Lookup(tmplDeviceSuccess),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -242,6 +249,24 @@ func (n byName) Len() int { return len(n) }
|
||||||
func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
|
func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
|
||||||
func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
|
func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
|
||||||
|
|
||||||
|
func (t *templates) device(r *http.Request, w http.ResponseWriter, postURL string, userCode string, lastWasInvalid bool) error {
|
||||||
|
data := struct {
|
||||||
|
PostURL string
|
||||||
|
UserCode string
|
||||||
|
Invalid bool
|
||||||
|
ReqPath string
|
||||||
|
}{postURL, userCode, lastWasInvalid, r.URL.Path}
|
||||||
|
return renderTemplate(w, t.deviceTmpl, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *templates) deviceSuccess(r *http.Request, w http.ResponseWriter, clientName string) error {
|
||||||
|
data := struct {
|
||||||
|
ClientName string
|
||||||
|
ReqPath string
|
||||||
|
}{clientName, r.URL.Path}
|
||||||
|
return renderTemplate(w, t.deviceSuccessTmpl, data)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error {
|
func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error {
|
||||||
sort.Sort(byName(connectors))
|
sort.Sort(byName(connectors))
|
||||||
data := struct {
|
data := struct {
|
||||||
|
|
|
@ -847,7 +847,6 @@ func testGC(t *testing.T, s storage.Storage) {
|
||||||
DeviceCode: storage.NewID(),
|
DeviceCode: storage.NewID(),
|
||||||
ClientID: "client1",
|
ClientID: "client1",
|
||||||
Scopes: []string{"openid", "email"},
|
Scopes: []string{"openid", "email"},
|
||||||
PkceVerifier: storage.NewID(),
|
|
||||||
Expiry: expiry,
|
Expiry: expiry,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -974,7 +973,6 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
|
||||||
DeviceCode: storage.NewID(),
|
DeviceCode: storage.NewID(),
|
||||||
ClientID: "client1",
|
ClientID: "client1",
|
||||||
Scopes: []string{"openid", "email"},
|
Scopes: []string{"openid", "email"},
|
||||||
PkceVerifier: storage.NewID(),
|
|
||||||
Expiry: neverExpire,
|
Expiry: neverExpire,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -991,20 +989,44 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
|
func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
|
||||||
|
//Create a Token
|
||||||
d1 := storage.DeviceToken{
|
d1 := storage.DeviceToken{
|
||||||
DeviceCode: storage.NewID(),
|
DeviceCode: storage.NewID(),
|
||||||
Status: "pending",
|
Status: "pending",
|
||||||
Token: storage.NewID(),
|
Token: storage.NewID(),
|
||||||
Expiry: neverExpire,
|
Expiry: neverExpire,
|
||||||
|
LastRequestTime: time.Now(),
|
||||||
|
PollIntervalSeconds: 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.CreateDeviceToken(d1); err != nil {
|
if err := s.CreateDeviceToken(d1); err != nil {
|
||||||
t.Fatalf("failed creating device token: %v", err)
|
t.Fatalf("failed creating device token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt to create same DeviceRequest twice.
|
// Attempt to create same Device Token twice.
|
||||||
err := s.CreateDeviceToken(d1)
|
err := s.CreateDeviceToken(d1)
|
||||||
mustBeErrAlreadyExists(t, "device token", err)
|
mustBeErrAlreadyExists(t, "device token", err)
|
||||||
|
|
||||||
//TODO Add update / delete tests as functionality is put into main code
|
//Update the device token, simulate a redemption
|
||||||
|
if err := s.UpdateDeviceToken(d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||||
|
old.Token = "token data"
|
||||||
|
old.Status = "complete"
|
||||||
|
return old, nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("failed to update device token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Retrieve the device token
|
||||||
|
got, err := s.GetDeviceToken(d1.DeviceCode)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get device token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Validate expected result set
|
||||||
|
if got.Status != "complete" {
|
||||||
|
t.Fatalf("update failed, wanted token status=%#v got %#v", "complete", got.Status)
|
||||||
|
}
|
||||||
|
if got.Token != "token data" {
|
||||||
|
t.Fatalf("update failed, wanted token =%#v got %#v", "token data", got.Token)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -570,6 +570,13 @@ func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||||
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
|
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) GetDeviceRequest(userCode string) (r storage.DeviceRequest, err error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||||
|
defer cancel()
|
||||||
|
err = c.getKey(ctx, keyID(deviceRequestPrefix, userCode), &r)
|
||||||
|
return r, err
|
||||||
|
}
|
||||||
|
|
||||||
func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) {
|
func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) {
|
||||||
res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix())
|
res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -612,3 +619,21 @@ func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken
|
||||||
}
|
}
|
||||||
return deviceTokens, nil
|
return deviceTokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||||
|
defer cancel()
|
||||||
|
return c.txnUpdate(ctx, keyID(deviceTokenPrefix, deviceCode), func(currentValue []byte) ([]byte, error) {
|
||||||
|
var current DeviceToken
|
||||||
|
if len(currentValue) > 0 {
|
||||||
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
updated, err := updater(toStorageDeviceToken(current))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return json.Marshal(fromStorageDeviceToken(updated))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -223,7 +223,6 @@ type DeviceRequest struct {
|
||||||
DeviceCode string `json:"device_code"`
|
DeviceCode string `json:"device_code"`
|
||||||
ClientID string `json:"client_id"`
|
ClientID string `json:"client_id"`
|
||||||
Scopes []string `json:"scopes"`
|
Scopes []string `json:"scopes"`
|
||||||
PkceVerifier string `json:"pkce_verifier"`
|
|
||||||
Expiry time.Time `json:"expiry"`
|
Expiry time.Time `json:"expiry"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -233,7 +232,6 @@ func fromStorageDeviceRequest(d storage.DeviceRequest) DeviceRequest {
|
||||||
DeviceCode: d.DeviceCode,
|
DeviceCode: d.DeviceCode,
|
||||||
ClientID: d.ClientID,
|
ClientID: d.ClientID,
|
||||||
Scopes: d.Scopes,
|
Scopes: d.Scopes,
|
||||||
PkceVerifier: d.PkceVerifier,
|
|
||||||
Expiry: d.Expiry,
|
Expiry: d.Expiry,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -244,6 +242,8 @@ type DeviceToken struct {
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
Expiry time.Time `json:"expiry"`
|
Expiry time.Time `json:"expiry"`
|
||||||
|
LastRequestTime time.Time `json:"last_request"`
|
||||||
|
PollIntervalSeconds int `json:"poll_interval"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
|
func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
|
||||||
|
@ -252,5 +252,18 @@ func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
|
||||||
Status: t.Status,
|
Status: t.Status,
|
||||||
Token: t.Token,
|
Token: t.Token,
|
||||||
Expiry: t.Expiry,
|
Expiry: t.Expiry,
|
||||||
|
LastRequestTime: t.LastRequestTime,
|
||||||
|
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
|
||||||
|
return storage.DeviceToken{
|
||||||
|
DeviceCode: t.DeviceCode,
|
||||||
|
Status: t.Status,
|
||||||
|
Token: t.Token,
|
||||||
|
Expiry: t.Expiry,
|
||||||
|
LastRequestTime: t.LastRequestTime,
|
||||||
|
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -638,6 +638,14 @@ func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||||
return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d))
|
return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cli *client) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
|
||||||
|
var req DeviceRequest
|
||||||
|
if err := cli.get(resourceDeviceRequest, strings.ToLower(userCode), &req); err != nil {
|
||||||
|
return storage.DeviceRequest{}, err
|
||||||
|
}
|
||||||
|
return toStorageDeviceRequest(req), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
|
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
|
||||||
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
|
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
|
||||||
}
|
}
|
||||||
|
@ -649,3 +657,24 @@ func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error
|
||||||
}
|
}
|
||||||
return toStorageDeviceToken(token), nil
|
return toStorageDeviceToken(token), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cli *client) getDeviceToken(deviceCode string) (t DeviceToken, err error) {
|
||||||
|
err = cli.get(resourceDeviceToken, deviceCode, &t)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *client) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||||
|
r, err := cli.getDeviceToken(deviceCode)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
updated, err := updater(toStorageDeviceToken(r))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
updated.DeviceCode = deviceCode
|
||||||
|
|
||||||
|
newToken := cli.fromStorageDeviceToken(updated)
|
||||||
|
newToken.ObjectMeta = r.ObjectMeta
|
||||||
|
return cli.put(resourceDeviceToken, r.ObjectMeta.Name, newToken)
|
||||||
|
}
|
||||||
|
|
|
@ -675,7 +675,6 @@ type DeviceRequest struct {
|
||||||
DeviceCode string `json:"device_code,omitempty"`
|
DeviceCode string `json:"device_code,omitempty"`
|
||||||
CLientID string `json:"client_id,omitempty"`
|
CLientID string `json:"client_id,omitempty"`
|
||||||
Scopes []string `json:"scopes,omitempty"`
|
Scopes []string `json:"scopes,omitempty"`
|
||||||
PkceVerifier string `json:"pkce_verifier,omitempty"`
|
|
||||||
Expiry time.Time `json:"expiry"`
|
Expiry time.Time `json:"expiry"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -699,12 +698,21 @@ func (cli *client) fromStorageDeviceRequest(a storage.DeviceRequest) DeviceReque
|
||||||
DeviceCode: a.DeviceCode,
|
DeviceCode: a.DeviceCode,
|
||||||
CLientID: a.ClientID,
|
CLientID: a.ClientID,
|
||||||
Scopes: a.Scopes,
|
Scopes: a.Scopes,
|
||||||
PkceVerifier: a.PkceVerifier,
|
|
||||||
Expiry: a.Expiry,
|
Expiry: a.Expiry,
|
||||||
}
|
}
|
||||||
return req
|
return req
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toStorageDeviceRequest(req DeviceRequest) storage.DeviceRequest {
|
||||||
|
return storage.DeviceRequest{
|
||||||
|
UserCode: strings.ToUpper(req.ObjectMeta.Name),
|
||||||
|
DeviceCode: req.DeviceCode,
|
||||||
|
ClientID: req.CLientID,
|
||||||
|
Scopes: req.Scopes,
|
||||||
|
Expiry: req.Expiry,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// DeviceToken is a mirrored struct from storage with JSON struct tags and
|
// DeviceToken is a mirrored struct from storage with JSON struct tags and
|
||||||
// Kubernetes type metadata.
|
// Kubernetes type metadata.
|
||||||
type DeviceToken struct {
|
type DeviceToken struct {
|
||||||
|
@ -714,6 +722,8 @@ type DeviceToken struct {
|
||||||
Status string `json:"status,omitempty"`
|
Status string `json:"status,omitempty"`
|
||||||
Token string `json:"token,omitempty"`
|
Token string `json:"token,omitempty"`
|
||||||
Expiry time.Time `json:"expiry"`
|
Expiry time.Time `json:"expiry"`
|
||||||
|
LastRequestTime time.Time `json:"last_request"`
|
||||||
|
PollIntervalSeconds int `json:"poll_interval"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceTokenList is a list of DeviceTokens.
|
// DeviceTokenList is a list of DeviceTokens.
|
||||||
|
@ -736,6 +746,8 @@ func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
|
||||||
Status: t.Status,
|
Status: t.Status,
|
||||||
Token: t.Token,
|
Token: t.Token,
|
||||||
Expiry: t.Expiry,
|
Expiry: t.Expiry,
|
||||||
|
LastRequestTime: t.LastRequestTime,
|
||||||
|
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||||
}
|
}
|
||||||
return req
|
return req
|
||||||
}
|
}
|
||||||
|
@ -746,5 +758,7 @@ func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
|
||||||
Status: t.Status,
|
Status: t.Status,
|
||||||
Token: t.Token,
|
Token: t.Token,
|
||||||
Expiry: t.Expiry,
|
Expiry: t.Expiry,
|
||||||
|
LastRequestTime: t.LastRequestTime,
|
||||||
|
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -493,6 +493,17 @@ func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *memStorage) GetDeviceRequest(userCode string) (req storage.DeviceRequest, err error) {
|
||||||
|
s.tx(func() {
|
||||||
|
var ok bool
|
||||||
|
if req, ok = s.deviceRequests[userCode]; !ok {
|
||||||
|
err = storage.ErrNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
|
func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
|
||||||
s.tx(func() {
|
s.tx(func() {
|
||||||
if _, ok := s.deviceTokens[t.DeviceCode]; ok {
|
if _, ok := s.deviceTokens[t.DeviceCode]; ok {
|
||||||
|
@ -514,3 +525,17 @@ func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, e
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *memStorage) UpdateDeviceToken(deviceCode string, updater func(p storage.DeviceToken) (storage.DeviceToken, error)) (err error) {
|
||||||
|
s.tx(func() {
|
||||||
|
r, ok := s.deviceTokens[deviceCode]
|
||||||
|
if !ok {
|
||||||
|
err = storage.ErrNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r, err = updater(r); err == nil {
|
||||||
|
s.deviceTokens[deviceCode] = r
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -888,12 +888,12 @@ func (c *conn) delete(table, field, id string) error {
|
||||||
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
|
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||||
_, err := c.Exec(`
|
_, err := c.Exec(`
|
||||||
insert into device_request (
|
insert into device_request (
|
||||||
user_code, device_code, client_id, scopes, pkce_verifier, expiry
|
user_code, device_code, client_id, scopes, expiry
|
||||||
)
|
)
|
||||||
values (
|
values (
|
||||||
$1, $2, $3, $4, $5, $6
|
$1, $2, $3, $4, $5
|
||||||
);`,
|
);`,
|
||||||
d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.PkceVerifier, d.Expiry,
|
d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.Expiry,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if c.alreadyExistsCheck(err) {
|
if c.alreadyExistsCheck(err) {
|
||||||
|
@ -907,12 +907,12 @@ func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||||
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
|
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
|
||||||
_, err := c.Exec(`
|
_, err := c.Exec(`
|
||||||
insert into device_token (
|
insert into device_token (
|
||||||
device_code, status, token, expiry
|
device_code, status, token, expiry, last_request, poll_interval
|
||||||
)
|
)
|
||||||
values (
|
values (
|
||||||
$1, $2, $3, $4
|
$1, $2, $3, $4, $5, $6
|
||||||
);`,
|
);`,
|
||||||
t.DeviceCode, t.Status, t.Token, t.Expiry,
|
t.DeviceCode, t.Status, t.Token, t.Expiry, t.LastRequestTime, t.PollIntervalSeconds,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if c.alreadyExistsCheck(err) {
|
if c.alreadyExistsCheck(err) {
|
||||||
|
@ -923,6 +923,28 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
|
||||||
|
return getDeviceRequest(c, userCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) {
|
||||||
|
err = q.QueryRow(`
|
||||||
|
select
|
||||||
|
device_code, client_id, scopes, expiry
|
||||||
|
from device_request where user_code = $1;
|
||||||
|
`, userCode).Scan(
|
||||||
|
&d.DeviceCode, &d.ClientID, decoder(&d.Scopes), &d.Expiry,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return d, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
return d, fmt.Errorf("select device token: %v", err)
|
||||||
|
}
|
||||||
|
d.UserCode = userCode
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||||
return getDeviceToken(c, deviceCode)
|
return getDeviceToken(c, deviceCode)
|
||||||
}
|
}
|
||||||
|
@ -930,10 +952,10 @@ func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||||
func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
|
func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
|
||||||
err = q.QueryRow(`
|
err = q.QueryRow(`
|
||||||
select
|
select
|
||||||
status, token, expiry
|
status, token, expiry, last_request, poll_interval
|
||||||
from device_token where device_code = $1;
|
from device_token where device_code = $1;
|
||||||
`, deviceCode).Scan(
|
`, deviceCode).Scan(
|
||||||
&a.Status, &a.Token, &a.Expiry,
|
&a.Status, &a.Token, &a.Expiry, &a.LastRequestTime, &a.PollIntervalSeconds,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
|
@ -944,3 +966,31 @@ func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err er
|
||||||
a.DeviceCode = deviceCode
|
a.DeviceCode = deviceCode
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||||
|
return c.ExecTx(func(tx *trans) error {
|
||||||
|
r, err := getDeviceToken(tx, deviceCode)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if r, err = updater(r); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec(`
|
||||||
|
update device_token
|
||||||
|
set
|
||||||
|
status = $1,
|
||||||
|
token = $2,
|
||||||
|
last_request = $3,
|
||||||
|
poll_interval = $4
|
||||||
|
where
|
||||||
|
device_code = $5
|
||||||
|
`,
|
||||||
|
r.Status, r.Token, r.LastRequestTime, r.PollIntervalSeconds, r.DeviceCode,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update device token: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -236,7 +236,6 @@ var migrations = []migration{
|
||||||
device_code text not null,
|
device_code text not null,
|
||||||
client_id text not null,
|
client_id text not null,
|
||||||
scopes bytea not null, -- JSON array of strings
|
scopes bytea not null, -- JSON array of strings
|
||||||
pkce_verifier text not null,
|
|
||||||
expiry timestamptz not null
|
expiry timestamptz not null
|
||||||
);`,
|
);`,
|
||||||
`
|
`
|
||||||
|
@ -244,7 +243,9 @@ var migrations = []migration{
|
||||||
device_code text not null primary key,
|
device_code text not null primary key,
|
||||||
status text not null,
|
status text not null,
|
||||||
token text,
|
token text,
|
||||||
expiry timestamptz not null
|
expiry timestamptz not null,
|
||||||
|
last_request timestamptz not null,
|
||||||
|
poll_interval integer not null
|
||||||
);`,
|
);`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
@ -82,6 +82,7 @@ type Storage interface {
|
||||||
GetPassword(email string) (Password, error)
|
GetPassword(email string) (Password, error)
|
||||||
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
|
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
|
||||||
GetConnector(id string) (Connector, error)
|
GetConnector(id string) (Connector, error)
|
||||||
|
GetDeviceRequest(userCode string) (DeviceRequest, error)
|
||||||
GetDeviceToken(deviceCode string) (DeviceToken, error)
|
GetDeviceToken(deviceCode string) (DeviceToken, error)
|
||||||
|
|
||||||
ListClients() ([]Client, error)
|
ListClients() ([]Client, error)
|
||||||
|
@ -119,6 +120,7 @@ type Storage interface {
|
||||||
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
||||||
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
|
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
|
||||||
UpdateConnector(id string, updater func(c Connector) (Connector, error)) error
|
UpdateConnector(id string, updater func(c Connector) (Connector, error)) error
|
||||||
|
UpdateDeviceToken(deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error
|
||||||
|
|
||||||
// GarbageCollect deletes all expired AuthCodes,AuthRequests, DeviceRequests, and DeviceTokens.
|
// GarbageCollect deletes all expired AuthCodes,AuthRequests, DeviceRequests, and DeviceTokens.
|
||||||
GarbageCollect(now time.Time) (GCResult, error)
|
GarbageCollect(now time.Time) (GCResult, error)
|
||||||
|
@ -392,8 +394,6 @@ type DeviceRequest struct {
|
||||||
ClientID string
|
ClientID string
|
||||||
//The scopes the device requests
|
//The scopes the device requests
|
||||||
Scopes []string
|
Scopes []string
|
||||||
//PKCE Verification
|
|
||||||
PkceVerifier string
|
|
||||||
//The expire time
|
//The expire time
|
||||||
Expiry time.Time
|
Expiry time.Time
|
||||||
}
|
}
|
||||||
|
@ -403,4 +403,6 @@ type DeviceToken struct {
|
||||||
Status string
|
Status string
|
||||||
Token string
|
Token string
|
||||||
Expiry time.Time
|
Expiry time.Time
|
||||||
|
LastRequestTime time.Time
|
||||||
|
PollIntervalSeconds int
|
||||||
}
|
}
|
||||||
|
|
23
web/templates/device.html
Normal file
23
web/templates/device.html
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
{{ template "header.html" . }}
|
||||||
|
|
||||||
|
<div class="theme-panel">
|
||||||
|
<h2 class="theme-heading">Enter User Code</h2>
|
||||||
|
<form method="post" action="{{ .PostURL }}" method="get">
|
||||||
|
<div class="theme-form-row">
|
||||||
|
{{ if( .UserCode )}}
|
||||||
|
<input tabindex="2" required id="user_code" name="user_code" type="text" class="theme-form-input" autocomplete="off" value="{{.UserCode}}" {{ if .Invalid }} autofocus {{ end }}/>
|
||||||
|
{{ else }}
|
||||||
|
<input tabindex="2" required id="user_code" name="user_code" type="text" class="theme-form-input" placeholder="XXXX-XXXX" autocomplete="off" {{ if .Invalid }} autofocus {{ end }}/>
|
||||||
|
{{ end }}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{ if .Invalid }}
|
||||||
|
<div id="login-error" class="dex-error-box">
|
||||||
|
Invalid or Expired User Code
|
||||||
|
</div>
|
||||||
|
{{ end }}
|
||||||
|
<button tabindex="3" id="submit-login" type="submit" class="dex-btn theme-btn--primary">Submit</button>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{ template "footer.html" . }}
|
8
web/templates/device_success.html
Normal file
8
web/templates/device_success.html
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
{{ template "header.html" . }}
|
||||||
|
|
||||||
|
<div class="theme-panel">
|
||||||
|
<h2 class="theme-heading">Login Successful for {{ .ClientName }}</h2>
|
||||||
|
<p>Return to your device to continue</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{ template "footer.html" . }}
|
Loading…
Reference in a new issue