package refresh import ( "crypto/rand" "encoding/base64" "errors" "fmt" "strconv" "strings" ) const ( DefaultRefreshTokenPayloadLength = 64 TokenDelimer = "/" ) var ( ErrorInvalidUserID = errors.New("invalid user ID") ErrorInvalidClientID = errors.New("invalid client ID") ErrorInvalidToken = errors.New("invalid token") ) type RefreshTokenGenerator func() (string, error) func (g RefreshTokenGenerator) Generate() (string, error) { return g() } func DefaultRefreshTokenGenerator() (string, error) { // TODO(yifan) Remove this duplicated token generate function. b := make([]byte, DefaultRefreshTokenPayloadLength) n, err := rand.Read(b) if err != nil { return "", err } else if n != DefaultRefreshTokenPayloadLength { return "", errors.New("unable to read enough random bytes") } return base64.URLEncoding.EncodeToString(b), nil } type RefreshTokenRepo interface { // Create generates and returns a new refresh token for the given client-user pair. // On success the token will be return. Create(userID, clientID string) (string, error) // Verify verifies that a token belongs to the client, and returns the corresponding user ID. // Note that this assumes the client validation is currently done in the application layer, Verify(clientID, token string) (string, error) // Revoke deletes the refresh token if the token belongs to the given userID. Revoke(userID, token string) error } type refreshToken struct { payload string userID string clientID string } type memRefreshTokenRepo struct { store map[int]refreshToken tokenGenerator RefreshTokenGenerator } // buildToken combines the token ID and token payload to create a new token. func buildToken(tokenID int, tokenPayload string) string { return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, tokenPayload) } // parseToken parses a token and returns the token ID and token payload. func parseToken(token string) (int, string, error) { parts := strings.SplitN(token, TokenDelimer, 2) if len(parts) != 2 { return -1, "", ErrorInvalidToken } id, err := strconv.Atoi(parts[0]) if err != nil { return -1, "", ErrorInvalidToken } return id, parts[1], nil } // NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development. func NewRefreshTokenRepo() RefreshTokenRepo { return NewRefreshTokenRepoWithTokenGenerator(DefaultRefreshTokenGenerator) } func NewRefreshTokenRepoWithTokenGenerator(tokenGenerator RefreshTokenGenerator) RefreshTokenRepo { repo := &memRefreshTokenRepo{} repo.store = make(map[int]refreshToken) repo.tokenGenerator = tokenGenerator return repo } func (r *memRefreshTokenRepo) Create(userID, clientID string) (string, error) { // Validate userID. if userID == "" { return "", ErrorInvalidUserID } // Validate clientID. if clientID == "" { return "", ErrorInvalidClientID } // Generate and store token. tokenPayload, err := r.tokenGenerator.Generate() if err != nil { return "", err } tokenID := len(r.store) // Should only be used in single threaded tests. // No limits on the number of tokens per user/client for this in-memory repo. r.store[tokenID] = refreshToken{ payload: tokenPayload, userID: userID, clientID: clientID, } return buildToken(tokenID, tokenPayload), nil } func (r *memRefreshTokenRepo) Verify(clientID, token string) (string, error) { tokenID, tokenPayload, err := parseToken(token) if err != nil { return "", err } record, ok := r.store[tokenID] if !ok { return "", ErrorInvalidToken } if record.payload != tokenPayload { return "", ErrorInvalidToken } if record.clientID != clientID { return "", ErrorInvalidClientID } return record.userID, nil } func (r *memRefreshTokenRepo) Revoke(userID, token string) error { tokenID, tokenPayload, err := parseToken(token) if err != nil { return err } record, ok := r.store[tokenID] if !ok { return ErrorInvalidToken } if record.payload != tokenPayload { return ErrorInvalidToken } if record.userID != userID { return ErrorInvalidUserID } delete(r.store, tokenID) return nil }