Device flow token code exchange (#2)

* Added /device/token handler with associated business logic and storage tests.

Perform user code exchange, flag the device code as complete.

Moved device handler code into its own file for cleanliness.  Cleanup

* Removed PKCE code

* Rate limiting for /device/token endpoint based on ietf standards

* Configurable Device expiry

Signed-off-by: justin-slowik <justin.slowik@thermofisher.com>
This commit is contained in:
Justin Slowik 2020-01-28 14:14:30 -05:00 committed by justin-slowik
parent 0d1a0e4129
commit 9bbdc721d5
20 changed files with 777 additions and 274 deletions

View file

@ -283,6 +283,9 @@ type Expiry struct {
// AuthRequests defines the duration of time for which the AuthRequests will be valid. // AuthRequests defines the duration of time for which the AuthRequests will be valid.
AuthRequests string `json:"authRequests"` AuthRequests string `json:"authRequests"`
// DeviceRequests defines the duration of time for which the DeviceRequests will be valid.
DeviceRequests string `json:"deviceRequests"`
} }
// Logger holds configuration required to customize logging for dex. // Logger holds configuration required to customize logging for dex.

View file

@ -119,6 +119,7 @@ expiry:
signingKeys: "7h" signingKeys: "7h"
idTokens: "25h" idTokens: "25h"
authRequests: "25h" authRequests: "25h"
deviceRequests: "10m"
logger: logger:
level: "debug" level: "debug"
@ -200,6 +201,7 @@ logger:
SigningKeys: "7h", SigningKeys: "7h",
IDTokens: "25h", IDTokens: "25h",
AuthRequests: "25h", AuthRequests: "25h",
DeviceRequests: "10m",
}, },
Logger: Logger{ Logger: Logger{
Level: "debug", Level: "debug",

View file

@ -269,7 +269,14 @@ func serve(cmd *cobra.Command, args []string) error {
logger.Infof("config auth requests valid for: %v", authRequests) logger.Infof("config auth requests valid for: %v", authRequests)
serverConfig.AuthRequestsValidFor = authRequests serverConfig.AuthRequestsValidFor = authRequests
} }
if c.Expiry.DeviceRequests != "" {
deviceRequests, err := time.ParseDuration(c.Expiry.DeviceRequests)
if err != nil {
return fmt.Errorf("invalid config value %q for device request expiry: %v", c.Expiry.AuthRequests, err)
}
logger.Infof("config device requests valid for: %v", deviceRequests)
serverConfig.DeviceRequestsValidFor = deviceRequests
}
serv, err := server.NewServer(context.Background(), serverConfig) serv, err := server.NewServer(context.Background(), serverConfig)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize server: %v", err) return fmt.Errorf("failed to initialize server: %v", err)

View file

@ -64,6 +64,7 @@ telemetry:
# Uncomment this block to enable configuration for the expiration time durations. # Uncomment this block to enable configuration for the expiration time durations.
# expiry: # expiry:
# deviceRequests: "5m"
# signingKeys: "6h" # signingKeys: "6h"
# idTokens: "24h" # idTokens: "24h"

359
server/deviceHandlers.go Normal file
View file

@ -0,0 +1,359 @@
package server
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/dexidp/dex/storage"
)
type deviceCodeResponse struct {
//The unique device code for device authentication
DeviceCode string `json:"device_code"`
//The code the user will exchange via a browser and log in
UserCode string `json:"user_code"`
//The url to verify the user code.
VerificationURI string `json:"verification_uri"`
//The verification uri with the user code appended for pre-filling form
VerificationURIComplete string `json:"verification_uri_complete"`
//The lifetime of the device code
ExpireTime int `json:"expires_in"`
//How often the device is allowed to poll to verify that the user login occurred
PollInterval int `json:"interval"`
}
func (s *Server) getDeviceAuthURI() string {
return path.Join(s.issuerURL.Path, "/device/auth/verify_code")
}
func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
userCode := r.URL.Query().Get("user_code")
invalidAttempt, err := strconv.ParseBool(r.URL.Query().Get("invalid"))
if err != nil {
invalidAttempt = false
}
if err := s.templates.device(r, w, s.getDeviceAuthURI(), userCode, invalidAttempt); err != nil {
s.logger.Errorf("Server template error: %v", err)
}
default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
}
}
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
pollIntervalSeconds := 5
switch r.Method {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
s.logger.Errorf("Could not parse Device Request body: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
return
}
//Get the client id and scopes from the post
clientID := r.Form.Get("client_id")
scopes := r.Form["scope"]
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
//Make device code
deviceCode := storage.NewDeviceCode()
//make user code
userCode, err := storage.NewUserCode()
if err != nil {
s.logger.Errorf("Error generating user code: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
}
//Generate the expire time
expireTime := time.Now().Add(s.deviceRequestsValidFor)
//Store the Device Request
deviceReq := storage.DeviceRequest{
UserCode: userCode,
DeviceCode: deviceCode,
ClientID: clientID,
Scopes: scopes,
Expiry: expireTime,
}
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
s.logger.Errorf("Failed to store device request; %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}
//Store the device token
deviceToken := storage.DeviceToken{
DeviceCode: deviceCode,
Status: deviceTokenPending,
Expiry: expireTime,
LastRequestTime: time.Now(),
PollIntervalSeconds: 5,
}
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
s.logger.Errorf("Failed to store device token %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}
u, err := url.Parse(s.issuerURL.String())
if err != nil {
s.logger.Errorf("Could not parse issuer URL %v", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return
}
u.Path = path.Join(u.Path, "device")
vURI := u.String()
q := u.Query()
q.Set("user_code", userCode)
u.RawQuery = q.Encode()
vURIComplete := u.String()
code := deviceCodeResponse{
DeviceCode: deviceCode,
UserCode: userCode,
VerificationURI: vURI,
VerificationURIComplete: vURIComplete,
ExpireTime: int(s.deviceRequestsValidFor.Seconds()),
PollInterval: pollIntervalSeconds,
}
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
enc.Encode(code)
default:
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
}
}
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
s.logger.Warnf("Could not parse Device Token Request body: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
return
}
deviceCode := r.Form.Get("device_code")
if deviceCode == "" {
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
return
}
grantType := r.PostFormValue("grant_type")
if grantType != grantTypeDeviceCode {
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
return
}
now := s.now()
//Grab the device token
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
if err != nil || now.After(deviceToken.Expiry) {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get device code: %v", err)
}
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest)
return
}
//Rate Limiting check
pollInterval := deviceToken.PollIntervalSeconds
minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval))
if now.Before(minRequestTime) {
s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest)
//Continually increase the poll interval until the user waits the proper time
pollInterval += 5
} else {
pollInterval = 5
}
switch deviceToken.Status {
case deviceTokenPending:
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
old.PollIntervalSeconds = pollInterval
old.LastRequestTime = now
return old, nil
}
// Update device token last request time in storage
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
s.logger.Errorf("failed to update device token: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return
}
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
case deviceTokenComplete:
w.Write([]byte(deviceToken.Token))
}
default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
}
}
func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
userCode := r.FormValue("state")
code := r.FormValue("code")
if userCode == "" || code == "" {
s.renderError(r, w, http.StatusBadRequest, "Request was missing parameters")
return
}
// Authorization redirect callback from OAuth2 auth flow.
if errMsg := r.FormValue("error"); errMsg != "" {
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
return
}
authCode, err := s.storage.GetAuthCode(code)
if err != nil || s.now().After(authCode.Expiry) {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get auth code: %v", err)
}
s.renderError(r, w, http.StatusBadRequest, "Invalid or expired auth code.")
return
}
//Grab the device request from storage
deviceReq, err := s.storage.GetDeviceRequest(userCode)
if err != nil || s.now().After(deviceReq.Expiry) {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get device code: %v", err)
}
s.renderError(r, w, http.StatusInternalServerError, "Invalid or expired device code.")
return
}
reqClient, err := s.storage.GetClient(deviceReq.ClientID)
if err != nil {
s.logger.Errorf("Failed to get reqClient %q: %v", deviceReq.ClientID, err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve device client.")
return
}
resp, err := s.exchangeAuthCode(w, authCode, reqClient)
if err != nil {
s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
return
}
//Grab the device request from storage
old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode)
if err != nil || s.now().After(old.Expiry) {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get device token: %v", err)
}
s.renderError(r, w, http.StatusInternalServerError, "Invalid or expired device code.")
return
}
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
if old.Status == deviceTokenComplete {
return old, errors.New("device token already complete")
}
respStr, err := json.MarshalIndent(resp, "", " ")
if err != nil {
s.logger.Errorf("failed to marshal device token response: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return old, err
}
old.Token = string(respStr)
old.Status = deviceTokenComplete
return old, nil
}
// Update refresh token in the storage, store the token and mark as complete
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
s.logger.Errorf("failed to update device token: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return
}
if err := s.templates.deviceSuccess(r, w, reqClient.Name); err != nil {
s.logger.Errorf("Server template error: %v", err)
}
default:
http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
return
}
}
func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
message := "Could not parse user code verification Request body"
s.logger.Warnf("%s : %v", message, err)
s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest)
return
}
userCode := r.Form.Get("user_code")
if userCode == "" {
s.renderError(r, w, http.StatusBadRequest, "No user code received")
return
}
userCode = strings.ToUpper(userCode)
//Find the user code in the available requests
deviceRequest, err := s.storage.GetDeviceRequest(userCode)
if err != nil || s.now().After(deviceRequest.Expiry) {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get device request: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
}
if err := s.templates.device(r, w, s.getDeviceAuthURI(), userCode, true); err != nil {
s.logger.Errorf("Server template error: %v", err)
}
return
}
//Redirect to Dex Auth Endpoint
authURL := path.Join(s.issuerURL.Path, "/auth")
u, err := url.Parse(authURL)
if err != nil {
s.renderError(r, w, http.StatusInternalServerError, "Invalid auth URI.")
return
}
q := u.Query()
q.Set("client_id", deviceRequest.ClientID)
q.Set("state", deviceRequest.UserCode)
q.Set("response_type", "code")
q.Set("redirect_uri", path.Join(s.issuerURL.Path, "/device/callback"))
q.Set("scope", strings.Join(deviceRequest.Scopes, " "))
u.RawQuery = q.Encode()
http.Redirect(w, r, u.String(), http.StatusFound)
default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
}
}

