From 83e2df821e5d128be68bd2088f4e9d816fbb50b5 Mon Sep 17 00:00:00 2001 From: Bob Callaway Date: Wed, 27 Jul 2022 09:02:18 -0700 Subject: [PATCH] add PKCE support to device code flow (#2575) Signed-off-by: Bob Callaway --- server/deviceflowhandlers.go | 39 ++++ server/deviceflowhandlers_test.go | 126 ++++++++++++ storage/conformance/conformance.go | 13 ++ storage/ent/client/devicetoken.go | 4 + storage/ent/client/types.go | 4 + storage/ent/db/devicetoken.go | 22 +- storage/ent/db/devicetoken/devicetoken.go | 10 + storage/ent/db/devicetoken/where.go | 236 ++++++++++++++++++++++ storage/ent/db/devicetoken_create.go | 64 ++++++ storage/ent/db/devicetoken_update.go | 84 ++++++++ storage/ent/db/migrate/schema.go | 2 + storage/ent/db/mutation.go | 138 +++++++++++-- storage/ent/db/runtime.go | 8 + storage/ent/schema/devicetoken.go | 20 +- storage/etcd/etcd.go | 7 +- storage/etcd/types.go | 8 + storage/kubernetes/types.go | 8 + storage/sql/crud.go | 18 +- storage/sql/migrate.go | 10 + storage/storage.go | 1 + 20 files changed, 790 insertions(+), 32 deletions(-) diff --git a/server/deviceflowhandlers.go b/server/deviceflowhandlers.go index 0efe3b2b..95fed3b3 100644 --- a/server/deviceflowhandlers.go +++ b/server/deviceflowhandlers.go @@ -73,6 +73,17 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { clientID := r.Form.Get("client_id") clientSecret := r.Form.Get("client_secret") scopes := strings.Fields(r.Form.Get("scope")) + codeChallenge := r.Form.Get("code_challenge") + codeChallengeMethod := r.Form.Get("code_challenge_method") + + if codeChallengeMethod == "" { + codeChallengeMethod = codeChallengeMethodPlain + } + if codeChallengeMethod != codeChallengeMethodS256 && codeChallengeMethod != codeChallengeMethodPlain { + description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod) + s.tokenErrHelper(w, errInvalidRequest, description, http.StatusBadRequest) + return + } s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes) @@ -108,6 +119,10 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { Expiry: expireTime, LastRequestTime: s.now(), PollIntervalSeconds: 0, + PKCE: storage.PKCE{ + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + }, } if err := s.storage.CreateDeviceToken(deviceToken); err != nil { @@ -236,6 +251,30 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized) } case deviceTokenComplete: + codeChallengeFromStorage := deviceToken.PKCE.CodeChallenge + providedCodeVerifier := r.Form.Get("code_verifier") + + switch { + case providedCodeVerifier != "" && codeChallengeFromStorage != "": + calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, deviceToken.PKCE.CodeChallengeMethod) + if err != nil { + s.logger.Error(err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + if codeChallengeFromStorage != calculatedCodeChallenge { + s.tokenErrHelper(w, errInvalidGrant, "Invalid code_verifier.", http.StatusBadRequest) + return + } + case providedCodeVerifier != "": + // Received no code_challenge on /auth, but a code_verifier on /token + s.tokenErrHelper(w, errInvalidRequest, "No PKCE flow started. Cannot check code_verifier.", http.StatusBadRequest) + return + case codeChallengeFromStorage != "": + // Received PKCE request on /auth, but no code_verifier on /token + s.tokenErrHelper(w, errInvalidGrant, "Expecting parameter code_verifier in PKCE flow.", http.StatusBadRequest) + return + } w.Write([]byte(deviceToken.Token)) } } diff --git a/server/deviceflowhandlers_test.go b/server/deviceflowhandlers_test.go index 225703a4..9a9f2858 100644 --- a/server/deviceflowhandlers_test.go +++ b/server/deviceflowhandlers_test.go @@ -49,6 +49,7 @@ func TestHandleDeviceCode(t *testing.T) { tests := []struct { testName string clientID string + codeChallengeMethod string requestType string scopes []string expectedResponseCode int @@ -71,6 +72,24 @@ func TestHandleDeviceCode(t *testing.T) { expectedResponseCode: http.StatusBadRequest, expectedContentType: "application/json", }, + { + testName: "New Code with valid PKCE", + clientID: "test", + requestType: "POST", + scopes: []string{"openid", "profile", "email"}, + codeChallengeMethod: "S256", + expectedResponseCode: http.StatusOK, + expectedContentType: "application/json", + }, + { + testName: "Invalid code challenge method", + clientID: "test", + requestType: "POST", + codeChallengeMethod: "invalid", + scopes: []string{"openid", "profile", "email"}, + expectedResponseCode: http.StatusBadRequest, + expectedContentType: "application/json", + }, } for _, tc := range tests { t.Run(tc.testName, func(t *testing.T) { @@ -92,6 +111,7 @@ func TestHandleDeviceCode(t *testing.T) { data := url.Values{} data.Set("client_id", tc.clientID) + data.Set("code_challenge_method", tc.codeChallengeMethod) for _, scope := range tc.scopes { data.Add("scope", scope) } @@ -401,6 +421,13 @@ func TestDeviceTokenResponse(t *testing.T) { now := func() time.Time { return t0 } + // Base PKCE values + // base64-urlencoded, sha256 digest of code_verifier + codeChallenge := "L7ZqsT_zNwvrH8E7J0CqPHx1wgBaFiaE-fAZcKUUAbc" + codeChallengeMethod := "S256" + // "random" string between 43 & 128 ASCII characters + codeVerifier := "66114650f56cc45dee7ee03c49f048ddf9aa53cbf5b09985832fa4f790ff2604" + baseDeviceRequest := storage.DeviceRequest{ UserCode: "ABCD-WXYZ", DeviceCode: "foo", @@ -415,6 +442,7 @@ func TestDeviceTokenResponse(t *testing.T) { testDeviceToken storage.DeviceToken testGrantType string testDeviceCode string + testCodeVerifier string expectedServerResponse string expectedResponseCode int }{ @@ -524,6 +552,101 @@ func TestDeviceTokenResponse(t *testing.T) { expectedServerResponse: "{\"access_token\": \"foobar\"}", expectedResponseCode: http.StatusOK, }, + { + testName: "Successful Exchange with PKCE", + testDeviceToken: storage.DeviceToken{ + DeviceCode: "foo", + Status: deviceTokenComplete, + Token: "{\"access_token\": \"foobar\"}", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + PKCE: storage.PKCE{ + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + }, + }, + testDeviceCode: "foo", + testCodeVerifier: codeVerifier, + testDeviceRequest: baseDeviceRequest, + expectedServerResponse: "{\"access_token\": \"foobar\"}", + expectedResponseCode: http.StatusOK, + }, + { + testName: "Test Exchange started with PKCE but without verifier provided", + testDeviceToken: storage.DeviceToken{ + DeviceCode: "foo", + Status: deviceTokenComplete, + Token: "{\"access_token\": \"foobar\"}", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + PKCE: storage.PKCE{ + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + }, + }, + testDeviceCode: "foo", + testDeviceRequest: baseDeviceRequest, + expectedServerResponse: errInvalidGrant, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Test Exchange not started with PKCE but verifier provided", + testDeviceToken: storage.DeviceToken{ + DeviceCode: "foo", + Status: deviceTokenComplete, + Token: "{\"access_token\": \"foobar\"}", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + testDeviceCode: "foo", + testCodeVerifier: codeVerifier, + testDeviceRequest: baseDeviceRequest, + expectedServerResponse: errInvalidRequest, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Test with PKCE but incorrect verifier provided", + testDeviceToken: storage.DeviceToken{ + DeviceCode: "foo", + Status: deviceTokenComplete, + Token: "{\"access_token\": \"foobar\"}", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + PKCE: storage.PKCE{ + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + }, + }, + testDeviceCode: "foo", + testCodeVerifier: "invalid", + testDeviceRequest: baseDeviceRequest, + expectedServerResponse: errInvalidGrant, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Test with PKCE but incorrect challenge provided", + testDeviceToken: storage.DeviceToken{ + DeviceCode: "foo", + Status: deviceTokenComplete, + Token: "{\"access_token\": \"foobar\"}", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + PKCE: storage.PKCE{ + CodeChallenge: "invalid", + CodeChallengeMethod: codeChallengeMethod, + }, + }, + testDeviceCode: "foo", + testCodeVerifier: codeVerifier, + testDeviceRequest: baseDeviceRequest, + expectedServerResponse: errInvalidGrant, + expectedResponseCode: http.StatusBadRequest, + }, } for _, tc := range tests { t.Run(tc.testName, func(t *testing.T) { @@ -558,6 +681,9 @@ func TestDeviceTokenResponse(t *testing.T) { } data.Set("grant_type", grantType) data.Set("device_code", tc.testDeviceCode) + if tc.testCodeVerifier != "" { + data.Set("code_verifier", tc.testCodeVerifier) + } req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index dde369c4..9d9766eb 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -890,6 +890,10 @@ func testGC(t *testing.T, s storage.Storage) { Expiry: expiry, LastRequestTime: time.Now(), PollIntervalSeconds: 0, + PKCE: storage.PKCE{ + CodeChallenge: "challenge", + CodeChallengeMethod: "S256", + }, } if err := s.CreateDeviceToken(dt); err != nil { @@ -989,6 +993,11 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { } func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { + codeChallenge := storage.PKCE{ + CodeChallenge: "code_challenge_test", + CodeChallengeMethod: "plain", + } + // Create a Token d1 := storage.DeviceToken{ DeviceCode: storage.NewID(), @@ -997,6 +1006,7 @@ func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { Expiry: neverExpire, LastRequestTime: time.Now(), PollIntervalSeconds: 0, + PKCE: codeChallenge, } if err := s.CreateDeviceToken(d1); err != nil { @@ -1029,4 +1039,7 @@ func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { if got.Token != "token data" { t.Fatalf("update failed, wanted token %v got %v", "token data", got.Token) } + if !reflect.DeepEqual(got.PKCE, codeChallenge) { + t.Fatalf("storage does not support PKCE, wanted challenge=%#v got %#v", codeChallenge, got.PKCE) + } } diff --git a/storage/ent/client/devicetoken.go b/storage/ent/client/devicetoken.go index d8870787..99cf077d 100644 --- a/storage/ent/client/devicetoken.go +++ b/storage/ent/client/devicetoken.go @@ -17,6 +17,8 @@ func (d *Database) CreateDeviceToken(token storage.DeviceToken) error { SetExpiry(token.Expiry.UTC()). SetLastRequest(token.LastRequestTime.UTC()). SetStatus(token.Status). + SetCodeChallenge(token.PKCE.CodeChallenge). + SetCodeChallengeMethod(token.PKCE.CodeChallengeMethod). Save(context.TODO()) if err != nil { return convertDBError("create device token: %w", err) @@ -63,6 +65,8 @@ func (d *Database) UpdateDeviceToken(deviceCode string, updater func(old storage SetExpiry(newToken.Expiry.UTC()). SetLastRequest(newToken.LastRequestTime.UTC()). SetStatus(newToken.Status). + SetCodeChallenge(newToken.PKCE.CodeChallenge). + SetCodeChallengeMethod(newToken.PKCE.CodeChallengeMethod). Save(context.TODO()) if err != nil { return rollback(tx, "update device token uploading: %w", err) diff --git a/storage/ent/client/types.go b/storage/ent/client/types.go index 57f1c0a7..256bb73d 100644 --- a/storage/ent/client/types.go +++ b/storage/ent/client/types.go @@ -164,5 +164,9 @@ func toStorageDeviceToken(t *db.DeviceToken) storage.DeviceToken { Expiry: t.Expiry, LastRequestTime: t.LastRequest, PollIntervalSeconds: t.PollInterval, + PKCE: storage.PKCE{ + CodeChallenge: t.CodeChallenge, + CodeChallengeMethod: t.CodeChallengeMethod, + }, } } diff --git a/storage/ent/db/devicetoken.go b/storage/ent/db/devicetoken.go index 1731d1f0..5155df9a 100644 --- a/storage/ent/db/devicetoken.go +++ b/storage/ent/db/devicetoken.go @@ -28,6 +28,10 @@ type DeviceToken struct { LastRequest time.Time `json:"last_request,omitempty"` // PollInterval holds the value of the "poll_interval" field. PollInterval int `json:"poll_interval,omitempty"` + // CodeChallenge holds the value of the "code_challenge" field. + CodeChallenge string `json:"code_challenge,omitempty"` + // CodeChallengeMethod holds the value of the "code_challenge_method" field. + CodeChallengeMethod string `json:"code_challenge_method,omitempty"` } // scanValues returns the types for scanning values from sql.Rows. @@ -39,7 +43,7 @@ func (*DeviceToken) scanValues(columns []string) ([]interface{}, error) { values[i] = new([]byte) case devicetoken.FieldID, devicetoken.FieldPollInterval: values[i] = new(sql.NullInt64) - case devicetoken.FieldDeviceCode, devicetoken.FieldStatus: + case devicetoken.FieldDeviceCode, devicetoken.FieldStatus, devicetoken.FieldCodeChallenge, devicetoken.FieldCodeChallengeMethod: values[i] = new(sql.NullString) case devicetoken.FieldExpiry, devicetoken.FieldLastRequest: values[i] = new(sql.NullTime) @@ -100,6 +104,18 @@ func (dt *DeviceToken) assignValues(columns []string, values []interface{}) erro } else if value.Valid { dt.PollInterval = int(value.Int64) } + case devicetoken.FieldCodeChallenge: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field code_challenge", values[i]) + } else if value.Valid { + dt.CodeChallenge = value.String + } + case devicetoken.FieldCodeChallengeMethod: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field code_challenge_method", values[i]) + } else if value.Valid { + dt.CodeChallengeMethod = value.String + } } } return nil @@ -142,6 +158,10 @@ func (dt *DeviceToken) String() string { builder.WriteString(dt.LastRequest.Format(time.ANSIC)) builder.WriteString(", poll_interval=") builder.WriteString(fmt.Sprintf("%v", dt.PollInterval)) + builder.WriteString(", code_challenge=") + builder.WriteString(dt.CodeChallenge) + builder.WriteString(", code_challenge_method=") + builder.WriteString(dt.CodeChallengeMethod) builder.WriteByte(')') return builder.String() } diff --git a/storage/ent/db/devicetoken/devicetoken.go b/storage/ent/db/devicetoken/devicetoken.go index 7af65799..8f3ae361 100644 --- a/storage/ent/db/devicetoken/devicetoken.go +++ b/storage/ent/db/devicetoken/devicetoken.go @@ -19,6 +19,10 @@ const ( FieldLastRequest = "last_request" // FieldPollInterval holds the string denoting the poll_interval field in the database. FieldPollInterval = "poll_interval" + // FieldCodeChallenge holds the string denoting the code_challenge field in the database. + FieldCodeChallenge = "code_challenge" + // FieldCodeChallengeMethod holds the string denoting the code_challenge_method field in the database. + FieldCodeChallengeMethod = "code_challenge_method" // Table holds the table name of the devicetoken in the database. Table = "device_tokens" ) @@ -32,6 +36,8 @@ var Columns = []string{ FieldExpiry, FieldLastRequest, FieldPollInterval, + FieldCodeChallenge, + FieldCodeChallengeMethod, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -49,4 +55,8 @@ var ( DeviceCodeValidator func(string) error // StatusValidator is a validator for the "status" field. It is called by the builders before save. StatusValidator func(string) error + // DefaultCodeChallenge holds the default value on creation for the "code_challenge" field. + DefaultCodeChallenge string + // DefaultCodeChallengeMethod holds the default value on creation for the "code_challenge_method" field. + DefaultCodeChallengeMethod string ) diff --git a/storage/ent/db/devicetoken/where.go b/storage/ent/db/devicetoken/where.go index b22879a6..779badbe 100644 --- a/storage/ent/db/devicetoken/where.go +++ b/storage/ent/db/devicetoken/where.go @@ -134,6 +134,20 @@ func PollInterval(v int) predicate.DeviceToken { }) } +// CodeChallenge applies equality check predicate on the "code_challenge" field. It's identical to CodeChallengeEQ. +func CodeChallenge(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldCodeChallenge), v)) + }) +} + +// CodeChallengeMethod applies equality check predicate on the "code_challenge_method" field. It's identical to CodeChallengeMethodEQ. +func CodeChallengeMethod(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldCodeChallengeMethod), v)) + }) +} + // DeviceCodeEQ applies the EQ predicate on the "device_code" field. func DeviceCodeEQ(v string) predicate.DeviceToken { return predicate.DeviceToken(func(s *sql.Selector) { @@ -674,6 +688,228 @@ func PollIntervalLTE(v int) predicate.DeviceToken { }) } +// CodeChallengeEQ applies the EQ predicate on the "code_challenge" field. +func CodeChallengeEQ(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldCodeChallenge), v)) + }) +} + +// CodeChallengeNEQ applies the NEQ predicate on the "code_challenge" field. +func CodeChallengeNEQ(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldCodeChallenge), v)) + }) +} + +// CodeChallengeIn applies the In predicate on the "code_challenge" field. +func CodeChallengeIn(vs ...string) predicate.DeviceToken { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.DeviceToken(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldCodeChallenge), v...)) + }) +} + +// CodeChallengeNotIn applies the NotIn predicate on the "code_challenge" field. +func CodeChallengeNotIn(vs ...string) predicate.DeviceToken { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.DeviceToken(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldCodeChallenge), v...)) + }) +} + +// CodeChallengeGT applies the GT predicate on the "code_challenge" field. +func CodeChallengeGT(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldCodeChallenge), v)) + }) +} + +// CodeChallengeGTE applies the GTE predicate on the "code_challenge" field. +func CodeChallengeGTE(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldCodeChallenge), v)) + }) +} + +// CodeChallengeLT applies the LT predicate on the "code_challenge" field. +func CodeChallengeLT(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldCodeChallenge), v)) + }) +} + +// CodeChallengeLTE applies the LTE predicate on the "code_challenge" field. +func CodeChallengeLTE(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldCodeChallenge), v)) + }) +} + +// CodeChallengeContains applies the Contains predicate on the "code_challenge" field. +func CodeChallengeContains(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.Contains(s.C(FieldCodeChallenge), v)) + }) +} + +// CodeChallengeHasPrefix applies the HasPrefix predicate on the "code_challenge" field. +func CodeChallengeHasPrefix(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.HasPrefix(s.C(FieldCodeChallenge), v)) + }) +} + +// CodeChallengeHasSuffix applies the HasSuffix predicate on the "code_challenge" field. +func CodeChallengeHasSuffix(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.HasSuffix(s.C(FieldCodeChallenge), v)) + }) +} + +// CodeChallengeEqualFold applies the EqualFold predicate on the "code_challenge" field. +func CodeChallengeEqualFold(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.EqualFold(s.C(FieldCodeChallenge), v)) + }) +} + +// CodeChallengeContainsFold applies the ContainsFold predicate on the "code_challenge" field. +func CodeChallengeContainsFold(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.ContainsFold(s.C(FieldCodeChallenge), v)) + }) +} + +// CodeChallengeMethodEQ applies the EQ predicate on the "code_challenge_method" field. +func CodeChallengeMethodEQ(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldCodeChallengeMethod), v)) + }) +} + +// CodeChallengeMethodNEQ applies the NEQ predicate on the "code_challenge_method" field. +func CodeChallengeMethodNEQ(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldCodeChallengeMethod), v)) + }) +} + +// CodeChallengeMethodIn applies the In predicate on the "code_challenge_method" field. +func CodeChallengeMethodIn(vs ...string) predicate.DeviceToken { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.DeviceToken(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldCodeChallengeMethod), v...)) + }) +} + +// CodeChallengeMethodNotIn applies the NotIn predicate on the "code_challenge_method" field. +func CodeChallengeMethodNotIn(vs ...string) predicate.DeviceToken { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.DeviceToken(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldCodeChallengeMethod), v...)) + }) +} + +// CodeChallengeMethodGT applies the GT predicate on the "code_challenge_method" field. +func CodeChallengeMethodGT(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldCodeChallengeMethod), v)) + }) +} + +// CodeChallengeMethodGTE applies the GTE predicate on the "code_challenge_method" field. +func CodeChallengeMethodGTE(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldCodeChallengeMethod), v)) + }) +} + +// CodeChallengeMethodLT applies the LT predicate on the "code_challenge_method" field. +func CodeChallengeMethodLT(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldCodeChallengeMethod), v)) + }) +} + +// CodeChallengeMethodLTE applies the LTE predicate on the "code_challenge_method" field. +func CodeChallengeMethodLTE(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldCodeChallengeMethod), v)) + }) +} + +// CodeChallengeMethodContains applies the Contains predicate on the "code_challenge_method" field. +func CodeChallengeMethodContains(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.Contains(s.C(FieldCodeChallengeMethod), v)) + }) +} + +// CodeChallengeMethodHasPrefix applies the HasPrefix predicate on the "code_challenge_method" field. +func CodeChallengeMethodHasPrefix(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.HasPrefix(s.C(FieldCodeChallengeMethod), v)) + }) +} + +// CodeChallengeMethodHasSuffix applies the HasSuffix predicate on the "code_challenge_method" field. +func CodeChallengeMethodHasSuffix(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.HasSuffix(s.C(FieldCodeChallengeMethod), v)) + }) +} + +// CodeChallengeMethodEqualFold applies the EqualFold predicate on the "code_challenge_method" field. +func CodeChallengeMethodEqualFold(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.EqualFold(s.C(FieldCodeChallengeMethod), v)) + }) +} + +// CodeChallengeMethodContainsFold applies the ContainsFold predicate on the "code_challenge_method" field. +func CodeChallengeMethodContainsFold(v string) predicate.DeviceToken { + return predicate.DeviceToken(func(s *sql.Selector) { + s.Where(sql.ContainsFold(s.C(FieldCodeChallengeMethod), v)) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.DeviceToken) predicate.DeviceToken { return predicate.DeviceToken(func(s *sql.Selector) { diff --git a/storage/ent/db/devicetoken_create.go b/storage/ent/db/devicetoken_create.go index 42c86755..023f7185 100644 --- a/storage/ent/db/devicetoken_create.go +++ b/storage/ent/db/devicetoken_create.go @@ -56,6 +56,34 @@ func (dtc *DeviceTokenCreate) SetPollInterval(i int) *DeviceTokenCreate { return dtc } +// SetCodeChallenge sets the "code_challenge" field. +func (dtc *DeviceTokenCreate) SetCodeChallenge(s string) *DeviceTokenCreate { + dtc.mutation.SetCodeChallenge(s) + return dtc +} + +// SetNillableCodeChallenge sets the "code_challenge" field if the given value is not nil. +func (dtc *DeviceTokenCreate) SetNillableCodeChallenge(s *string) *DeviceTokenCreate { + if s != nil { + dtc.SetCodeChallenge(*s) + } + return dtc +} + +// SetCodeChallengeMethod sets the "code_challenge_method" field. +func (dtc *DeviceTokenCreate) SetCodeChallengeMethod(s string) *DeviceTokenCreate { + dtc.mutation.SetCodeChallengeMethod(s) + return dtc +} + +// SetNillableCodeChallengeMethod sets the "code_challenge_method" field if the given value is not nil. +func (dtc *DeviceTokenCreate) SetNillableCodeChallengeMethod(s *string) *DeviceTokenCreate { + if s != nil { + dtc.SetCodeChallengeMethod(*s) + } + return dtc +} + // Mutation returns the DeviceTokenMutation object of the builder. func (dtc *DeviceTokenCreate) Mutation() *DeviceTokenMutation { return dtc.mutation @@ -67,6 +95,7 @@ func (dtc *DeviceTokenCreate) Save(ctx context.Context) (*DeviceToken, error) { err error node *DeviceToken ) + dtc.defaults() if len(dtc.hooks) == 0 { if err = dtc.check(); err != nil { return nil, err @@ -124,6 +153,18 @@ func (dtc *DeviceTokenCreate) ExecX(ctx context.Context) { } } +// defaults sets the default values of the builder before save. +func (dtc *DeviceTokenCreate) defaults() { + if _, ok := dtc.mutation.CodeChallenge(); !ok { + v := devicetoken.DefaultCodeChallenge + dtc.mutation.SetCodeChallenge(v) + } + if _, ok := dtc.mutation.CodeChallengeMethod(); !ok { + v := devicetoken.DefaultCodeChallengeMethod + dtc.mutation.SetCodeChallengeMethod(v) + } +} + // check runs all checks and user-defined validators on the builder. func (dtc *DeviceTokenCreate) check() error { if _, ok := dtc.mutation.DeviceCode(); !ok { @@ -151,6 +192,12 @@ func (dtc *DeviceTokenCreate) check() error { if _, ok := dtc.mutation.PollInterval(); !ok { return &ValidationError{Name: "poll_interval", err: errors.New(`db: missing required field "DeviceToken.poll_interval"`)} } + if _, ok := dtc.mutation.CodeChallenge(); !ok { + return &ValidationError{Name: "code_challenge", err: errors.New(`db: missing required field "DeviceToken.code_challenge"`)} + } + if _, ok := dtc.mutation.CodeChallengeMethod(); !ok { + return &ValidationError{Name: "code_challenge_method", err: errors.New(`db: missing required field "DeviceToken.code_challenge_method"`)} + } return nil } @@ -226,6 +273,22 @@ func (dtc *DeviceTokenCreate) createSpec() (*DeviceToken, *sqlgraph.CreateSpec) }) _node.PollInterval = value } + if value, ok := dtc.mutation.CodeChallenge(); ok { + _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: devicetoken.FieldCodeChallenge, + }) + _node.CodeChallenge = value + } + if value, ok := dtc.mutation.CodeChallengeMethod(); ok { + _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: devicetoken.FieldCodeChallengeMethod, + }) + _node.CodeChallengeMethod = value + } return _node, _spec } @@ -243,6 +306,7 @@ func (dtcb *DeviceTokenCreateBulk) Save(ctx context.Context) ([]*DeviceToken, er for i := range dtcb.builders { func(i int, root context.Context) { builder := dtcb.builders[i] + builder.defaults() var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { mutation, ok := m.(*DeviceTokenMutation) if !ok { diff --git a/storage/ent/db/devicetoken_update.go b/storage/ent/db/devicetoken_update.go index bf9d2993..8f95e7a0 100644 --- a/storage/ent/db/devicetoken_update.go +++ b/storage/ent/db/devicetoken_update.go @@ -77,6 +77,34 @@ func (dtu *DeviceTokenUpdate) AddPollInterval(i int) *DeviceTokenUpdate { return dtu } +// SetCodeChallenge sets the "code_challenge" field. +func (dtu *DeviceTokenUpdate) SetCodeChallenge(s string) *DeviceTokenUpdate { + dtu.mutation.SetCodeChallenge(s) + return dtu +} + +// SetNillableCodeChallenge sets the "code_challenge" field if the given value is not nil. +func (dtu *DeviceTokenUpdate) SetNillableCodeChallenge(s *string) *DeviceTokenUpdate { + if s != nil { + dtu.SetCodeChallenge(*s) + } + return dtu +} + +// SetCodeChallengeMethod sets the "code_challenge_method" field. +func (dtu *DeviceTokenUpdate) SetCodeChallengeMethod(s string) *DeviceTokenUpdate { + dtu.mutation.SetCodeChallengeMethod(s) + return dtu +} + +// SetNillableCodeChallengeMethod sets the "code_challenge_method" field if the given value is not nil. +func (dtu *DeviceTokenUpdate) SetNillableCodeChallengeMethod(s *string) *DeviceTokenUpdate { + if s != nil { + dtu.SetCodeChallengeMethod(*s) + } + return dtu +} + // Mutation returns the DeviceTokenMutation object of the builder. func (dtu *DeviceTokenUpdate) Mutation() *DeviceTokenMutation { return dtu.mutation @@ -230,6 +258,20 @@ func (dtu *DeviceTokenUpdate) sqlSave(ctx context.Context) (n int, err error) { Column: devicetoken.FieldPollInterval, }) } + if value, ok := dtu.mutation.CodeChallenge(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: devicetoken.FieldCodeChallenge, + }) + } + if value, ok := dtu.mutation.CodeChallengeMethod(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: devicetoken.FieldCodeChallengeMethod, + }) + } if n, err = sqlgraph.UpdateNodes(ctx, dtu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{devicetoken.Label} @@ -298,6 +340,34 @@ func (dtuo *DeviceTokenUpdateOne) AddPollInterval(i int) *DeviceTokenUpdateOne { return dtuo } +// SetCodeChallenge sets the "code_challenge" field. +func (dtuo *DeviceTokenUpdateOne) SetCodeChallenge(s string) *DeviceTokenUpdateOne { + dtuo.mutation.SetCodeChallenge(s) + return dtuo +} + +// SetNillableCodeChallenge sets the "code_challenge" field if the given value is not nil. +func (dtuo *DeviceTokenUpdateOne) SetNillableCodeChallenge(s *string) *DeviceTokenUpdateOne { + if s != nil { + dtuo.SetCodeChallenge(*s) + } + return dtuo +} + +// SetCodeChallengeMethod sets the "code_challenge_method" field. +func (dtuo *DeviceTokenUpdateOne) SetCodeChallengeMethod(s string) *DeviceTokenUpdateOne { + dtuo.mutation.SetCodeChallengeMethod(s) + return dtuo +} + +// SetNillableCodeChallengeMethod sets the "code_challenge_method" field if the given value is not nil. +func (dtuo *DeviceTokenUpdateOne) SetNillableCodeChallengeMethod(s *string) *DeviceTokenUpdateOne { + if s != nil { + dtuo.SetCodeChallengeMethod(*s) + } + return dtuo +} + // Mutation returns the DeviceTokenMutation object of the builder. func (dtuo *DeviceTokenUpdateOne) Mutation() *DeviceTokenMutation { return dtuo.mutation @@ -475,6 +545,20 @@ func (dtuo *DeviceTokenUpdateOne) sqlSave(ctx context.Context) (_node *DeviceTok Column: devicetoken.FieldPollInterval, }) } + if value, ok := dtuo.mutation.CodeChallenge(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: devicetoken.FieldCodeChallenge, + }) + } + if value, ok := dtuo.mutation.CodeChallengeMethod(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: devicetoken.FieldCodeChallengeMethod, + }) + } _node = &DeviceToken{config: dtuo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/storage/ent/db/migrate/schema.go b/storage/ent/db/migrate/schema.go index d8b8b62c..c31e0a1e 100644 --- a/storage/ent/db/migrate/schema.go +++ b/storage/ent/db/migrate/schema.go @@ -101,6 +101,8 @@ var ( {Name: "expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, {Name: "last_request", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, {Name: "poll_interval", Type: field.TypeInt}, + {Name: "code_challenge", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "code_challenge_method", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, } // DeviceTokensTable holds the schema information for the "device_tokens" table. DeviceTokensTable = &schema.Table{ diff --git a/storage/ent/db/mutation.go b/storage/ent/db/mutation.go index d7d4e423..53a4a9fd 100644 --- a/storage/ent/db/mutation.go +++ b/storage/ent/db/mutation.go @@ -3633,20 +3633,22 @@ func (m *DeviceRequestMutation) ResetEdge(name string) error { // DeviceTokenMutation represents an operation that mutates the DeviceToken nodes in the graph. type DeviceTokenMutation struct { config - op Op - typ string - id *int - device_code *string - status *string - token *[]byte - expiry *time.Time - last_request *time.Time - poll_interval *int - addpoll_interval *int - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*DeviceToken, error) - predicates []predicate.DeviceToken + op Op + typ string + id *int + device_code *string + status *string + token *[]byte + expiry *time.Time + last_request *time.Time + poll_interval *int + addpoll_interval *int + code_challenge *string + code_challenge_method *string + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*DeviceToken, error) + predicates []predicate.DeviceToken } var _ ent.Mutation = (*DeviceTokenMutation)(nil) @@ -3996,6 +3998,78 @@ func (m *DeviceTokenMutation) ResetPollInterval() { m.addpoll_interval = nil } +// SetCodeChallenge sets the "code_challenge" field. +func (m *DeviceTokenMutation) SetCodeChallenge(s string) { + m.code_challenge = &s +} + +// CodeChallenge returns the value of the "code_challenge" field in the mutation. +func (m *DeviceTokenMutation) CodeChallenge() (r string, exists bool) { + v := m.code_challenge + if v == nil { + return + } + return *v, true +} + +// OldCodeChallenge returns the old "code_challenge" field's value of the DeviceToken entity. +// If the DeviceToken object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DeviceTokenMutation) OldCodeChallenge(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCodeChallenge is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCodeChallenge requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCodeChallenge: %w", err) + } + return oldValue.CodeChallenge, nil +} + +// ResetCodeChallenge resets all changes to the "code_challenge" field. +func (m *DeviceTokenMutation) ResetCodeChallenge() { + m.code_challenge = nil +} + +// SetCodeChallengeMethod sets the "code_challenge_method" field. +func (m *DeviceTokenMutation) SetCodeChallengeMethod(s string) { + m.code_challenge_method = &s +} + +// CodeChallengeMethod returns the value of the "code_challenge_method" field in the mutation. +func (m *DeviceTokenMutation) CodeChallengeMethod() (r string, exists bool) { + v := m.code_challenge_method + if v == nil { + return + } + return *v, true +} + +// OldCodeChallengeMethod returns the old "code_challenge_method" field's value of the DeviceToken entity. +// If the DeviceToken object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *DeviceTokenMutation) OldCodeChallengeMethod(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCodeChallengeMethod is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCodeChallengeMethod requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCodeChallengeMethod: %w", err) + } + return oldValue.CodeChallengeMethod, nil +} + +// ResetCodeChallengeMethod resets all changes to the "code_challenge_method" field. +func (m *DeviceTokenMutation) ResetCodeChallengeMethod() { + m.code_challenge_method = nil +} + // Where appends a list predicates to the DeviceTokenMutation builder. func (m *DeviceTokenMutation) Where(ps ...predicate.DeviceToken) { m.predicates = append(m.predicates, ps...) @@ -4015,7 +4089,7 @@ func (m *DeviceTokenMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *DeviceTokenMutation) Fields() []string { - fields := make([]string, 0, 6) + fields := make([]string, 0, 8) if m.device_code != nil { fields = append(fields, devicetoken.FieldDeviceCode) } @@ -4034,6 +4108,12 @@ func (m *DeviceTokenMutation) Fields() []string { if m.poll_interval != nil { fields = append(fields, devicetoken.FieldPollInterval) } + if m.code_challenge != nil { + fields = append(fields, devicetoken.FieldCodeChallenge) + } + if m.code_challenge_method != nil { + fields = append(fields, devicetoken.FieldCodeChallengeMethod) + } return fields } @@ -4054,6 +4134,10 @@ func (m *DeviceTokenMutation) Field(name string) (ent.Value, bool) { return m.LastRequest() case devicetoken.FieldPollInterval: return m.PollInterval() + case devicetoken.FieldCodeChallenge: + return m.CodeChallenge() + case devicetoken.FieldCodeChallengeMethod: + return m.CodeChallengeMethod() } return nil, false } @@ -4075,6 +4159,10 @@ func (m *DeviceTokenMutation) OldField(ctx context.Context, name string) (ent.Va return m.OldLastRequest(ctx) case devicetoken.FieldPollInterval: return m.OldPollInterval(ctx) + case devicetoken.FieldCodeChallenge: + return m.OldCodeChallenge(ctx) + case devicetoken.FieldCodeChallengeMethod: + return m.OldCodeChallengeMethod(ctx) } return nil, fmt.Errorf("unknown DeviceToken field %s", name) } @@ -4126,6 +4214,20 @@ func (m *DeviceTokenMutation) SetField(name string, value ent.Value) error { } m.SetPollInterval(v) return nil + case devicetoken.FieldCodeChallenge: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCodeChallenge(v) + return nil + case devicetoken.FieldCodeChallengeMethod: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCodeChallengeMethod(v) + return nil } return fmt.Errorf("unknown DeviceToken field %s", name) } @@ -4217,6 +4319,12 @@ func (m *DeviceTokenMutation) ResetField(name string) error { case devicetoken.FieldPollInterval: m.ResetPollInterval() return nil + case devicetoken.FieldCodeChallenge: + m.ResetCodeChallenge() + return nil + case devicetoken.FieldCodeChallengeMethod: + m.ResetCodeChallengeMethod() + return nil } return fmt.Errorf("unknown DeviceToken field %s", name) } diff --git a/storage/ent/db/runtime.go b/storage/ent/db/runtime.go index d3123b3f..4be5fbfb 100644 --- a/storage/ent/db/runtime.go +++ b/storage/ent/db/runtime.go @@ -142,6 +142,14 @@ func init() { devicetokenDescStatus := devicetokenFields[1].Descriptor() // devicetoken.StatusValidator is a validator for the "status" field. It is called by the builders before save. devicetoken.StatusValidator = devicetokenDescStatus.Validators[0].(func(string) error) + // devicetokenDescCodeChallenge is the schema descriptor for code_challenge field. + devicetokenDescCodeChallenge := devicetokenFields[6].Descriptor() + // devicetoken.DefaultCodeChallenge holds the default value on creation for the code_challenge field. + devicetoken.DefaultCodeChallenge = devicetokenDescCodeChallenge.Default.(string) + // devicetokenDescCodeChallengeMethod is the schema descriptor for code_challenge_method field. + devicetokenDescCodeChallengeMethod := devicetokenFields[7].Descriptor() + // devicetoken.DefaultCodeChallengeMethod holds the default value on creation for the code_challenge_method field. + devicetoken.DefaultCodeChallengeMethod = devicetokenDescCodeChallengeMethod.Default.(string) keysFields := schema.Keys{}.Fields() _ = keysFields // keysDescID is the schema descriptor for id field. diff --git a/storage/ent/schema/devicetoken.go b/storage/ent/schema/devicetoken.go index 29927e2b..dc0e7b8e 100644 --- a/storage/ent/schema/devicetoken.go +++ b/storage/ent/schema/devicetoken.go @@ -8,12 +8,14 @@ import ( /* Original SQL table: create table device_token ( - device_code text not null primary key, - status text not null, - token blob, - expiry timestamp not null, - last_request timestamp not null, - poll_interval integer not null + device_code text not null primary key, + status text not null, + token blob, + expiry timestamp not null, + last_request timestamp not null, + poll_interval integer not null, + code_challenge text default '' not null, + code_challenge_method text default '' not null ); */ @@ -38,6 +40,12 @@ func (DeviceToken) Fields() []ent.Field { field.Time("last_request"). SchemaType(timeSchema), field.Int("poll_interval"), + field.Text("code_challenge"). + SchemaType(textSchema). + Default(""), + field.Text("code_challenge_method"). + SchemaType(textSchema). + Default(""), } } diff --git a/storage/etcd/etcd.go b/storage/etcd/etcd.go index 63fa7bc2..13e815ec 100644 --- a/storage/etcd/etcd.go +++ b/storage/etcd/etcd.go @@ -605,8 +605,11 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) { ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) defer cancel() - err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &t) - return t, err + var dt DeviceToken + if err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &dt); err == nil { + t = toStorageDeviceToken(dt) + } + return } func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) { diff --git a/storage/etcd/types.go b/storage/etcd/types.go index 9390608a..1174a2d2 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -281,6 +281,8 @@ type DeviceToken struct { Expiry time.Time `json:"expiry"` LastRequestTime time.Time `json:"last_request"` PollIntervalSeconds int `json:"poll_interval"` + CodeChallenge string `json:"code_challenge,omitempty"` + CodeChallengeMethod string `json:"code_challenge_method,omitempty"` } func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken { @@ -291,6 +293,8 @@ func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken { Expiry: t.Expiry, LastRequestTime: t.LastRequestTime, PollIntervalSeconds: t.PollIntervalSeconds, + CodeChallenge: t.PKCE.CodeChallenge, + CodeChallengeMethod: t.PKCE.CodeChallengeMethod, } } @@ -302,5 +306,9 @@ func toStorageDeviceToken(t DeviceToken) storage.DeviceToken { Expiry: t.Expiry, LastRequestTime: t.LastRequestTime, PollIntervalSeconds: t.PollIntervalSeconds, + PKCE: storage.PKCE{ + CodeChallenge: t.CodeChallenge, + CodeChallengeMethod: t.CodeChallengeMethod, + }, } } diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index faf4ac57..5149e3ee 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -802,6 +802,8 @@ type DeviceToken struct { Expiry time.Time `json:"expiry"` LastRequestTime time.Time `json:"last_request"` PollIntervalSeconds int `json:"poll_interval"` + CodeChallenge string `json:"code_challenge,omitempty"` + CodeChallengeMethod string `json:"code_challenge_method,omitempty"` } // DeviceTokenList is a list of DeviceTokens. @@ -826,6 +828,8 @@ func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken { Expiry: t.Expiry, LastRequestTime: t.LastRequestTime, PollIntervalSeconds: t.PollIntervalSeconds, + CodeChallenge: t.PKCE.CodeChallenge, + CodeChallengeMethod: t.PKCE.CodeChallengeMethod, } return req } @@ -838,5 +842,9 @@ func toStorageDeviceToken(t DeviceToken) storage.DeviceToken { Expiry: t.Expiry, LastRequestTime: t.LastRequestTime, PollIntervalSeconds: t.PollIntervalSeconds, + PKCE: storage.PKCE{ + CodeChallenge: t.CodeChallenge, + CodeChallengeMethod: t.CodeChallengeMethod, + }, } } diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 5a234f9d..ac67bf28 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -927,12 +927,12 @@ func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error { func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { _, err := c.Exec(` insert into device_token ( - device_code, status, token, expiry, last_request, poll_interval + device_code, status, token, expiry, last_request, poll_interval, code_challenge, code_challenge_method ) values ( - $1, $2, $3, $4, $5, $6 + $1, $2, $3, $4, $5, $6, $7, $8 );`, - t.DeviceCode, t.Status, t.Token, t.Expiry, t.LastRequestTime, t.PollIntervalSeconds, + t.DeviceCode, t.Status, t.Token, t.Expiry, t.LastRequestTime, t.PollIntervalSeconds, t.PKCE.CodeChallenge, t.PKCE.CodeChallengeMethod, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -972,10 +972,10 @@ func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) { func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) { err = q.QueryRow(` select - status, token, expiry, last_request, poll_interval + status, token, expiry, last_request, poll_interval, code_challenge, code_challenge_method from device_token where device_code = $1; `, deviceCode).Scan( - &a.Status, &a.Token, &a.Expiry, &a.LastRequestTime, &a.PollIntervalSeconds, + &a.Status, &a.Token, &a.Expiry, &a.LastRequestTime, &a.PollIntervalSeconds, &a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod, ) if err != nil { if err == sql.ErrNoRows { @@ -1002,11 +1002,13 @@ func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.Dev status = $1, token = $2, last_request = $3, - poll_interval = $4 + poll_interval = $4, + code_challenge = $5, + code_challenge_method = $6 where - device_code = $5 + device_code = $7 `, - r.Status, r.Token, r.LastRequestTime, r.PollIntervalSeconds, r.DeviceCode, + r.Status, r.Token, r.LastRequestTime, r.PollIntervalSeconds, r.PKCE.CodeChallenge, r.PKCE.CodeChallengeMethod, r.DeviceCode, ) if err != nil { return fmt.Errorf("update device token: %v", err) diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 498db252..57720e17 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -281,4 +281,14 @@ var migrations = []migration{ add column obsolete_token text default '';`, }, }, + { + stmts: []string{ + ` + alter table device_token + add column code_challenge text not null default '';`, + ` + alter table device_token + add column code_challenge_method text not null default '';`, + }, + }, } diff --git a/storage/storage.go b/storage/storage.go index cdd83ca6..af39228a 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -427,4 +427,5 @@ type DeviceToken struct { Expiry time.Time LastRequestTime time.Time PollIntervalSeconds int + PKCE PKCE }