Generates/Stores the device request and returns the device and user codes.

Signed-off-by: justin-slowik <justin.slowik@thermofisher.com>
This commit is contained in:
Justin Slowik 2020-01-16 10:55:07 -05:00 committed by justin-slowik
parent 11fc8568cb
commit 6d343e059b
14 changed files with 690 additions and 8 deletions

View file

@ -0,0 +1,12 @@
apiVersion: apiextensions.k8s.io/v1beta1
kind: CustomResourceDefinition
metadata:
name: devicerequests.dex.coreos.com
spec:
group: dex.coreos.com
names:
kind: DeviceRequest
listKind: DeviceRequestList
plural: devicerequests
singular: devicerequest
version: v1

View file

@ -0,0 +1,12 @@
apiVersion: apiextensions.k8s.io/v1beta1
kind: CustomResourceDefinition
metadata:
name: devicetokens.dex.coreos.com
spec:
group: dex.coreos.com
names:
kind: DeviceToken
listKind: DeviceTokenList
plural: devicetokens
singular: devicetoken
version: v1

View file

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
@ -15,12 +16,11 @@ import (
"time" "time"
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
"github.com/gorilla/mux"
jose "gopkg.in/square/go-jose.v2"
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
"github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/server/internal"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/gorilla/mux"
jose "gopkg.in/square/go-jose.v2"
) )
// newHealthChecker returns the healthz handler. The handler runs until the // newHealthChecker returns the healthz handler. The handler runs until the
@ -1415,3 +1415,112 @@ func usernamePrompt(conn connector.PasswordConnector) string {
} }
return "Username" return "Username"
} }
type deviceCodeResponse struct {
//The unique device code for device authentication
DeviceCode string `json:"device_code"`
//The code the user will exchange via a browser and log in
UserCode string `json:"user_code"`
//The url to verify the user code.
VerificationURI string `json:"verification_uri"`
//The lifetime of the device code
ExpireTime int `json:"expires_in"`
//How often the device is allowed to poll to verify that the user login occurred
PollInterval int `json:"interval"`
}
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
//TODO replace with configurable values
expireIntervalSeconds := 300
requestsPerMinute := 5
switch r.Method {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
message := "Could not parse Device Request body"
s.logger.Errorf("%s : %v", message, err)
respondWithError(w, message, err)
return
}
//Get the client id and scopes from the post
clientID := r.Form.Get("client_id")
scopes := r.Form["scope"]
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
//Make device code
deviceCode := storage.NewDeviceCode()
//make user code
userCode := storage.NewUserCode()
//make a pkce verification code
pkceCode := storage.NewID()
//Generate the expire time
expireTime := time.Now().Add(time.Second * time.Duration(expireIntervalSeconds))
//Store the Device Request
deviceReq := storage.DeviceRequest{
UserCode: userCode,
DeviceCode: deviceCode,
ClientID: clientID,
Scopes: scopes,
PkceVerifier: pkceCode,
Expiry: expireTime,
}
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
message := fmt.Sprintf("Failed to store device request %v", err)
s.logger.Errorf(message)
respondWithError(w, message, err)
return
}
//Store the device token
deviceToken := storage.DeviceToken{
DeviceCode: deviceCode,
Status: "pending",
Token: "",
Expiry: expireTime,
}
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
message := fmt.Sprintf("Failed to store device token %v", err)
s.logger.Errorf(message)
respondWithError(w, message, err)
return
}
code := deviceCodeResponse{
DeviceCode: deviceCode,
UserCode: userCode,
VerificationURI: path.Join(s.issuerURL.String(), "/device"),
ExpireTime: expireIntervalSeconds,
PollInterval: requestsPerMinute,
}
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
enc.Encode(code)
default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
}
}
func respondWithError(w io.Writer, errorMessage string, err error) {
resp := struct {
Error string `json:"error"`
ErrorMessage string `json:"message"`
}{
Error: err.Error(),
ErrorMessage: errorMessage,
}
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
enc.Encode(resp)
}

View file