View file

@ -152,6 +152,8 @@ type discovery struct {
Token string `json:"token_endpoint"` Token string `json:"token_endpoint"`
Keys string `json:"jwks_uri"` Keys string `json:"jwks_uri"`
UserInfo string `json:"userinfo_endpoint"` UserInfo string `json:"userinfo_endpoint"`
DeviceEndpoint string `json:"device_authorization_endpoint"`
GrantTypes []string `json:"grant_types_supported"'`
ResponseTypes []string `json:"response_types_supported"` ResponseTypes []string `json:"response_types_supported"`
Subjects []string `json:"subject_types_supported"` Subjects []string `json:"subject_types_supported"`
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
@ -167,7 +169,9 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
Token: s.absURL("/token"), Token: s.absURL("/token"),
Keys: s.absURL("/keys"), Keys: s.absURL("/keys"),
UserInfo: s.absURL("/userinfo"), UserInfo: s.absURL("/userinfo"),
DeviceEndpoint: s.absURL("/device/code"),
Subjects: []string{"public"}, Subjects: []string{"public"},
GrantTypes: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode},
IDTokenAlgs: []string{string(jose.RS256)}, IDTokenAlgs: []string{string(jose.RS256)},
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"}, Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
AuthMethods: []string{"client_secret_basic"}, AuthMethods: []string{"client_secret_basic"},
@ -783,24 +787,33 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
return return
} }
tokenResponse, err := s.exchangeAuthCode(w, authCode, client)
if err != nil {
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
s.writeAccessToken(w, tokenResponse)
}
func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenReponse, error) {
accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID) accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("failed to create new access token: %v", err) s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return nil, err
} }
idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID) idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("failed to create ID token: %v", err) s.logger.Errorf("failed to create ID token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return nil, err
} }
if err := s.storage.DeleteAuthCode(code); err != nil { if err := s.storage.DeleteAuthCode(authCode.ID); err != nil {
s.logger.Errorf("failed to delete auth code: %v", err) s.logger.Errorf("failed to delete auth code: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return nil, err
} }
reqRefresh := func() bool { reqRefresh := func() bool {
@ -847,13 +860,13 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
if refreshToken, err = internal.Marshal(token); err != nil { if refreshToken, err = internal.Marshal(token); err != nil {
s.logger.Errorf("failed to marshal refresh token: %v", err) s.logger.Errorf("failed to marshal refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return nil, err
} }
if err := s.storage.CreateRefresh(refresh); err != nil { if err := s.storage.CreateRefresh(refresh); err != nil {
s.logger.Errorf("failed to create refresh token: %v", err) s.logger.Errorf("failed to create refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return nil, err
} }
// deleteToken determines if we need to delete the newly created refresh token // deleteToken determines if we need to delete the newly created refresh token
@ -884,7 +897,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
s.logger.Errorf("failed to get offline session: %v", err) s.logger.Errorf("failed to get offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return return nil, err
} }
offlineSessions := storage.OfflineSessions{ offlineSessions := storage.OfflineSessions{
UserID: refresh.Claims.UserID, UserID: refresh.Claims.UserID,
@ -899,7 +912,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
s.logger.Errorf("failed to create offline session: %v", err) s.logger.Errorf("failed to create offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return return nil, err
} }
} else { } else {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
@ -908,7 +921,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
s.logger.Errorf("failed to delete refresh token: %v", err) s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return return nil, err
} }
} }
@ -920,11 +933,11 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
s.logger.Errorf("failed to update offline session: %v", err) s.logger.Errorf("failed to update offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return return nil, err
} }
} }
} }
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry) return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil
} }
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6 // handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
@ -1120,7 +1133,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return return
} }
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry) resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
s.writeAccessToken(w, resp)
} }
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
@ -1378,12 +1392,26 @@ func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, r
RefreshToken string `json:"refresh_token,omitempty"` RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token"` IDToken string `json:"id_token"`
}{ }{
type accessTokenReponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token"`
}
func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string, expiry time.Time) *accessTokenReponse {
return &accessTokenReponse{
accessToken, accessToken,
"bearer", "bearer",
int(expiry.Sub(s.now()).Seconds()), int(expiry.Sub(s.now()).Seconds()),
refreshToken, refreshToken,
idToken, idToken,
} }
}
func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenReponse) {
data, err := json.Marshal(resp) data, err := json.Marshal(resp)
if err != nil { if err != nil {
s.logger.Errorf("failed to marshal access token response: %v", err) s.logger.Errorf("failed to marshal access token response: %v", err)
@ -1414,145 +1442,3 @@ func usernamePrompt(conn connector.PasswordConnector) string {
} }
return "Username" return "Username"
} }
type deviceCodeResponse struct {
//The unique device code for device authentication
DeviceCode string `json:"device_code"`
//The code the user will exchange via a browser and log in
UserCode string `json:"user_code"`
//The url to verify the user code.
VerificationURI string `json:"verification_uri"`
//The lifetime of the device code
ExpireTime int `json:"expires_in"`
//How often the device is allowed to poll to verify that the user login occurred
PollInterval int `json:"interval"`
}
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
//TODO replace with configurable values
expireIntervalSeconds := 300
requestsPerMinute := 5
switch r.Method {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
s.logger.Errorf("Could not parse Device Request body: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
return
}
//Get the client id and scopes from the post
clientID := r.Form.Get("client_id")
scopes := r.Form["scope"]
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
//Make device code
deviceCode := storage.NewDeviceCode()
//make user code
userCode, err := storage.NewUserCode()
if err != nil {
s.logger.Errorf("Error generating user code: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
}
//make a pkce verification code
pkceCode := storage.NewID()
//Generate the expire time
expireTime := time.Now().Add(time.Second * time.Duration(expireIntervalSeconds))
//Store the Device Request
deviceReq := storage.DeviceRequest{
UserCode: userCode,
DeviceCode: deviceCode,
ClientID: clientID,
Scopes: scopes,
PkceVerifier: pkceCode,
Expiry: expireTime,
}
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
s.logger.Errorf("Failed to store device request; %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}
//Store the device token
deviceToken := storage.DeviceToken{
DeviceCode: deviceCode,
Status: deviceTokenPending,
Expiry: expireTime,
}
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
s.logger.Errorf("Failed to store device token %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}
code := deviceCodeResponse{
DeviceCode: deviceCode,
UserCode: userCode,
VerificationURI: path.Join(s.issuerURL.String(), "/device"),
ExpireTime: expireIntervalSeconds,
PollInterval: requestsPerMinute,
}
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
enc.Encode(code)
default:
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
}
}
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
message := "Could not parse Device Token Request body"
s.logger.Warnf("%s : %v", message, err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
return
}
deviceCode := r.Form.Get("device_code")
if deviceCode == "" {
message := "No device code received"
s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest)
return
}
grantType := r.PostFormValue("grant_type")
if grantType != grantTypeDeviceCode {
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
return
}
//Grab the device token from the db
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
if err != nil || s.now().After(deviceToken.Expiry) {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get device code: %v", err)
}
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest)
return
}
switch deviceToken.Status {
case deviceTokenPending:
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
case deviceTokenComplete:
w.Write([]byte(deviceToken.Token))
}
default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
}
}

