From 6d343e059bb21dac8b1ffff99691b43d6e26629d Mon Sep 17 00:00:00 2001 From: Justin Slowik Date: Thu, 16 Jan 2020 10:55:07 -0500 Subject: [PATCH] Generates/Stores the device request and returns the device and user codes. Signed-off-by: justin-slowik --- scripts/manifests/crds/devicerequests.yaml | 12 +++ scripts/manifests/crds/devicetokens.yaml | 12 +++ server/handlers.go | 115 +++++++++++++++++++- server/server.go | 4 +- storage/conformance/conformance.go | 119 +++++++++++++++++++++ storage/etcd/etcd.go | 74 +++++++++++++ storage/etcd/types.go | 38 +++++++ storage/kubernetes/storage.go | 46 ++++++++ storage/kubernetes/storage_test.go | 2 + storage/kubernetes/types.go | 104 ++++++++++++++++++ storage/memory/memory.go | 38 +++++++ storage/sql/crud.go | 55 ++++++++++ storage/sql/migrate.go | 19 ++++ storage/storage.go | 60 ++++++++++- 14 files changed, 690 insertions(+), 8 deletions(-) create mode 100644 scripts/manifests/crds/devicerequests.yaml create mode 100644 scripts/manifests/crds/devicetokens.yaml diff --git a/scripts/manifests/crds/devicerequests.yaml b/scripts/manifests/crds/devicerequests.yaml new file mode 100644 index 00000000..9b5b4200 --- /dev/null +++ b/scripts/manifests/crds/devicerequests.yaml @@ -0,0 +1,12 @@ +apiVersion: apiextensions.k8s.io/v1beta1 +kind: CustomResourceDefinition +metadata: + name: devicerequests.dex.coreos.com +spec: + group: dex.coreos.com + names: + kind: DeviceRequest + listKind: DeviceRequestList + plural: devicerequests + singular: devicerequest + version: v1 diff --git a/scripts/manifests/crds/devicetokens.yaml b/scripts/manifests/crds/devicetokens.yaml new file mode 100644 index 00000000..b6ce78dc --- /dev/null +++ b/scripts/manifests/crds/devicetokens.yaml @@ -0,0 +1,12 @@ +apiVersion: apiextensions.k8s.io/v1beta1 +kind: CustomResourceDefinition +metadata: + name: devicetokens.dex.coreos.com +spec: + group: dex.coreos.com + names: + kind: DeviceToken + listKind: DeviceTokenList + plural: devicetokens + singular: devicetoken + version: v1 diff --git a/server/handlers.go b/server/handlers.go index 5512d87f..5756f652 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "net/url" "path" @@ -15,12 +16,11 @@ import ( "time" oidc "github.com/coreos/go-oidc" - "github.com/gorilla/mux" - jose "gopkg.in/square/go-jose.v2" - "github.com/dexidp/dex/connector" "github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/storage" + "github.com/gorilla/mux" + jose "gopkg.in/square/go-jose.v2" ) // newHealthChecker returns the healthz handler. The handler runs until the @@ -1415,3 +1415,112 @@ func usernamePrompt(conn connector.PasswordConnector) string { } return "Username" } + +type deviceCodeResponse struct { + //The unique device code for device authentication + DeviceCode string `json:"device_code"` + //The code the user will exchange via a browser and log in + UserCode string `json:"user_code"` + //The url to verify the user code. + VerificationURI string `json:"verification_uri"` + //The lifetime of the device code + ExpireTime int `json:"expires_in"` + //How often the device is allowed to poll to verify that the user login occurred + PollInterval int `json:"interval"` +} + +func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { + //TODO replace with configurable values + expireIntervalSeconds := 300 + requestsPerMinute := 5 + + switch r.Method { + case http.MethodPost: + err := r.ParseForm() + if err != nil { + message := "Could not parse Device Request body" + s.logger.Errorf("%s : %v", message, err) + respondWithError(w, message, err) + return + } + + //Get the client id and scopes from the post + clientID := r.Form.Get("client_id") + scopes := r.Form["scope"] + + s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes) + + //Make device code + deviceCode := storage.NewDeviceCode() + + //make user code + userCode := storage.NewUserCode() + + //make a pkce verification code + pkceCode := storage.NewID() + + //Generate the expire time + expireTime := time.Now().Add(time.Second * time.Duration(expireIntervalSeconds)) + + //Store the Device Request + deviceReq := storage.DeviceRequest{ + UserCode: userCode, + DeviceCode: deviceCode, + ClientID: clientID, + Scopes: scopes, + PkceVerifier: pkceCode, + Expiry: expireTime, + } + + if err := s.storage.CreateDeviceRequest(deviceReq); err != nil { + message := fmt.Sprintf("Failed to store device request %v", err) + s.logger.Errorf(message) + respondWithError(w, message, err) + return + } + + //Store the device token + deviceToken := storage.DeviceToken{ + DeviceCode: deviceCode, + Status: "pending", + Token: "", + Expiry: expireTime, + } + + if err := s.storage.CreateDeviceToken(deviceToken); err != nil { + message := fmt.Sprintf("Failed to store device token %v", err) + s.logger.Errorf(message) + respondWithError(w, message, err) + return + } + + code := deviceCodeResponse{ + DeviceCode: deviceCode, + UserCode: userCode, + VerificationURI: path.Join(s.issuerURL.String(), "/device"), + ExpireTime: expireIntervalSeconds, + PollInterval: requestsPerMinute, + } + + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + enc.Encode(code) + + default: + s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") + } +} + +func respondWithError(w io.Writer, errorMessage string, err error) { + resp := struct { + Error string `json:"error"` + ErrorMessage string `json:"message"` + }{ + Error: err.Error(), + ErrorMessage: errorMessage, + } + + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + enc.Encode(resp) +} diff --git a/server/server.go b/server/server.go index a0a075fb..95f5359b 100644 --- a/server/server.go +++ b/server/server.go @@ -302,6 +302,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleWithCORS("/userinfo", s.handleUserInfo) handleFunc("/auth", s.handleAuthorization) handleFunc("/auth/{connector}", s.handleConnectorLogin) + handleFunc("/device/code", s.handleDeviceCode) r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { // Strip the X-Remote-* headers to prevent security issues on // misconfigured authproxy connector setups. @@ -450,7 +451,8 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura if r, err := s.storage.GarbageCollect(now()); err != nil { s.logger.Errorf("garbage collection failed: %v", err) } else if r.AuthRequests > 0 || r.AuthCodes > 0 { - s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes) + s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d, device requests =%d, device tokens=%d", + r.AuthRequests, r.AuthCodes, r.DeviceRequests, r.DeviceTokens) } } } diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 1ac51fc8..c1bd318f 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -49,6 +49,8 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) { {"ConnectorCRUD", testConnectorCRUD}, {"GarbageCollection", testGC}, {"TimezoneSupport", testTimezones}, + {"DeviceRequestCRUD", testDeviceRequestCRUD}, + {"DeviceTokenCRUD", testDeviceTokenCRUD}, }) } @@ -834,6 +836,82 @@ func testGC(t *testing.T, s storage.Storage) { } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) } + + d := storage.DeviceRequest{ + UserCode: storage.NewUserCode(), + DeviceCode: storage.NewID(), + ClientID: "client1", + Scopes: []string{"openid", "email"}, + PkceVerifier: storage.NewID(), + Expiry: expiry, + } + + if err := s.CreateDeviceRequest(d); err != nil { + t.Fatalf("failed creating device request: %v", err) + } + + for _, tz := range []*time.Location{time.UTC, est, pst} { + result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) + if err != nil { + t.Errorf("garbage collection failed: %v", err) + } else { + if result.DeviceRequests != 0 { + t.Errorf("expected no device garbage collection results, got %#v", result) + } + } + //if _, err := s.GetDeviceRequest(d.UserCode); err != nil { + // t.Errorf("expected to be able to get auth request after GC: %v", err) + //} + } + if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { + t.Errorf("garbage collection failed: %v", err) + } else if r.DeviceRequests != 1 { + t.Errorf("expected to garbage collect 1 device request, got %d", r.DeviceRequests) + } + + //TODO add this code back once Getters are written for device requests + //if _, err := s.GetDeviceRequest(d.UserCode); err == nil { + // t.Errorf("expected device request to be GC'd") + //} else if err != storage.ErrNotFound { + // t.Errorf("expected storage.ErrNotFound, got %v", err) + //} + + dt := storage.DeviceToken{ + DeviceCode: storage.NewID(), + Status: "pending", + Token: "foo", + Expiry: expiry, + } + + if err := s.CreateDeviceToken(dt); err != nil { + t.Fatalf("failed creating device token: %v", err) + } + + for _, tz := range []*time.Location{time.UTC, est, pst} { + result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) + if err != nil { + t.Errorf("garbage collection failed: %v", err) + } else { + if result.DeviceTokens != 0 { + t.Errorf("expected no device token garbage collection results, got %#v", result) + } + } + //if _, err := s.GetDeviceRequest(d.UserCode); err != nil { + // t.Errorf("expected to be able to get auth request after GC: %v", err) + //} + } + if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { + t.Errorf("garbage collection failed: %v", err) + } else if r.DeviceTokens != 1 { + t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens) + } + + //TODO add this code back once Getters are written for device tokens + //if _, err := s.GetDeviceRequest(d.UserCode); err == nil { + // t.Errorf("expected device request to be GC'd") + //} else if err != storage.ErrNotFound { + // t.Errorf("expected storage.ErrNotFound, got %v", err) + //} } // testTimezones tests that backends either fully support timezones or @@ -881,3 +959,44 @@ func testTimezones(t *testing.T, s storage.Storage) { t.Fatalf("expected expiry %v got %v", wantTime, gotTime) } } + +func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { + d1 := storage.DeviceRequest{ + UserCode: storage.NewUserCode(), + DeviceCode: storage.NewID(), + ClientID: "client1", + Scopes: []string{"openid", "email"}, + PkceVerifier: storage.NewID(), + Expiry: neverExpire, + } + + if err := s.CreateDeviceRequest(d1); err != nil { + t.Fatalf("failed creating device request: %v", err) + } + + // Attempt to create same DeviceRequest twice. + err := s.CreateDeviceRequest(d1) + mustBeErrAlreadyExists(t, "device request", err) + + //No manual deletes for device requests, will be handled by garbage collection routines + //see testGC +} + +func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { + d1 := storage.DeviceToken{ + DeviceCode: storage.NewID(), + Status: "pending", + Token: storage.NewID(), + Expiry: neverExpire, + } + + if err := s.CreateDeviceToken(d1); err != nil { + t.Fatalf("failed creating device token: %v", err) + } + + // Attempt to create same DeviceRequest twice. + err := s.CreateDeviceToken(d1) + mustBeErrAlreadyExists(t, "device token", err) + + //TODO Add update / delete tests as functionality is put into main code +} diff --git a/storage/etcd/etcd.go b/storage/etcd/etcd.go index e26ce760..27e337a4 100644 --- a/storage/etcd/etcd.go +++ b/storage/etcd/etcd.go @@ -22,6 +22,8 @@ const ( offlineSessionPrefix = "offline_session/" connectorPrefix = "connector/" keysName = "openid-connect-keys" + deviceRequestPrefix = "device_req/" + deviceTokenPrefix = "device_token/" // defaultStorageTimeout will be applied to all storage's operations. defaultStorageTimeout = 5 * time.Second @@ -72,6 +74,36 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error result.AuthCodes++ } } + + deviceRequests, err := c.listDeviceRequests(ctx) + if err != nil { + return result, err + } + + for _, deviceRequest := range deviceRequests { + if now.After(deviceRequest.Expiry) { + if err := c.deleteKey(ctx, keyID(deviceRequestPrefix, deviceRequest.UserCode)); err != nil { + c.logger.Errorf("failed to delete device request %v", err) + delErr = fmt.Errorf("failed to delete device request: %v", err) + } + result.DeviceRequests++ + } + } + + deviceTokens, err := c.listDeviceTokens(ctx) + if err != nil { + return result, err + } + + for _, deviceToken := range deviceTokens { + if now.After(deviceToken.Expiry) { + if err := c.deleteKey(ctx, keyID(deviceTokenPrefix, deviceToken.DeviceCode)); err != nil { + c.logger.Errorf("failed to delete device token %v", err) + delErr = fmt.Errorf("failed to delete device token: %v", err) + } + result.DeviceTokens++ + } + } return result, delErr } @@ -531,3 +563,45 @@ func keyEmail(prefix, email string) string { return prefix + strings.ToLower(ema func keySession(prefix, userID, connID string) string { return prefix + strings.ToLower(userID+"|"+connID) } + +func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d)) +} + +func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) { + res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix()) + if err != nil { + return requests, err + } + for _, v := range res.Kvs { + var r DeviceRequest + if err = json.Unmarshal(v.Value, &r); err != nil { + return requests, err + } + requests = append(requests, r) + } + return requests, nil +} + +func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnCreate(ctx, keyID(deviceRequestPrefix, t.DeviceCode), fromStorageDeviceToken(t)) +} + +func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) { + res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix()) + if err != nil { + return deviceTokens, err + } + for _, v := range res.Kvs { + var dt DeviceToken + if err = json.Unmarshal(v.Value, &dt); err != nil { + return deviceTokens, err + } + deviceTokens = append(deviceTokens, dt) + } + return deviceTokens, nil +} diff --git a/storage/etcd/types.go b/storage/etcd/types.go index a16eae8e..ab7bce4c 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -216,3 +216,41 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { } return s } + +// DeviceRequest is a mirrored struct from storage with JSON struct tags +type DeviceRequest struct { + UserCode string `json:"user_code"` + DeviceCode string `json:"device_code"` + ClientID string `json:"client_id"` + Scopes []string `json:"scopes"` + PkceVerifier string `json:"pkce_verifier"` + Expiry time.Time `json:"expiry"` +} + +func fromStorageDeviceRequest(d storage.DeviceRequest) DeviceRequest { + return DeviceRequest{ + UserCode: d.UserCode, + DeviceCode: d.DeviceCode, + ClientID: d.ClientID, + Scopes: d.Scopes, + PkceVerifier: d.PkceVerifier, + Expiry: d.Expiry, + } +} + +// DeviceToken is a mirrored struct from storage with JSON struct tags +type DeviceToken struct { + DeviceCode string `json:"device_code"` + Status string `json:"status"` + Token string `json:"token"` + Expiry time.Time `json:"expiry"` +} + +func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken { + return DeviceToken{ + DeviceCode: t.DeviceCode, + Status: t.Status, + Token: t.Token, + Expiry: t.Expiry, + } +} diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index 4bdf3dd6..e87b9c01 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -21,6 +21,8 @@ const ( kindPassword = "Password" kindOfflineSessions = "OfflineSessions" kindConnector = "Connector" + kindDeviceRequest = "DeviceRequest" + kindDeviceToken = "DeviceToken" ) const ( @@ -32,6 +34,8 @@ const ( resourcePassword = "passwords" resourceOfflineSessions = "offlinesessionses" // Again attempts to pluralize. resourceConnector = "connectors" + resourceDeviceRequest = "devicerequests" + resourceDeviceToken = "devicetokens" ) // Config values for the Kubernetes storage type. @@ -593,5 +597,47 @@ func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err e result.AuthCodes++ } } + + var deviceRequests DeviceRequestList + if err := cli.list(resourceDeviceRequest, &deviceRequests); err != nil { + return result, fmt.Errorf("failed to list device requests: %v", err) + } + + for _, deviceRequest := range deviceRequests.DeviceRequests { + if now.After(deviceRequest.Expiry) { + if err := cli.delete(resourceDeviceRequest, deviceRequest.ObjectMeta.Name); err != nil { + cli.logger.Errorf("failed to delete device request: %v", err) + delErr = fmt.Errorf("failed to delete device request: %v", err) + } + result.DeviceRequests++ + } + } + + var deviceTokens DeviceTokenList + if err := cli.list(resourceDeviceToken, &deviceTokens); err != nil { + return result, fmt.Errorf("failed to list device tokens: %v", err) + } + + for _, deviceToken := range deviceTokens.DeviceTokens { + if now.After(deviceToken.Expiry) { + if err := cli.delete(resourceDeviceToken, deviceToken.ObjectMeta.Name); err != nil { + cli.logger.Errorf("failed to delete device token: %v", err) + delErr = fmt.Errorf("failed to delete device token: %v", err) + } + result.DeviceTokens++ + } + } + + if delErr != nil { + return result, delErr + } return result, delErr } + +func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error { + return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d)) +} + +func (cli *client) CreateDeviceToken(t storage.DeviceToken) error { + return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t)) +} diff --git a/storage/kubernetes/storage_test.go b/storage/kubernetes/storage_test.go index ea471427..2c9deeb2 100644 --- a/storage/kubernetes/storage_test.go +++ b/storage/kubernetes/storage_test.go @@ -85,6 +85,8 @@ func (s *StorageTestSuite) TestStorage() { for _, resource := range []string{ resourceAuthCode, resourceAuthRequest, + resourceDeviceRequest, + resourceDeviceToken, resourceClient, resourceRefreshToken, resourceKeys, diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 0fbb2907..5a61b92e 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -143,6 +143,36 @@ var customResourceDefinitions = []k8sapi.CustomResourceDefinition{ }, }, }, + { + ObjectMeta: k8sapi.ObjectMeta{ + Name: "devicerequests.dex.coreos.com", + }, + TypeMeta: crdMeta, + Spec: k8sapi.CustomResourceDefinitionSpec{ + Group: apiGroup, + Version: "v1", + Names: k8sapi.CustomResourceDefinitionNames{ + Plural: "devicerequests", + Singular: "devicerequest", + Kind: "DeviceRequest", + }, + }, + }, + { + ObjectMeta: k8sapi.ObjectMeta{ + Name: "devicetokens.dex.coreos.com", + }, + TypeMeta: crdMeta, + Spec: k8sapi.CustomResourceDefinitionSpec{ + Group: apiGroup, + Version: "v1", + Names: k8sapi.CustomResourceDefinitionNames{ + Plural: "devicetokens", + Singular: "devicetoken", + Kind: "DeviceToken", + }, + }, + }, } // There will only ever be a single keys resource. Maintain this by setting a @@ -635,3 +665,77 @@ type ConnectorList struct { k8sapi.ListMeta `json:"metadata,omitempty"` Connectors []Connector `json:"items"` } + +// DeviceRequest is a mirrored struct from storage with JSON struct tags and +// Kubernetes type metadata. +type DeviceRequest struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ObjectMeta `json:"metadata,omitempty"` + + DeviceCode string `json:"device_code,omitempty"` + CLientID string `json:"client_id,omitempty"` + Scopes []string `json:"scopes,omitempty"` + PkceVerifier string `json:"pkce_verifier,omitempty"` + Expiry time.Time `json:"expiry"` +} + +// AuthRequestList is a list of AuthRequests. +type DeviceRequestList struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ListMeta `json:"metadata,omitempty"` + DeviceRequests []DeviceRequest `json:"items"` +} + +func (cli *client) fromStorageDeviceRequest(a storage.DeviceRequest) DeviceRequest { + req := DeviceRequest{ + TypeMeta: k8sapi.TypeMeta{ + Kind: kindDeviceRequest, + APIVersion: cli.apiVersion, + }, + ObjectMeta: k8sapi.ObjectMeta{ + Name: strings.ToLower(a.UserCode), + Namespace: cli.namespace, + }, + DeviceCode: a.DeviceCode, + CLientID: a.ClientID, + Scopes: a.Scopes, + PkceVerifier: a.PkceVerifier, + Expiry: a.Expiry, + } + return req +} + +// DeviceToken is a mirrored struct from storage with JSON struct tags and +// Kubernetes type metadata. +type DeviceToken struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ObjectMeta `json:"metadata,omitempty"` + + Status string `json:"status,omitempty"` + Token string `json:"token,omitempty"` + Expiry time.Time `json:"expiry"` +} + +// DeviceTokenList is a list of DeviceTokens. +type DeviceTokenList struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ListMeta `json:"metadata,omitempty"` + DeviceTokens []DeviceToken `json:"items"` +} + +func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken { + req := DeviceToken{ + TypeMeta: k8sapi.TypeMeta{ + Kind: kindDeviceToken, + APIVersion: cli.apiVersion, + }, + ObjectMeta: k8sapi.ObjectMeta{ + Name: t.DeviceCode, + Namespace: cli.namespace, + }, + Status: t.Status, + Token: t.Token, + Expiry: t.Expiry, + } + return req +} diff --git a/storage/memory/memory.go b/storage/memory/memory.go index 681d204e..29d4af27 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -20,6 +20,8 @@ func New(logger log.Logger) storage.Storage { passwords: make(map[string]storage.Password), offlineSessions: make(map[offlineSessionID]storage.OfflineSessions), connectors: make(map[string]storage.Connector), + deviceRequests: make(map[string]storage.DeviceRequest), + deviceTokens: make(map[string]storage.DeviceToken), logger: logger, } } @@ -46,6 +48,8 @@ type memStorage struct { passwords map[string]storage.Password offlineSessions map[offlineSessionID]storage.OfflineSessions connectors map[string]storage.Connector + deviceRequests map[string]storage.DeviceRequest + deviceTokens map[string]storage.DeviceToken keys storage.Keys @@ -79,6 +83,18 @@ func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err result.AuthRequests++ } } + for id, a := range s.deviceRequests { + if now.After(a.Expiry) { + delete(s.deviceRequests, id) + result.DeviceRequests++ + } + } + for id, a := range s.deviceTokens { + if now.After(a.Expiry) { + delete(s.deviceTokens, id) + result.DeviceTokens++ + } + } }) return result, nil } @@ -465,3 +481,25 @@ func (s *memStorage) UpdateConnector(id string, updater func(c storage.Connector }) return } + +func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) { + s.tx(func() { + if _, ok := s.deviceRequests[d.UserCode]; ok { + err = storage.ErrAlreadyExists + } else { + s.deviceRequests[d.UserCode] = d + } + }) + return +} + +func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) { + s.tx(func() { + if _, ok := s.deviceTokens[t.DeviceCode]; ok { + err = storage.ErrAlreadyExists + } else { + s.deviceTokens[t.DeviceCode] = t + } + }) + return +} diff --git a/storage/sql/crud.go b/storage/sql/crud.go index e87dc56a..989d2db0 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -100,6 +100,23 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error if n, err := r.RowsAffected(); err == nil { result.AuthCodes = n } + + r, err = c.Exec(`delete from device_request where expiry < $1`, now) + if err != nil { + return result, fmt.Errorf("gc device_request: %v", err) + } + if n, err := r.RowsAffected(); err == nil { + result.DeviceRequests = n + } + + r, err = c.Exec(`delete from device_token where expiry < $1`, now) + if err != nil { + return result, fmt.Errorf("gc device_token: %v", err) + } + if n, err := r.RowsAffected(); err == nil { + result.DeviceTokens = n + } + return } @@ -867,3 +884,41 @@ func (c *conn) delete(table, field, id string) error { } return nil } + +func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error { + _, err := c.Exec(` + insert into device_request ( + user_code, device_code, client_id, scopes, pkce_verifier, expiry + ) + values ( + $1, $2, $3, $4, $5, $6 + );`, + d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.PkceVerifier, d.Expiry, + ) + if err != nil { + if c.alreadyExistsCheck(err) { + return storage.ErrAlreadyExists + } + return fmt.Errorf("insert device request: %v", err) + } + return nil +} + +func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { + _, err := c.Exec(` + insert into device_token ( + device_code, status, token, expiry + ) + values ( + $1, $2, $3, $4 + );`, + t.DeviceCode, t.Status, t.Token, t.Expiry, + ) + if err != nil { + if c.alreadyExistsCheck(err) { + return storage.ErrAlreadyExists + } + return fmt.Errorf("insert device token: %v", err) + } + return nil +} diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index dc727535..96cd6c0a 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -229,4 +229,23 @@ var migrations = []migration{ }, flavor: &flavorMySQL, }, + { + stmts: []string{` + create table device_request ( + user_code text not null primary key, + device_code text not null, + client_id text not null, + scopes bytea not null, -- JSON array of strings + pkce_verifier text not null, + expiry timestamptz not null + );`, + ` + create table device_token ( + device_code text not null primary key, + status text not null, + token text, + expiry timestamptz not null + );`, + }, + }, } diff --git a/storage/storage.go b/storage/storage.go index 5bbb2b3f..7078ccf5 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -5,6 +5,7 @@ import ( "encoding/base32" "errors" "io" + mrand "math/rand" "strings" "time" @@ -24,9 +25,18 @@ var ( // TODO(ericchiang): refactor ID creation onto the storage. var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567") +// NewDeviceCode returns a 32 char alphanumeric cryptographically secure string +func NewDeviceCode() string { + return newSecureID(32) +} + // NewID returns a random string which can be used as an ID for objects. func NewID() string { - buff := make([]byte, 16) // 128 bit random ID. + return newSecureID(16) +} + +func newSecureID(len int) string { + buff := make([]byte, len) // 128 bit random ID. if _, err := io.ReadFull(rand.Reader, buff); err != nil { panic(err) } @@ -36,8 +46,10 @@ func NewID() string { // GCResult returns the number of objects deleted by garbage collection. type GCResult struct { - AuthRequests int64 - AuthCodes int64 + AuthRequests int64 + AuthCodes int64 + DeviceRequests int64 + DeviceTokens int64 } // Storage is the storage interface used by the server. Implementations are @@ -54,6 +66,8 @@ type Storage interface { CreatePassword(p Password) error CreateOfflineSessions(s OfflineSessions) error CreateConnector(c Connector) error + CreateDeviceRequest(d DeviceRequest) error + CreateDeviceToken(d DeviceToken) error // TODO(ericchiang): return (T, bool, error) so we can indicate not found // requests that way instead of using ErrNotFound. @@ -102,7 +116,7 @@ type Storage interface { UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error UpdateConnector(id string, updater func(c Connector) (Connector, error)) error - // GarbageCollect deletes all expired AuthCodes and AuthRequests. + // GarbageCollect deletes all expired AuthCodes,AuthRequests, DeviceRequests, and DeviceTokens. GarbageCollect(now time.Time) (GCResult, error) } @@ -342,3 +356,41 @@ type Keys struct { // For caching purposes, implementations MUST NOT update keys before this time. NextRotation time.Time } + +func NewUserCode() string { + mrand.Seed(time.Now().UnixNano()) + return randomString(4) + "-" + randomString(4) +} + +func randomString(n int) string { + var letter = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + b := make([]rune, n) + for i := range b { + b[i] = letter[mrand.Intn(len(letter))] + } + return string(b) +} + +//DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user +//authenticates using their user code or the expiry time passes. +type DeviceRequest struct { + //The code the user will enter in a browser + UserCode string + //The unique device code for device authentication + DeviceCode string + //The client ID the code is for + ClientID string + //The scopes the device requests + Scopes []string + //PKCE Verification + PkceVerifier string + //The expire time + Expiry time.Time +} + +type DeviceToken struct { + DeviceCode string + Status string + Token string + Expiry time.Time +}