diff --git a/Makefile b/Makefile index 8006982a..9519a186 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ fmt: @go fmt $(shell go list ./... | grep -v '/vendor/') lint: - @for package in $(shell go list ./... | grep -v '/vendor/' | grep -v '/api'); do \ + @for package in $(shell go list ./... | grep -v '/vendor/' | grep -v '/api' | grep -v '/server/internal'); do \ golint -set_exit_status $$package $$i || exit 1; \ done @@ -81,12 +81,15 @@ aci: clean-release _output/bin/dex _output/images/library-alpine-3.4.aci docker-image: clean-release _output/bin/dex @sudo docker build -t $(DOCKER_IMAGE) . -.PHONY: grpc -grpc: api/api.pb.go +.PHONY: proto +proto: api/api.pb.go server/internal/types.pb.go api/api.pb.go: api/api.proto bin/protoc bin/protoc-gen-go @protoc --go_out=plugins=grpc:. api/*.proto +server/internal/types.pb.go: server/internal/types.proto bin/protoc bin/protoc-gen-go + @protoc --go_out=. server/internal/*.proto + bin/protoc: scripts/get-protoc @./scripts/get-protoc bin/protoc diff --git a/server/handlers.go b/server/handlers.go index 808f8031..ef264dfe 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -2,6 +2,7 @@ package server import ( "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -16,6 +17,7 @@ import ( jose "gopkg.in/square/go-jose.v2" "github.com/coreos/dex/connector" + "github.com/coreos/dex/server/internal" "github.com/coreos/dex/storage" ) @@ -645,20 +647,32 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s var refreshToken string if reqRefresh { refresh := storage.RefreshToken{ - RefreshToken: storage.NewID(), + ID: storage.NewID(), + Token: storage.NewID(), ClientID: authCode.ClientID, ConnectorID: authCode.ConnectorID, Scopes: authCode.Scopes, Claims: authCode.Claims, Nonce: authCode.Nonce, ConnectorData: authCode.ConnectorData, + CreatedAt: s.now(), + LastUsed: s.now(), } + token := &internal.RefreshToken{ + RefreshId: refresh.ID, + Token: refresh.Token, + } + if refreshToken, err = internal.Marshal(token); err != nil { + s.logger.Errorf("failed to marshal refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + if err := s.storage.CreateRefresh(refresh); err != nil { s.logger.Errorf("failed to create refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - refreshToken = refresh.RefreshToken } s.writeAccessToken(w, idToken, refreshToken, expiry) } @@ -672,16 +686,37 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - refresh, err := s.storage.GetRefresh(code) - if err != nil || refresh.ClientID != client.ID { - if err != storage.ErrNotFound { - s.logger.Errorf("failed to get auth code: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - } else { + token := new(internal.RefreshToken) + if err := internal.Unmarshal(code, token); err != nil { + // For backward compatibility, assume the refresh_token is a raw refresh token ID + // if it fails to decode. + // + // Because refresh_token values that aren't unmarshable were generated by servers + // that don't have a Token value, we'll still reject any attempts to claim a + // refresh_token twice. + token = &internal.RefreshToken{RefreshId: code, Token: ""} + } + + refresh, err := s.storage.GetRefresh(token.RefreshId) + if err != nil { + s.logger.Errorf("failed to get refresh token: %v", err) + if err == storage.ErrNotFound { s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) + } else { + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) } return } + if refresh.ClientID != client.ID { + s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID) + s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) + return + } + if refresh.Token != token.Token { + s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) + s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) + return + } // Per the OAuth2 spec, if the client has omitted the scopes, default to the original // authorized scopes. @@ -720,6 +755,14 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } + ident := connector.Identity{ + UserID: refresh.Claims.UserID, + Username: refresh.Claims.Username, + Email: refresh.Claims.Email, + EmailVerified: refresh.Claims.EmailVerified, + Groups: refresh.Claims.Groups, + ConnectorData: refresh.ConnectorData, + } // Can the connector refresh the identity? If so, attempt to refresh the data // in the connector. @@ -727,52 +770,63 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie // TODO(ericchiang): We may want a strict mode where connectors that don't implement // this interface can't perform refreshing. if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok { - ident := connector.Identity{ - UserID: refresh.Claims.UserID, - Username: refresh.Claims.Username, - Email: refresh.Claims.Email, - EmailVerified: refresh.Claims.EmailVerified, - Groups: refresh.Claims.Groups, - ConnectorData: refresh.ConnectorData, - } - ident, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident) + newIdent, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident) if err != nil { s.logger.Errorf("failed to refresh identity: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - - // Update the claims of the refresh token. - // - // UserID intentionally ignored for now. - refresh.Claims.Username = ident.Username - refresh.Claims.Email = ident.Email - refresh.Claims.EmailVerified = ident.EmailVerified - refresh.Claims.Groups = ident.Groups - refresh.ConnectorData = ident.ConnectorData + ident = newIdent } - idToken, expiry, err := s.newIDToken(client.ID, refresh.Claims, scopes, refresh.Nonce) + claims := storage.Claims{ + UserID: ident.UserID, + Username: ident.Username, + Email: ident.Email, + EmailVerified: ident.EmailVerified, + Groups: ident.Groups, + } + + idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - // Refresh tokens are claimed exactly once. Delete the current token and - // create a new one. - if err := s.storage.DeleteRefresh(code); err != nil { - s.logger.Errorf("failed to delete auth code: %v", err) + newToken := &internal.RefreshToken{ + RefreshId: refresh.ID, + Token: storage.NewID(), + } + rawNewToken, err := internal.Marshal(newToken) + if err != nil { + s.logger.Errorf("failed to marshal refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - refresh.RefreshToken = storage.NewID() - if err := s.storage.CreateRefresh(refresh); err != nil { - s.logger.Errorf("failed to create refresh token: %v", err) + + updater := func(old storage.RefreshToken) (storage.RefreshToken, error) { + if old.Token != refresh.Token { + return old, errors.New("refresh token claimed twice") + } + old.Token = newToken.Token + // Update the claims of the refresh token. + // + // UserID intentionally ignored for now. + old.Claims.Username = ident.Username + old.Claims.Email = ident.Email + old.Claims.EmailVerified = ident.EmailVerified + old.Claims.Groups = ident.Groups + old.ConnectorData = ident.ConnectorData + old.LastUsed = s.now() + return old, nil + } + if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil { + s.logger.Errorf("failed to update refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } - s.writeAccessToken(w, idToken, refresh.RefreshToken, expiry) + s.writeAccessToken(w, idToken, rawNewToken, expiry) } func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, refreshToken string, expiry time.Time) { diff --git a/server/internal/codec.go b/server/internal/codec.go new file mode 100644 index 00000000..a92c26f9 --- /dev/null +++ b/server/internal/codec.go @@ -0,0 +1,25 @@ +package internal + +import ( + "encoding/base64" + + "github.com/golang/protobuf/proto" +) + +// Marshal converts a protobuf message to a URL legal string. +func Marshal(message proto.Message) (string, error) { + data, err := proto.Marshal(message) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(data), nil +} + +// Unmarshal decodes a protobuf message. +func Unmarshal(s string, message proto.Message) error { + data, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return err + } + return proto.Unmarshal(data, message) +} diff --git a/server/internal/types.proto b/server/internal/types.proto new file mode 100644 index 00000000..442dbd95 --- /dev/null +++ b/server/internal/types.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +// Package internal holds protobuf types used by the server +package internal; + +// RefreshToken is a message that holds refresh token data used by dex. +message RefreshToken { + string refresh_id = 1; + string token = 2; +} diff --git a/server/server_test.go b/server/server_test.go index 7c499c15..d848076f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -237,6 +237,10 @@ func TestOAuth2CodeFlow(t *testing.T) { if token.RefreshToken == newToken.RefreshToken { return fmt.Errorf("old refresh token was the same as the new token %q", token.RefreshToken) } + + if _, err := config.TokenSource(ctx, token).Token(); err == nil { + return errors.New("was able to redeem the same refresh token twice") + } return nil }, },