View file

@ -122,7 +122,7 @@ const (
grantTypeAuthorizationCode = "authorization_code" grantTypeAuthorizationCode = "authorization_code"
grantTypeRefreshToken = "refresh_token" grantTypeRefreshToken = "refresh_token"
grantTypePassword = "password" grantTypePassword = "password"
grantTypeDeviceCode = "device_code" grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
) )
const ( const (
@ -134,6 +134,8 @@ const (
const ( const (
deviceTokenPending = "authorization_pending" deviceTokenPending = "authorization_pending"
deviceTokenComplete = "complete" deviceTokenComplete = "complete"
deviceTokenSlowDown = "slow_down"
deviceTokenExpired = "expired_token"
) )
func parseScopes(scopes []string) connector.Scopes { func parseScopes(scopes []string) connector.Scopes {

View file

@ -78,6 +78,7 @@ type Config struct {
RotateKeysAfter time.Duration // Defaults to 6 hours. RotateKeysAfter time.Duration // Defaults to 6 hours.
IDTokensValidFor time.Duration // Defaults to 24 hours IDTokensValidFor time.Duration // Defaults to 24 hours
AuthRequestsValidFor time.Duration // Defaults to 24 hours AuthRequestsValidFor time.Duration // Defaults to 24 hours
DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
// If set, the server will use this connector to handle password grants // If set, the server will use this connector to handle password grants
PasswordConnector string PasswordConnector string
@ -158,6 +159,7 @@ type Server struct {
idTokensValidFor time.Duration idTokensValidFor time.Duration
authRequestsValidFor time.Duration authRequestsValidFor time.Duration
deviceRequestsValidFor time.Duration
logger log.Logger logger log.Logger
} }
@ -219,6 +221,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
supportedResponseTypes: supported, supportedResponseTypes: supported,
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour), authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
skipApproval: c.SkipApprovalScreen, skipApproval: c.SkipApprovalScreen,
alwaysShowLogin: c.AlwaysShowLoginScreen, alwaysShowLogin: c.AlwaysShowLoginScreen,
now: now, now: now,
@ -302,8 +305,11 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
handleWithCORS("/userinfo", s.handleUserInfo) handleWithCORS("/userinfo", s.handleUserInfo)
handleFunc("/auth", s.handleAuthorization) handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin) handleFunc("/auth/{connector}", s.handleConnectorLogin)
handleFunc("/device", s.handleDeviceExchange)
handleFunc("/device/auth/verify_code", s.verifyUserCode)
handleFunc("/device/code", s.handleDeviceCode) handleFunc("/device/code", s.handleDeviceCode)
handleFunc("/device/token", s.handleDeviceToken) handleFunc("/device/token", s.handleDeviceToken)
handleFunc("/device/callback", s.handleDeviceCallback)
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
// Strip the X-Remote-* headers to prevent security issues on // Strip the X-Remote-* headers to prevent security issues on
// misconfigured authproxy connector setups. // misconfigured authproxy connector setups.

View file

@ -20,6 +20,8 @@ const (
tmplPassword = "password.html" tmplPassword = "password.html"
tmplOOB = "oob.html" tmplOOB = "oob.html"
tmplError = "error.html" tmplError = "error.html"
tmplDevice = "device.html"
tmplDeviceSuccess = "device_success.html"
) )
var requiredTmpls = []string{ var requiredTmpls = []string{
@ -28,6 +30,7 @@ var requiredTmpls = []string{
tmplPassword, tmplPassword,
tmplOOB, tmplOOB,
tmplError, tmplError,
tmplDevice,
} }
type templates struct { type templates struct {
@ -36,6 +39,8 @@ type templates struct {
passwordTmpl *template.Template passwordTmpl *template.Template
oobTmpl *template.Template oobTmpl *template.Template
errorTmpl *template.Template errorTmpl *template.Template
deviceTmpl *template.Template
deviceSuccessTmpl *template.Template
} }
type webConfig struct { type webConfig struct {
@ -157,6 +162,8 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) {
passwordTmpl: tmpls.Lookup(tmplPassword), passwordTmpl: tmpls.Lookup(tmplPassword),
oobTmpl: tmpls.Lookup(tmplOOB), oobTmpl: tmpls.Lookup(tmplOOB),
errorTmpl: tmpls.Lookup(tmplError), errorTmpl: tmpls.Lookup(tmplError),
deviceTmpl: tmpls.Lookup(tmplDevice),
deviceSuccessTmpl: tmpls.Lookup(tmplDeviceSuccess),
}, nil }, nil
} }
@ -242,6 +249,24 @@ func (n byName) Len() int { return len(n) }
func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name } func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] } func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
func (t *templates) device(r *http.Request, w http.ResponseWriter, postURL string, userCode string, lastWasInvalid bool) error {
data := struct {
PostURL string
UserCode string
Invalid bool
ReqPath string
}{postURL, userCode, lastWasInvalid, r.URL.Path}
return renderTemplate(w, t.deviceTmpl, data)
}
func (t *templates) deviceSuccess(r *http.Request, w http.ResponseWriter, clientName string) error {
data := struct {
ClientName string
ReqPath string
}{clientName, r.URL.Path}
return renderTemplate(w, t.deviceSuccessTmpl, data)
}
func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error { func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error {
sort.Sort(byName(connectors)) sort.Sort(byName(connectors))
data := struct { data := struct {

View file

@ -847,7 +847,6 @@ func testGC(t *testing.T, s storage.Storage) {
DeviceCode: storage.NewID(), DeviceCode: storage.NewID(),
ClientID: "client1", ClientID: "client1",
Scopes: []string{"openid", "email"}, Scopes: []string{"openid", "email"},
PkceVerifier: storage.NewID(),
Expiry: expiry, Expiry: expiry,
} }
@ -974,7 +973,6 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
DeviceCode: storage.NewID(), DeviceCode: storage.NewID(),
ClientID: "client1", ClientID: "client1",
Scopes: []string{"openid", "email"}, Scopes: []string{"openid", "email"},
PkceVerifier: storage.NewID(),
Expiry: neverExpire, Expiry: neverExpire,
} }
@ -991,20 +989,44 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
} }
func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
//Create a Token
d1 := storage.DeviceToken{ d1 := storage.DeviceToken{
DeviceCode: storage.NewID(), DeviceCode: storage.NewID(),
Status: "pending", Status: "pending",
Token: storage.NewID(), Token: storage.NewID(),
Expiry: neverExpire, Expiry: neverExpire,
LastRequestTime: time.Now(),
PollIntervalSeconds: 0,
} }
if err := s.CreateDeviceToken(d1); err != nil { if err := s.CreateDeviceToken(d1); err != nil {
t.Fatalf("failed creating device token: %v", err) t.Fatalf("failed creating device token: %v", err)
} }
// Attempt to create same DeviceRequest twice. // Attempt to create same Device Token twice.
err := s.CreateDeviceToken(d1) err := s.CreateDeviceToken(d1)
mustBeErrAlreadyExists(t, "device token", err) mustBeErrAlreadyExists(t, "device token", err)
//TODO Add update / delete tests as functionality is put into main code //Update the device token, simulate a redemption
if err := s.UpdateDeviceToken(d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) {
old.Token = "token data"
old.Status = "complete"
return old, nil
}); err != nil {
t.Fatalf("failed to update device token: %v", err)
}
//Retrieve the device token
got, err := s.GetDeviceToken(d1.DeviceCode)
if err != nil {
t.Fatalf("failed to get device token: %v", err)
}
//Validate expected result set
if got.Status != "complete" {
t.Fatalf("update failed, wanted token status=%#v got %#v", "complete", got.Status)
}
if got.Token != "token data" {
t.Fatalf("update failed, wanted token =%#v got %#v", "token data", got.Token)
}
} }

