forked from mystiq/dex
Merge pull request #757 from ericchiang/constant-refresh-tokens
*: update refresh tokens instead of deleting and creating another
This commit is contained in:
commit
3c247db00a
15 changed files with 405 additions and 100 deletions
9
Makefile
9
Makefile
|
@ -55,7 +55,7 @@ fmt:
|
|||
@go fmt $(shell go list ./... | grep -v '/vendor/')
|
||||
|
||||
lint:
|
||||
@for package in $(shell go list ./... | grep -v '/vendor/' | grep -v '/api'); do \
|
||||
@for package in $(shell go list ./... | grep -v '/vendor/' | grep -v '/api' | grep -v '/server/internal'); do \
|
||||
golint -set_exit_status $$package $$i || exit 1; \
|
||||
done
|
||||
|
||||
|
@ -81,12 +81,15 @@ aci: clean-release _output/bin/dex _output/images/library-alpine-3.4.aci
|
|||
docker-image: clean-release _output/bin/dex
|
||||
@sudo docker build -t $(DOCKER_IMAGE) .
|
||||
|
||||
.PHONY: grpc
|
||||
grpc: api/api.pb.go
|
||||
.PHONY: proto
|
||||
proto: api/api.pb.go server/internal/types.pb.go
|
||||
|
||||
api/api.pb.go: api/api.proto bin/protoc bin/protoc-gen-go
|
||||
@protoc --go_out=plugins=grpc:. api/*.proto
|
||||
|
||||
server/internal/types.pb.go: server/internal/types.proto bin/protoc bin/protoc-gen-go
|
||||
@protoc --go_out=. server/internal/*.proto
|
||||
|
||||
bin/protoc: scripts/get-protoc
|
||||
@./scripts/get-protoc bin/protoc
|
||||
|
||||
|
|
|
@ -241,7 +241,7 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
authCodeURL := ""
|
||||
scopes = append(scopes, "openid", "profile", "email")
|
||||
if r.FormValue("offline_acecss") != "yes" {
|
||||
if r.FormValue("offline_access") != "yes" {
|
||||
authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState)
|
||||
} else if a.offlineAsScope {
|
||||
scopes = append(scopes, "offline_access")
|
||||
|
@ -254,34 +254,42 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
if errMsg := r.FormValue("error"); errMsg != "" {
|
||||
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if state := r.FormValue("state"); state != exampleAppState {
|
||||
http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
code := r.FormValue("code")
|
||||
refresh := r.FormValue("refresh_token")
|
||||
var (
|
||||
err error
|
||||
token *oauth2.Token
|
||||
)
|
||||
oauth2Config := a.oauth2Config(nil)
|
||||
switch {
|
||||
case code != "":
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
// Authorization redirect callback from OAuth2 auth flow.
|
||||
if errMsg := r.FormValue("error"); errMsg != "" {
|
||||
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
code := r.FormValue("code")
|
||||
if code == "" {
|
||||
http.Error(w, fmt.Sprintf("no code in request: %q", r.Form), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if state := r.FormValue("state"); state != exampleAppState {
|
||||
http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
token, err = oauth2Config.Exchange(a.ctx, code)
|
||||
case refresh != "":
|
||||
case "POST":
|
||||
// Form request from frontend to refresh a token.
|
||||
refresh := r.FormValue("refresh_token")
|
||||
if refresh == "" {
|
||||
http.Error(w, fmt.Sprintf("no refresh_token in request: %q", r.Form), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
t := &oauth2.Token{
|
||||
RefreshToken: refresh,
|
||||
Expiry: time.Now().Add(-time.Hour),
|
||||
}
|
||||
token, err = oauth2Config.TokenSource(r.Context(), t).Token()
|
||||
default:
|
||||
http.Error(w, fmt.Sprintf("no code in request: %q", r.Form), http.StatusBadRequest)
|
||||
http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
var indexTmpl = template.Must(template.New("index.html").Parse(`<html>
|
||||
<body>
|
||||
<form action="/login">
|
||||
<form action="/login" method="post">
|
||||
<p>
|
||||
Authenticate for:<input type="text" name="cross_client" placeholder="list of client-ids">
|
||||
</p>
|
||||
|
@ -50,8 +50,13 @@ pre {
|
|||
<body>
|
||||
<p> Token: <pre><code>{{ .IDToken }}</code></pre></p>
|
||||
<p> Claims: <pre><code>{{ .Claims }}</code></pre></p>
|
||||
{{ if .RefreshToken }}
|
||||
<p> Refresh Token: <pre><code>{{ .RefreshToken }}</code></pre></p>
|
||||
<p><a href="{{ .RedirectURL }}?refresh_token={{ .RefreshToken }}">Redeem refresh token</a><p>
|
||||
<form action="{{ .RedirectURL }}" method="post">
|
||||
<input type="hidden" name="refresh_token" value="{{ .RefreshToken }}">
|
||||
<input type="submit" value="Redeem refresh token">
|
||||
</form>
|
||||
{{ end }}
|
||||
</body>
|
||||
</html>
|
||||
`))
|
||||
|
|
|
@ -2,6 +2,7 @@ package server
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -16,6 +17,7 @@ import (
|
|||
jose "gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
"github.com/coreos/dex/server/internal"
|
||||
"github.com/coreos/dex/storage"
|
||||
)
|
||||
|
||||
|
@ -645,20 +647,32 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||
var refreshToken string
|
||||
if reqRefresh {
|
||||
refresh := storage.RefreshToken{
|
||||
RefreshToken: storage.NewID(),
|
||||
ID: storage.NewID(),
|
||||
Token: storage.NewID(),
|
||||
ClientID: authCode.ClientID,
|
||||
ConnectorID: authCode.ConnectorID,
|
||||
Scopes: authCode.Scopes,
|
||||
Claims: authCode.Claims,
|
||||
Nonce: authCode.Nonce,
|
||||
ConnectorData: authCode.ConnectorData,
|
||||
CreatedAt: s.now(),
|
||||
LastUsed: s.now(),
|
||||
}
|
||||
token := &internal.RefreshToken{
|
||||
RefreshId: refresh.ID,
|
||||
Token: refresh.Token,
|
||||
}
|
||||
if refreshToken, err = internal.Marshal(token); err != nil {
|
||||
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.storage.CreateRefresh(refresh); err != nil {
|
||||
s.logger.Errorf("failed to create refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
refreshToken = refresh.RefreshToken
|
||||
}
|
||||
s.writeAccessToken(w, idToken, refreshToken, expiry)
|
||||
}
|
||||
|
@ -672,14 +686,35 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||
return
|
||||
}
|
||||
|
||||
refresh, err := s.storage.GetRefresh(code)
|
||||
if err != nil || refresh.ClientID != client.ID {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get auth code: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
} else {
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
|
||||
token := new(internal.RefreshToken)
|
||||
if err := internal.Unmarshal(code, token); err != nil {
|
||||
// For backward compatibility, assume the refresh_token is a raw refresh token ID
|
||||
// if it fails to decode.
|
||||
//
|
||||
// Because refresh_token values that aren't unmarshable were generated by servers
|
||||
// that don't have a Token value, we'll still reject any attempts to claim a
|
||||
// refresh_token twice.
|
||||
token = &internal.RefreshToken{RefreshId: code, Token: ""}
|
||||
}
|
||||
|
||||
refresh, err := s.storage.GetRefresh(token.RefreshId)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to get refresh token: %v", err)
|
||||
if err == storage.ErrNotFound {
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
|
||||
} else {
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
if refresh.ClientID != client.ID {
|
||||
s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if refresh.Token != token.Token {
|
||||
s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -720,13 +755,6 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Can the connector refresh the identity? If so, attempt to refresh the data
|
||||
// in the connector.
|
||||
//
|
||||
// TODO(ericchiang): We may want a strict mode where connectors that don't implement
|
||||
// this interface can't perform refreshing.
|
||||
if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok {
|
||||
ident := connector.Identity{
|
||||
UserID: refresh.Claims.UserID,
|
||||
Username: refresh.Claims.Username,
|
||||
|
@ -735,44 +763,70 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||
Groups: refresh.Claims.Groups,
|
||||
ConnectorData: refresh.ConnectorData,
|
||||
}
|
||||
ident, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident)
|
||||
|
||||
// Can the connector refresh the identity? If so, attempt to refresh the data
|
||||
// in the connector.
|
||||
//
|
||||
// TODO(ericchiang): We may want a strict mode where connectors that don't implement
|
||||
// this interface can't perform refreshing.
|
||||
if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok {
|
||||
newIdent, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to refresh identity: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Update the claims of the refresh token.
|
||||
//
|
||||
// UserID intentionally ignored for now.
|
||||
refresh.Claims.Username = ident.Username
|
||||
refresh.Claims.Email = ident.Email
|
||||
refresh.Claims.EmailVerified = ident.EmailVerified
|
||||
refresh.Claims.Groups = ident.Groups
|
||||
refresh.ConnectorData = ident.ConnectorData
|
||||
ident = newIdent
|
||||
}
|
||||
|
||||
idToken, expiry, err := s.newIDToken(client.ID, refresh.Claims, scopes, refresh.Nonce)
|
||||
claims := storage.Claims{
|
||||
UserID: ident.UserID,
|
||||
Username: ident.Username,
|
||||
Email: ident.Email,
|
||||
EmailVerified: ident.EmailVerified,
|
||||
Groups: ident.Groups,
|
||||
}
|
||||
|
||||
idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to create ID token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh tokens are claimed exactly once. Delete the current token and
|
||||
// create a new one.
|
||||
if err := s.storage.DeleteRefresh(code); err != nil {
|
||||
s.logger.Errorf("failed to delete auth code: %v", err)
|
||||
newToken := &internal.RefreshToken{
|
||||
RefreshId: refresh.ID,
|
||||
Token: storage.NewID(),
|
||||
}
|
||||
rawNewToken, err := internal.Marshal(newToken)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
refresh.RefreshToken = storage.NewID()
|
||||
if err := s.storage.CreateRefresh(refresh); err != nil {
|
||||
s.logger.Errorf("failed to create refresh token: %v", err)
|
||||
|
||||
updater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
|
||||
if old.Token != refresh.Token {
|
||||
return old, errors.New("refresh token claimed twice")
|
||||
}
|
||||
old.Token = newToken.Token
|
||||
// Update the claims of the refresh token.
|
||||
//
|
||||
// UserID intentionally ignored for now.
|
||||
old.Claims.Username = ident.Username
|
||||
old.Claims.Email = ident.Email
|
||||
old.Claims.EmailVerified = ident.EmailVerified
|
||||
old.Claims.Groups = ident.Groups
|
||||
old.ConnectorData = ident.ConnectorData
|
||||
old.LastUsed = s.now()
|
||||
return old, nil
|
||||
}
|
||||
if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil {
|
||||
s.logger.Errorf("failed to update refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
s.writeAccessToken(w, idToken, refresh.RefreshToken, expiry)
|
||||
s.writeAccessToken(w, idToken, rawNewToken, expiry)
|
||||
}
|
||||
|
||||
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, refreshToken string, expiry time.Time) {
|
||||
|
|
25
server/internal/codec.go
Normal file
25
server/internal/codec.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
)
|
||||
|
||||
// Marshal converts a protobuf message to a URL legal string.
|
||||
func Marshal(message proto.Message) (string, error) {
|
||||
data, err := proto.Marshal(message)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(data), nil
|
||||
}
|
||||
|
||||
// Unmarshal decodes a protobuf message.
|
||||
func Unmarshal(s string, message proto.Message) error {
|
||||
data, err := base64.RawURLEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return proto.Unmarshal(data, message)
|
||||
}
|
59
server/internal/types.pb.go
Normal file
59
server/internal/types.pb.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
// Code generated by protoc-gen-go.
|
||||
// source: server/internal/types.proto
|
||||
// DO NOT EDIT!
|
||||
|
||||
/*
|
||||
Package internal is a generated protocol buffer package.
|
||||
|
||||
Package internal holds protobuf types used by the server
|
||||
|
||||
It is generated from these files:
|
||||
server/internal/types.proto
|
||||
|
||||
It has these top-level messages:
|
||||
RefreshToken
|
||||
*/
|
||||
package internal
|
||||
|
||||
import proto "github.com/golang/protobuf/proto"
|
||||
import fmt "fmt"
|
||||
import math "math"
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package
|
||||
|
||||
// RefreshToken is a message that holds refresh token data used by dex.
|
||||
type RefreshToken struct {
|
||||
RefreshId string `protobuf:"bytes,1,opt,name=refresh_id,json=refreshId" json:"refresh_id,omitempty"`
|
||||
Token string `protobuf:"bytes,2,opt,name=token" json:"token,omitempty"`
|
||||
}
|
||||
|
||||
func (m *RefreshToken) Reset() { *m = RefreshToken{} }
|
||||
func (m *RefreshToken) String() string { return proto.CompactTextString(m) }
|
||||
func (*RefreshToken) ProtoMessage() {}
|
||||
func (*RefreshToken) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
|
||||
|
||||
func init() {
|
||||
proto.RegisterType((*RefreshToken)(nil), "internal.RefreshToken")
|
||||
}
|
||||
|
||||
func init() { proto.RegisterFile("server/internal/types.proto", fileDescriptor0) }
|
||||
|
||||
var fileDescriptor0 = []byte{
|
||||
// 112 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0x92, 0x2e, 0x4e, 0x2d, 0x2a,
|
||||
0x4b, 0x2d, 0xd2, 0xcf, 0xcc, 0x2b, 0x49, 0x2d, 0xca, 0x4b, 0xcc, 0xd1, 0x2f, 0xa9, 0x2c, 0x48,
|
||||
0x2d, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x80, 0x89, 0x2a, 0x39, 0x73, 0xf1, 0x04,
|
||||
0xa5, 0xa6, 0x15, 0xa5, 0x16, 0x67, 0x84, 0xe4, 0x67, 0xa7, 0xe6, 0x09, 0xc9, 0x72, 0x71, 0x15,
|
||||
0x41, 0xf8, 0xf1, 0x99, 0x29, 0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x9c, 0x50, 0x11, 0xcf,
|
||||
0x14, 0x21, 0x11, 0x2e, 0xd6, 0x12, 0x90, 0x3a, 0x09, 0x26, 0xb0, 0x0c, 0x84, 0x93, 0xc4, 0x06,
|
||||
0x36, 0xd5, 0x18, 0x10, 0x00, 0x00, 0xff, 0xff, 0x9b, 0xd0, 0x5a, 0x1d, 0x74, 0x00, 0x00, 0x00,
|
||||
}
|
10
server/internal/types.proto
Normal file
10
server/internal/types.proto
Normal file
|
@ -0,0 +1,10 @@
|
|||
syntax = "proto3";
|
||||
|
||||
// Package internal holds protobuf types used by the server
|
||||
package internal;
|
||||
|
||||
// RefreshToken is a message that holds refresh token data used by dex.
|
||||
message RefreshToken {
|
||||
string refresh_id = 1;
|
||||
string token = 2;
|
||||
}
|
|
@ -237,6 +237,10 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||
if token.RefreshToken == newToken.RefreshToken {
|
||||
return fmt.Errorf("old refresh token was the same as the new token %q", token.RefreshToken)
|
||||
}
|
||||
|
||||
if _, err := config.TokenSource(ctx, token).Token(); err == nil {
|
||||
return errors.New("was able to redeem the same refresh token twice")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
|
|
|
@ -208,10 +208,14 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
|
|||
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
|
||||
id := storage.NewID()
|
||||
refresh := storage.RefreshToken{
|
||||
RefreshToken: id,
|
||||
ID: id,
|
||||
Token: "bar",
|
||||
Nonce: "foo",
|
||||
ClientID: "client_id",
|
||||
ConnectorID: "client_secret",
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
CreatedAt: time.Now().UTC().Round(time.Millisecond),
|
||||
LastUsed: time.Now().UTC().Round(time.Millisecond),
|
||||
Claims: storage.Claims{
|
||||
UserID: "1",
|
||||
Username: "jane",
|
||||
|
@ -238,6 +242,20 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
|
|||
|
||||
getAndCompare(id, refresh)
|
||||
|
||||
updatedAt := time.Now().UTC().Round(time.Millisecond)
|
||||
|
||||
updater := func(r storage.RefreshToken) (storage.RefreshToken, error) {
|
||||
r.Token = "spam"
|
||||
r.LastUsed = updatedAt
|
||||
return r, nil
|
||||
}
|
||||
if err := s.UpdateRefreshToken(id, updater); err != nil {
|
||||
t.Errorf("failed to udpate refresh token: %v", err)
|
||||
}
|
||||
refresh.Token = "spam"
|
||||
refresh.LastUsed = updatedAt
|
||||
getAndCompare(id, refresh)
|
||||
|
||||
if err := s.DeleteRefresh(id); err != nil {
|
||||
t.Fatalf("failed to delete refresh request: %v", err)
|
||||
}
|
||||
|
|
|
@ -153,23 +153,7 @@ func (cli *client) CreatePassword(p storage.Password) error {
|
|||
}
|
||||
|
||||
func (cli *client) CreateRefresh(r storage.RefreshToken) error {
|
||||
refresh := RefreshToken{
|
||||
TypeMeta: k8sapi.TypeMeta{
|
||||
Kind: kindRefreshToken,
|
||||
APIVersion: cli.apiVersion,
|
||||
},
|
||||
ObjectMeta: k8sapi.ObjectMeta{
|
||||
Name: r.RefreshToken,
|
||||
Namespace: cli.namespace,
|
||||
},
|
||||
ClientID: r.ClientID,
|
||||
ConnectorID: r.ConnectorID,
|
||||
Scopes: r.Scopes,
|
||||
Nonce: r.Nonce,
|
||||
Claims: fromStorageClaims(r.Claims),
|
||||
ConnectorData: r.ConnectorData,
|
||||
}
|
||||
return cli.post(resourceRefreshToken, refresh)
|
||||
return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r))
|
||||
}
|
||||
|
||||
func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
||||
|
@ -239,19 +223,16 @@ func (cli *client) GetKeys() (storage.Keys, error) {
|
|||
}
|
||||
|
||||
func (cli *client) GetRefresh(id string) (storage.RefreshToken, error) {
|
||||
var r RefreshToken
|
||||
if err := cli.get(resourceRefreshToken, id, &r); err != nil {
|
||||
r, err := cli.getRefreshToken(id)
|
||||
if err != nil {
|
||||
return storage.RefreshToken{}, err
|
||||
}
|
||||
return storage.RefreshToken{
|
||||
RefreshToken: r.ObjectMeta.Name,
|
||||
ClientID: r.ClientID,
|
||||
ConnectorID: r.ConnectorID,
|
||||
Scopes: r.Scopes,
|
||||
Nonce: r.Nonce,
|
||||
Claims: toStorageClaims(r.Claims),
|
||||
ConnectorData: r.ConnectorData,
|
||||
}, nil
|
||||
return toStorageRefreshToken(r), nil
|
||||
}
|
||||
|
||||
func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) {
|
||||
err = cli.get(resourceRefreshToken, id, &r)
|
||||
return
|
||||
}
|
||||
|
||||
func (cli *client) ListClients() ([]storage.Client, error) {
|
||||
|
@ -311,6 +292,22 @@ func (cli *client) DeletePassword(email string) error {
|
|||
return cli.delete(resourcePassword, p.ObjectMeta.Name)
|
||||
}
|
||||
|
||||
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||
r, err := cli.getRefreshToken(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updated, err := updater(toStorageRefreshToken(r))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updated.ID = id
|
||||
|
||||
newToken := cli.fromStorageRefreshToken(updated)
|
||||
newToken.ObjectMeta = r.ObjectMeta
|
||||
return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken)
|
||||
}
|
||||
|
||||
func (cli *client) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
||||
c, err := cli.getClient(id)
|
||||
if err != nil {
|
||||
|
|
|
@ -362,9 +362,14 @@ type RefreshToken struct {
|
|||
k8sapi.TypeMeta `json:",inline"`
|
||||
k8sapi.ObjectMeta `json:"metadata,omitempty"`
|
||||
|
||||
CreatedAt time.Time
|
||||
LastUsed time.Time
|
||||
|
||||
ClientID string `json:"clientID"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
|
||||
Token string `json:"token,omitempty"`
|
||||
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
|
||||
Claims Claims `json:"claims,omitempty"`
|
||||
|
@ -379,6 +384,43 @@ type RefreshList struct {
|
|||
RefreshTokens []RefreshToken `json:"items"`
|
||||
}
|
||||
|
||||
func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
|
||||
return storage.RefreshToken{
|
||||
ID: r.ObjectMeta.Name,
|
||||
Token: r.Token,
|
||||
CreatedAt: r.CreatedAt,
|
||||
LastUsed: r.LastUsed,
|
||||
ClientID: r.ClientID,
|
||||
ConnectorID: r.ConnectorID,
|
||||
ConnectorData: r.ConnectorData,
|
||||
Scopes: r.Scopes,
|
||||
Nonce: r.Nonce,
|
||||
Claims: toStorageClaims(r.Claims),
|
||||
}
|
||||
}
|
||||
|
||||
func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken {
|
||||
return RefreshToken{
|
||||
TypeMeta: k8sapi.TypeMeta{
|
||||
Kind: kindRefreshToken,
|
||||
APIVersion: cli.apiVersion,
|
||||
},
|
||||
ObjectMeta: k8sapi.ObjectMeta{
|
||||
Name: r.ID,
|
||||
Namespace: cli.namespace,
|
||||
},
|
||||
Token: r.Token,
|
||||
CreatedAt: r.CreatedAt,
|
||||
LastUsed: r.LastUsed,
|
||||
ClientID: r.ClientID,
|
||||
ConnectorID: r.ConnectorID,
|
||||
ConnectorData: r.ConnectorData,
|
||||
Scopes: r.Scopes,
|
||||
Nonce: r.Nonce,
|
||||
Claims: fromStorageClaims(r.Claims),
|
||||
}
|
||||
}
|
||||
|
||||
// Keys is a mirrored struct from storage with JSON struct tags and Kubernetes
|
||||
// type metadata.
|
||||
type Keys struct {
|
||||
|
|
|
@ -98,10 +98,10 @@ func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) {
|
|||
|
||||
func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) {
|
||||
s.tx(func() {
|
||||
if _, ok := s.refreshTokens[r.RefreshToken]; ok {
|
||||
if _, ok := s.refreshTokens[r.ID]; ok {
|
||||
err = storage.ErrAlreadyExists
|
||||
} else {
|
||||
s.refreshTokens[r.RefreshToken] = r
|
||||
s.refreshTokens[r.ID] = r
|
||||
}
|
||||
})
|
||||
return
|
||||
|
@ -324,3 +324,17 @@ func (s *memStorage) UpdatePassword(email string, updater func(p storage.Passwor
|
|||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.RefreshToken) (storage.RefreshToken, error)) (err error) {
|
||||
s.tx(func() {
|
||||
r, ok := s.refreshTokens[id]
|
||||
if !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
if r, err = updater(r); err == nil {
|
||||
s.refreshTokens[id] = r
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -244,14 +244,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
|||
id, client_id, scopes, nonce,
|
||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||
claims_groups,
|
||||
connector_id, connector_data
|
||||
connector_id, connector_data,
|
||||
token, created_at, last_used
|
||||
)
|
||||
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11);
|
||||
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14);
|
||||
`,
|
||||
r.RefreshToken, r.ClientID, encoder(r.Scopes), r.Nonce,
|
||||
r.ID, r.ClientID, encoder(r.Scopes), r.Nonce,
|
||||
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
|
||||
encoder(r.Claims.Groups),
|
||||
r.ConnectorID, r.ConnectorData,
|
||||
r.Token, r.CreatedAt, r.LastUsed,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert refresh_token: %v", err)
|
||||
|
@ -259,13 +261,57 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
r, err := getRefresh(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r, err = updater(r); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`
|
||||
update refresh_token
|
||||
set
|
||||
client_id = $1,
|
||||
scopes = $2,
|
||||
nonce = $3,
|
||||
claims_user_id = $4,
|
||||
claims_username = $5,
|
||||
claims_email = $6,
|
||||
claims_email_verified = $7,
|
||||
claims_groups = $8,
|
||||
connector_id = $9,
|
||||
connector_data = $10,
|
||||
token = $11,
|
||||
created_at = $12,
|
||||
last_used = $13
|
||||
`,
|
||||
r.ClientID, encoder(r.Scopes), r.Nonce,
|
||||
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
|
||||
encoder(r.Claims.Groups),
|
||||
r.ConnectorID, r.ConnectorData,
|
||||
r.Token, r.CreatedAt, r.LastUsed,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update refresh token: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
|
||||
return scanRefresh(c.QueryRow(`
|
||||
return getRefresh(c, id)
|
||||
}
|
||||
|
||||
func getRefresh(q querier, id string) (storage.RefreshToken, error) {
|
||||
return scanRefresh(q.QueryRow(`
|
||||
select
|
||||
id, client_id, scopes, nonce,
|
||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||
claims_groups,
|
||||
connector_id, connector_data
|
||||
connector_id, connector_data,
|
||||
token, created_at, last_used
|
||||
from refresh_token where id = $1;
|
||||
`, id))
|
||||
}
|
||||
|
@ -276,7 +322,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
|
|||
id, client_id, scopes, nonce,
|
||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||
claims_groups,
|
||||
connector_id, connector_data
|
||||
connector_id, connector_data,
|
||||
token, created_at, last_used
|
||||
from refresh_token;
|
||||
`)
|
||||
if err != nil {
|
||||
|
@ -298,10 +345,11 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
|
|||
|
||||
func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
|
||||
err = s.Scan(
|
||||
&r.RefreshToken, &r.ClientID, decoder(&r.Scopes), &r.Nonce,
|
||||
&r.ID, &r.ClientID, decoder(&r.Scopes), &r.Nonce,
|
||||
&r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified,
|
||||
decoder(&r.Claims.Groups),
|
||||
&r.ConnectorID, &r.ConnectorData,
|
||||
&r.Token, &r.CreatedAt, &r.LastUsed,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
|
|
|
@ -155,4 +155,14 @@ var migrations = []migration{
|
|||
);
|
||||
`,
|
||||
},
|
||||
{
|
||||
stmt: `
|
||||
alter table refresh_token
|
||||
add column token text not null default '';
|
||||
alter table refresh_token
|
||||
add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';
|
||||
alter table refresh_token
|
||||
add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
|
|
@ -94,6 +94,7 @@ type Storage interface {
|
|||
UpdateClient(id string, updater func(old Client) (Client, error)) error
|
||||
UpdateKeys(updater func(old Keys) (Keys, error)) error
|
||||
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
|
||||
UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error
|
||||
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
||||
|
||||
// GarbageCollect deletes all expired AuthCodes and AuthRequests.
|
||||
|
@ -216,8 +217,15 @@ type AuthCode struct {
|
|||
// RefreshToken is an OAuth2 refresh token which allows a client to request new
|
||||
// tokens on the end user's behalf.
|
||||
type RefreshToken struct {
|
||||
// The actual refresh token.
|
||||
RefreshToken string
|
||||
ID string
|
||||
|
||||
// A single token that's rotated every time the refresh token is refreshed.
|
||||
//
|
||||
// May be empty.
|
||||
Token string
|
||||
|
||||
CreatedAt time.Time
|
||||
LastUsed time.Time
|
||||
|
||||
// Client this refresh token is valid for.
|
||||
ClientID string
|
||||
|
|
Loading…
Reference in a new issue