@ -302,6 +302,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
handleWithCORS("/userinfo", s.handleUserInfo) handleWithCORS("/userinfo", s.handleUserInfo)
handleFunc("/auth", s.handleAuthorization) handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin) handleFunc("/auth/{connector}", s.handleConnectorLogin)
handleFunc("/device/code", s.handleDeviceCode)
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
// Strip the X-Remote-* headers to prevent security issues on // Strip the X-Remote-* headers to prevent security issues on
// misconfigured authproxy connector setups. // misconfigured authproxy connector setups.
@ -450,7 +451,8 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura
if r, err := s.storage.GarbageCollect(now()); err != nil { if r, err := s.storage.GarbageCollect(now()); err != nil {
s.logger.Errorf("garbage collection failed: %v", err) s.logger.Errorf("garbage collection failed: %v", err)
} else if r.AuthRequests > 0 || r.AuthCodes > 0 { } else if r.AuthRequests > 0 || r.AuthCodes > 0 {
s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes) s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d, device requests =%d, device tokens=%d",
r.AuthRequests, r.AuthCodes, r.DeviceRequests, r.DeviceTokens)
} }
} }
} }

View file

@ -49,6 +49,8 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
{"ConnectorCRUD", testConnectorCRUD}, {"ConnectorCRUD", testConnectorCRUD},
{"GarbageCollection", testGC}, {"GarbageCollection", testGC},
{"TimezoneSupport", testTimezones}, {"TimezoneSupport", testTimezones},
{"DeviceRequestCRUD", testDeviceRequestCRUD},
{"DeviceTokenCRUD", testDeviceTokenCRUD},
}) })
} }
@ -834,6 +836,82 @@ func testGC(t *testing.T, s storage.Storage) {
} else if err != storage.ErrNotFound { } else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err) t.Errorf("expected storage.ErrNotFound, got %v", err)
} }
d := storage.DeviceRequest{
UserCode: storage.NewUserCode(),
DeviceCode: storage.NewID(),
ClientID: "client1",
Scopes: []string{"openid", "email"},
PkceVerifier: storage.NewID(),
Expiry: expiry,
}
if err := s.CreateDeviceRequest(d); err != nil {
t.Fatalf("failed creating device request: %v", err)
}
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else {
if result.DeviceRequests != 0 {
t.Errorf("expected no device garbage collection results, got %#v", result)
}
}
//if _, err := s.GetDeviceRequest(d.UserCode); err != nil {
// t.Errorf("expected to be able to get auth request after GC: %v", err)
//}
}
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.DeviceRequests != 1 {
t.Errorf("expected to garbage collect 1 device request, got %d", r.DeviceRequests)
}
//TODO add this code back once Getters are written for device requests
//if _, err := s.GetDeviceRequest(d.UserCode); err == nil {
// t.Errorf("expected device request to be GC'd")
//} else if err != storage.ErrNotFound {
// t.Errorf("expected storage.ErrNotFound, got %v", err)
//}
dt := storage.DeviceToken{
DeviceCode: storage.NewID(),
Status: "pending",
Token: "foo",
Expiry: expiry,
}
if err := s.CreateDeviceToken(dt); err != nil {
t.Fatalf("failed creating device token: %v", err)
}
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else {
if result.DeviceTokens != 0 {
t.Errorf("expected no device token garbage collection results, got %#v", result)
}
}
//if _, err := s.GetDeviceRequest(d.UserCode); err != nil {
// t.Errorf("expected to be able to get auth request after GC: %v", err)
//}
}
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.DeviceTokens != 1 {
t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens)
}
//TODO add this code back once Getters are written for device tokens
//if _, err := s.GetDeviceRequest(d.UserCode); err == nil {
// t.Errorf("expected device request to be GC'd")
//} else if err != storage.ErrNotFound {
// t.Errorf("expected storage.ErrNotFound, got %v", err)
//}
} }
// testTimezones tests that backends either fully support timezones or // testTimezones tests that backends either fully support timezones or
@ -881,3 +959,44 @@ func testTimezones(t *testing.T, s storage.Storage) {
t.Fatalf("expected expiry %v got %v", wantTime, gotTime) t.Fatalf("expected expiry %v got %v", wantTime, gotTime)
} }
} }
func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
d1 := storage.DeviceRequest{
UserCode: storage.NewUserCode(),
DeviceCode: storage.NewID(),
ClientID: "client1",
Scopes: []string{"openid", "email"},
PkceVerifier: storage.NewID(),
Expiry: neverExpire,
}
if err := s.CreateDeviceRequest(d1); err != nil {
t.Fatalf("failed creating device request: %v", err)
}
// Attempt to create same DeviceRequest twice.
err := s.CreateDeviceRequest(d1)
mustBeErrAlreadyExists(t, "device request", err)
//No manual deletes for device requests, will be handled by garbage collection routines
//see testGC
}
func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
d1 := storage.DeviceToken{
DeviceCode: storage.NewID(),
Status: "pending",
Token: storage.NewID(),
Expiry: neverExpire,
}
if err := s.CreateDeviceToken(d1); err != nil {
t.Fatalf("failed creating device token: %v", err)
}
// Attempt to create same DeviceRequest twice.
err := s.CreateDeviceToken(d1)
mustBeErrAlreadyExists(t, "device token", err)
//TODO Add update / delete tests as functionality is put into main code
}