View file

@ -570,6 +570,13 @@ func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d)) return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
} }
func (c *conn) GetDeviceRequest(userCode string) (r storage.DeviceRequest, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyID(deviceRequestPrefix, userCode), &r)
return r, err
}
func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) { func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) {
res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix()) res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix())
if err != nil { if err != nil {
@ -612,3 +619,21 @@ func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken
} }
return deviceTokens, nil return deviceTokens, nil
} }
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyID(deviceTokenPrefix, deviceCode), func(currentValue []byte) ([]byte, error) {
var current DeviceToken
if len(currentValue) > 0 {
if err := json.Unmarshal(currentValue, &current); err != nil {
return nil, err
}
}
updated, err := updater(toStorageDeviceToken(current))
if err != nil {
return nil, err
}
return json.Marshal(fromStorageDeviceToken(updated))
})
}

View file

@ -223,7 +223,6 @@ type DeviceRequest struct {
DeviceCode string `json:"device_code"` DeviceCode string `json:"device_code"`
ClientID string `json:"client_id"` ClientID string `json:"client_id"`
Scopes []string `json:"scopes"` Scopes []string `json:"scopes"`
PkceVerifier string `json:"pkce_verifier"`
Expiry time.Time `json:"expiry"` Expiry time.Time `json:"expiry"`
} }
@ -233,7 +232,6 @@ func fromStorageDeviceRequest(d storage.DeviceRequest) DeviceRequest {
DeviceCode: d.DeviceCode, DeviceCode: d.DeviceCode,
ClientID: d.ClientID, ClientID: d.ClientID,
Scopes: d.Scopes, Scopes: d.Scopes,
PkceVerifier: d.PkceVerifier,
Expiry: d.Expiry, Expiry: d.Expiry,
} }
} }
@ -244,6 +242,8 @@ type DeviceToken struct {
Status string `json:"status"` Status string `json:"status"`
Token string `json:"token"` Token string `json:"token"`
Expiry time.Time `json:"expiry"` Expiry time.Time `json:"expiry"`
LastRequestTime time.Time `json:"last_request"`
PollIntervalSeconds int `json:"poll_interval"`
} }
func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken { func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
@ -252,5 +252,18 @@ func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
Status: t.Status, Status: t.Status,
Token: t.Token, Token: t.Token,
Expiry: t.Expiry, Expiry: t.Expiry,
LastRequestTime: t.LastRequestTime,
PollIntervalSeconds: t.PollIntervalSeconds,
}
}
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
return storage.DeviceToken{
DeviceCode: t.DeviceCode,
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
LastRequestTime: t.LastRequestTime,
PollIntervalSeconds: t.PollIntervalSeconds,
} }
} }

