diff --git a/Makefile b/Makefile index 8006982a..9519a186 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ fmt: @go fmt $(shell go list ./... | grep -v '/vendor/') lint: - @for package in $(shell go list ./... | grep -v '/vendor/' | grep -v '/api'); do \ + @for package in $(shell go list ./... | grep -v '/vendor/' | grep -v '/api' | grep -v '/server/internal'); do \ golint -set_exit_status $$package $$i || exit 1; \ done @@ -81,12 +81,15 @@ aci: clean-release _output/bin/dex _output/images/library-alpine-3.4.aci docker-image: clean-release _output/bin/dex @sudo docker build -t $(DOCKER_IMAGE) . -.PHONY: grpc -grpc: api/api.pb.go +.PHONY: proto +proto: api/api.pb.go server/internal/types.pb.go api/api.pb.go: api/api.proto bin/protoc bin/protoc-gen-go @protoc --go_out=plugins=grpc:. api/*.proto +server/internal/types.pb.go: server/internal/types.proto bin/protoc bin/protoc-gen-go + @protoc --go_out=. server/internal/*.proto + bin/protoc: scripts/get-protoc @./scripts/get-protoc bin/protoc diff --git a/cmd/example-app/main.go b/cmd/example-app/main.go index ffa21c29..3ec34e38 100644 --- a/cmd/example-app/main.go +++ b/cmd/example-app/main.go @@ -241,7 +241,7 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) { authCodeURL := "" scopes = append(scopes, "openid", "profile", "email") - if r.FormValue("offline_acecss") != "yes" { + if r.FormValue("offline_access") != "yes" { authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState) } else if a.offlineAsScope { scopes = append(scopes, "offline_access") @@ -254,34 +254,42 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) { } func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { - if errMsg := r.FormValue("error"); errMsg != "" { - http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest) - return - } - - if state := r.FormValue("state"); state != exampleAppState { - http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest) - return - } - - code := r.FormValue("code") - refresh := r.FormValue("refresh_token") var ( err error token *oauth2.Token ) oauth2Config := a.oauth2Config(nil) - switch { - case code != "": + switch r.Method { + case "GET": + // Authorization redirect callback from OAuth2 auth flow. + if errMsg := r.FormValue("error"); errMsg != "" { + http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest) + return + } + code := r.FormValue("code") + if code == "" { + http.Error(w, fmt.Sprintf("no code in request: %q", r.Form), http.StatusBadRequest) + return + } + if state := r.FormValue("state"); state != exampleAppState { + http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest) + return + } token, err = oauth2Config.Exchange(a.ctx, code) - case refresh != "": + case "POST": + // Form request from frontend to refresh a token. + refresh := r.FormValue("refresh_token") + if refresh == "" { + http.Error(w, fmt.Sprintf("no refresh_token in request: %q", r.Form), http.StatusBadRequest) + return + } t := &oauth2.Token{ RefreshToken: refresh, Expiry: time.Now().Add(-time.Hour), } token, err = oauth2Config.TokenSource(r.Context(), t).Token() default: - http.Error(w, fmt.Sprintf("no code in request: %q", r.Form), http.StatusBadRequest) + http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest) return } diff --git a/cmd/example-app/templates.go b/cmd/example-app/templates.go index c0f9dfbd..a870d0f0 100644 --- a/cmd/example-app/templates.go +++ b/cmd/example-app/templates.go @@ -8,7 +8,7 @@ import ( var indexTmpl = template.Must(template.New("index.html").Parse(`
- + {{ end }} `)) diff --git a/server/handlers.go b/server/handlers.go index 808f8031..ef264dfe 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -2,6 +2,7 @@ package server import ( "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -16,6 +17,7 @@ import ( jose "gopkg.in/square/go-jose.v2" "github.com/coreos/dex/connector" + "github.com/coreos/dex/server/internal" "github.com/coreos/dex/storage" ) @@ -645,20 +647,32 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s var refreshToken string if reqRefresh { refresh := storage.RefreshToken{ - RefreshToken: storage.NewID(), + ID: storage.NewID(), + Token: storage.NewID(), ClientID: authCode.ClientID, ConnectorID: authCode.ConnectorID, Scopes: authCode.Scopes, Claims: authCode.Claims, Nonce: authCode.Nonce, ConnectorData: authCode.ConnectorData, + CreatedAt: s.now(), + LastUsed: s.now(), } + token := &internal.RefreshToken{ + RefreshId: refresh.ID, + Token: refresh.Token, + } + if refreshToken, err = internal.Marshal(token); err != nil { + s.logger.Errorf("failed to marshal refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + if err := s.storage.CreateRefresh(refresh); err != nil { s.logger.Errorf("failed to create refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - refreshToken = refresh.RefreshToken } s.writeAccessToken(w, idToken, refreshToken, expiry) } @@ -672,16 +686,37 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - refresh, err := s.storage.GetRefresh(code) - if err != nil || refresh.ClientID != client.ID { - if err != storage.ErrNotFound { - s.logger.Errorf("failed to get auth code: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - } else { + token := new(internal.RefreshToken) + if err := internal.Unmarshal(code, token); err != nil { + // For backward compatibility, assume the refresh_token is a raw refresh token ID + // if it fails to decode. + // + // Because refresh_token values that aren't unmarshable were generated by servers + // that don't have a Token value, we'll still reject any attempts to claim a + // refresh_token twice. + token = &internal.RefreshToken{RefreshId: code, Token: ""} + } + + refresh, err := s.storage.GetRefresh(token.RefreshId) + if err != nil { + s.logger.Errorf("failed to get refresh token: %v", err) + if err == storage.ErrNotFound { s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) + } else { + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) } return } + if refresh.ClientID != client.ID { + s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID) + s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) + return + } + if refresh.Token != token.Token { + s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) + s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) + return + } // Per the OAuth2 spec, if the client has omitted the scopes, default to the original // authorized scopes. @@ -720,6 +755,14 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } + ident := connector.Identity{ + UserID: refresh.Claims.UserID, + Username: refresh.Claims.Username, + Email: refresh.Claims.Email, + EmailVerified: refresh.Claims.EmailVerified, + Groups: refresh.Claims.Groups, + ConnectorData: refresh.ConnectorData, + } // Can the connector refresh the identity? If so, attempt to refresh the data // in the connector. @@ -727,52 +770,63 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie // TODO(ericchiang): We may want a strict mode where connectors that don't implement // this interface can't perform refreshing. if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok { - ident := connector.Identity{ - UserID: refresh.Claims.UserID, - Username: refresh.Claims.Username, - Email: refresh.Claims.Email, - EmailVerified: refresh.Claims.EmailVerified, - Groups: refresh.Claims.Groups, - ConnectorData: refresh.ConnectorData, - } - ident, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident) + newIdent, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident) if err != nil { s.logger.Errorf("failed to refresh identity: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - - // Update the claims of the refresh token. - // - // UserID intentionally ignored for now. - refresh.Claims.Username = ident.Username - refresh.Claims.Email = ident.Email - refresh.Claims.EmailVerified = ident.EmailVerified - refresh.Claims.Groups = ident.Groups - refresh.ConnectorData = ident.ConnectorData + ident = newIdent } - idToken, expiry, err := s.newIDToken(client.ID, refresh.Claims, scopes, refresh.Nonce) + claims := storage.Claims{ + UserID: ident.UserID, + Username: ident.Username, + Email: ident.Email, + EmailVerified: ident.EmailVerified, + Groups: ident.Groups, + } + + idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - // Refresh tokens are claimed exactly once. Delete the current token and - // create a new one. - if err := s.storage.DeleteRefresh(code); err != nil { - s.logger.Errorf("failed to delete auth code: %v", err) + newToken := &internal.RefreshToken{ + RefreshId: refresh.ID, + Token: storage.NewID(), + } + rawNewToken, err := internal.Marshal(newToken) + if err != nil { + s.logger.Errorf("failed to marshal refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - refresh.RefreshToken = storage.NewID() - if err := s.storage.CreateRefresh(refresh); err != nil { - s.logger.Errorf("failed to create refresh token: %v", err) + + updater := func(old storage.RefreshToken) (storage.RefreshToken, error) { + if old.Token != refresh.Token { + return old, errors.New("refresh token claimed twice") + } + old.Token = newToken.Token + // Update the claims of the refresh token. + // + // UserID intentionally ignored for now. + old.Claims.Username = ident.Username + old.Claims.Email = ident.Email + old.Claims.EmailVerified = ident.EmailVerified + old.Claims.Groups = ident.Groups + old.ConnectorData = ident.ConnectorData + old.LastUsed = s.now() + return old, nil + } + if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil { + s.logger.Errorf("failed to update refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - s.writeAccessToken(w, idToken, refresh.RefreshToken, expiry) + s.writeAccessToken(w, idToken, rawNewToken, expiry) } func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, refreshToken string, expiry time.Time) { diff --git a/server/internal/codec.go b/server/internal/codec.go new file mode 100644 index 00000000..a92c26f9 --- /dev/null +++ b/server/internal/codec.go @@ -0,0 +1,25 @@ +package internal + +import ( + "encoding/base64" + + "github.com/golang/protobuf/proto" +) + +// Marshal converts a protobuf message to a URL legal string. +func Marshal(message proto.Message) (string, error) { + data, err := proto.Marshal(message) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(data), nil +} + +// Unmarshal decodes a protobuf message. +func Unmarshal(s string, message proto.Message) error { + data, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return err + } + return proto.Unmarshal(data, message) +} diff --git a/server/internal/types.pb.go b/server/internal/types.pb.go new file mode 100644 index 00000000..791944f5 --- /dev/null +++ b/server/internal/types.pb.go @@ -0,0 +1,59 @@ +// Code generated by protoc-gen-go. +// source: server/internal/types.proto +// DO NOT EDIT! + +/* +Package internal is a generated protocol buffer package. + +Package internal holds protobuf types used by the server + +It is generated from these files: + server/internal/types.proto + +It has these top-level messages: + RefreshToken +*/ +package internal + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +// RefreshToken is a message that holds refresh token data used by dex. +type RefreshToken struct { + RefreshId string `protobuf:"bytes,1,opt,name=refresh_id,json=refreshId" json:"refresh_id,omitempty"` + Token string `protobuf:"bytes,2,opt,name=token" json:"token,omitempty"` +} + +func (m *RefreshToken) Reset() { *m = RefreshToken{} } +func (m *RefreshToken) String() string { return proto.CompactTextString(m) } +func (*RefreshToken) ProtoMessage() {} +func (*RefreshToken) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func init() { + proto.RegisterType((*RefreshToken)(nil), "internal.RefreshToken") +} + +func init() { proto.RegisterFile("server/internal/types.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 112 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0x92, 0x2e, 0x4e, 0x2d, 0x2a, + 0x4b, 0x2d, 0xd2, 0xcf, 0xcc, 0x2b, 0x49, 0x2d, 0xca, 0x4b, 0xcc, 0xd1, 0x2f, 0xa9, 0x2c, 0x48, + 0x2d, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x80, 0x89, 0x2a, 0x39, 0x73, 0xf1, 0x04, + 0xa5, 0xa6, 0x15, 0xa5, 0x16, 0x67, 0x84, 0xe4, 0x67, 0xa7, 0xe6, 0x09, 0xc9, 0x72, 0x71, 0x15, + 0x41, 0xf8, 0xf1, 0x99, 0x29, 0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x9c, 0x50, 0x11, 0xcf, + 0x14, 0x21, 0x11, 0x2e, 0xd6, 0x12, 0x90, 0x3a, 0x09, 0x26, 0xb0, 0x0c, 0x84, 0x93, 0xc4, 0x06, + 0x36, 0xd5, 0x18, 0x10, 0x00, 0x00, 0xff, 0xff, 0x9b, 0xd0, 0x5a, 0x1d, 0x74, 0x00, 0x00, 0x00, +} diff --git a/server/internal/types.proto b/server/internal/types.proto new file mode 100644 index 00000000..442dbd95 --- /dev/null +++ b/server/internal/types.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +// Package internal holds protobuf types used by the server +package internal; + +// RefreshToken is a message that holds refresh token data used by dex. +message RefreshToken { + string refresh_id = 1; + string token = 2; +} diff --git a/server/server_test.go b/server/server_test.go index 7c499c15..d848076f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -237,6 +237,10 @@ func TestOAuth2CodeFlow(t *testing.T) { if token.RefreshToken == newToken.RefreshToken { return fmt.Errorf("old refresh token was the same as the new token %q", token.RefreshToken) } + + if _, err := config.TokenSource(ctx, token).Token(); err == nil { + return errors.New("was able to redeem the same refresh token twice") + } return nil }, }, diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 8cb911aa..0a6fe1c9 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -208,10 +208,14 @@ func testClientCRUD(t *testing.T, s storage.Storage) { func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { id := storage.NewID() refresh := storage.RefreshToken{ - RefreshToken: id, - ClientID: "client_id", - ConnectorID: "client_secret", - Scopes: []string{"openid", "email", "profile"}, + ID: id, + Token: "bar", + Nonce: "foo", + ClientID: "client_id", + ConnectorID: "client_secret", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), Claims: storage.Claims{ UserID: "1", Username: "jane", @@ -238,6 +242,20 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { getAndCompare(id, refresh) + updatedAt := time.Now().UTC().Round(time.Millisecond) + + updater := func(r storage.RefreshToken) (storage.RefreshToken, error) { + r.Token = "spam" + r.LastUsed = updatedAt + return r, nil + } + if err := s.UpdateRefreshToken(id, updater); err != nil { + t.Errorf("failed to udpate refresh token: %v", err) + } + refresh.Token = "spam" + refresh.LastUsed = updatedAt + getAndCompare(id, refresh) + if err := s.DeleteRefresh(id); err != nil { t.Fatalf("failed to delete refresh request: %v", err) } diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index e744ab2d..102a7494 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -153,23 +153,7 @@ func (cli *client) CreatePassword(p storage.Password) error { } func (cli *client) CreateRefresh(r storage.RefreshToken) error { - refresh := RefreshToken{ - TypeMeta: k8sapi.TypeMeta{ - Kind: kindRefreshToken, - APIVersion: cli.apiVersion, - }, - ObjectMeta: k8sapi.ObjectMeta{ - Name: r.RefreshToken, - Namespace: cli.namespace, - }, - ClientID: r.ClientID, - ConnectorID: r.ConnectorID, - Scopes: r.Scopes, - Nonce: r.Nonce, - Claims: fromStorageClaims(r.Claims), - ConnectorData: r.ConnectorData, - } - return cli.post(resourceRefreshToken, refresh) + return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r)) } func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) { @@ -239,19 +223,16 @@ func (cli *client) GetKeys() (storage.Keys, error) { } func (cli *client) GetRefresh(id string) (storage.RefreshToken, error) { - var r RefreshToken - if err := cli.get(resourceRefreshToken, id, &r); err != nil { + r, err := cli.getRefreshToken(id) + if err != nil { return storage.RefreshToken{}, err } - return storage.RefreshToken{ - RefreshToken: r.ObjectMeta.Name, - ClientID: r.ClientID, - ConnectorID: r.ConnectorID, - Scopes: r.Scopes, - Nonce: r.Nonce, - Claims: toStorageClaims(r.Claims), - ConnectorData: r.ConnectorData, - }, nil + return toStorageRefreshToken(r), nil +} + +func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) { + err = cli.get(resourceRefreshToken, id, &r) + return } func (cli *client) ListClients() ([]storage.Client, error) { @@ -311,6 +292,22 @@ func (cli *client) DeletePassword(email string) error { return cli.delete(resourcePassword, p.ObjectMeta.Name) } +func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { + r, err := cli.getRefreshToken(id) + if err != nil { + return err + } + updated, err := updater(toStorageRefreshToken(r)) + if err != nil { + return err + } + updated.ID = id + + newToken := cli.fromStorageRefreshToken(updated) + newToken.ObjectMeta = r.ObjectMeta + return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken) +} + func (cli *client) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { c, err := cli.getClient(id) if err != nil { diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 9009c800..660f86d8 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -362,9 +362,14 @@ type RefreshToken struct { k8sapi.TypeMeta `json:",inline"` k8sapi.ObjectMeta `json:"metadata,omitempty"` + CreatedAt time.Time + LastUsed time.Time + ClientID string `json:"clientID"` Scopes []string `json:"scopes,omitempty"` + Token string `json:"token,omitempty"` + Nonce string `json:"nonce,omitempty"` Claims Claims `json:"claims,omitempty"` @@ -379,6 +384,43 @@ type RefreshList struct { RefreshTokens []RefreshToken `json:"items"` } +func toStorageRefreshToken(r RefreshToken) storage.RefreshToken { + return storage.RefreshToken{ + ID: r.ObjectMeta.Name, + Token: r.Token, + CreatedAt: r.CreatedAt, + LastUsed: r.LastUsed, + ClientID: r.ClientID, + ConnectorID: r.ConnectorID, + ConnectorData: r.ConnectorData, + Scopes: r.Scopes, + Nonce: r.Nonce, + Claims: toStorageClaims(r.Claims), + } +} + +func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken { + return RefreshToken{ + TypeMeta: k8sapi.TypeMeta{ + Kind: kindRefreshToken, + APIVersion: cli.apiVersion, + }, + ObjectMeta: k8sapi.ObjectMeta{ + Name: r.ID, + Namespace: cli.namespace, + }, + Token: r.Token, + CreatedAt: r.CreatedAt, + LastUsed: r.LastUsed, + ClientID: r.ClientID, + ConnectorID: r.ConnectorID, + ConnectorData: r.ConnectorData, + Scopes: r.Scopes, + Nonce: r.Nonce, + Claims: fromStorageClaims(r.Claims), + } +} + // Keys is a mirrored struct from storage with JSON struct tags and Kubernetes // type metadata. type Keys struct { diff --git a/storage/memory/memory.go b/storage/memory/memory.go index 6d609717..8bfbdce2 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -98,10 +98,10 @@ func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) { func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) { s.tx(func() { - if _, ok := s.refreshTokens[r.RefreshToken]; ok { + if _, ok := s.refreshTokens[r.ID]; ok { err = storage.ErrAlreadyExists } else { - s.refreshTokens[r.RefreshToken] = r + s.refreshTokens[r.ID] = r } }) return @@ -324,3 +324,17 @@ func (s *memStorage) UpdatePassword(email string, updater func(p storage.Passwor }) return } + +func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.RefreshToken) (storage.RefreshToken, error)) (err error) { + s.tx(func() { + r, ok := s.refreshTokens[id] + if !ok { + err = storage.ErrNotFound + return + } + if r, err = updater(r); err == nil { + s.refreshTokens[id] = r + } + }) + return +} diff --git a/storage/sql/crud.go b/storage/sql/crud.go index e3270363..494f1c20 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -244,14 +244,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { id, client_id, scopes, nonce, claims_user_id, claims_username, claims_email, claims_email_verified, claims_groups, - connector_id, connector_data + connector_id, connector_data, + token, created_at, last_used ) - values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11); + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14); `, - r.RefreshToken, r.ClientID, encoder(r.Scopes), r.Nonce, + r.ID, r.ClientID, encoder(r.Scopes), r.Nonce, r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified, encoder(r.Claims.Groups), r.ConnectorID, r.ConnectorData, + r.Token, r.CreatedAt, r.LastUsed, ) if err != nil { return fmt.Errorf("insert refresh_token: %v", err) @@ -259,13 +261,57 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { return nil } +func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { + return c.ExecTx(func(tx *trans) error { + r, err := getRefresh(tx, id) + if err != nil { + return err + } + if r, err = updater(r); err != nil { + return err + } + _, err = tx.Exec(` + update refresh_token + set + client_id = $1, + scopes = $2, + nonce = $3, + claims_user_id = $4, + claims_username = $5, + claims_email = $6, + claims_email_verified = $7, + claims_groups = $8, + connector_id = $9, + connector_data = $10, + token = $11, + created_at = $12, + last_used = $13 + `, + r.ClientID, encoder(r.Scopes), r.Nonce, + r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified, + encoder(r.Claims.Groups), + r.ConnectorID, r.ConnectorData, + r.Token, r.CreatedAt, r.LastUsed, + ) + if err != nil { + return fmt.Errorf("update refresh token: %v", err) + } + return nil + }) +} + func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) { - return scanRefresh(c.QueryRow(` + return getRefresh(c, id) +} + +func getRefresh(q querier, id string) (storage.RefreshToken, error) { + return scanRefresh(q.QueryRow(` select id, client_id, scopes, nonce, claims_user_id, claims_username, claims_email, claims_email_verified, claims_groups, - connector_id, connector_data + connector_id, connector_data, + token, created_at, last_used from refresh_token where id = $1; `, id)) } @@ -276,7 +322,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { id, client_id, scopes, nonce, claims_user_id, claims_username, claims_email, claims_email_verified, claims_groups, - connector_id, connector_data + connector_id, connector_data, + token, created_at, last_used from refresh_token; `) if err != nil { @@ -298,10 +345,11 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { func scanRefresh(s scanner) (r storage.RefreshToken, err error) { err = s.Scan( - &r.RefreshToken, &r.ClientID, decoder(&r.Scopes), &r.Nonce, + &r.ID, &r.ClientID, decoder(&r.Scopes), &r.Nonce, &r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified, decoder(&r.Claims.Groups), &r.ConnectorID, &r.ConnectorData, + &r.Token, &r.CreatedAt, &r.LastUsed, ) if err != nil { if err == sql.ErrNoRows { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 3bb410aa..b2b66d39 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -155,4 +155,14 @@ var migrations = []migration{ ); `, }, + { + stmt: ` + alter table refresh_token + add column token text not null default ''; + alter table refresh_token + add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC'; + alter table refresh_token + add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC'; + `, + }, } diff --git a/storage/storage.go b/storage/storage.go index 22a9ea50..47f5dcc6 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -94,6 +94,7 @@ type Storage interface { UpdateClient(id string, updater func(old Client) (Client, error)) error UpdateKeys(updater func(old Keys) (Keys, error)) error UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error + UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error UpdatePassword(email string, updater func(p Password) (Password, error)) error // GarbageCollect deletes all expired AuthCodes and AuthRequests. @@ -216,8 +217,15 @@ type AuthCode struct { // RefreshToken is an OAuth2 refresh token which allows a client to request new // tokens on the end user's behalf. type RefreshToken struct { - // The actual refresh token. - RefreshToken string + ID string + + // A single token that's rotated every time the refresh token is refreshed. + // + // May be empty. + Token string + + CreatedAt time.Time + LastUsed time.Time // Client this refresh token is valid for. ClientID string