View file

@ -22,6 +22,8 @@ const (
offlineSessionPrefix = "offline_session/" offlineSessionPrefix = "offline_session/"
connectorPrefix = "connector/" connectorPrefix = "connector/"
keysName = "openid-connect-keys" keysName = "openid-connect-keys"
deviceRequestPrefix = "device_req/"
deviceTokenPrefix = "device_token/"
// defaultStorageTimeout will be applied to all storage's operations. // defaultStorageTimeout will be applied to all storage's operations.
defaultStorageTimeout = 5 * time.Second defaultStorageTimeout = 5 * time.Second
@ -72,6 +74,36 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
result.AuthCodes++ result.AuthCodes++
} }
} }
deviceRequests, err := c.listDeviceRequests(ctx)
if err != nil {
return result, err
}
for _, deviceRequest := range deviceRequests {
if now.After(deviceRequest.Expiry) {
if err := c.deleteKey(ctx, keyID(deviceRequestPrefix, deviceRequest.UserCode)); err != nil {
c.logger.Errorf("failed to delete device request %v", err)
delErr = fmt.Errorf("failed to delete device request: %v", err)
}
result.DeviceRequests++
}
}
deviceTokens, err := c.listDeviceTokens(ctx)
if err != nil {
return result, err
}
for _, deviceToken := range deviceTokens {
if now.After(deviceToken.Expiry) {
if err := c.deleteKey(ctx, keyID(deviceTokenPrefix, deviceToken.DeviceCode)); err != nil {
c.logger.Errorf("failed to delete device token %v", err)
delErr = fmt.Errorf("failed to delete device token: %v", err)
}
result.DeviceTokens++
}
}
return result, delErr return result, delErr
} }
@ -531,3 +563,45 @@ func keyEmail(prefix, email string) string { return prefix + strings.ToLower(ema
func keySession(prefix, userID, connID string) string { func keySession(prefix, userID, connID string) string {
return prefix + strings.ToLower(userID+"|"+connID) return prefix + strings.ToLower(userID+"|"+connID)
} }
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
}
func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) {
res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix())
if err != nil {
return requests, err
}
for _, v := range res.Kvs {
var r DeviceRequest
if err = json.Unmarshal(v.Value, &r); err != nil {
return requests, err
}
requests = append(requests, r)
}
return requests, nil
}
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keyID(deviceRequestPrefix, t.DeviceCode), fromStorageDeviceToken(t))
}
func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) {
res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix())
if err != nil {
return deviceTokens, err
}
for _, v := range res.Kvs {
var dt DeviceToken
if err = json.Unmarshal(v.Value, &dt); err != nil {
return deviceTokens, err
}
deviceTokens = append(deviceTokens, dt)
}
return deviceTokens, nil
}

View file

@ -216,3 +216,41 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
} }
return s return s
} }
// DeviceRequest is a mirrored struct from storage with JSON struct tags
type DeviceRequest struct {
UserCode string `json:"user_code"`
DeviceCode string `json:"device_code"`
ClientID string `json:"client_id"`
Scopes []string `json:"scopes"`
PkceVerifier string `json:"pkce_verifier"`
Expiry time.Time `json:"expiry"`
}
func fromStorageDeviceRequest(d storage.DeviceRequest) DeviceRequest {
return DeviceRequest{
UserCode: d.UserCode,
DeviceCode: d.DeviceCode,
ClientID: d.ClientID,
Scopes: d.Scopes,
PkceVerifier: d.PkceVerifier,
Expiry: d.Expiry,
}
}
// DeviceToken is a mirrored struct from storage with JSON struct tags
type DeviceToken struct {
DeviceCode string `json:"device_code"`
Status string `json:"status"`
Token string `json:"token"`
Expiry time.Time `json:"expiry"`
}
func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
return DeviceToken{
DeviceCode: t.DeviceCode,
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
}
}

View file