View file

@ -638,6 +638,14 @@ func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d)) return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d))
} }
func (cli *client) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
var req DeviceRequest
if err := cli.get(resourceDeviceRequest, strings.ToLower(userCode), &req); err != nil {
return storage.DeviceRequest{}, err
}
return toStorageDeviceRequest(req), nil
}
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error { func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t)) return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
} }
@ -649,3 +657,24 @@ func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error
} }
return toStorageDeviceToken(token), nil return toStorageDeviceToken(token), nil
} }
func (cli *client) getDeviceToken(deviceCode string) (t DeviceToken, err error) {
err = cli.get(resourceDeviceToken, deviceCode, &t)
return
}
func (cli *client) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
r, err := cli.getDeviceToken(deviceCode)
if err != nil {
return err
}
updated, err := updater(toStorageDeviceToken(r))
if err != nil {
return err
}
updated.DeviceCode = deviceCode
newToken := cli.fromStorageDeviceToken(updated)
newToken.ObjectMeta = r.ObjectMeta
return cli.put(resourceDeviceToken, r.ObjectMeta.Name, newToken)
}

View file

@ -675,7 +675,6 @@ type DeviceRequest struct {
DeviceCode string `json:"device_code,omitempty"` DeviceCode string `json:"device_code,omitempty"`
CLientID string `json:"client_id,omitempty"` CLientID string `json:"client_id,omitempty"`
Scopes []string `json:"scopes,omitempty"` Scopes []string `json:"scopes,omitempty"`
PkceVerifier string `json:"pkce_verifier,omitempty"`
Expiry time.Time `json:"expiry"` Expiry time.Time `json:"expiry"`
} }
@ -699,12 +698,21 @@ func (cli *client) fromStorageDeviceRequest(a storage.DeviceRequest) DeviceReque
DeviceCode: a.DeviceCode, DeviceCode: a.DeviceCode,
CLientID: a.ClientID, CLientID: a.ClientID,
Scopes: a.Scopes, Scopes: a.Scopes,
PkceVerifier: a.PkceVerifier,
Expiry: a.Expiry, Expiry: a.Expiry,
} }
return req return req
} }
func toStorageDeviceRequest(req DeviceRequest) storage.DeviceRequest {
return storage.DeviceRequest{
UserCode: strings.ToUpper(req.ObjectMeta.Name),
DeviceCode: req.DeviceCode,
ClientID: req.CLientID,
Scopes: req.Scopes,
Expiry: req.Expiry,
}
}
// DeviceToken is a mirrored struct from storage with JSON struct tags and // DeviceToken is a mirrored struct from storage with JSON struct tags and
// Kubernetes type metadata. // Kubernetes type metadata.
type DeviceToken struct { type DeviceToken struct {
@ -714,6 +722,8 @@ type DeviceToken struct {
Status string `json:"status,omitempty"` Status string `json:"status,omitempty"`
Token string `json:"token,omitempty"` Token string `json:"token,omitempty"`
Expiry time.Time `json:"expiry"` Expiry time.Time `json:"expiry"`
LastRequestTime time.Time `json:"last_request"`
PollIntervalSeconds int `json:"poll_interval"`
} }
// DeviceTokenList is a list of DeviceTokens. // DeviceTokenList is a list of DeviceTokens.
@ -736,6 +746,8 @@ func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
Status: t.Status, Status: t.Status,
Token: t.Token, Token: t.Token,
Expiry: t.Expiry, Expiry: t.Expiry,
LastRequestTime: t.LastRequestTime,
PollIntervalSeconds: t.PollIntervalSeconds,
} }
return req return req
} }
@ -746,5 +758,7 @@ func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
Status: t.Status, Status: t.Status,
Token: t.Token, Token: t.Token,
Expiry: t.Expiry, Expiry: t.Expiry,
LastRequestTime: t.LastRequestTime,
PollIntervalSeconds: t.PollIntervalSeconds,
} }
} }

