diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 2d71a936..cc8a7273 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -283,6 +283,9 @@ type Expiry struct { // AuthRequests defines the duration of time for which the AuthRequests will be valid. AuthRequests string `json:"authRequests"` + + // DeviceRequests defines the duration of time for which the DeviceRequests will be valid. + DeviceRequests string `json:"deviceRequests"` } // Logger holds configuration required to customize logging for dex. diff --git a/cmd/dex/config_test.go b/cmd/dex/config_test.go index 12c7b218..bced93b5 100644 --- a/cmd/dex/config_test.go +++ b/cmd/dex/config_test.go @@ -119,6 +119,7 @@ expiry: signingKeys: "7h" idTokens: "25h" authRequests: "25h" + deviceRequests: "10m" logger: level: "debug" @@ -197,9 +198,10 @@ logger: }, }, Expiry: Expiry{ - SigningKeys: "7h", - IDTokens: "25h", - AuthRequests: "25h", + SigningKeys: "7h", + IDTokens: "25h", + AuthRequests: "25h", + DeviceRequests: "10m", }, Logger: Logger{ Level: "debug", diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index d0e8f9ac..ca740593 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -269,7 +269,14 @@ func serve(cmd *cobra.Command, args []string) error { logger.Infof("config auth requests valid for: %v", authRequests) serverConfig.AuthRequestsValidFor = authRequests } - + if c.Expiry.DeviceRequests != "" { + deviceRequests, err := time.ParseDuration(c.Expiry.DeviceRequests) + if err != nil { + return fmt.Errorf("invalid config value %q for device request expiry: %v", c.Expiry.AuthRequests, err) + } + logger.Infof("config device requests valid for: %v", deviceRequests) + serverConfig.DeviceRequestsValidFor = deviceRequests + } serv, err := server.NewServer(context.Background(), serverConfig) if err != nil { return fmt.Errorf("failed to initialize server: %v", err) diff --git a/examples/config-dev.yaml b/examples/config-dev.yaml index f7b011ba..9cd81fa6 100644 --- a/examples/config-dev.yaml +++ b/examples/config-dev.yaml @@ -64,6 +64,7 @@ telemetry: # Uncomment this block to enable configuration for the expiration time durations. # expiry: +# deviceRequests: "5m" # signingKeys: "6h" # idTokens: "24h" diff --git a/server/deviceHandlers.go b/server/deviceHandlers.go new file mode 100644 index 00000000..55255408 --- /dev/null +++ b/server/deviceHandlers.go @@ -0,0 +1,359 @@ +package server + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "path" + "strconv" + "strings" + "time" + + "github.com/dexidp/dex/storage" +) + +type deviceCodeResponse struct { + //The unique device code for device authentication + DeviceCode string `json:"device_code"` + //The code the user will exchange via a browser and log in + UserCode string `json:"user_code"` + //The url to verify the user code. + VerificationURI string `json:"verification_uri"` + //The verification uri with the user code appended for pre-filling form + VerificationURIComplete string `json:"verification_uri_complete"` + //The lifetime of the device code + ExpireTime int `json:"expires_in"` + //How often the device is allowed to poll to verify that the user login occurred + PollInterval int `json:"interval"` +} + +func (s *Server) getDeviceAuthURI() string { + return path.Join(s.issuerURL.Path, "/device/auth/verify_code") +} + +func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + userCode := r.URL.Query().Get("user_code") + invalidAttempt, err := strconv.ParseBool(r.URL.Query().Get("invalid")) + if err != nil { + invalidAttempt = false + } + if err := s.templates.device(r, w, s.getDeviceAuthURI(), userCode, invalidAttempt); err != nil { + s.logger.Errorf("Server template error: %v", err) + } + default: + s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") + } +} + +func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { + pollIntervalSeconds := 5 + + switch r.Method { + case http.MethodPost: + err := r.ParseForm() + if err != nil { + s.logger.Errorf("Could not parse Device Request body: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound) + return + } + + //Get the client id and scopes from the post + clientID := r.Form.Get("client_id") + scopes := r.Form["scope"] + + s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes) + + //Make device code + deviceCode := storage.NewDeviceCode() + + //make user code + userCode, err := storage.NewUserCode() + if err != nil { + s.logger.Errorf("Error generating user code: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) + } + + //Generate the expire time + expireTime := time.Now().Add(s.deviceRequestsValidFor) + + //Store the Device Request + deviceReq := storage.DeviceRequest{ + UserCode: userCode, + DeviceCode: deviceCode, + ClientID: clientID, + Scopes: scopes, + Expiry: expireTime, + } + + if err := s.storage.CreateDeviceRequest(deviceReq); err != nil { + s.logger.Errorf("Failed to store device request; %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) + return + } + + //Store the device token + deviceToken := storage.DeviceToken{ + DeviceCode: deviceCode, + Status: deviceTokenPending, + Expiry: expireTime, + LastRequestTime: time.Now(), + PollIntervalSeconds: 5, + } + + if err := s.storage.CreateDeviceToken(deviceToken); err != nil { + s.logger.Errorf("Failed to store device token %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) + return + } + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + s.logger.Errorf("Could not parse issuer URL %v", err) + s.renderError(r, w, http.StatusInternalServerError, "") + return + } + u.Path = path.Join(u.Path, "device") + vURI := u.String() + + q := u.Query() + q.Set("user_code", userCode) + u.RawQuery = q.Encode() + vURIComplete := u.String() + + code := deviceCodeResponse{ + DeviceCode: deviceCode, + UserCode: userCode, + VerificationURI: vURI, + VerificationURIComplete: vURIComplete, + ExpireTime: int(s.deviceRequestsValidFor.Seconds()), + PollInterval: pollIntervalSeconds, + } + + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + enc.Encode(code) + + default: + s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type") + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) + } +} + +func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.Method { + case http.MethodPost: + err := r.ParseForm() + if err != nil { + s.logger.Warnf("Could not parse Device Token Request body: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) + return + } + + deviceCode := r.Form.Get("device_code") + if deviceCode == "" { + s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest) + return + } + + grantType := r.PostFormValue("grant_type") + if grantType != grantTypeDeviceCode { + s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest) + return + } + + now := s.now() + + //Grab the device token + deviceToken, err := s.storage.GetDeviceToken(deviceCode) + if err != nil || now.After(deviceToken.Expiry) { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get device code: %v", err) + } + s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest) + return + } + + //Rate Limiting check + pollInterval := deviceToken.PollIntervalSeconds + minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval)) + if now.Before(minRequestTime) { + s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest) + //Continually increase the poll interval until the user waits the proper time + pollInterval += 5 + } else { + pollInterval = 5 + } + + switch deviceToken.Status { + case deviceTokenPending: + updater := func(old storage.DeviceToken) (storage.DeviceToken, error) { + old.PollIntervalSeconds = pollInterval + old.LastRequestTime = now + return old, nil + } + // Update device token last request time in storage + if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil { + s.logger.Errorf("failed to update device token: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "") + return + } + s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized) + case deviceTokenComplete: + w.Write([]byte(deviceToken.Token)) + } + default: + s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") + } +} + +func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + userCode := r.FormValue("state") + code := r.FormValue("code") + + if userCode == "" || code == "" { + s.renderError(r, w, http.StatusBadRequest, "Request was missing parameters") + return + } + + // Authorization redirect callback from OAuth2 auth flow. + if errMsg := r.FormValue("error"); errMsg != "" { + http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest) + return + } + + authCode, err := s.storage.GetAuthCode(code) + if err != nil || s.now().After(authCode.Expiry) { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get auth code: %v", err) + } + s.renderError(r, w, http.StatusBadRequest, "Invalid or expired auth code.") + return + } + + //Grab the device request from storage + deviceReq, err := s.storage.GetDeviceRequest(userCode) + if err != nil || s.now().After(deviceReq.Expiry) { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get device code: %v", err) + } + s.renderError(r, w, http.StatusInternalServerError, "Invalid or expired device code.") + return + } + + reqClient, err := s.storage.GetClient(deviceReq.ClientID) + if err != nil { + s.logger.Errorf("Failed to get reqClient %q: %v", deviceReq.ClientID, err) + s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve device client.") + return + } + + resp, err := s.exchangeAuthCode(w, authCode, reqClient) + if err != nil { + s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err) + s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.") + return + } + + //Grab the device request from storage + old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode) + if err != nil || s.now().After(old.Expiry) { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get device token: %v", err) + } + s.renderError(r, w, http.StatusInternalServerError, "Invalid or expired device code.") + return + } + + updater := func(old storage.DeviceToken) (storage.DeviceToken, error) { + if old.Status == deviceTokenComplete { + return old, errors.New("device token already complete") + } + respStr, err := json.MarshalIndent(resp, "", " ") + if err != nil { + s.logger.Errorf("failed to marshal device token response: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "") + return old, err + } + + old.Token = string(respStr) + old.Status = deviceTokenComplete + return old, nil + } + + // Update refresh token in the storage, store the token and mark as complete + if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil { + s.logger.Errorf("failed to update device token: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "") + return + } + + if err := s.templates.deviceSuccess(r, w, reqClient.Name); err != nil { + s.logger.Errorf("Server template error: %v", err) + } + + default: + http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest) + return + } +} + +func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + err := r.ParseForm() + if err != nil { + message := "Could not parse user code verification Request body" + s.logger.Warnf("%s : %v", message, err) + s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest) + return + } + + userCode := r.Form.Get("user_code") + if userCode == "" { + s.renderError(r, w, http.StatusBadRequest, "No user code received") + return + } + + userCode = strings.ToUpper(userCode) + + //Find the user code in the available requests + deviceRequest, err := s.storage.GetDeviceRequest(userCode) + if err != nil || s.now().After(deviceRequest.Expiry) { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get device request: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + } + if err := s.templates.device(r, w, s.getDeviceAuthURI(), userCode, true); err != nil { + s.logger.Errorf("Server template error: %v", err) + } + return + } + + //Redirect to Dex Auth Endpoint + authURL := path.Join(s.issuerURL.Path, "/auth") + u, err := url.Parse(authURL) + if err != nil { + s.renderError(r, w, http.StatusInternalServerError, "Invalid auth URI.") + return + } + q := u.Query() + q.Set("client_id", deviceRequest.ClientID) + q.Set("state", deviceRequest.UserCode) + q.Set("response_type", "code") + q.Set("redirect_uri", path.Join(s.issuerURL.Path, "/device/callback")) + q.Set("scope", strings.Join(deviceRequest.Scopes, " ")) + u.RawQuery = q.Encode() + + http.Redirect(w, r, u.String(), http.StatusFound) + + default: + s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") + } +} diff --git a/server/handlers.go b/server/handlers.go index b059b98f..32a81b98 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -147,30 +147,34 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) { } type discovery struct { - Issuer string `json:"issuer"` - Auth string `json:"authorization_endpoint"` - Token string `json:"token_endpoint"` - Keys string `json:"jwks_uri"` - UserInfo string `json:"userinfo_endpoint"` - ResponseTypes []string `json:"response_types_supported"` - Subjects []string `json:"subject_types_supported"` - IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` - Scopes []string `json:"scopes_supported"` - AuthMethods []string `json:"token_endpoint_auth_methods_supported"` - Claims []string `json:"claims_supported"` + Issuer string `json:"issuer"` + Auth string `json:"authorization_endpoint"` + Token string `json:"token_endpoint"` + Keys string `json:"jwks_uri"` + UserInfo string `json:"userinfo_endpoint"` + DeviceEndpoint string `json:"device_authorization_endpoint"` + GrantTypes []string `json:"grant_types_supported"'` + ResponseTypes []string `json:"response_types_supported"` + Subjects []string `json:"subject_types_supported"` + IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` + Scopes []string `json:"scopes_supported"` + AuthMethods []string `json:"token_endpoint_auth_methods_supported"` + Claims []string `json:"claims_supported"` } func (s *Server) discoveryHandler() (http.HandlerFunc, error) { d := discovery{ - Issuer: s.issuerURL.String(), - Auth: s.absURL("/auth"), - Token: s.absURL("/token"), - Keys: s.absURL("/keys"), - UserInfo: s.absURL("/userinfo"), - Subjects: []string{"public"}, - IDTokenAlgs: []string{string(jose.RS256)}, - Scopes: []string{"openid", "email", "groups", "profile", "offline_access"}, - AuthMethods: []string{"client_secret_basic"}, + Issuer: s.issuerURL.String(), + Auth: s.absURL("/auth"), + Token: s.absURL("/token"), + Keys: s.absURL("/keys"), + UserInfo: s.absURL("/userinfo"), + DeviceEndpoint: s.absURL("/device/code"), + Subjects: []string{"public"}, + GrantTypes: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode}, + IDTokenAlgs: []string{string(jose.RS256)}, + Scopes: []string{"openid", "email", "groups", "profile", "offline_access"}, + AuthMethods: []string{"client_secret_basic"}, Claims: []string{ "aud", "email", "email_verified", "exp", "iat", "iss", "locale", "name", "sub", @@ -783,24 +787,33 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s return } + tokenResponse, err := s.exchangeAuthCode(w, authCode, client) + if err != nil { + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + s.writeAccessToken(w, tokenResponse) +} + +func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenReponse, error) { accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID) if err != nil { s.logger.Errorf("failed to create new access token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return + return nil, err } idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return + return nil, err } - if err := s.storage.DeleteAuthCode(code); err != nil { + if err := s.storage.DeleteAuthCode(authCode.ID); err != nil { s.logger.Errorf("failed to delete auth code: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return + return nil, err } reqRefresh := func() bool { @@ -847,13 +860,13 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s if refreshToken, err = internal.Marshal(token); err != nil { s.logger.Errorf("failed to marshal refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return + return nil, err } if err := s.storage.CreateRefresh(refresh); err != nil { s.logger.Errorf("failed to create refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return + return nil, err } // deleteToken determines if we need to delete the newly created refresh token @@ -884,7 +897,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s s.logger.Errorf("failed to get offline session: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true - return + return nil, err } offlineSessions := storage.OfflineSessions{ UserID: refresh.Claims.UserID, @@ -899,7 +912,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s s.logger.Errorf("failed to create offline session: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true - return + return nil, err } } else { if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { @@ -908,7 +921,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s s.logger.Errorf("failed to delete refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true - return + return nil, err } } @@ -920,11 +933,11 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s s.logger.Errorf("failed to update offline session: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true - return + return nil, err } } } - s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry) + return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil } // handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6 @@ -1120,7 +1133,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry) + resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry) + s.writeAccessToken(w, resp) } func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { @@ -1378,12 +1392,26 @@ func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, r RefreshToken string `json:"refresh_token,omitempty"` IDToken string `json:"id_token"` }{ + +type accessTokenReponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + IDToken string `json:"id_token"` +} + +func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string, expiry time.Time) *accessTokenReponse { + return &accessTokenReponse{ accessToken, "bearer", int(expiry.Sub(s.now()).Seconds()), refreshToken, idToken, } +} + +func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenReponse) { data, err := json.Marshal(resp) if err != nil { s.logger.Errorf("failed to marshal access token response: %v", err) @@ -1414,145 +1442,3 @@ func usernamePrompt(conn connector.PasswordConnector) string { } return "Username" } - -type deviceCodeResponse struct { - //The unique device code for device authentication - DeviceCode string `json:"device_code"` - //The code the user will exchange via a browser and log in - UserCode string `json:"user_code"` - //The url to verify the user code. - VerificationURI string `json:"verification_uri"` - //The lifetime of the device code - ExpireTime int `json:"expires_in"` - //How often the device is allowed to poll to verify that the user login occurred - PollInterval int `json:"interval"` -} - -func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { - //TODO replace with configurable values - expireIntervalSeconds := 300 - requestsPerMinute := 5 - - switch r.Method { - case http.MethodPost: - err := r.ParseForm() - if err != nil { - s.logger.Errorf("Could not parse Device Request body: %v", err) - s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound) - return - } - - //Get the client id and scopes from the post - clientID := r.Form.Get("client_id") - scopes := r.Form["scope"] - - s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes) - - //Make device code - deviceCode := storage.NewDeviceCode() - - //make user code - userCode, err := storage.NewUserCode() - if err != nil { - s.logger.Errorf("Error generating user code: %v", err) - s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) - } - - //make a pkce verification code - pkceCode := storage.NewID() - - //Generate the expire time - expireTime := time.Now().Add(time.Second * time.Duration(expireIntervalSeconds)) - - //Store the Device Request - deviceReq := storage.DeviceRequest{ - UserCode: userCode, - DeviceCode: deviceCode, - ClientID: clientID, - Scopes: scopes, - PkceVerifier: pkceCode, - Expiry: expireTime, - } - - if err := s.storage.CreateDeviceRequest(deviceReq); err != nil { - s.logger.Errorf("Failed to store device request; %v", err) - s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) - return - } - - //Store the device token - deviceToken := storage.DeviceToken{ - DeviceCode: deviceCode, - Status: deviceTokenPending, - Expiry: expireTime, - } - - if err := s.storage.CreateDeviceToken(deviceToken); err != nil { - s.logger.Errorf("Failed to store device token %v", err) - s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) - return - } - - code := deviceCodeResponse{ - DeviceCode: deviceCode, - UserCode: userCode, - VerificationURI: path.Join(s.issuerURL.String(), "/device"), - ExpireTime: expireIntervalSeconds, - PollInterval: requestsPerMinute, - } - - enc := json.NewEncoder(w) - enc.SetIndent("", " ") - enc.Encode(code) - - default: - s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type") - s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) - } -} - -func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - switch r.Method { - case http.MethodPost: - err := r.ParseForm() - if err != nil { - message := "Could not parse Device Token Request body" - s.logger.Warnf("%s : %v", message, err) - s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) - return - } - - deviceCode := r.Form.Get("device_code") - if deviceCode == "" { - message := "No device code received" - s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest) - return - } - - grantType := r.PostFormValue("grant_type") - if grantType != grantTypeDeviceCode { - s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest) - return - } - - //Grab the device token from the db - deviceToken, err := s.storage.GetDeviceToken(deviceCode) - if err != nil || s.now().After(deviceToken.Expiry) { - if err != storage.ErrNotFound { - s.logger.Errorf("failed to get device code: %v", err) - } - s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest) - return - } - - switch deviceToken.Status { - case deviceTokenPending: - s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized) - case deviceTokenComplete: - w.Write([]byte(deviceToken.Token)) - } - default: - s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") - } -} diff --git a/server/oauth2.go b/server/oauth2.go index ddeffc3f..59e132bc 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -122,7 +122,7 @@ const ( grantTypeAuthorizationCode = "authorization_code" grantTypeRefreshToken = "refresh_token" grantTypePassword = "password" - grantTypeDeviceCode = "device_code" + grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code" ) const ( @@ -134,6 +134,8 @@ const ( const ( deviceTokenPending = "authorization_pending" deviceTokenComplete = "complete" + deviceTokenSlowDown = "slow_down" + deviceTokenExpired = "expired_token" ) func parseScopes(scopes []string) connector.Scopes { diff --git a/server/server.go b/server/server.go index b86dac04..90e96327 100644 --- a/server/server.go +++ b/server/server.go @@ -75,12 +75,13 @@ type Config struct { // If enabled, the connectors selection page will always be shown even if there's only one AlwaysShowLoginScreen bool - RotateKeysAfter time.Duration // Defaults to 6 hours. - IDTokensValidFor time.Duration // Defaults to 24 hours - AuthRequestsValidFor time.Duration // Defaults to 24 hours + RotateKeysAfter time.Duration // Defaults to 6 hours. + IDTokensValidFor time.Duration // Defaults to 24 hours + AuthRequestsValidFor time.Duration // Defaults to 24 hours + DeviceRequestsValidFor time.Duration // Defaults to 5 minutes // If set, the server will use this connector to handle password grants PasswordConnector string - + GCFrequency time.Duration // Defaults to 5 minutes // If specified, the server will use this function for determining time. @@ -156,8 +157,9 @@ type Server struct { now func() time.Time - idTokensValidFor time.Duration - authRequestsValidFor time.Duration + idTokensValidFor time.Duration + authRequestsValidFor time.Duration + deviceRequestsValidFor time.Duration logger log.Logger } @@ -219,6 +221,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) supportedResponseTypes: supported, idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour), + deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute), skipApproval: c.SkipApprovalScreen, alwaysShowLogin: c.AlwaysShowLoginScreen, now: now, @@ -302,8 +305,11 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleWithCORS("/userinfo", s.handleUserInfo) handleFunc("/auth", s.handleAuthorization) handleFunc("/auth/{connector}", s.handleConnectorLogin) + handleFunc("/device", s.handleDeviceExchange) + handleFunc("/device/auth/verify_code", s.verifyUserCode) handleFunc("/device/code", s.handleDeviceCode) handleFunc("/device/token", s.handleDeviceToken) + handleFunc("/device/callback", s.handleDeviceCallback) r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { // Strip the X-Remote-* headers to prevent security issues on // misconfigured authproxy connector setups. diff --git a/server/templates.go b/server/templates.go index 4947a102..934362b7 100644 --- a/server/templates.go +++ b/server/templates.go @@ -15,11 +15,13 @@ import ( ) const ( - tmplApproval = "approval.html" - tmplLogin = "login.html" - tmplPassword = "password.html" - tmplOOB = "oob.html" - tmplError = "error.html" + tmplApproval = "approval.html" + tmplLogin = "login.html" + tmplPassword = "password.html" + tmplOOB = "oob.html" + tmplError = "error.html" + tmplDevice = "device.html" + tmplDeviceSuccess = "device_success.html" ) var requiredTmpls = []string{ @@ -28,14 +30,17 @@ var requiredTmpls = []string{ tmplPassword, tmplOOB, tmplError, + tmplDevice, } type templates struct { - loginTmpl *template.Template - approvalTmpl *template.Template - passwordTmpl *template.Template - oobTmpl *template.Template - errorTmpl *template.Template + loginTmpl *template.Template + approvalTmpl *template.Template + passwordTmpl *template.Template + oobTmpl *template.Template + errorTmpl *template.Template + deviceTmpl *template.Template + deviceSuccessTmpl *template.Template } type webConfig struct { @@ -152,11 +157,13 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) { return nil, fmt.Errorf("missing template(s): %s", missingTmpls) } return &templates{ - loginTmpl: tmpls.Lookup(tmplLogin), - approvalTmpl: tmpls.Lookup(tmplApproval), - passwordTmpl: tmpls.Lookup(tmplPassword), - oobTmpl: tmpls.Lookup(tmplOOB), - errorTmpl: tmpls.Lookup(tmplError), + loginTmpl: tmpls.Lookup(tmplLogin), + approvalTmpl: tmpls.Lookup(tmplApproval), + passwordTmpl: tmpls.Lookup(tmplPassword), + oobTmpl: tmpls.Lookup(tmplOOB), + errorTmpl: tmpls.Lookup(tmplError), + deviceTmpl: tmpls.Lookup(tmplDevice), + deviceSuccessTmpl: tmpls.Lookup(tmplDeviceSuccess), }, nil } @@ -242,6 +249,24 @@ func (n byName) Len() int { return len(n) } func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name } func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] } +func (t *templates) device(r *http.Request, w http.ResponseWriter, postURL string, userCode string, lastWasInvalid bool) error { + data := struct { + PostURL string + UserCode string + Invalid bool + ReqPath string + }{postURL, userCode, lastWasInvalid, r.URL.Path} + return renderTemplate(w, t.deviceTmpl, data) +} + +func (t *templates) deviceSuccess(r *http.Request, w http.ResponseWriter, clientName string) error { + data := struct { + ClientName string + ReqPath string + }{clientName, r.URL.Path} + return renderTemplate(w, t.deviceSuccessTmpl, data) +} + func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error { sort.Sort(byName(connectors)) data := struct { diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 944d8a78..6edc8350 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -843,12 +843,11 @@ func testGC(t *testing.T, s storage.Storage) { } d := storage.DeviceRequest{ - UserCode: userCode, - DeviceCode: storage.NewID(), - ClientID: "client1", - Scopes: []string{"openid", "email"}, - PkceVerifier: storage.NewID(), - Expiry: expiry, + UserCode: userCode, + DeviceCode: storage.NewID(), + ClientID: "client1", + Scopes: []string{"openid", "email"}, + Expiry: expiry, } if err := s.CreateDeviceRequest(d); err != nil { @@ -970,12 +969,11 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { panic(err) } d1 := storage.DeviceRequest{ - UserCode: userCode, - DeviceCode: storage.NewID(), - ClientID: "client1", - Scopes: []string{"openid", "email"}, - PkceVerifier: storage.NewID(), - Expiry: neverExpire, + UserCode: userCode, + DeviceCode: storage.NewID(), + ClientID: "client1", + Scopes: []string{"openid", "email"}, + Expiry: neverExpire, } if err := s.CreateDeviceRequest(d1); err != nil { @@ -991,20 +989,44 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { } func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { + //Create a Token d1 := storage.DeviceToken{ - DeviceCode: storage.NewID(), - Status: "pending", - Token: storage.NewID(), - Expiry: neverExpire, + DeviceCode: storage.NewID(), + Status: "pending", + Token: storage.NewID(), + Expiry: neverExpire, + LastRequestTime: time.Now(), + PollIntervalSeconds: 0, } if err := s.CreateDeviceToken(d1); err != nil { t.Fatalf("failed creating device token: %v", err) } - // Attempt to create same DeviceRequest twice. + // Attempt to create same Device Token twice. err := s.CreateDeviceToken(d1) mustBeErrAlreadyExists(t, "device token", err) - //TODO Add update / delete tests as functionality is put into main code + //Update the device token, simulate a redemption + if err := s.UpdateDeviceToken(d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) { + old.Token = "token data" + old.Status = "complete" + return old, nil + }); err != nil { + t.Fatalf("failed to update device token: %v", err) + } + + //Retrieve the device token + got, err := s.GetDeviceToken(d1.DeviceCode) + if err != nil { + t.Fatalf("failed to get device token: %v", err) + } + + //Validate expected result set + if got.Status != "complete" { + t.Fatalf("update failed, wanted token status=%#v got %#v", "complete", got.Status) + } + if got.Token != "token data" { + t.Fatalf("update failed, wanted token =%#v got %#v", "token data", got.Token) + } } diff --git a/storage/etcd/etcd.go b/storage/etcd/etcd.go index bbb86651..f41831cd 100644 --- a/storage/etcd/etcd.go +++ b/storage/etcd/etcd.go @@ -570,6 +570,13 @@ func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error { return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d)) } +func (c *conn) GetDeviceRequest(userCode string) (r storage.DeviceRequest, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + err = c.getKey(ctx, keyID(deviceRequestPrefix, userCode), &r) + return r, err +} + func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) { res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix()) if err != nil { @@ -612,3 +619,21 @@ func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken } return deviceTokens, nil } + +func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnUpdate(ctx, keyID(deviceTokenPrefix, deviceCode), func(currentValue []byte) ([]byte, error) { + var current DeviceToken + if len(currentValue) > 0 { + if err := json.Unmarshal(currentValue, ¤t); err != nil { + return nil, err + } + } + updated, err := updater(toStorageDeviceToken(current)) + if err != nil { + return nil, err + } + return json.Marshal(fromStorageDeviceToken(updated)) + }) +} diff --git a/storage/etcd/types.go b/storage/etcd/types.go index ab7bce4c..cc8045c9 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -219,38 +219,51 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { // DeviceRequest is a mirrored struct from storage with JSON struct tags type DeviceRequest struct { - UserCode string `json:"user_code"` - DeviceCode string `json:"device_code"` - ClientID string `json:"client_id"` - Scopes []string `json:"scopes"` - PkceVerifier string `json:"pkce_verifier"` - Expiry time.Time `json:"expiry"` + UserCode string `json:"user_code"` + DeviceCode string `json:"device_code"` + ClientID string `json:"client_id"` + Scopes []string `json:"scopes"` + Expiry time.Time `json:"expiry"` } func fromStorageDeviceRequest(d storage.DeviceRequest) DeviceRequest { return DeviceRequest{ - UserCode: d.UserCode, - DeviceCode: d.DeviceCode, - ClientID: d.ClientID, - Scopes: d.Scopes, - PkceVerifier: d.PkceVerifier, - Expiry: d.Expiry, + UserCode: d.UserCode, + DeviceCode: d.DeviceCode, + ClientID: d.ClientID, + Scopes: d.Scopes, + Expiry: d.Expiry, } } // DeviceToken is a mirrored struct from storage with JSON struct tags type DeviceToken struct { - DeviceCode string `json:"device_code"` - Status string `json:"status"` - Token string `json:"token"` - Expiry time.Time `json:"expiry"` + DeviceCode string `json:"device_code"` + Status string `json:"status"` + Token string `json:"token"` + Expiry time.Time `json:"expiry"` + LastRequestTime time.Time `json:"last_request"` + PollIntervalSeconds int `json:"poll_interval"` } func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken { return DeviceToken{ - DeviceCode: t.DeviceCode, - Status: t.Status, - Token: t.Token, - Expiry: t.Expiry, + DeviceCode: t.DeviceCode, + Status: t.Status, + Token: t.Token, + Expiry: t.Expiry, + LastRequestTime: t.LastRequestTime, + PollIntervalSeconds: t.PollIntervalSeconds, + } +} + +func toStorageDeviceToken(t DeviceToken) storage.DeviceToken { + return storage.DeviceToken{ + DeviceCode: t.DeviceCode, + Status: t.Status, + Token: t.Token, + Expiry: t.Expiry, + LastRequestTime: t.LastRequestTime, + PollIntervalSeconds: t.PollIntervalSeconds, } } diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index 20f9daac..baf1d567 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -638,6 +638,14 @@ func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error { return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d)) } +func (cli *client) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) { + var req DeviceRequest + if err := cli.get(resourceDeviceRequest, strings.ToLower(userCode), &req); err != nil { + return storage.DeviceRequest{}, err + } + return toStorageDeviceRequest(req), nil +} + func (cli *client) CreateDeviceToken(t storage.DeviceToken) error { return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t)) } @@ -649,3 +657,24 @@ func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error } return toStorageDeviceToken(token), nil } + +func (cli *client) getDeviceToken(deviceCode string) (t DeviceToken, err error) { + err = cli.get(resourceDeviceToken, deviceCode, &t) + return +} + +func (cli *client) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { + r, err := cli.getDeviceToken(deviceCode) + if err != nil { + return err + } + updated, err := updater(toStorageDeviceToken(r)) + if err != nil { + return err + } + updated.DeviceCode = deviceCode + + newToken := cli.fromStorageDeviceToken(updated) + newToken.ObjectMeta = r.ObjectMeta + return cli.put(resourceDeviceToken, r.ObjectMeta.Name, newToken) +} diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 66fe5780..61794ccf 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -672,11 +672,10 @@ type DeviceRequest struct { k8sapi.TypeMeta `json:",inline"` k8sapi.ObjectMeta `json:"metadata,omitempty"` - DeviceCode string `json:"device_code,omitempty"` - CLientID string `json:"client_id,omitempty"` - Scopes []string `json:"scopes,omitempty"` - PkceVerifier string `json:"pkce_verifier,omitempty"` - Expiry time.Time `json:"expiry"` + DeviceCode string `json:"device_code,omitempty"` + CLientID string `json:"client_id,omitempty"` + Scopes []string `json:"scopes,omitempty"` + Expiry time.Time `json:"expiry"` } // AuthRequestList is a list of AuthRequests. @@ -696,24 +695,35 @@ func (cli *client) fromStorageDeviceRequest(a storage.DeviceRequest) DeviceReque Name: strings.ToLower(a.UserCode), Namespace: cli.namespace, }, - DeviceCode: a.DeviceCode, - CLientID: a.ClientID, - Scopes: a.Scopes, - PkceVerifier: a.PkceVerifier, - Expiry: a.Expiry, + DeviceCode: a.DeviceCode, + CLientID: a.ClientID, + Scopes: a.Scopes, + Expiry: a.Expiry, } return req } +func toStorageDeviceRequest(req DeviceRequest) storage.DeviceRequest { + return storage.DeviceRequest{ + UserCode: strings.ToUpper(req.ObjectMeta.Name), + DeviceCode: req.DeviceCode, + ClientID: req.CLientID, + Scopes: req.Scopes, + Expiry: req.Expiry, + } +} + // DeviceToken is a mirrored struct from storage with JSON struct tags and // Kubernetes type metadata. type DeviceToken struct { k8sapi.TypeMeta `json:",inline"` k8sapi.ObjectMeta `json:"metadata,omitempty"` - Status string `json:"status,omitempty"` - Token string `json:"token,omitempty"` - Expiry time.Time `json:"expiry"` + Status string `json:"status,omitempty"` + Token string `json:"token,omitempty"` + Expiry time.Time `json:"expiry"` + LastRequestTime time.Time `json:"last_request"` + PollIntervalSeconds int `json:"poll_interval"` } // DeviceTokenList is a list of DeviceTokens. @@ -733,18 +743,22 @@ func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken { Name: t.DeviceCode, Namespace: cli.namespace, }, - Status: t.Status, - Token: t.Token, - Expiry: t.Expiry, + Status: t.Status, + Token: t.Token, + Expiry: t.Expiry, + LastRequestTime: t.LastRequestTime, + PollIntervalSeconds: t.PollIntervalSeconds, } return req } func toStorageDeviceToken(t DeviceToken) storage.DeviceToken { return storage.DeviceToken{ - DeviceCode: t.ObjectMeta.Name, - Status: t.Status, - Token: t.Token, - Expiry: t.Expiry, + DeviceCode: t.ObjectMeta.Name, + Status: t.Status, + Token: t.Token, + Expiry: t.Expiry, + LastRequestTime: t.LastRequestTime, + PollIntervalSeconds: t.PollIntervalSeconds, } } diff --git a/storage/memory/memory.go b/storage/memory/memory.go index 32cfd415..82264205 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -493,6 +493,17 @@ func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) { return } +func (s *memStorage) GetDeviceRequest(userCode string) (req storage.DeviceRequest, err error) { + s.tx(func() { + var ok bool + if req, ok = s.deviceRequests[userCode]; !ok { + err = storage.ErrNotFound + return + } + }) + return +} + func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) { s.tx(func() { if _, ok := s.deviceTokens[t.DeviceCode]; ok { @@ -514,3 +525,17 @@ func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, e }) return } + +func (s *memStorage) UpdateDeviceToken(deviceCode string, updater func(p storage.DeviceToken) (storage.DeviceToken, error)) (err error) { + s.tx(func() { + r, ok := s.deviceTokens[deviceCode] + if !ok { + err = storage.ErrNotFound + return + } + if r, err = updater(r); err == nil { + s.deviceTokens[deviceCode] = r + } + }) + return +} diff --git a/storage/sql/crud.go b/storage/sql/crud.go index c52e67cf..a85c972b 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -888,12 +888,12 @@ func (c *conn) delete(table, field, id string) error { func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error { _, err := c.Exec(` insert into device_request ( - user_code, device_code, client_id, scopes, pkce_verifier, expiry + user_code, device_code, client_id, scopes, expiry ) values ( - $1, $2, $3, $4, $5, $6 + $1, $2, $3, $4, $5 );`, - d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.PkceVerifier, d.Expiry, + d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.Expiry, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -907,12 +907,12 @@ func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error { func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { _, err := c.Exec(` insert into device_token ( - device_code, status, token, expiry + device_code, status, token, expiry, last_request, poll_interval ) values ( - $1, $2, $3, $4 + $1, $2, $3, $4, $5, $6 );`, - t.DeviceCode, t.Status, t.Token, t.Expiry, + t.DeviceCode, t.Status, t.Token, t.Expiry, t.LastRequestTime, t.PollIntervalSeconds, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -923,6 +923,28 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { return nil } +func (c *conn) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) { + return getDeviceRequest(c, userCode) +} + +func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) { + err = q.QueryRow(` + select + device_code, client_id, scopes, expiry + from device_request where user_code = $1; + `, userCode).Scan( + &d.DeviceCode, &d.ClientID, decoder(&d.Scopes), &d.Expiry, + ) + if err != nil { + if err == sql.ErrNoRows { + return d, storage.ErrNotFound + } + return d, fmt.Errorf("select device token: %v", err) + } + d.UserCode = userCode + return d, nil +} + func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) { return getDeviceToken(c, deviceCode) } @@ -930,10 +952,10 @@ func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) { func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) { err = q.QueryRow(` select - status, token, expiry + status, token, expiry, last_request, poll_interval from device_token where device_code = $1; `, deviceCode).Scan( - &a.Status, &a.Token, &a.Expiry, + &a.Status, &a.Token, &a.Expiry, &a.LastRequestTime, &a.PollIntervalSeconds, ) if err != nil { if err == sql.ErrNoRows { @@ -944,3 +966,31 @@ func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err er a.DeviceCode = deviceCode return a, nil } + +func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { + return c.ExecTx(func(tx *trans) error { + r, err := getDeviceToken(tx, deviceCode) + if err != nil { + return err + } + if r, err = updater(r); err != nil { + return err + } + _, err = tx.Exec(` + update device_token + set + status = $1, + token = $2, + last_request = $3, + poll_interval = $4 + where + device_code = $5 + `, + r.Status, r.Token, r.LastRequestTime, r.PollIntervalSeconds, r.DeviceCode, + ) + if err != nil { + return fmt.Errorf("update device token: %v", err) + } + return nil + }) +} diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 96cd6c0a..e399d2b8 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -236,7 +236,6 @@ var migrations = []migration{ device_code text not null, client_id text not null, scopes bytea not null, -- JSON array of strings - pkce_verifier text not null, expiry timestamptz not null );`, ` @@ -244,7 +243,9 @@ var migrations = []migration{ device_code text not null primary key, status text not null, token text, - expiry timestamptz not null + expiry timestamptz not null, + last_request timestamptz not null, + poll_interval integer not null );`, }, }, diff --git a/storage/storage.go b/storage/storage.go index 88ab71cd..005e9190 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -82,6 +82,7 @@ type Storage interface { GetPassword(email string) (Password, error) GetOfflineSessions(userID string, connID string) (OfflineSessions, error) GetConnector(id string) (Connector, error) + GetDeviceRequest(userCode string) (DeviceRequest, error) GetDeviceToken(deviceCode string) (DeviceToken, error) ListClients() ([]Client, error) @@ -119,6 +120,7 @@ type Storage interface { UpdatePassword(email string, updater func(p Password) (Password, error)) error UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error UpdateConnector(id string, updater func(c Connector) (Connector, error)) error + UpdateDeviceToken(deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error // GarbageCollect deletes all expired AuthCodes,AuthRequests, DeviceRequests, and DeviceTokens. GarbageCollect(now time.Time) (GCResult, error) @@ -392,15 +394,15 @@ type DeviceRequest struct { ClientID string //The scopes the device requests Scopes []string - //PKCE Verification - PkceVerifier string //The expire time Expiry time.Time } type DeviceToken struct { - DeviceCode string - Status string - Token string - Expiry time.Time + DeviceCode string + Status string + Token string + Expiry time.Time + LastRequestTime time.Time + PollIntervalSeconds int } diff --git a/web/templates/device.html b/web/templates/device.html new file mode 100644 index 00000000..674cbdc3 --- /dev/null +++ b/web/templates/device.html @@ -0,0 +1,23 @@ +{{ template "header.html" . }} + +
+

Enter User Code

+
+
+ {{ if( .UserCode )}} + + {{ else }} + + {{ end }} +
+ + {{ if .Invalid }} +
+ Invalid or Expired User Code +
+ {{ end }} + +
+
+ +{{ template "footer.html" . }} diff --git a/web/templates/device_success.html b/web/templates/device_success.html new file mode 100644 index 00000000..53b09ce5 --- /dev/null +++ b/web/templates/device_success.html @@ -0,0 +1,8 @@ +{{ template "header.html" . }} + +
+

Login Successful for {{ .ClientName }}

+

Return to your device to continue

+
+ +{{ template "footer.html" . }}