@ -21,6 +21,8 @@ const (
kindPassword = "Password" kindPassword = "Password"
kindOfflineSessions = "OfflineSessions" kindOfflineSessions = "OfflineSessions"
kindConnector = "Connector" kindConnector = "Connector"
kindDeviceRequest = "DeviceRequest"
kindDeviceToken = "DeviceToken"
) )
const ( const (
@ -32,6 +34,8 @@ const (
resourcePassword = "passwords" resourcePassword = "passwords"
resourceOfflineSessions = "offlinesessionses" // Again attempts to pluralize. resourceOfflineSessions = "offlinesessionses" // Again attempts to pluralize.
resourceConnector = "connectors" resourceConnector = "connectors"
resourceDeviceRequest = "devicerequests"
resourceDeviceToken = "devicetokens"
) )
// Config values for the Kubernetes storage type. // Config values for the Kubernetes storage type.
@ -593,5 +597,47 @@ func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err e
result.AuthCodes++ result.AuthCodes++
} }
} }
var deviceRequests DeviceRequestList
if err := cli.list(resourceDeviceRequest, &deviceRequests); err != nil {
return result, fmt.Errorf("failed to list device requests: %v", err)
}
for _, deviceRequest := range deviceRequests.DeviceRequests {
if now.After(deviceRequest.Expiry) {
if err := cli.delete(resourceDeviceRequest, deviceRequest.ObjectMeta.Name); err != nil {
cli.logger.Errorf("failed to delete device request: %v", err)
delErr = fmt.Errorf("failed to delete device request: %v", err)
}
result.DeviceRequests++
}
}
var deviceTokens DeviceTokenList
if err := cli.list(resourceDeviceToken, &deviceTokens); err != nil {
return result, fmt.Errorf("failed to list device tokens: %v", err)
}
for _, deviceToken := range deviceTokens.DeviceTokens {
if now.After(deviceToken.Expiry) {
if err := cli.delete(resourceDeviceToken, deviceToken.ObjectMeta.Name); err != nil {
cli.logger.Errorf("failed to delete device token: %v", err)
delErr = fmt.Errorf("failed to delete device token: %v", err)
}
result.DeviceTokens++
}
}
if delErr != nil {
return result, delErr
}
return result, delErr return result, delErr
} }
func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d))
}
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
}

View file

@ -85,6 +85,8 @@ func (s *StorageTestSuite) TestStorage() {
for _, resource := range []string{ for _, resource := range []string{
resourceAuthCode, resourceAuthCode,
resourceAuthRequest, resourceAuthRequest,
resourceDeviceRequest,
resourceDeviceToken,
resourceClient, resourceClient,
resourceRefreshToken, resourceRefreshToken,
resourceKeys, resourceKeys,

View file

@ -143,6 +143,36 @@ var customResourceDefinitions = []k8sapi.CustomResourceDefinition{
}, },
}, },
}, },
{
ObjectMeta: k8sapi.ObjectMeta{
Name: "devicerequests.dex.coreos.com",
},
TypeMeta: crdMeta,
Spec: k8sapi.CustomResourceDefinitionSpec{
Group: apiGroup,
Version: "v1",
Names: k8sapi.CustomResourceDefinitionNames{
Plural: "devicerequests",
Singular: "devicerequest",
Kind: "DeviceRequest",
},
},
},
{
ObjectMeta: k8sapi.ObjectMeta{
Name: "devicetokens.dex.coreos.com",
},
TypeMeta: crdMeta,
Spec: k8sapi.CustomResourceDefinitionSpec{
Group: apiGroup,
Version: "v1",
Names: k8sapi.CustomResourceDefinitionNames{
Plural: "devicetokens",
Singular: "devicetoken",
Kind: "DeviceToken",
},
},
},
} }
// There will only ever be a single keys resource. Maintain this by setting a // There will only ever be a single keys resource. Maintain this by setting a
@ -635,3 +665,77 @@ type ConnectorList struct {
k8sapi.ListMeta `json:"metadata,omitempty"` k8sapi.ListMeta `json:"metadata,omitempty"`
Connectors []Connector `json:"items"` Connectors []Connector `json:"items"`
} }
// DeviceRequest is a mirrored struct from storage with JSON struct tags and
// Kubernetes type metadata.
type DeviceRequest struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
DeviceCode string `json:"device_code,omitempty"`
CLientID string `json:"client_id,omitempty"`
Scopes []string `json:"scopes,omitempty"`
PkceVerifier string `json:"pkce_verifier,omitempty"`
Expiry time.Time `json:"expiry"`
}
// AuthRequestList is a list of AuthRequests.
type DeviceRequestList struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ListMeta `json:"metadata,omitempty"`
DeviceRequests []DeviceRequest `json:"items"`
}
func (cli *client) fromStorageDeviceRequest(a storage.DeviceRequest) DeviceRequest {
req := DeviceRequest{
TypeMeta: k8sapi.TypeMeta{
Kind: kindDeviceRequest,
APIVersion: cli.apiVersion,
},
ObjectMeta: k8sapi.ObjectMeta{
Name: strings.ToLower(a.UserCode),
Namespace: cli.namespace,
},
DeviceCode: a.DeviceCode,
CLientID: a.ClientID,
Scopes: a.Scopes,
PkceVerifier: a.PkceVerifier,
Expiry: a.Expiry,
}
return req
}
// DeviceToken is a mirrored struct from storage with JSON struct tags and
// Kubernetes type metadata.
type DeviceToken struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
Status string `json:"status,omitempty"`
Token string `json:"token,omitempty"`
Expiry time.Time `json:"expiry"`
}
// DeviceTokenList is a list of DeviceTokens.
type DeviceTokenList struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ListMeta `json:"metadata,omitempty"`
DeviceTokens []DeviceToken `json:"items"`
}
func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
req := DeviceToken{
TypeMeta: k8sapi.TypeMeta{
Kind: kindDeviceToken,
APIVersion: cli.apiVersion,
},
ObjectMeta: k8sapi.ObjectMeta{
Name: t.DeviceCode,
Namespace: cli.namespace,
},
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
}
return req
}