View file

@ -493,6 +493,17 @@ func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) {
return return
} }
func (s *memStorage) GetDeviceRequest(userCode string) (req storage.DeviceRequest, err error) {
s.tx(func() {
var ok bool
if req, ok = s.deviceRequests[userCode]; !ok {
err = storage.ErrNotFound
return
}
})
return
}
func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) { func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
s.tx(func() { s.tx(func() {
if _, ok := s.deviceTokens[t.DeviceCode]; ok { if _, ok := s.deviceTokens[t.DeviceCode]; ok {
@ -514,3 +525,17 @@ func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, e
}) })
return return
} }
func (s *memStorage) UpdateDeviceToken(deviceCode string, updater func(p storage.DeviceToken) (storage.DeviceToken, error)) (err error) {
s.tx(func() {
r, ok := s.deviceTokens[deviceCode]
if !ok {
err = storage.ErrNotFound
return
}
if r, err = updater(r); err == nil {
s.deviceTokens[deviceCode] = r
}
})
return
}

View file

@ -888,12 +888,12 @@ func (c *conn) delete(table, field, id string) error {
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error { func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
_, err := c.Exec(` _, err := c.Exec(`
insert into device_request ( insert into device_request (
user_code, device_code, client_id, scopes, pkce_verifier, expiry user_code, device_code, client_id, scopes, expiry
) )
values ( values (
$1, $2, $3, $4, $5, $6 $1, $2, $3, $4, $5
);`, );`,
d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.PkceVerifier, d.Expiry, d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.Expiry,
) )
if err != nil { if err != nil {
if c.alreadyExistsCheck(err) { if c.alreadyExistsCheck(err) {
@ -907,12 +907,12 @@ func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
_, err := c.Exec(` _, err := c.Exec(`
insert into device_token ( insert into device_token (
device_code, status, token, expiry device_code, status, token, expiry, last_request, poll_interval
) )
values ( values (
$1, $2, $3, $4 $1, $2, $3, $4, $5, $6
);`, );`,
t.DeviceCode, t.Status, t.Token, t.Expiry, t.DeviceCode, t.Status, t.Token, t.Expiry, t.LastRequestTime, t.PollIntervalSeconds,
) )
if err != nil { if err != nil {
if c.alreadyExistsCheck(err) { if c.alreadyExistsCheck(err) {
@ -923,6 +923,28 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
return nil return nil
} }
func (c *conn) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
return getDeviceRequest(c, userCode)
}
func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) {
err = q.QueryRow(`
select
device_code, client_id, scopes, expiry
from device_request where user_code = $1;
`, userCode).Scan(
&d.DeviceCode, &d.ClientID, decoder(&d.Scopes), &d.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
return d, storage.ErrNotFound
}
return d, fmt.Errorf("select device token: %v", err)
}
d.UserCode = userCode
return d, nil
}
func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) { func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
return getDeviceToken(c, deviceCode) return getDeviceToken(c, deviceCode)
} }
@ -930,10 +952,10 @@ func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) { func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
err = q.QueryRow(` err = q.QueryRow(`
select select
status, token, expiry status, token, expiry, last_request, poll_interval
from device_token where device_code = $1; from device_token where device_code = $1;
`, deviceCode).Scan( `, deviceCode).Scan(
&a.Status, &a.Token, &a.Expiry, &a.Status, &a.Token, &a.Expiry, &a.LastRequestTime, &a.PollIntervalSeconds,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -944,3 +966,31 @@ func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err er
a.DeviceCode = deviceCode a.DeviceCode = deviceCode
return a, nil return a, nil
} }
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getDeviceToken(tx, deviceCode)
if err != nil {
return err
}
if r, err = updater(r); err != nil {
return err
}
_, err = tx.Exec(`
update device_token
set
status = $1,
token = $2,
last_request = $3,
poll_interval = $4
where
device_code = $5
`,
r.Status, r.Token, r.LastRequestTime, r.PollIntervalSeconds, r.DeviceCode,
)
if err != nil {
return fmt.Errorf("update device token: %v", err)
}
return nil
})
}

