forked from mystiq/dex
Merge pull request #1706 from justin-slowik/device_flow
Implementing the OAuth2 Device Authorization Grant
This commit is contained in:
commit
336c73c0a2
26 changed files with 2529 additions and 322 deletions
|
@ -279,6 +279,9 @@ type Expiry struct {
|
|||
|
||||
// AuthRequests defines the duration of time for which the AuthRequests will be valid.
|
||||
AuthRequests string `json:"authRequests"`
|
||||
|
||||
// DeviceRequests defines the duration of time for which the DeviceRequests will be valid.
|
||||
DeviceRequests string `json:"deviceRequests"`
|
||||
}
|
||||
|
||||
// Logger holds configuration required to customize logging for dex.
|
||||
|
|
|
@ -119,6 +119,7 @@ expiry:
|
|||
signingKeys: "7h"
|
||||
idTokens: "25h"
|
||||
authRequests: "25h"
|
||||
deviceRequests: "10m"
|
||||
|
||||
logger:
|
||||
level: "debug"
|
||||
|
@ -200,6 +201,7 @@ logger:
|
|||
SigningKeys: "7h",
|
||||
IDTokens: "25h",
|
||||
AuthRequests: "25h",
|
||||
DeviceRequests: "10m",
|
||||
},
|
||||
Logger: Logger{
|
||||
Level: "debug",
|
||||
|
|
|
@ -269,7 +269,14 @@ func serve(cmd *cobra.Command, args []string) error {
|
|||
logger.Infof("config auth requests valid for: %v", authRequests)
|
||||
serverConfig.AuthRequestsValidFor = authRequests
|
||||
}
|
||||
|
||||
if c.Expiry.DeviceRequests != "" {
|
||||
deviceRequests, err := time.ParseDuration(c.Expiry.DeviceRequests)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid config value %q for device request expiry: %v", c.Expiry.AuthRequests, err)
|
||||
}
|
||||
logger.Infof("config device requests valid for: %v", deviceRequests)
|
||||
serverConfig.DeviceRequestsValidFor = deviceRequests
|
||||
}
|
||||
serv, err := server.NewServer(context.Background(), serverConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize server: %v", err)
|
||||
|
|
|
@ -64,6 +64,7 @@ telemetry:
|
|||
|
||||
# Uncomment this block to enable configuration for the expiration time durations.
|
||||
# expiry:
|
||||
# deviceRequests: "5m"
|
||||
# signingKeys: "6h"
|
||||
# idTokens: "24h"
|
||||
|
||||
|
@ -95,7 +96,11 @@ staticClients:
|
|||
- 'http://127.0.0.1:5555/callback'
|
||||
name: 'Example App'
|
||||
secret: ZXhhbXBsZS1hcHAtc2VjcmV0
|
||||
|
||||
# - id: example-device-client
|
||||
# redirectURIs:
|
||||
# - /device/callback
|
||||
# name: 'Static Client for Device Flow'
|
||||
# public: true
|
||||
connectors:
|
||||
- type: mockCallback
|
||||
id: mock
|
||||
|
|
12
scripts/manifests/crds/devicerequests.yaml
Normal file
12
scripts/manifests/crds/devicerequests.yaml
Normal file
|
@ -0,0 +1,12 @@
|
|||
apiVersion: apiextensions.k8s.io/v1beta1
|
||||
kind: CustomResourceDefinition
|
||||
metadata:
|
||||
name: devicerequests.dex.coreos.com
|
||||
spec:
|
||||
group: dex.coreos.com
|
||||
names:
|
||||
kind: DeviceRequest
|
||||
listKind: DeviceRequestList
|
||||
plural: devicerequests
|
||||
singular: devicerequest
|
||||
version: v1
|
12
scripts/manifests/crds/devicetokens.yaml
Normal file
12
scripts/manifests/crds/devicetokens.yaml
Normal file
|
@ -0,0 +1,12 @@
|
|||
apiVersion: apiextensions.k8s.io/v1beta1
|
||||
kind: CustomResourceDefinition
|
||||
metadata:
|
||||
name: devicetokens.dex.coreos.com
|
||||
spec:
|
||||
group: dex.coreos.com
|
||||
names:
|
||||
kind: DeviceToken
|
||||
listKind: DeviceTokenList
|
||||
plural: devicetokens
|
||||
singular: devicetoken
|
||||
version: v1
|
390
server/deviceflowhandlers.go
Normal file
390
server/deviceflowhandlers.go
Normal file
|
@ -0,0 +1,390 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
type deviceCodeResponse struct {
|
||||
//The unique device code for device authentication
|
||||
DeviceCode string `json:"device_code"`
|
||||
//The code the user will exchange via a browser and log in
|
||||
UserCode string `json:"user_code"`
|
||||
//The url to verify the user code.
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
//The verification uri with the user code appended for pre-filling form
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
//The lifetime of the device code
|
||||
ExpireTime int `json:"expires_in"`
|
||||
//How often the device is allowed to poll to verify that the user login occurred
|
||||
PollInterval int `json:"interval"`
|
||||
}
|
||||
|
||||
func (s *Server) getDeviceVerificationURI() string {
|
||||
return path.Join(s.issuerURL.Path, "/device/auth/verify_code")
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
// Grab the parameter(s) from the query.
|
||||
// If "user_code" is set, pre-populate the user code text field.
|
||||
// If "invalid" is set, set the invalidAttempt boolean, which will display a message to the user that they
|
||||
// attempted to redeem an invalid or expired user code.
|
||||
userCode := r.URL.Query().Get("user_code")
|
||||
invalidAttempt, err := strconv.ParseBool(r.URL.Query().Get("invalid"))
|
||||
if err != nil {
|
||||
invalidAttempt = false
|
||||
}
|
||||
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, invalidAttempt); err != nil {
|
||||
s.logger.Errorf("Server template error: %v", err)
|
||||
s.renderError(r, w, http.StatusNotFound, "Page not found")
|
||||
}
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||
pollIntervalSeconds := 5
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
s.logger.Errorf("Could not parse Device Request body: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
//Get the client id and scopes from the post
|
||||
clientID := r.Form.Get("client_id")
|
||||
clientSecret := r.Form.Get("client_secret")
|
||||
scopes := strings.Fields(r.Form.Get("scope"))
|
||||
|
||||
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
|
||||
|
||||
//Make device code
|
||||
deviceCode := storage.NewDeviceCode()
|
||||
|
||||
//make user code
|
||||
userCode, err := storage.NewUserCode()
|
||||
if err != nil {
|
||||
s.logger.Errorf("Error generating user code: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
//Generate the expire time
|
||||
expireTime := time.Now().Add(s.deviceRequestsValidFor)
|
||||
|
||||
//Store the Device Request
|
||||
deviceReq := storage.DeviceRequest{
|
||||
UserCode: userCode,
|
||||
DeviceCode: deviceCode,
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Scopes: scopes,
|
||||
Expiry: expireTime,
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
|
||||
s.logger.Errorf("Failed to store device request; %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
//Store the device token
|
||||
deviceToken := storage.DeviceToken{
|
||||
DeviceCode: deviceCode,
|
||||
Status: deviceTokenPending,
|
||||
Expiry: expireTime,
|
||||
LastRequestTime: s.now(),
|
||||
PollIntervalSeconds: 0,
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
|
||||
s.logger.Errorf("Failed to store device token %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
u, err := url.Parse(s.issuerURL.String())
|
||||
if err != nil {
|
||||
s.logger.Errorf("Could not parse issuer URL %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
u.Path = path.Join(u.Path, "device")
|
||||
vURI := u.String()
|
||||
|
||||
q := u.Query()
|
||||
q.Set("user_code", userCode)
|
||||
u.RawQuery = q.Encode()
|
||||
vURIComplete := u.String()
|
||||
|
||||
code := deviceCodeResponse{
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
VerificationURI: vURI,
|
||||
VerificationURIComplete: vURIComplete,
|
||||
ExpireTime: int(s.deviceRequestsValidFor.Seconds()),
|
||||
PollInterval: pollIntervalSeconds,
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetEscapeHTML(false)
|
||||
enc.SetIndent("", " ")
|
||||
enc.Encode(code)
|
||||
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
s.logger.Warnf("Could not parse Device Token Request body: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
deviceCode := r.Form.Get("device_code")
|
||||
if deviceCode == "" {
|
||||
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
grantType := r.PostFormValue("grant_type")
|
||||
if grantType != grantTypeDeviceCode {
|
||||
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
now := s.now()
|
||||
|
||||
//Grab the device token, check validity
|
||||
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
||||
if err != nil {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device code: %v", err)
|
||||
}
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest)
|
||||
return
|
||||
} else if now.After(deviceToken.Expiry) {
|
||||
s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
//Rate Limiting check
|
||||
slowDown := false
|
||||
pollInterval := deviceToken.PollIntervalSeconds
|
||||
minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval))
|
||||
if now.Before(minRequestTime) {
|
||||
slowDown = true
|
||||
//Continually increase the poll interval until the user waits the proper time
|
||||
pollInterval += 5
|
||||
} else {
|
||||
pollInterval = 5
|
||||
}
|
||||
|
||||
switch deviceToken.Status {
|
||||
case deviceTokenPending:
|
||||
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||
old.PollIntervalSeconds = pollInterval
|
||||
old.LastRequestTime = now
|
||||
return old, nil
|
||||
}
|
||||
// Update device token last request time in storage
|
||||
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
|
||||
s.logger.Errorf("failed to update device token: %v", err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
if slowDown {
|
||||
s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest)
|
||||
} else {
|
||||
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
||||
}
|
||||
case deviceTokenComplete:
|
||||
w.Write([]byte(deviceToken.Token))
|
||||
}
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
userCode := r.FormValue("state")
|
||||
code := r.FormValue("code")
|
||||
|
||||
if userCode == "" || code == "" {
|
||||
s.renderError(r, w, http.StatusBadRequest, "Request was missing parameters")
|
||||
return
|
||||
}
|
||||
|
||||
// Authorization redirect callback from OAuth2 auth flow.
|
||||
if errMsg := r.FormValue("error"); errMsg != "" {
|
||||
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
authCode, err := s.storage.GetAuthCode(code)
|
||||
if err != nil || s.now().After(authCode.Expiry) {
|
||||
errCode := http.StatusBadRequest
|
||||
if err != nil && err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get auth code: %v", err)
|
||||
errCode = http.StatusInternalServerError
|
||||
}
|
||||
s.renderError(r, w, errCode, "Invalid or expired auth code.")
|
||||
return
|
||||
}
|
||||
|
||||
//Grab the device request from storage
|
||||
deviceReq, err := s.storage.GetDeviceRequest(userCode)
|
||||
if err != nil || s.now().After(deviceReq.Expiry) {
|
||||
errCode := http.StatusBadRequest
|
||||
if err != nil && err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device code: %v", err)
|
||||
errCode = http.StatusInternalServerError
|
||||
}
|
||||
s.renderError(r, w, errCode, "Invalid or expired user code.")
|
||||
return
|
||||
}
|
||||
|
||||
client, err := s.storage.GetClient(deviceReq.ClientID)
|
||||
if err != nil {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get client: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
} else {
|
||||
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
|
||||
}
|
||||
return
|
||||
}
|
||||
if client.Secret != deviceReq.ClientSecret {
|
||||
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := s.exchangeAuthCode(w, authCode, client)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
|
||||
return
|
||||
}
|
||||
|
||||
//Grab the device token from storage
|
||||
old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode)
|
||||
if err != nil || s.now().After(old.Expiry) {
|
||||
errCode := http.StatusBadRequest
|
||||
if err != nil && err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device token: %v", err)
|
||||
errCode = http.StatusInternalServerError
|
||||
}
|
||||
s.renderError(r, w, errCode, "Invalid or expired device code.")
|
||||
return
|
||||
}
|
||||
|
||||
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||
if old.Status == deviceTokenComplete {
|
||||
return old, errors.New("device token already complete")
|
||||
}
|
||||
respStr, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to marshal device token response: %v", err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||
return old, err
|
||||
}
|
||||
|
||||
old.Token = string(respStr)
|
||||
old.Status = deviceTokenComplete
|
||||
return old, nil
|
||||
}
|
||||
|
||||
// Update refresh token in the storage, store the token and mark as complete
|
||||
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
|
||||
s.logger.Errorf("failed to update device token: %v", err)
|
||||
s.renderError(r, w, http.StatusBadRequest, "")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.templates.deviceSuccess(r, w, client.Name); err != nil {
|
||||
s.logger.Errorf("Server template error: %v", err)
|
||||
s.renderError(r, w, http.StatusNotFound, "Page not found")
|
||||
}
|
||||
|
||||
default:
|
||||
http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
s.logger.Warnf("Could not parse user code verification request body : %v", err)
|
||||
s.renderError(r, w, http.StatusBadRequest, "")
|
||||
return
|
||||
}
|
||||
|
||||
userCode := r.Form.Get("user_code")
|
||||
if userCode == "" {
|
||||
s.renderError(r, w, http.StatusBadRequest, "No user code received")
|
||||
return
|
||||
}
|
||||
|
||||
userCode = strings.ToUpper(userCode)
|
||||
|
||||
//Find the user code in the available requests
|
||||
deviceRequest, err := s.storage.GetDeviceRequest(userCode)
|
||||
if err != nil || s.now().After(deviceRequest.Expiry) {
|
||||
if err != nil && err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device request: %v", err)
|
||||
}
|
||||
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, true); err != nil {
|
||||
s.logger.Errorf("Server template error: %v", err)
|
||||
s.renderError(r, w, http.StatusNotFound, "Page not found")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//Redirect to Dex Auth Endpoint
|
||||
authURL := path.Join(s.issuerURL.Path, "/auth")
|
||||
u, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
s.renderError(r, w, http.StatusInternalServerError, "Invalid auth URI.")
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("client_id", deviceRequest.ClientID)
|
||||
q.Set("client_secret", deviceRequest.ClientSecret)
|
||||
q.Set("state", deviceRequest.UserCode)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("redirect_uri", "/device/callback")
|
||||
q.Set("scope", strings.Join(deviceRequest.Scopes, " "))
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(w, r, u.String(), http.StatusFound)
|
||||
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||
}
|
||||
}
|
678
server/deviceflowhandlers_test.go
Normal file
678
server/deviceflowhandlers_test.go
Normal file
|
@ -0,0 +1,678 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
func TestDeviceVerificationURI(t *testing.T) {
|
||||
t0 := time.Now()
|
||||
|
||||
now := func() time.Time { return t0 }
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
// Setup a dex server.
|
||||
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||
c.Issuer = c.Issuer + "/non-root-path"
|
||||
c.Now = now
|
||||
})
|
||||
defer httpServer.Close()
|
||||
|
||||
u, err := url.Parse(s.issuerURL.String())
|
||||
if err != nil {
|
||||
t.Fatalf("Could not parse issuer URL %v", err)
|
||||
}
|
||||
u.Path = path.Join(u.Path, "/device/auth/verify_code")
|
||||
|
||||
uri := s.getDeviceVerificationURI()
|
||||
if uri != u.Path {
|
||||
t.Errorf("Invalid verification URI. Expected %v got %v", u.Path, uri)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleDeviceCode(t *testing.T) {
|
||||
t0 := time.Now()
|
||||
|
||||
now := func() time.Time { return t0 }
|
||||
|
||||
tests := []struct {
|
||||
testName string
|
||||
clientID string
|
||||
requestType string
|
||||
scopes []string
|
||||
expectedResponseCode int
|
||||
expectedServerResponse string
|
||||
}{
|
||||
{
|
||||
testName: "New Code",
|
||||
clientID: "test",
|
||||
requestType: "POST",
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
expectedResponseCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
testName: "Invalid request Type (GET)",
|
||||
clientID: "test",
|
||||
requestType: "GET",
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.testName, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Setup a dex server.
|
||||
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||
c.Issuer = c.Issuer + "/non-root-path"
|
||||
c.Now = now
|
||||
})
|
||||
defer httpServer.Close()
|
||||
|
||||
u, err := url.Parse(s.issuerURL.String())
|
||||
if err != nil {
|
||||
t.Fatalf("Could not parse issuer URL %v", err)
|
||||
}
|
||||
u.Path = path.Join(u.Path, "device/code")
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("client_id", tc.clientID)
|
||||
for _, scope := range tc.scopes {
|
||||
data.Add("scope", scope)
|
||||
}
|
||||
req, _ := http.NewRequest(tc.requestType, u.String(), bytes.NewBufferString(data.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
s.ServeHTTP(rr, req)
|
||||
if rr.Code != tc.expectedResponseCode {
|
||||
t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(rr.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Could read token response %v", err)
|
||||
}
|
||||
if tc.expectedResponseCode == http.StatusOK {
|
||||
var resp deviceCodeResponse
|
||||
if err := json.Unmarshal(body, &resp); err != nil {
|
||||
t.Errorf("Unexpected Device Code Response Format %v", string(body))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceCallback(t *testing.T) {
|
||||
t0 := time.Now()
|
||||
|
||||
now := func() time.Time { return t0 }
|
||||
|
||||
type formValues struct {
|
||||
state string
|
||||
code string
|
||||
error string
|
||||
}
|
||||
|
||||
// Base "Control" test values
|
||||
baseFormValues := formValues{
|
||||
state: "XXXX-XXXX",
|
||||
code: "somecode",
|
||||
}
|
||||
baseAuthCode := storage.AuthCode{
|
||||
ID: "somecode",
|
||||
ClientID: "testclient",
|
||||
RedirectURI: deviceCallbackURI,
|
||||
Nonce: "",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
ConnectorID: "mock",
|
||||
ConnectorData: nil,
|
||||
Claims: storage.Claims{},
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
}
|
||||
baseDeviceRequest := storage.DeviceRequest{
|
||||
UserCode: "XXXX-XXXX",
|
||||
DeviceCode: "devicecode",
|
||||
ClientID: "testclient",
|
||||
ClientSecret: "",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
}
|
||||
baseDeviceToken := storage.DeviceToken{
|
||||
DeviceCode: "devicecode",
|
||||
Status: deviceTokenPending,
|
||||
Token: "",
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
LastRequestTime: time.Time{},
|
||||
PollIntervalSeconds: 0,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
testName string
|
||||
expectedResponseCode int
|
||||
values formValues
|
||||
testAuthCode storage.AuthCode
|
||||
testDeviceRequest storage.DeviceRequest
|
||||
testDeviceToken storage.DeviceToken
|
||||
}{
|
||||
{
|
||||
testName: "Missing State",
|
||||
values: formValues{
|
||||
state: "",
|
||||
code: "somecode",
|
||||
error: "",
|
||||
},
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Missing Code",
|
||||
values: formValues{
|
||||
state: "XXXX-XXXX",
|
||||
code: "",
|
||||
error: "",
|
||||
},
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Error During Authorization",
|
||||
values: formValues{
|
||||
state: "XXXX-XXXX",
|
||||
code: "somecode",
|
||||
error: "Error Condition",
|
||||
},
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Expired Auth Code",
|
||||
values: baseFormValues,
|
||||
testAuthCode: storage.AuthCode{
|
||||
ID: "somecode",
|
||||
ClientID: "testclient",
|
||||
RedirectURI: deviceCallbackURI,
|
||||
Nonce: "",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
ConnectorID: "pic",
|
||||
ConnectorData: nil,
|
||||
Claims: storage.Claims{},
|
||||
Expiry: now().Add(-5 * time.Minute),
|
||||
},
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Invalid Auth Code",
|
||||
values: baseFormValues,
|
||||
testAuthCode: storage.AuthCode{
|
||||
ID: "somecode",
|
||||
ClientID: "testclient",
|
||||
RedirectURI: deviceCallbackURI,
|
||||
Nonce: "",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
ConnectorID: "pic",
|
||||
ConnectorData: nil,
|
||||
Claims: storage.Claims{},
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
},
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Expired Device Request",
|
||||
values: baseFormValues,
|
||||
testAuthCode: baseAuthCode,
|
||||
testDeviceRequest: storage.DeviceRequest{
|
||||
UserCode: "XXXX-XXXX",
|
||||
DeviceCode: "devicecode",
|
||||
ClientID: "testclient",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
Expiry: now().Add(-5 * time.Minute),
|
||||
},
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Non-Existent User Code",
|
||||
values: baseFormValues,
|
||||
testAuthCode: baseAuthCode,
|
||||
testDeviceRequest: storage.DeviceRequest{
|
||||
UserCode: "ZZZZ-ZZZZ",
|
||||
DeviceCode: "devicecode",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
},
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Bad Device Request Client",
|
||||
values: baseFormValues,
|
||||
testAuthCode: baseAuthCode,
|
||||
testDeviceRequest: storage.DeviceRequest{
|
||||
UserCode: "XXXX-XXXX",
|
||||
DeviceCode: "devicecode",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
},
|
||||
expectedResponseCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
testName: "Bad Device Request Secret",
|
||||
values: baseFormValues,
|
||||
testAuthCode: baseAuthCode,
|
||||
testDeviceRequest: storage.DeviceRequest{
|
||||
UserCode: "XXXX-XXXX",
|
||||
DeviceCode: "devicecode",
|
||||
ClientSecret: "foobar",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
},
|
||||
expectedResponseCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
testName: "Expired Device Token",
|
||||
values: baseFormValues,
|
||||
testAuthCode: baseAuthCode,
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
testDeviceToken: storage.DeviceToken{
|
||||
DeviceCode: "devicecode",
|
||||
Status: deviceTokenPending,
|
||||
Token: "",
|
||||
Expiry: now().Add(-5 * time.Minute),
|
||||
LastRequestTime: time.Time{},
|
||||
PollIntervalSeconds: 0,
|
||||
},
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Device Code Already Redeemed",
|
||||
values: baseFormValues,
|
||||
testAuthCode: baseAuthCode,
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
testDeviceToken: storage.DeviceToken{
|
||||
DeviceCode: "devicecode",
|
||||
Status: deviceTokenComplete,
|
||||
Token: "",
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
LastRequestTime: time.Time{},
|
||||
PollIntervalSeconds: 0,
|
||||
},
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Successful Exchange",
|
||||
values: baseFormValues,
|
||||
testAuthCode: baseAuthCode,
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
testDeviceToken: baseDeviceToken,
|
||||
expectedResponseCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.testName, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Setup a dex server.
|
||||
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||
//c.Issuer = c.Issuer + "/non-root-path"
|
||||
c.Now = now
|
||||
})
|
||||
defer httpServer.Close()
|
||||
|
||||
if err := s.storage.CreateAuthCode(tc.testAuthCode); err != nil {
|
||||
t.Fatalf("failed to create auth code: %v", err)
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
|
||||
t.Fatalf("failed to create device request: %v", err)
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil {
|
||||
t.Fatalf("failed to create device token: %v", err)
|
||||
}
|
||||
|
||||
client := storage.Client{
|
||||
ID: "testclient",
|
||||
Secret: "",
|
||||
RedirectURIs: []string{deviceCallbackURI},
|
||||
}
|
||||
if err := s.storage.CreateClient(client); err != nil {
|
||||
t.Fatalf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
u, err := url.Parse(s.issuerURL.String())
|
||||
if err != nil {
|
||||
t.Fatalf("Could not parse issuer URL %v", err)
|
||||
}
|
||||
u.Path = path.Join(u.Path, "device/callback")
|
||||
q := u.Query()
|
||||
q.Set("state", tc.values.state)
|
||||
q.Set("code", tc.values.code)
|
||||
q.Set("error", tc.values.error)
|
||||
u.RawQuery = q.Encode()
|
||||
req, _ := http.NewRequest("GET", u.String(), nil)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
s.ServeHTTP(rr, req)
|
||||
if rr.Code != tc.expectedResponseCode {
|
||||
t.Errorf("%s: Unexpected Response Type. Expected %v got %v", tc.testName, tc.expectedResponseCode, rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceTokenResponse(t *testing.T) {
|
||||
t0 := time.Now()
|
||||
|
||||
now := func() time.Time { return t0 }
|
||||
|
||||
baseDeviceRequest := storage.DeviceRequest{
|
||||
UserCode: "ABCD-WXYZ",
|
||||
DeviceCode: "foo",
|
||||
ClientID: "testclient",
|
||||
Scopes: []string{"openid", "profile", "offline_access"},
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
testName string
|
||||
testDeviceRequest storage.DeviceRequest
|
||||
testDeviceToken storage.DeviceToken
|
||||
testGrantType string
|
||||
testDeviceCode string
|
||||
expectedServerResponse string
|
||||
expectedResponseCode int
|
||||
}{
|
||||
{
|
||||
testName: "Valid but pending token",
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
testDeviceToken: storage.DeviceToken{
|
||||
DeviceCode: "f00bar",
|
||||
Status: deviceTokenPending,
|
||||
Token: "",
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
LastRequestTime: time.Time{},
|
||||
PollIntervalSeconds: 0,
|
||||
},
|
||||
testDeviceCode: "f00bar",
|
||||
expectedServerResponse: deviceTokenPending,
|
||||
expectedResponseCode: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
testName: "Invalid Grant Type",
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
testDeviceToken: storage.DeviceToken{
|
||||
DeviceCode: "f00bar",
|
||||
Status: deviceTokenPending,
|
||||
Token: "",
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
LastRequestTime: time.Time{},
|
||||
PollIntervalSeconds: 0,
|
||||
},
|
||||
testDeviceCode: "f00bar",
|
||||
testGrantType: grantTypeAuthorizationCode,
|
||||
expectedServerResponse: errInvalidGrant,
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Test Slow Down State",
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
testDeviceToken: storage.DeviceToken{
|
||||
DeviceCode: "f00bar",
|
||||
Status: deviceTokenPending,
|
||||
Token: "",
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
LastRequestTime: now(),
|
||||
PollIntervalSeconds: 10,
|
||||
},
|
||||
testDeviceCode: "f00bar",
|
||||
expectedServerResponse: deviceTokenSlowDown,
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Test Expired Device Token",
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
testDeviceToken: storage.DeviceToken{
|
||||
DeviceCode: "f00bar",
|
||||
Status: deviceTokenPending,
|
||||
Token: "",
|
||||
Expiry: now().Add(-5 * time.Minute),
|
||||
LastRequestTime: time.Time{},
|
||||
PollIntervalSeconds: 0,
|
||||
},
|
||||
testDeviceCode: "f00bar",
|
||||
expectedServerResponse: deviceTokenExpired,
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Test Non-existent Device Code",
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
testDeviceToken: storage.DeviceToken{
|
||||
DeviceCode: "foo",
|
||||
Status: deviceTokenPending,
|
||||
Token: "",
|
||||
Expiry: now().Add(-5 * time.Minute),
|
||||
LastRequestTime: time.Time{},
|
||||
PollIntervalSeconds: 0,
|
||||
},
|
||||
testDeviceCode: "bar",
|
||||
expectedServerResponse: errInvalidRequest,
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Empty Device Code in Request",
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
testDeviceToken: storage.DeviceToken{
|
||||
DeviceCode: "bar",
|
||||
Status: deviceTokenPending,
|
||||
Token: "",
|
||||
Expiry: now().Add(-5 * time.Minute),
|
||||
LastRequestTime: time.Time{},
|
||||
PollIntervalSeconds: 0,
|
||||
},
|
||||
testDeviceCode: "",
|
||||
expectedServerResponse: errInvalidRequest,
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Claim validated token from Device Code",
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
testDeviceToken: storage.DeviceToken{
|
||||
DeviceCode: "foo",
|
||||
Status: deviceTokenComplete,
|
||||
Token: "{\"access_token\": \"foobar\"}",
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
LastRequestTime: time.Time{},
|
||||
PollIntervalSeconds: 0,
|
||||
},
|
||||
testDeviceCode: "foo",
|
||||
expectedServerResponse: "{\"access_token\": \"foobar\"}",
|
||||
expectedResponseCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.testName, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Setup a dex server.
|
||||
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||
c.Issuer = c.Issuer + "/non-root-path"
|
||||
c.Now = now
|
||||
})
|
||||
defer httpServer.Close()
|
||||
|
||||
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
|
||||
t.Fatalf("Failed to store device token %v", err)
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil {
|
||||
t.Fatalf("Failed to store device token %v", err)
|
||||
}
|
||||
|
||||
u, err := url.Parse(s.issuerURL.String())
|
||||
if err != nil {
|
||||
t.Fatalf("Could not parse issuer URL %v", err)
|
||||
}
|
||||
u.Path = path.Join(u.Path, "device/token")
|
||||
|
||||
data := url.Values{}
|
||||
grantType := grantTypeDeviceCode
|
||||
if tc.testGrantType != "" {
|
||||
grantType = tc.testGrantType
|
||||
}
|
||||
data.Set("grant_type", grantType)
|
||||
data.Set("device_code", tc.testDeviceCode)
|
||||
req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
s.ServeHTTP(rr, req)
|
||||
if rr.Code != tc.expectedResponseCode {
|
||||
t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(rr.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Could read token response %v", err)
|
||||
}
|
||||
if tc.expectedResponseCode == http.StatusBadRequest || tc.expectedResponseCode == http.StatusUnauthorized {
|
||||
expectJsonErrorResponse(tc.testName, body, tc.expectedServerResponse, t)
|
||||
} else if string(body) != tc.expectedServerResponse {
|
||||
t.Errorf("Unexpected Server Response. Expected %v got %v", tc.expectedServerResponse, string(body))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func expectJsonErrorResponse(testCase string, body []byte, expectedError string, t *testing.T) {
|
||||
jsonMap := make(map[string]interface{})
|
||||
err := json.Unmarshal(body, &jsonMap)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error unmarshalling response: %v", err)
|
||||
}
|
||||
if jsonMap["error"] != expectedError {
|
||||
t.Errorf("Test Case %s expected error %v, received %v", testCase, expectedError, jsonMap["error"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyCodeResponse(t *testing.T) {
|
||||
t0 := time.Now()
|
||||
|
||||
now := func() time.Time { return t0 }
|
||||
|
||||
tests := []struct {
|
||||
testName string
|
||||
testDeviceRequest storage.DeviceRequest
|
||||
userCode string
|
||||
expectedResponseCode int
|
||||
expectedRedirectPath string
|
||||
}{
|
||||
{
|
||||
testName: "Unknown user code",
|
||||
testDeviceRequest: storage.DeviceRequest{
|
||||
UserCode: "ABCD-WXYZ",
|
||||
DeviceCode: "f00bar",
|
||||
ClientID: "testclient",
|
||||
Scopes: []string{"openid", "profile", "offline_access"},
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
},
|
||||
userCode: "CODE-TEST",
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
expectedRedirectPath: "",
|
||||
},
|
||||
{
|
||||
testName: "Expired user code",
|
||||
testDeviceRequest: storage.DeviceRequest{
|
||||
UserCode: "ABCD-WXYZ",
|
||||
DeviceCode: "f00bar",
|
||||
ClientID: "testclient",
|
||||
Scopes: []string{"openid", "profile", "offline_access"},
|
||||
Expiry: now().Add(-5 * time.Minute),
|
||||
},
|
||||
userCode: "ABCD-WXYZ",
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
expectedRedirectPath: "",
|
||||
},
|
||||
{
|
||||
testName: "No user code",
|
||||
testDeviceRequest: storage.DeviceRequest{
|
||||
UserCode: "ABCD-WXYZ",
|
||||
DeviceCode: "f00bar",
|
||||
ClientID: "testclient",
|
||||
Scopes: []string{"openid", "profile", "offline_access"},
|
||||
Expiry: now().Add(-5 * time.Minute),
|
||||
},
|
||||
userCode: "",
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
expectedRedirectPath: "",
|
||||
},
|
||||
{
|
||||
testName: "Valid user code, expect redirect to auth endpoint",
|
||||
testDeviceRequest: storage.DeviceRequest{
|
||||
UserCode: "ABCD-WXYZ",
|
||||
DeviceCode: "f00bar",
|
||||
ClientID: "testclient",
|
||||
Scopes: []string{"openid", "profile", "offline_access"},
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
},
|
||||
userCode: "ABCD-WXYZ",
|
||||
expectedResponseCode: http.StatusFound,
|
||||
expectedRedirectPath: "/auth",
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.testName, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Setup a dex server.
|
||||
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||
c.Issuer = c.Issuer + "/non-root-path"
|
||||
c.Now = now
|
||||
})
|
||||
defer httpServer.Close()
|
||||
|
||||
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
|
||||
t.Fatalf("Failed to store device token %v", err)
|
||||
}
|
||||
|
||||
u, err := url.Parse(s.issuerURL.String())
|
||||
if err != nil {
|
||||
t.Fatalf("Could not parse issuer URL %v", err)
|
||||
}
|
||||
|
||||
u.Path = path.Join(u.Path, "device/auth/verify_code")
|
||||
data := url.Values{}
|
||||
data.Set("user_code", tc.userCode)
|
||||
req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
s.ServeHTTP(rr, req)
|
||||
if rr.Code != tc.expectedResponseCode {
|
||||
t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code)
|
||||
}
|
||||
|
||||
u, err = url.Parse(s.issuerURL.String())
|
||||
if err != nil {
|
||||
t.Errorf("Could not parse issuer URL %v", err)
|
||||
}
|
||||
u.Path = path.Join(u.Path, tc.expectedRedirectPath)
|
||||
|
||||
location := rr.Header().Get("Location")
|
||||
if rr.Code == http.StatusFound && !strings.HasPrefix(location, u.Path) {
|
||||
t.Errorf("Invalid Redirect. Expected %v got %v", u.Path, location)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -153,6 +153,8 @@ type discovery struct {
|
|||
Token string `json:"token_endpoint"`
|
||||
Keys string `json:"jwks_uri"`
|
||||
UserInfo string `json:"userinfo_endpoint"`
|
||||
DeviceEndpoint string `json:"device_authorization_endpoint"`
|
||||
GrantTypes []string `json:"grant_types_supported"`
|
||||
ResponseTypes []string `json:"response_types_supported"`
|
||||
Subjects []string `json:"subject_types_supported"`
|
||||
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
|
||||
|
@ -168,7 +170,9 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
|
|||
Token: s.absURL("/token"),
|
||||
Keys: s.absURL("/keys"),
|
||||
UserInfo: s.absURL("/userinfo"),
|
||||
DeviceEndpoint: s.absURL("/device/code"),
|
||||
Subjects: []string{"public"},
|
||||
GrantTypes: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode},
|
||||
IDTokenAlgs: []string{string(jose.RS256)},
|
||||
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
|
||||
AuthMethods: []string{"client_secret_basic"},
|
||||
|
@ -784,24 +788,33 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||
return
|
||||
}
|
||||
|
||||
tokenResponse, err := s.exchangeAuthCode(w, authCode, client)
|
||||
if err != nil {
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
s.writeAccessToken(w, tokenResponse)
|
||||
}
|
||||
|
||||
func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenReponse, error) {
|
||||
accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to create new access token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to create ID token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.storage.DeleteAuthCode(code); err != nil {
|
||||
if err := s.storage.DeleteAuthCode(authCode.ID); err != nil {
|
||||
s.logger.Errorf("failed to delete auth code: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqRefresh := func() bool {
|
||||
|
@ -848,13 +861,13 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||
if refreshToken, err = internal.Marshal(token); err != nil {
|
||||
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.storage.CreateRefresh(refresh); err != nil {
|
||||
s.logger.Errorf("failed to create refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// deleteToken determines if we need to delete the newly created refresh token
|
||||
|
@ -885,7 +898,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||
s.logger.Errorf("failed to get offline session: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
offlineSessions := storage.OfflineSessions{
|
||||
UserID: refresh.Claims.UserID,
|
||||
|
@ -900,7 +913,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||
s.logger.Errorf("failed to create offline session: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
|
||||
|
@ -909,7 +922,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||
s.logger.Errorf("failed to delete refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -921,11 +934,11 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||
s.logger.Errorf("failed to update offline session: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
|
||||
return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil
|
||||
}
|
||||
|
||||
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
|
||||
|
@ -1121,7 +1134,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||
return
|
||||
}
|
||||
|
||||
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
|
||||
resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
|
||||
s.writeAccessToken(w, resp)
|
||||
}
|
||||
|
||||
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -1368,23 +1382,29 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
|
|||
}
|
||||
}
|
||||
|
||||
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
|
||||
resp := s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry)
|
||||
s.writeAccessToken(w, resp)
|
||||
}
|
||||
|
||||
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
|
||||
resp := struct {
|
||||
type accessTokenReponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
IDToken string `json:"id_token"`
|
||||
}{
|
||||
}
|
||||
|
||||
func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string, expiry time.Time) *accessTokenReponse {
|
||||
return &accessTokenReponse{
|
||||
accessToken,
|
||||
"bearer",
|
||||
int(expiry.Sub(s.now()).Seconds()),
|
||||
refreshToken,
|
||||
idToken,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenReponse) {
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to marshal access token response: %v", err)
|
||||
|
|
|
@ -114,6 +114,10 @@ const (
|
|||
scopeCrossClientPrefix = "audience:server:client_id:"
|
||||
)
|
||||
|
||||
const (
|
||||
deviceCallbackURI = "/device/callback"
|
||||
)
|
||||
|
||||
const (
|
||||
redirectURIOOB = "urn:ietf:wg:oauth:2.0:oob"
|
||||
)
|
||||
|
@ -122,6 +126,7 @@ const (
|
|||
grantTypeAuthorizationCode = "authorization_code"
|
||||
grantTypeRefreshToken = "refresh_token"
|
||||
grantTypePassword = "password"
|
||||
grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -130,6 +135,13 @@ const (
|
|||
responseTypeIDToken = "id_token" // ID Token in url fragment
|
||||
)
|
||||
|
||||
const (
|
||||
deviceTokenPending = "authorization_pending"
|
||||
deviceTokenComplete = "complete"
|
||||
deviceTokenSlowDown = "slow_down"
|
||||
deviceTokenExpired = "expired_token"
|
||||
)
|
||||
|
||||
func parseScopes(scopes []string) connector.Scopes {
|
||||
var s connector.Scopes
|
||||
for _, scope := range scopes {
|
||||
|
@ -425,6 +437,9 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
|
|||
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
|
||||
return nil, &authErr{"", "", errInvalidRequest, description}
|
||||
}
|
||||
if redirectURI == deviceCallbackURI && client.Public {
|
||||
redirectURI = s.issuerURL.Path + deviceCallbackURI
|
||||
}
|
||||
|
||||
// From here on out, we want to redirect back to the client with an error.
|
||||
newErr := func(typ, format string, a ...interface{}) *authErr {
|
||||
|
@ -566,7 +581,7 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
if redirectURI == redirectURIOOB {
|
||||
if redirectURI == redirectURIOOB || redirectURI == deviceCallbackURI {
|
||||
return true
|
||||
}
|
||||
|
||||
|
|
|
@ -78,6 +78,7 @@ type Config struct {
|
|||
RotateKeysAfter time.Duration // Defaults to 6 hours.
|
||||
IDTokensValidFor time.Duration // Defaults to 24 hours
|
||||
AuthRequestsValidFor time.Duration // Defaults to 24 hours
|
||||
DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
|
||||
// If set, the server will use this connector to handle password grants
|
||||
PasswordConnector string
|
||||
|
||||
|
@ -158,6 +159,7 @@ type Server struct {
|
|||
|
||||
idTokensValidFor time.Duration
|
||||
authRequestsValidFor time.Duration
|
||||
deviceRequestsValidFor time.Duration
|
||||
|
||||
logger log.Logger
|
||||
}
|
||||
|
@ -219,6 +221,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
|||
supportedResponseTypes: supported,
|
||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
|
||||
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
|
||||
skipApproval: c.SkipApprovalScreen,
|
||||
alwaysShowLogin: c.AlwaysShowLoginScreen,
|
||||
now: now,
|
||||
|
@ -302,6 +305,11 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
|||
handleWithCORS("/userinfo", s.handleUserInfo)
|
||||
handleFunc("/auth", s.handleAuthorization)
|
||||
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
||||
handleFunc("/device", s.handleDeviceExchange)
|
||||
handleFunc("/device/auth/verify_code", s.verifyUserCode)
|
||||
handleFunc("/device/code", s.handleDeviceCode)
|
||||
handleFunc("/device/token", s.handleDeviceToken)
|
||||
handleFunc(deviceCallbackURI, s.handleDeviceCallback)
|
||||
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
||||
// Strip the X-Remote-* headers to prevent security issues on
|
||||
// misconfigured authproxy connector setups.
|
||||
|
@ -450,7 +458,8 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura
|
|||
if r, err := s.storage.GarbageCollect(now()); err != nil {
|
||||
s.logger.Errorf("garbage collection failed: %v", err)
|
||||
} else if r.AuthRequests > 0 || r.AuthCodes > 0 {
|
||||
s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes)
|
||||
s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d, device requests =%d, device tokens=%d",
|
||||
r.AuthRequests, r.AuthCodes, r.DeviceRequests, r.DeviceTokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,11 +8,13 @@ import (
|
|||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
|
@ -203,41 +205,36 @@ func TestDiscovery(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestOAuth2CodeFlow runs integration tests against a test server. The tests stand up a server
|
||||
// which requires no interaction to login, logs in through a test client, then passes the client
|
||||
// and returned token to the test.
|
||||
func TestOAuth2CodeFlow(t *testing.T) {
|
||||
clientID := "testclient"
|
||||
clientSecret := "testclientsecret"
|
||||
type oauth2Tests struct {
|
||||
clientID string
|
||||
tests []test
|
||||
}
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
// If specified these set of scopes will be used during the test case.
|
||||
scopes []string
|
||||
// handleToken provides the OAuth2 token response for the integration test.
|
||||
handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token, *mock.Callback) error
|
||||
}
|
||||
|
||||
func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time) oauth2Tests {
|
||||
requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"}
|
||||
|
||||
t0 := time.Now()
|
||||
|
||||
// Always have the time function used by the server return the same time so
|
||||
// we can predict expected values of "expires_in" fields exactly.
|
||||
now := func() time.Time { return t0 }
|
||||
|
||||
// Used later when configuring test servers to set how long id_tokens will be valid for.
|
||||
//
|
||||
// The actual value of 30s is completely arbitrary. We just need to set a value
|
||||
// so tests can compute the expected "expires_in" field.
|
||||
idTokensValidFor := time.Second * 30
|
||||
|
||||
// Connector used by the tests.
|
||||
var conn *mock.Callback
|
||||
|
||||
oidcConfig := &oidc.Config{SkipClientIDCheck: true}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
// If specified these set of scopes will be used during the test case.
|
||||
scopes []string
|
||||
// handleToken provides the OAuth2 token response for the integration test.
|
||||
handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token) error
|
||||
}{
|
||||
return oauth2Tests{
|
||||
clientID: clientID,
|
||||
tests: []test{
|
||||
{
|
||||
name: "verify ID Token",
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||
idToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("no id token found")
|
||||
|
@ -250,7 +247,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "fetch userinfo",
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||
ui, err := p.UserInfo(ctx, config.TokenSource(ctx, token))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch userinfo: %v", err)
|
||||
|
@ -263,7 +260,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "verify id token and oauth2 token expiry",
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||
expectedExpiry := now().Add(idTokensValidFor)
|
||||
|
||||
timeEq := func(t1, t2 time.Time, within time.Duration) bool {
|
||||
|
@ -290,7 +287,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "verify at_hash",
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("no id token found")
|
||||
|
@ -322,7 +319,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "refresh token",
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||
// have to use time.Now because the OAuth2 package uses it.
|
||||
token.Expiry = time.Now().Add(time.Second * -10)
|
||||
if token.Valid() {
|
||||
|
@ -345,7 +342,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "refresh with explicit scopes",
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||
v := url.Values{}
|
||||
v.Add("client_id", clientID)
|
||||
v.Add("client_secret", clientSecret)
|
||||
|
@ -369,7 +366,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "refresh with extra spaces",
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||
v := url.Values{}
|
||||
v.Add("client_id", clientID)
|
||||
v.Add("client_secret", clientSecret)
|
||||
|
@ -398,7 +395,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||
{
|
||||
name: "refresh with unauthorized scopes",
|
||||
scopes: []string{"openid", "email"},
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||
v := url.Values{}
|
||||
v.Add("client_id", clientID)
|
||||
v.Add("client_secret", clientSecret)
|
||||
|
@ -425,7 +422,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||
// This test ensures that the connector.RefreshConnector interface is being
|
||||
// used when clients request a refresh token.
|
||||
name: "refresh with identity changes",
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||
// have to use time.Now because the OAuth2 package uses it.
|
||||
token.Expiry = time.Now().Add(time.Second * -10)
|
||||
if token.Valid() {
|
||||
|
@ -472,9 +469,35 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||
return nil
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
// TestOAuth2CodeFlow runs integration tests against a test server. The tests stand up a server
|
||||
// which requires no interaction to login, logs in through a test client, then passes the client
|
||||
// and returned token to the test.
|
||||
func TestOAuth2CodeFlow(t *testing.T) {
|
||||
clientID := "testclient"
|
||||
clientSecret := "testclientsecret"
|
||||
requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"}
|
||||
|
||||
t0 := time.Now()
|
||||
|
||||
// Always have the time function used by the server return the same time so
|
||||
// we can predict expected values of "expires_in" fields exactly.
|
||||
now := func() time.Time { return t0 }
|
||||
|
||||
// Used later when configuring test servers to set how long id_tokens will be valid for.
|
||||
//
|
||||
// The actual value of 30s is completely arbitrary. We just need to set a value
|
||||
// so tests can compute the expected "expires_in" field.
|
||||
idTokensValidFor := time.Second * 30
|
||||
|
||||
// Connector used by the tests.
|
||||
var conn *mock.Callback
|
||||
|
||||
tests := makeOAuth2Tests(clientID, clientSecret, now)
|
||||
for _, tc := range tests.tests {
|
||||
func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
@ -540,7 +563,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||
t.Errorf("failed to exchange code for token: %v", err)
|
||||
return
|
||||
}
|
||||
err = tc.handleToken(ctx, p, oauth2Config, token)
|
||||
err = tc.handleToken(ctx, p, oauth2Config, token, conn)
|
||||
if err != nil {
|
||||
t.Errorf("%s: %v", tc.name, err)
|
||||
}
|
||||
|
@ -1253,3 +1276,157 @@ func TestRefreshTokenFlow(t *testing.T) {
|
|||
t.Errorf("Token refreshed with invalid refresh token, error expected.")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuth2DeviceFlow runs device flow integration tests against a test server
|
||||
func TestOAuth2DeviceFlow(t *testing.T) {
|
||||
clientID := "testclient"
|
||||
clientSecret := ""
|
||||
requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"}
|
||||
|
||||
t0 := time.Now()
|
||||
|
||||
// Always have the time function used by the server return the same time so
|
||||
// we can predict expected values of "expires_in" fields exactly.
|
||||
now := func() time.Time { return t0 }
|
||||
|
||||
// Connector used by the tests.
|
||||
var conn *mock.Callback
|
||||
idTokensValidFor := time.Second * 30
|
||||
|
||||
for _, tc := range makeOAuth2Tests(clientID, clientSecret, now).tests {
|
||||
func() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Setup a dex server.
|
||||
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||
c.Issuer = c.Issuer + "/non-root-path"
|
||||
c.Now = now
|
||||
c.IDTokensValidFor = idTokensValidFor
|
||||
})
|
||||
defer httpServer.Close()
|
||||
|
||||
mockConn := s.connectors["mock"]
|
||||
conn = mockConn.Connector.(*mock.Callback)
|
||||
|
||||
p, err := oidc.NewProvider(ctx, httpServer.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get provider: %v", err)
|
||||
}
|
||||
|
||||
//Add the Clients to the test server
|
||||
client := storage.Client{
|
||||
ID: clientID,
|
||||
RedirectURIs: []string{deviceCallbackURI},
|
||||
Public: true,
|
||||
}
|
||||
if err := s.storage.CreateClient(client); err != nil {
|
||||
t.Fatalf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
//Grab the issuer that we'll reuse for the different endpoints to hit
|
||||
issuer, err := url.Parse(s.issuerURL.String())
|
||||
if err != nil {
|
||||
t.Errorf("Could not parse issuer URL %v", err)
|
||||
}
|
||||
|
||||
//Send a new Device Request
|
||||
codeURL, _ := url.Parse(issuer.String())
|
||||
codeURL.Path = path.Join(codeURL.Path, "device/code")
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("client_id", clientID)
|
||||
data.Add("scope", strings.Join(requestedScopes, " "))
|
||||
resp, err := http.PostForm(codeURL.String(), data)
|
||||
if err != nil {
|
||||
t.Errorf("Could not request device code: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Could read device code response %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||
}
|
||||
|
||||
//Parse the code response
|
||||
var deviceCode deviceCodeResponse
|
||||
if err := json.Unmarshal(responseBody, &deviceCode); err != nil {
|
||||
t.Errorf("Unexpected Device Code Response Format %v", string(responseBody))
|
||||
}
|
||||
|
||||
//Mock the user hitting the verification URI and posting the form
|
||||
verifyURL, _ := url.Parse(issuer.String())
|
||||
verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code")
|
||||
urlData := url.Values{}
|
||||
urlData.Set("user_code", deviceCode.UserCode)
|
||||
resp, err = http.PostForm(verifyURL.String(), urlData)
|
||||
if err != nil {
|
||||
t.Errorf("Error Posting Form: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Could read verification response %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||
}
|
||||
|
||||
//Hit the Token Endpoint, and try and get an access token
|
||||
tokenURL, _ := url.Parse(issuer.String())
|
||||
tokenURL.Path = path.Join(tokenURL.Path, "/device/token")
|
||||
v := url.Values{}
|
||||
v.Add("grant_type", grantTypeDeviceCode)
|
||||
v.Add("device_code", deviceCode.DeviceCode)
|
||||
resp, err = http.PostForm(tokenURL.String(), v)
|
||||
if err != nil {
|
||||
t.Errorf("Could not request device token: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Errorf("Could read device token response %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||
}
|
||||
|
||||
//Parse the response
|
||||
var tokenRes accessTokenReponse
|
||||
if err := json.Unmarshal(responseBody, &tokenRes); err != nil {
|
||||
t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody))
|
||||
}
|
||||
|
||||
token := &oauth2.Token{
|
||||
AccessToken: tokenRes.AccessToken,
|
||||
TokenType: tokenRes.TokenType,
|
||||
RefreshToken: tokenRes.RefreshToken,
|
||||
}
|
||||
raw := make(map[string]interface{})
|
||||
json.Unmarshal(responseBody, &raw) // no error checks for optional fields
|
||||
token = token.WithExtra(raw)
|
||||
if secs := tokenRes.ExpiresIn; secs > 0 {
|
||||
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
|
||||
}
|
||||
|
||||
//Run token tests to validate info is correct
|
||||
// Create the OAuth2 config.
|
||||
oauth2Config := &oauth2.Config{
|
||||
ClientID: client.ID,
|
||||
ClientSecret: client.Secret,
|
||||
Endpoint: p.Endpoint(),
|
||||
Scopes: requestedScopes,
|
||||
RedirectURL: deviceCallbackURI,
|
||||
}
|
||||
if len(tc.scopes) != 0 {
|
||||
oauth2Config.Scopes = tc.scopes
|
||||
}
|
||||
err = tc.handleToken(ctx, p, oauth2Config, token, conn)
|
||||
if err != nil {
|
||||
t.Errorf("%s: %v", tc.name, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,8 @@ const (
|
|||
tmplPassword = "password.html"
|
||||
tmplOOB = "oob.html"
|
||||
tmplError = "error.html"
|
||||
tmplDevice = "device.html"
|
||||
tmplDeviceSuccess = "device_success.html"
|
||||
)
|
||||
|
||||
var requiredTmpls = []string{
|
||||
|
@ -28,6 +30,8 @@ var requiredTmpls = []string{
|
|||
tmplPassword,
|
||||
tmplOOB,
|
||||
tmplError,
|
||||
tmplDevice,
|
||||
tmplDeviceSuccess,
|
||||
}
|
||||
|
||||
type templates struct {
|
||||
|
@ -36,6 +40,8 @@ type templates struct {
|
|||
passwordTmpl *template.Template
|
||||
oobTmpl *template.Template
|
||||
errorTmpl *template.Template
|
||||
deviceTmpl *template.Template
|
||||
deviceSuccessTmpl *template.Template
|
||||
}
|
||||
|
||||
type webConfig struct {
|
||||
|
@ -157,6 +163,8 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) {
|
|||
passwordTmpl: tmpls.Lookup(tmplPassword),
|
||||
oobTmpl: tmpls.Lookup(tmplOOB),
|
||||
errorTmpl: tmpls.Lookup(tmplError),
|
||||
deviceTmpl: tmpls.Lookup(tmplDevice),
|
||||
deviceSuccessTmpl: tmpls.Lookup(tmplDeviceSuccess),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -242,6 +250,27 @@ func (n byName) Len() int { return len(n) }
|
|||
func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
|
||||
func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
|
||||
|
||||
func (t *templates) device(r *http.Request, w http.ResponseWriter, postURL string, userCode string, lastWasInvalid bool) error {
|
||||
if lastWasInvalid {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
data := struct {
|
||||
PostURL string
|
||||
UserCode string
|
||||
Invalid bool
|
||||
ReqPath string
|
||||
}{postURL, userCode, lastWasInvalid, r.URL.Path}
|
||||
return renderTemplate(w, t.deviceTmpl, data)
|
||||
}
|
||||
|
||||
func (t *templates) deviceSuccess(r *http.Request, w http.ResponseWriter, clientName string) error {
|
||||
data := struct {
|
||||
ClientName string
|
||||
ReqPath string
|
||||
}{clientName, r.URL.Path}
|
||||
return renderTemplate(w, t.deviceSuccessTmpl, data)
|
||||
}
|
||||
|
||||
func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error {
|
||||
sort.Sort(byName(connectors))
|
||||
data := struct {
|
||||
|
|
|
@ -49,6 +49,8 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
|
|||
{"ConnectorCRUD", testConnectorCRUD},
|
||||
{"GarbageCollection", testGC},
|
||||
{"TimezoneSupport", testTimezones},
|
||||
{"DeviceRequestCRUD", testDeviceRequestCRUD},
|
||||
{"DeviceTokenCRUD", testDeviceTokenCRUD},
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -834,6 +836,87 @@ func testGC(t *testing.T, s storage.Storage) {
|
|||
} else if err != storage.ErrNotFound {
|
||||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
|
||||
userCode, err := storage.NewUserCode()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected Error: %v", err)
|
||||
}
|
||||
|
||||
d := storage.DeviceRequest{
|
||||
UserCode: userCode,
|
||||
DeviceCode: storage.NewID(),
|
||||
ClientID: "client1",
|
||||
ClientSecret: "secret1",
|
||||
Scopes: []string{"openid", "email"},
|
||||
Expiry: expiry,
|
||||
}
|
||||
|
||||
if err := s.CreateDeviceRequest(d); err != nil {
|
||||
t.Fatalf("failed creating device request: %v", err)
|
||||
}
|
||||
|
||||
for _, tz := range []*time.Location{time.UTC, est, pst} {
|
||||
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
|
||||
if err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
} else {
|
||||
if result.DeviceRequests != 0 {
|
||||
t.Errorf("expected no device garbage collection results, got %#v", result)
|
||||
}
|
||||
}
|
||||
if _, err := s.GetDeviceRequest(d.UserCode); err != nil {
|
||||
t.Errorf("expected to be able to get auth request after GC: %v", err)
|
||||
}
|
||||
}
|
||||
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
} else if r.DeviceRequests != 1 {
|
||||
t.Errorf("expected to garbage collect 1 device request, got %d", r.DeviceRequests)
|
||||
}
|
||||
|
||||
if _, err := s.GetDeviceRequest(d.UserCode); err == nil {
|
||||
t.Errorf("expected device request to be GC'd")
|
||||
} else if err != storage.ErrNotFound {
|
||||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
|
||||
dt := storage.DeviceToken{
|
||||
DeviceCode: storage.NewID(),
|
||||
Status: "pending",
|
||||
Token: "foo",
|
||||
Expiry: expiry,
|
||||
LastRequestTime: time.Now(),
|
||||
PollIntervalSeconds: 0,
|
||||
}
|
||||
|
||||
if err := s.CreateDeviceToken(dt); err != nil {
|
||||
t.Fatalf("failed creating device token: %v", err)
|
||||
}
|
||||
|
||||
for _, tz := range []*time.Location{time.UTC, est, pst} {
|
||||
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
|
||||
if err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
} else {
|
||||
if result.DeviceTokens != 0 {
|
||||
t.Errorf("expected no device token garbage collection results, got %#v", result)
|
||||
}
|
||||
}
|
||||
if _, err := s.GetDeviceToken(dt.DeviceCode); err != nil {
|
||||
t.Errorf("expected to be able to get device token after GC: %v", err)
|
||||
}
|
||||
}
|
||||
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
} else if r.DeviceTokens != 1 {
|
||||
t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens)
|
||||
}
|
||||
|
||||
if _, err := s.GetDeviceToken(dt.DeviceCode); err == nil {
|
||||
t.Errorf("expected device token to be GC'd")
|
||||
} else if err != storage.ErrNotFound {
|
||||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// testTimezones tests that backends either fully support timezones or
|
||||
|
@ -881,3 +964,72 @@ func testTimezones(t *testing.T, s storage.Storage) {
|
|||
t.Fatalf("expected expiry %v got %v", wantTime, gotTime)
|
||||
}
|
||||
}
|
||||
|
||||
func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
|
||||
userCode, err := storage.NewUserCode()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
d1 := storage.DeviceRequest{
|
||||
UserCode: userCode,
|
||||
DeviceCode: storage.NewID(),
|
||||
ClientID: "client1",
|
||||
ClientSecret: "secret1",
|
||||
Scopes: []string{"openid", "email"},
|
||||
Expiry: neverExpire,
|
||||
}
|
||||
|
||||
if err := s.CreateDeviceRequest(d1); err != nil {
|
||||
t.Fatalf("failed creating device request: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to create same DeviceRequest twice.
|
||||
err = s.CreateDeviceRequest(d1)
|
||||
mustBeErrAlreadyExists(t, "device request", err)
|
||||
|
||||
//No manual deletes for device requests, will be handled by garbage collection routines
|
||||
//see testGC
|
||||
}
|
||||
|
||||
func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
|
||||
//Create a Token
|
||||
d1 := storage.DeviceToken{
|
||||
DeviceCode: storage.NewID(),
|
||||
Status: "pending",
|
||||
Token: storage.NewID(),
|
||||
Expiry: neverExpire,
|
||||
LastRequestTime: time.Now(),
|
||||
PollIntervalSeconds: 0,
|
||||
}
|
||||
|
||||
if err := s.CreateDeviceToken(d1); err != nil {
|
||||
t.Fatalf("failed creating device token: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to create same Device Token twice.
|
||||
err := s.CreateDeviceToken(d1)
|
||||
mustBeErrAlreadyExists(t, "device token", err)
|
||||
|
||||
//Update the device token, simulate a redemption
|
||||
if err := s.UpdateDeviceToken(d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||
old.Token = "token data"
|
||||
old.Status = "complete"
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to update device token: %v", err)
|
||||
}
|
||||
|
||||
//Retrieve the device token
|
||||
got, err := s.GetDeviceToken(d1.DeviceCode)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get device token: %v", err)
|
||||
}
|
||||
|
||||
//Validate expected result set
|
||||
if got.Status != "complete" {
|
||||
t.Fatalf("update failed, wanted token status=%v got %v", "complete", got.Status)
|
||||
}
|
||||
if got.Token != "token data" {
|
||||
t.Fatalf("update failed, wanted token %v got %v", "token data", got.Token)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,6 +22,8 @@ const (
|
|||
offlineSessionPrefix = "offline_session/"
|
||||
connectorPrefix = "connector/"
|
||||
keysName = "openid-connect-keys"
|
||||
deviceRequestPrefix = "device_req/"
|
||||
deviceTokenPrefix = "device_token/"
|
||||
|
||||
// defaultStorageTimeout will be applied to all storage's operations.
|
||||
defaultStorageTimeout = 5 * time.Second
|
||||
|
@ -72,6 +74,36 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
|
|||
result.AuthCodes++
|
||||
}
|
||||
}
|
||||
|
||||
deviceRequests, err := c.listDeviceRequests(ctx)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
for _, deviceRequest := range deviceRequests {
|
||||
if now.After(deviceRequest.Expiry) {
|
||||
if err := c.deleteKey(ctx, keyID(deviceRequestPrefix, deviceRequest.UserCode)); err != nil {
|
||||
c.logger.Errorf("failed to delete device request %v", err)
|
||||
delErr = fmt.Errorf("failed to delete device request: %v", err)
|
||||
}
|
||||
result.DeviceRequests++
|
||||
}
|
||||
}
|
||||
|
||||
deviceTokens, err := c.listDeviceTokens(ctx)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
for _, deviceToken := range deviceTokens {
|
||||
if now.After(deviceToken.Expiry) {
|
||||
if err := c.deleteKey(ctx, keyID(deviceTokenPrefix, deviceToken.DeviceCode)); err != nil {
|
||||
c.logger.Errorf("failed to delete device token %v", err)
|
||||
delErr = fmt.Errorf("failed to delete device token: %v", err)
|
||||
}
|
||||
result.DeviceTokens++
|
||||
}
|
||||
}
|
||||
return result, delErr
|
||||
}
|
||||
|
||||
|
@ -531,3 +563,77 @@ func keyEmail(prefix, email string) string { return prefix + strings.ToLower(ema
|
|||
func keySession(prefix, userID, connID string) string {
|
||||
return prefix + strings.ToLower(userID+"|"+connID)
|
||||
}
|
||||
|
||||
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
|
||||
}
|
||||
|
||||
func (c *conn) GetDeviceRequest(userCode string) (r storage.DeviceRequest, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
err = c.getKey(ctx, keyID(deviceRequestPrefix, userCode), &r)
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) {
|
||||
res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return requests, err
|
||||
}
|
||||
for _, v := range res.Kvs {
|
||||
var r DeviceRequest
|
||||
if err = json.Unmarshal(v.Value, &r); err != nil {
|
||||
return requests, err
|
||||
}
|
||||
requests = append(requests, r)
|
||||
}
|
||||
return requests, nil
|
||||
}
|
||||
|
||||
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, keyID(deviceTokenPrefix, t.DeviceCode), fromStorageDeviceToken(t))
|
||||
}
|
||||
|
||||
func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &t)
|
||||
return t, err
|
||||
}
|
||||
|
||||
func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) {
|
||||
res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return deviceTokens, err
|
||||
}
|
||||
for _, v := range res.Kvs {
|
||||
var dt DeviceToken
|
||||
if err = json.Unmarshal(v.Value, &dt); err != nil {
|
||||
return deviceTokens, err
|
||||
}
|
||||
deviceTokens = append(deviceTokens, dt)
|
||||
}
|
||||
return deviceTokens, nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keyID(deviceTokenPrefix, deviceCode), func(currentValue []byte) ([]byte, error) {
|
||||
var current DeviceToken
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(toStorageDeviceToken(current))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(fromStorageDeviceToken(updated))
|
||||
})
|
||||
}
|
||||
|
|
|
@ -44,6 +44,8 @@ func cleanDB(c *conn) error {
|
|||
passwordPrefix,
|
||||
offlineSessionPrefix,
|
||||
connectorPrefix,
|
||||
deviceRequestPrefix,
|
||||
deviceTokenPrefix,
|
||||
} {
|
||||
_, err := c.db.Delete(ctx, prefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
|
|
|
@ -216,3 +216,56 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
|
|||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// DeviceRequest is a mirrored struct from storage with JSON struct tags
|
||||
type DeviceRequest struct {
|
||||
UserCode string `json:"user_code"`
|
||||
DeviceCode string `json:"device_code"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
Scopes []string `json:"scopes"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
}
|
||||
|
||||
func fromStorageDeviceRequest(d storage.DeviceRequest) DeviceRequest {
|
||||
return DeviceRequest{
|
||||
UserCode: d.UserCode,
|
||||
DeviceCode: d.DeviceCode,
|
||||
ClientID: d.ClientID,
|
||||
ClientSecret: d.ClientSecret,
|
||||
Scopes: d.Scopes,
|
||||
Expiry: d.Expiry,
|
||||
}
|
||||
}
|
||||
|
||||
// DeviceToken is a mirrored struct from storage with JSON struct tags
|
||||
type DeviceToken struct {
|
||||
DeviceCode string `json:"device_code"`
|
||||
Status string `json:"status"`
|
||||
Token string `json:"token"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
LastRequestTime time.Time `json:"last_request"`
|
||||
PollIntervalSeconds int `json:"poll_interval"`
|
||||
}
|
||||
|
||||
func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
|
||||
return DeviceToken{
|
||||
DeviceCode: t.DeviceCode,
|
||||
Status: t.Status,
|
||||
Token: t.Token,
|
||||
Expiry: t.Expiry,
|
||||
LastRequestTime: t.LastRequestTime,
|
||||
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
|
||||
return storage.DeviceToken{
|
||||
DeviceCode: t.DeviceCode,
|
||||
Status: t.Status,
|
||||
Token: t.Token,
|
||||
Expiry: t.Expiry,
|
||||
LastRequestTime: t.LastRequestTime,
|
||||
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,8 @@ const (
|
|||
kindPassword = "Password"
|
||||
kindOfflineSessions = "OfflineSessions"
|
||||
kindConnector = "Connector"
|
||||
kindDeviceRequest = "DeviceRequest"
|
||||
kindDeviceToken = "DeviceToken"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -32,6 +34,8 @@ const (
|
|||
resourcePassword = "passwords"
|
||||
resourceOfflineSessions = "offlinesessionses" // Again attempts to pluralize.
|
||||
resourceConnector = "connectors"
|
||||
resourceDeviceRequest = "devicerequests"
|
||||
resourceDeviceToken = "devicetokens"
|
||||
)
|
||||
|
||||
// Config values for the Kubernetes storage type.
|
||||
|
@ -593,5 +597,84 @@ func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err e
|
|||
result.AuthCodes++
|
||||
}
|
||||
}
|
||||
|
||||
var deviceRequests DeviceRequestList
|
||||
if err := cli.list(resourceDeviceRequest, &deviceRequests); err != nil {
|
||||
return result, fmt.Errorf("failed to list device requests: %v", err)
|
||||
}
|
||||
|
||||
for _, deviceRequest := range deviceRequests.DeviceRequests {
|
||||
if now.After(deviceRequest.Expiry) {
|
||||
if err := cli.delete(resourceDeviceRequest, deviceRequest.ObjectMeta.Name); err != nil {
|
||||
cli.logger.Errorf("failed to delete device request: %v", err)
|
||||
delErr = fmt.Errorf("failed to delete device request: %v", err)
|
||||
}
|
||||
result.DeviceRequests++
|
||||
}
|
||||
}
|
||||
|
||||
var deviceTokens DeviceTokenList
|
||||
if err := cli.list(resourceDeviceToken, &deviceTokens); err != nil {
|
||||
return result, fmt.Errorf("failed to list device tokens: %v", err)
|
||||
}
|
||||
|
||||
for _, deviceToken := range deviceTokens.DeviceTokens {
|
||||
if now.After(deviceToken.Expiry) {
|
||||
if err := cli.delete(resourceDeviceToken, deviceToken.ObjectMeta.Name); err != nil {
|
||||
cli.logger.Errorf("failed to delete device token: %v", err)
|
||||
delErr = fmt.Errorf("failed to delete device token: %v", err)
|
||||
}
|
||||
result.DeviceTokens++
|
||||
}
|
||||
}
|
||||
|
||||
if delErr != nil {
|
||||
return result, delErr
|
||||
}
|
||||
return result, delErr
|
||||
}
|
||||
|
||||
func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||
return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d))
|
||||
}
|
||||
|
||||
func (cli *client) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
|
||||
var req DeviceRequest
|
||||
if err := cli.get(resourceDeviceRequest, strings.ToLower(userCode), &req); err != nil {
|
||||
return storage.DeviceRequest{}, err
|
||||
}
|
||||
return toStorageDeviceRequest(req), nil
|
||||
}
|
||||
|
||||
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
|
||||
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
|
||||
}
|
||||
|
||||
func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||
var token DeviceToken
|
||||
if err := cli.get(resourceDeviceToken, deviceCode, &token); err != nil {
|
||||
return storage.DeviceToken{}, err
|
||||
}
|
||||
return toStorageDeviceToken(token), nil
|
||||
}
|
||||
|
||||
func (cli *client) getDeviceToken(deviceCode string) (t DeviceToken, err error) {
|
||||
err = cli.get(resourceDeviceToken, deviceCode, &t)
|
||||
return
|
||||
}
|
||||
|
||||
func (cli *client) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||
r, err := cli.getDeviceToken(deviceCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updated, err := updater(toStorageDeviceToken(r))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updated.DeviceCode = deviceCode
|
||||
|
||||
newToken := cli.fromStorageDeviceToken(updated)
|
||||
newToken.ObjectMeta = r.ObjectMeta
|
||||
return cli.put(resourceDeviceToken, r.ObjectMeta.Name, newToken)
|
||||
}
|
||||
|
|
|
@ -85,6 +85,8 @@ func (s *StorageTestSuite) TestStorage() {
|
|||
for _, resource := range []string{
|
||||
resourceAuthCode,
|
||||
resourceAuthRequest,
|
||||
resourceDeviceRequest,
|
||||
resourceDeviceToken,
|
||||
resourceClient,
|
||||
resourceRefreshToken,
|
||||
resourceKeys,
|
||||
|
|
|
@ -143,6 +143,36 @@ var customResourceDefinitions = []k8sapi.CustomResourceDefinition{
|
|||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ObjectMeta: k8sapi.ObjectMeta{
|
||||
Name: "devicerequests.dex.coreos.com",
|
||||
},
|
||||
TypeMeta: crdMeta,
|
||||
Spec: k8sapi.CustomResourceDefinitionSpec{
|
||||
Group: apiGroup,
|
||||
Version: "v1",
|
||||
Names: k8sapi.CustomResourceDefinitionNames{
|
||||
Plural: "devicerequests",
|
||||
Singular: "devicerequest",
|
||||
Kind: "DeviceRequest",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ObjectMeta: k8sapi.ObjectMeta{
|
||||
Name: "devicetokens.dex.coreos.com",
|
||||
},
|
||||
TypeMeta: crdMeta,
|
||||
Spec: k8sapi.CustomResourceDefinitionSpec{
|
||||
Group: apiGroup,
|
||||
Version: "v1",
|
||||
Names: k8sapi.CustomResourceDefinitionNames{
|
||||
Plural: "devicetokens",
|
||||
Singular: "devicetoken",
|
||||
Kind: "DeviceToken",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// There will only ever be a single keys resource. Maintain this by setting a
|
||||
|
@ -635,3 +665,103 @@ type ConnectorList struct {
|
|||
k8sapi.ListMeta `json:"metadata,omitempty"`
|
||||
Connectors []Connector `json:"items"`
|
||||
}
|
||||
|
||||
// DeviceRequest is a mirrored struct from storage with JSON struct tags and
|
||||
// Kubernetes type metadata.
|
||||
type DeviceRequest struct {
|
||||
k8sapi.TypeMeta `json:",inline"`
|
||||
k8sapi.ObjectMeta `json:"metadata,omitempty"`
|
||||
|
||||
DeviceCode string `json:"device_code,omitempty"`
|
||||
ClientID string `json:"client_id,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
}
|
||||
|
||||
// AuthRequestList is a list of AuthRequests.
|
||||
type DeviceRequestList struct {
|
||||
k8sapi.TypeMeta `json:",inline"`
|
||||
k8sapi.ListMeta `json:"metadata,omitempty"`
|
||||
DeviceRequests []DeviceRequest `json:"items"`
|
||||
}
|
||||
|
||||
func (cli *client) fromStorageDeviceRequest(a storage.DeviceRequest) DeviceRequest {
|
||||
req := DeviceRequest{
|
||||
TypeMeta: k8sapi.TypeMeta{
|
||||
Kind: kindDeviceRequest,
|
||||
APIVersion: cli.apiVersion,
|
||||
},
|
||||
ObjectMeta: k8sapi.ObjectMeta{
|
||||
Name: strings.ToLower(a.UserCode),
|
||||
Namespace: cli.namespace,
|
||||
},
|
||||
DeviceCode: a.DeviceCode,
|
||||
ClientID: a.ClientID,
|
||||
ClientSecret: a.ClientSecret,
|
||||
Scopes: a.Scopes,
|
||||
Expiry: a.Expiry,
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func toStorageDeviceRequest(req DeviceRequest) storage.DeviceRequest {
|
||||
return storage.DeviceRequest{
|
||||
UserCode: strings.ToUpper(req.ObjectMeta.Name),
|
||||
DeviceCode: req.DeviceCode,
|
||||
ClientID: req.ClientID,
|
||||
ClientSecret: req.ClientSecret,
|
||||
Scopes: req.Scopes,
|
||||
Expiry: req.Expiry,
|
||||
}
|
||||
}
|
||||
|
||||
// DeviceToken is a mirrored struct from storage with JSON struct tags and
|
||||
// Kubernetes type metadata.
|
||||
type DeviceToken struct {
|
||||
k8sapi.TypeMeta `json:",inline"`
|
||||
k8sapi.ObjectMeta `json:"metadata,omitempty"`
|
||||
|
||||
Status string `json:"status,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
LastRequestTime time.Time `json:"last_request"`
|
||||
PollIntervalSeconds int `json:"poll_interval"`
|
||||
}
|
||||
|
||||
// DeviceTokenList is a list of DeviceTokens.
|
||||
type DeviceTokenList struct {
|
||||
k8sapi.TypeMeta `json:",inline"`
|
||||
k8sapi.ListMeta `json:"metadata,omitempty"`
|
||||
DeviceTokens []DeviceToken `json:"items"`
|
||||
}
|
||||
|
||||
func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
|
||||
req := DeviceToken{
|
||||
TypeMeta: k8sapi.TypeMeta{
|
||||
Kind: kindDeviceToken,
|
||||
APIVersion: cli.apiVersion,
|
||||
},
|
||||
ObjectMeta: k8sapi.ObjectMeta{
|
||||
Name: t.DeviceCode,
|
||||
Namespace: cli.namespace,
|
||||
},
|
||||
Status: t.Status,
|
||||
Token: t.Token,
|
||||
Expiry: t.Expiry,
|
||||
LastRequestTime: t.LastRequestTime,
|
||||
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
|
||||
return storage.DeviceToken{
|
||||
DeviceCode: t.ObjectMeta.Name,
|
||||
Status: t.Status,
|
||||
Token: t.Token,
|
||||
Expiry: t.Expiry,
|
||||
LastRequestTime: t.LastRequestTime,
|
||||
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,8 @@ func New(logger log.Logger) storage.Storage {
|
|||
passwords: make(map[string]storage.Password),
|
||||
offlineSessions: make(map[offlineSessionID]storage.OfflineSessions),
|
||||
connectors: make(map[string]storage.Connector),
|
||||
deviceRequests: make(map[string]storage.DeviceRequest),
|
||||
deviceTokens: make(map[string]storage.DeviceToken),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
@ -46,6 +48,8 @@ type memStorage struct {
|
|||
passwords map[string]storage.Password
|
||||
offlineSessions map[offlineSessionID]storage.OfflineSessions
|
||||
connectors map[string]storage.Connector
|
||||
deviceRequests map[string]storage.DeviceRequest
|
||||
deviceTokens map[string]storage.DeviceToken
|
||||
|
||||
keys storage.Keys
|
||||
|
||||
|
@ -79,6 +83,18 @@ func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err
|
|||
result.AuthRequests++
|
||||
}
|
||||
}
|
||||
for id, a := range s.deviceRequests {
|
||||
if now.After(a.Expiry) {
|
||||
delete(s.deviceRequests, id)
|
||||
result.DeviceRequests++
|
||||
}
|
||||
}
|
||||
for id, a := range s.deviceTokens {
|
||||
if now.After(a.Expiry) {
|
||||
delete(s.deviceTokens, id)
|
||||
result.DeviceTokens++
|
||||
}
|
||||
}
|
||||
})
|
||||
return result, nil
|
||||
}
|
||||
|
@ -465,3 +481,61 @@ func (s *memStorage) UpdateConnector(id string, updater func(c storage.Connector
|
|||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) {
|
||||
s.tx(func() {
|
||||
if _, ok := s.deviceRequests[d.UserCode]; ok {
|
||||
err = storage.ErrAlreadyExists
|
||||
} else {
|
||||
s.deviceRequests[d.UserCode] = d
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) GetDeviceRequest(userCode string) (req storage.DeviceRequest, err error) {
|
||||
s.tx(func() {
|
||||
var ok bool
|
||||
if req, ok = s.deviceRequests[userCode]; !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
|
||||
s.tx(func() {
|
||||
if _, ok := s.deviceTokens[t.DeviceCode]; ok {
|
||||
err = storage.ErrAlreadyExists
|
||||
} else {
|
||||
s.deviceTokens[t.DeviceCode] = t
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
|
||||
s.tx(func() {
|
||||
var ok bool
|
||||
if t, ok = s.deviceTokens[deviceCode]; !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) UpdateDeviceToken(deviceCode string, updater func(p storage.DeviceToken) (storage.DeviceToken, error)) (err error) {
|
||||
s.tx(func() {
|
||||
r, ok := s.deviceTokens[deviceCode]
|
||||
if !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
if r, err = updater(r); err == nil {
|
||||
s.deviceTokens[deviceCode] = r
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -100,6 +100,23 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
|
|||
if n, err := r.RowsAffected(); err == nil {
|
||||
result.AuthCodes = n
|
||||
}
|
||||
|
||||
r, err = c.Exec(`delete from device_request where expiry < $1`, now)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("gc device_request: %v", err)
|
||||
}
|
||||
if n, err := r.RowsAffected(); err == nil {
|
||||
result.DeviceRequests = n
|
||||
}
|
||||
|
||||
r, err = c.Exec(`delete from device_token where expiry < $1`, now)
|
||||
if err != nil {
|
||||
return result, fmt.Errorf("gc device_token: %v", err)
|
||||
}
|
||||
if n, err := r.RowsAffected(); err == nil {
|
||||
result.DeviceTokens = n
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -867,3 +884,113 @@ func (c *conn) delete(table, field, id string) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||
_, err := c.Exec(`
|
||||
insert into device_request (
|
||||
user_code, device_code, client_id, client_secret, scopes, expiry
|
||||
)
|
||||
values (
|
||||
$1, $2, $3, $4, $5, $6
|
||||
);`,
|
||||
d.UserCode, d.DeviceCode, d.ClientID, d.ClientSecret, encoder(d.Scopes), d.Expiry,
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert device request: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
|
||||
_, err := c.Exec(`
|
||||
insert into device_token (
|
||||
device_code, status, token, expiry, last_request, poll_interval
|
||||
)
|
||||
values (
|
||||
$1, $2, $3, $4, $5, $6
|
||||
);`,
|
||||
t.DeviceCode, t.Status, t.Token, t.Expiry, t.LastRequestTime, t.PollIntervalSeconds,
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert device token: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
|
||||
return getDeviceRequest(c, userCode)
|
||||
}
|
||||
|
||||
func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) {
|
||||
err = q.QueryRow(`
|
||||
select
|
||||
device_code, client_id, client_secret, scopes, expiry
|
||||
from device_request where user_code = $1;
|
||||
`, userCode).Scan(
|
||||
&d.DeviceCode, &d.ClientID, &d.ClientSecret, decoder(&d.Scopes), &d.Expiry,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return d, storage.ErrNotFound
|
||||
}
|
||||
return d, fmt.Errorf("select device token: %v", err)
|
||||
}
|
||||
d.UserCode = userCode
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||
return getDeviceToken(c, deviceCode)
|
||||
}
|
||||
|
||||
func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
|
||||
err = q.QueryRow(`
|
||||
select
|
||||
status, token, expiry, last_request, poll_interval
|
||||
from device_token where device_code = $1;
|
||||
`, deviceCode).Scan(
|
||||
&a.Status, &a.Token, &a.Expiry, &a.LastRequestTime, &a.PollIntervalSeconds,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return a, storage.ErrNotFound
|
||||
}
|
||||
return a, fmt.Errorf("select device token: %v", err)
|
||||
}
|
||||
a.DeviceCode = deviceCode
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
r, err := getDeviceToken(tx, deviceCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r, err = updater(r); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`
|
||||
update device_token
|
||||
set
|
||||
status = $1,
|
||||
token = $2,
|
||||
last_request = $3,
|
||||
poll_interval = $4
|
||||
where
|
||||
device_code = $5
|
||||
`,
|
||||
r.Status, r.Token, r.LastRequestTime, r.PollIntervalSeconds, r.DeviceCode,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update device token: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
|
|
@ -229,4 +229,25 @@ var migrations = []migration{
|
|||
},
|
||||
flavor: &flavorMySQL,
|
||||
},
|
||||
{
|
||||
stmts: []string{`
|
||||
create table device_request (
|
||||
user_code text not null primary key,
|
||||
device_code text not null,
|
||||
client_id text not null,
|
||||
client_secret text ,
|
||||
scopes bytea not null, -- JSON array of strings
|
||||
expiry timestamptz not null
|
||||
);`,
|
||||
`
|
||||
create table device_token (
|
||||
device_code text not null primary key,
|
||||
status text not null,
|
||||
token bytea,
|
||||
expiry timestamptz not null,
|
||||
last_request timestamptz not null,
|
||||
poll_interval integer not null
|
||||
);`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"encoding/base32"
|
||||
"errors"
|
||||
"io"
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -24,9 +25,21 @@ var (
|
|||
// TODO(ericchiang): refactor ID creation onto the storage.
|
||||
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
|
||||
|
||||
//Valid characters for user codes
|
||||
const validUserCharacters = "BCDFGHJKLMNPQRSTVWXZ"
|
||||
|
||||
// NewDeviceCode returns a 32 char alphanumeric cryptographically secure string
|
||||
func NewDeviceCode() string {
|
||||
return newSecureID(32)
|
||||
}
|
||||
|
||||
// NewID returns a random string which can be used as an ID for objects.
|
||||
func NewID() string {
|
||||
buff := make([]byte, 16) // 128 bit random ID.
|
||||
return newSecureID(16)
|
||||
}
|
||||
|
||||
func newSecureID(len int) string {
|
||||
buff := make([]byte, len) // random ID.
|
||||
if _, err := io.ReadFull(rand.Reader, buff); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -38,6 +51,8 @@ func NewID() string {
|
|||
type GCResult struct {
|
||||
AuthRequests int64
|
||||
AuthCodes int64
|
||||
DeviceRequests int64
|
||||
DeviceTokens int64
|
||||
}
|
||||
|
||||
// Storage is the storage interface used by the server. Implementations are
|
||||
|
@ -54,6 +69,8 @@ type Storage interface {
|
|||
CreatePassword(p Password) error
|
||||
CreateOfflineSessions(s OfflineSessions) error
|
||||
CreateConnector(c Connector) error
|
||||
CreateDeviceRequest(d DeviceRequest) error
|
||||
CreateDeviceToken(d DeviceToken) error
|
||||
|
||||
// TODO(ericchiang): return (T, bool, error) so we can indicate not found
|
||||
// requests that way instead of using ErrNotFound.
|
||||
|
@ -65,6 +82,8 @@ type Storage interface {
|
|||
GetPassword(email string) (Password, error)
|
||||
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
|
||||
GetConnector(id string) (Connector, error)
|
||||
GetDeviceRequest(userCode string) (DeviceRequest, error)
|
||||
GetDeviceToken(deviceCode string) (DeviceToken, error)
|
||||
|
||||
ListClients() ([]Client, error)
|
||||
ListRefreshTokens() ([]RefreshToken, error)
|
||||
|
@ -101,8 +120,10 @@ type Storage interface {
|
|||
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
||||
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
|
||||
UpdateConnector(id string, updater func(c Connector) (Connector, error)) error
|
||||
UpdateDeviceToken(deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error
|
||||
|
||||
// GarbageCollect deletes all expired AuthCodes and AuthRequests.
|
||||
// GarbageCollect deletes all expired AuthCodes,
|
||||
// AuthRequests, DeviceRequests, and DeviceTokens.
|
||||
GarbageCollect(now time.Time) (GCResult, error)
|
||||
}
|
||||
|
||||
|
@ -342,3 +363,49 @@ type Keys struct {
|
|||
// For caching purposes, implementations MUST NOT update keys before this time.
|
||||
NextRotation time.Time
|
||||
}
|
||||
|
||||
// NewUserCode returns a randomized 8 character user code for the device flow.
|
||||
// No vowels are included to prevent accidental generation of words
|
||||
func NewUserCode() (string, error) {
|
||||
code, err := randomString(8)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return code[:4] + "-" + code[4:], nil
|
||||
}
|
||||
|
||||
func randomString(n int) (string, error) {
|
||||
v := big.NewInt(int64(len(validUserCharacters)))
|
||||
bytes := make([]byte, n)
|
||||
for i := 0; i < n; i++ {
|
||||
c, _ := rand.Int(rand.Reader, v)
|
||||
bytes[i] = validUserCharacters[c.Int64()]
|
||||
}
|
||||
return string(bytes), nil
|
||||
}
|
||||
|
||||
//DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user
|
||||
//authenticates using their user code or the expiry time passes.
|
||||
type DeviceRequest struct {
|
||||
//The code the user will enter in a browser
|
||||
UserCode string
|
||||
//The unique device code for device authentication
|
||||
DeviceCode string
|
||||
//The client ID the code is for
|
||||
ClientID string
|
||||
//The Client Secret
|
||||
ClientSecret string
|
||||
//The scopes the device requests
|
||||
Scopes []string
|
||||
//The expire time
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
type DeviceToken struct {
|
||||
DeviceCode string
|
||||
Status string
|
||||
Token string
|
||||
Expiry time.Time
|
||||
LastRequestTime time.Time
|
||||
PollIntervalSeconds int
|
||||
}
|
||||
|
|
23
web/templates/device.html
Normal file
23
web/templates/device.html
Normal file
|
@ -0,0 +1,23 @@
|
|||
{{ template "header.html" . }}
|
||||
|
||||
<div class="theme-panel">
|
||||
<h2 class="theme-heading">Enter User Code</h2>
|
||||
<form method="post" action="{{ .PostURL }}" method="get">
|
||||
<div class="theme-form-row">
|
||||
{{ if( .UserCode )}}
|
||||
<input tabindex="2" required id="user_code" name="user_code" type="text" class="theme-form-input" autocomplete="off" value="{{.UserCode}}" {{ if .Invalid }} autofocus {{ end }}/>
|
||||
{{ else }}
|
||||
<input tabindex="2" required id="user_code" name="user_code" type="text" class="theme-form-input" placeholder="XXXX-XXXX" autocomplete="off" {{ if .Invalid }} autofocus {{ end }}/>
|
||||
{{ end }}
|
||||
</div>
|
||||
|
||||
{{ if .Invalid }}
|
||||
<div id="login-error" class="dex-error-box">
|
||||
Invalid or Expired User Code
|
||||
</div>
|
||||
{{ end }}
|
||||
<button tabindex="3" id="submit-login" type="submit" class="dex-btn theme-btn--primary">Submit</button>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
{{ template "footer.html" . }}
|
8
web/templates/device_success.html
Normal file
8
web/templates/device_success.html
Normal file
|
@ -0,0 +1,8 @@
|
|||
{{ template "header.html" . }}
|
||||
|
||||
<div class="theme-panel">
|
||||
<h2 class="theme-heading">Login Successful for {{ .ClientName }}</h2>
|
||||
<p>Return to your device to continue</p>
|
||||
</div>
|
||||
|
||||
{{ template "footer.html" . }}
|
Loading…
Reference in a new issue