View file

@ -20,6 +20,8 @@ func New(logger log.Logger) storage.Storage {
passwords: make(map[string]storage.Password), passwords: make(map[string]storage.Password),
offlineSessions: make(map[offlineSessionID]storage.OfflineSessions), offlineSessions: make(map[offlineSessionID]storage.OfflineSessions),
connectors: make(map[string]storage.Connector), connectors: make(map[string]storage.Connector),
deviceRequests: make(map[string]storage.DeviceRequest),
deviceTokens: make(map[string]storage.DeviceToken),
logger: logger, logger: logger,
} }
} }
@ -46,6 +48,8 @@ type memStorage struct {
passwords map[string]storage.Password passwords map[string]storage.Password
offlineSessions map[offlineSessionID]storage.OfflineSessions offlineSessions map[offlineSessionID]storage.OfflineSessions
connectors map[string]storage.Connector connectors map[string]storage.Connector
deviceRequests map[string]storage.DeviceRequest
deviceTokens map[string]storage.DeviceToken
keys storage.Keys keys storage.Keys
@ -79,6 +83,18 @@ func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err
result.AuthRequests++ result.AuthRequests++
} }
} }
for id, a := range s.deviceRequests {
if now.After(a.Expiry) {
delete(s.deviceRequests, id)
result.DeviceRequests++
}
}
for id, a := range s.deviceTokens {
if now.After(a.Expiry) {
delete(s.deviceTokens, id)
result.DeviceTokens++
}
}
}) })
return result, nil return result, nil
} }
@ -465,3 +481,25 @@ func (s *memStorage) UpdateConnector(id string, updater func(c storage.Connector
}) })
return return
} }
func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) {
s.tx(func() {
if _, ok := s.deviceRequests[d.UserCode]; ok {
err = storage.ErrAlreadyExists
} else {
s.deviceRequests[d.UserCode] = d
}
})
return
}
func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
s.tx(func() {
if _, ok := s.deviceTokens[t.DeviceCode]; ok {
err = storage.ErrAlreadyExists
} else {
s.deviceTokens[t.DeviceCode] = t
}
})
return
}

View file

@ -100,6 +100,23 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
if n, err := r.RowsAffected(); err == nil { if n, err := r.RowsAffected(); err == nil {
result.AuthCodes = n result.AuthCodes = n
} }
r, err = c.Exec(`delete from device_request where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc device_request: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.DeviceRequests = n
}
r, err = c.Exec(`delete from device_token where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc device_token: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.DeviceTokens = n
}
return return
} }
@ -867,3 +884,41 @@ func (c *conn) delete(table, field, id string) error {
} }
return nil return nil
} }
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
_, err := c.Exec(`
insert into device_request (
user_code, device_code, client_id, scopes, pkce_verifier, expiry
)
values (
$1, $2, $3, $4, $5, $6
);`,
d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.PkceVerifier, d.Expiry,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert device request: %v", err)
}
return nil
}
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
_, err := c.Exec(`
insert into device_token (
device_code, status, token, expiry
)
values (
$1, $2, $3, $4
);`,
t.DeviceCode, t.Status, t.Token, t.Expiry,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert device token: %v", err)
}
return nil
}