View file

@ -236,7 +236,6 @@ var migrations = []migration{
device_code text not null, device_code text not null,
client_id text not null, client_id text not null,
scopes bytea not null, -- JSON array of strings scopes bytea not null, -- JSON array of strings
pkce_verifier text not null,
expiry timestamptz not null expiry timestamptz not null
);`, );`,
` `
@ -244,7 +243,9 @@ var migrations = []migration{
device_code text not null primary key, device_code text not null primary key,
status text not null, status text not null,
token text, token text,
expiry timestamptz not null expiry timestamptz not null,
last_request timestamptz not null,
poll_interval integer not null
);`, );`,
}, },
}, },

View file

@ -82,6 +82,7 @@ type Storage interface {
GetPassword(email string) (Password, error) GetPassword(email string) (Password, error)
GetOfflineSessions(userID string, connID string) (OfflineSessions, error) GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
GetConnector(id string) (Connector, error) GetConnector(id string) (Connector, error)
GetDeviceRequest(userCode string) (DeviceRequest, error)
GetDeviceToken(deviceCode string) (DeviceToken, error) GetDeviceToken(deviceCode string) (DeviceToken, error)
ListClients() ([]Client, error) ListClients() ([]Client, error)
@ -119,6 +120,7 @@ type Storage interface {
UpdatePassword(email string, updater func(p Password) (Password, error)) error UpdatePassword(email string, updater func(p Password) (Password, error)) error
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
UpdateConnector(id string, updater func(c Connector) (Connector, error)) error UpdateConnector(id string, updater func(c Connector) (Connector, error)) error
UpdateDeviceToken(deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error
// GarbageCollect deletes all expired AuthCodes,AuthRequests, DeviceRequests, and DeviceTokens. // GarbageCollect deletes all expired AuthCodes,AuthRequests, DeviceRequests, and DeviceTokens.
GarbageCollect(now time.Time) (GCResult, error) GarbageCollect(now time.Time) (GCResult, error)
@ -392,8 +394,6 @@ type DeviceRequest struct {
ClientID string ClientID string
//The scopes the device requests //The scopes the device requests
Scopes []string Scopes []string
//PKCE Verification
PkceVerifier string
//The expire time //The expire time
Expiry time.Time Expiry time.Time
} }
@ -403,4 +403,6 @@ type DeviceToken struct {
Status string Status string
Token string Token string
Expiry time.Time Expiry time.Time
LastRequestTime time.Time
PollIntervalSeconds int
} }

23
web/templates/device.html Normal file
View file

@ -0,0 +1,23 @@
{{ template "header.html" . }}
<div class="theme-panel">
<h2 class="theme-heading">Enter User Code</h2>
<form method="post" action="{{ .PostURL }}" method="get">
<div class="theme-form-row">
{{ if( .UserCode )}}
<input tabindex="2" required id="user_code" name="user_code" type="text" class="theme-form-input" autocomplete="off" value="{{.UserCode}}" {{ if .Invalid }} autofocus {{ end }}/>
{{ else }}
<input tabindex="2" required id="user_code" name="user_code" type="text" class="theme-form-input" placeholder="XXXX-XXXX" autocomplete="off" {{ if .Invalid }} autofocus {{ end }}/>
{{ end }}
</div>
{{ if .Invalid }}
<div id="login-error" class="dex-error-box">
Invalid or Expired User Code
</div>
{{ end }}
<button tabindex="3" id="submit-login" type="submit" class="dex-btn theme-btn--primary">Submit</button>
</form>
</div>
{{ template "footer.html" . }}

View file

@ -0,0 +1,8 @@
{{ template "header.html" . }}
<div class="theme-panel">
<h2 class="theme-heading">Login Successful for {{ .ClientName }}</h2>
<p>Return to your device to continue</p>
</div>
{{ template "footer.html" . }}