View file

@ -229,4 +229,23 @@ var migrations = []migration{
}, },
flavor: &flavorMySQL, flavor: &flavorMySQL,
}, },
{
stmts: []string{`
create table device_request (
user_code text not null primary key,
device_code text not null,
client_id text not null,
scopes bytea not null, -- JSON array of strings
pkce_verifier text not null,
expiry timestamptz not null
);`,
`
create table device_token (
device_code text not null primary key,
status text not null,
token text,
expiry timestamptz not null
);`,
},
},
} }

View file

@ -5,6 +5,7 @@ import (
"encoding/base32" "encoding/base32"
"errors" "errors"
"io" "io"
mrand "math/rand"
"strings" "strings"
"time" "time"
@ -24,9 +25,18 @@ var (
// TODO(ericchiang): refactor ID creation onto the storage. // TODO(ericchiang): refactor ID creation onto the storage.
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567") var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
// NewDeviceCode returns a 32 char alphanumeric cryptographically secure string
func NewDeviceCode() string {
return newSecureID(32)
}
// NewID returns a random string which can be used as an ID for objects. // NewID returns a random string which can be used as an ID for objects.
func NewID() string { func NewID() string {
buff := make([]byte, 16) // 128 bit random ID. return newSecureID(16)
}
func newSecureID(len int) string {
buff := make([]byte, len) // 128 bit random ID.
if _, err := io.ReadFull(rand.Reader, buff); err != nil { if _, err := io.ReadFull(rand.Reader, buff); err != nil {
panic(err) panic(err)
} }
@ -38,6 +48,8 @@ func NewID() string {
type GCResult struct { type GCResult struct {
AuthRequests int64 AuthRequests int64
AuthCodes int64 AuthCodes int64
DeviceRequests int64
DeviceTokens int64
} }
// Storage is the storage interface used by the server. Implementations are // Storage is the storage interface used by the server. Implementations are
@ -54,6 +66,8 @@ type Storage interface {
CreatePassword(p Password) error CreatePassword(p Password) error
CreateOfflineSessions(s OfflineSessions) error CreateOfflineSessions(s OfflineSessions) error
CreateConnector(c Connector) error CreateConnector(c Connector) error
CreateDeviceRequest(d DeviceRequest) error
CreateDeviceToken(d DeviceToken) error
// TODO(ericchiang): return (T, bool, error) so we can indicate not found // TODO(ericchiang): return (T, bool, error) so we can indicate not found
// requests that way instead of using ErrNotFound. // requests that way instead of using ErrNotFound.
@ -102,7 +116,7 @@ type Storage interface {
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
UpdateConnector(id string, updater func(c Connector) (Connector, error)) error UpdateConnector(id string, updater func(c Connector) (Connector, error)) error
// GarbageCollect deletes all expired AuthCodes and AuthRequests. // GarbageCollect deletes all expired AuthCodes,AuthRequests, DeviceRequests, and DeviceTokens.
GarbageCollect(now time.Time) (GCResult, error) GarbageCollect(now time.Time) (GCResult, error)
} }
@ -342,3 +356,41 @@ type Keys struct {
// For caching purposes, implementations MUST NOT update keys before this time. // For caching purposes, implementations MUST NOT update keys before this time.
NextRotation time.Time NextRotation time.Time
} }
func NewUserCode() string {
mrand.Seed(time.Now().UnixNano())
return randomString(4) + "-" + randomString(4)
}
func randomString(n int) string {
var letter = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
b := make([]rune, n)
for i := range b {
b[i] = letter[mrand.Intn(len(letter))]
}
return string(b)
}
//DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user
//authenticates using their user code or the expiry time passes.
type DeviceRequest struct {
//The code the user will enter in a browser
UserCode string
//The unique device code for device authentication
DeviceCode string
//The client ID the code is for
ClientID string
//The scopes the device requests
Scopes []string
//PKCE Verification
PkceVerifier string
//The expire time
Expiry time.Time
}
type DeviceToken struct {
DeviceCode string
Status string
Token string
Expiry time.Time
}