refreshtoken: return base64 encoded token for in-memory backend.

Previously if we use the in-memory backend, it will return a raw
binary token for refresh token. This fixes the case.
This commit is contained in:
Yifan Gu 2015-10-12 14:38:02 -07:00
parent 2a1d32e6e8
commit 7282dd5187
2 changed files with 20 additions and 14 deletions

View file

@ -3,6 +3,7 @@ package refresh
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
@ -66,7 +67,7 @@ type memRefreshTokenRepo struct {
// buildToken combines the token ID and token payload to create a new token. // buildToken combines the token ID and token payload to create a new token.
func buildToken(tokenID int, tokenPayload []byte) string { func buildToken(tokenID int, tokenPayload []byte) string {
return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, tokenPayload) return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload))
} }
// parseToken parses a token and returns the token ID and token payload. // parseToken parses a token and returns the token ID and token payload.
@ -79,7 +80,11 @@ func parseToken(token string) (int, []byte, error) {
if err != nil { if err != nil {
return -1, nil, ErrorInvalidToken return -1, nil, ErrorInvalidToken
} }
return id, []byte(parts[1]), nil tokenPayload, err := base64.URLEncoding.DecodeString(parts[1])
if err != nil {
return -1, nil, ErrorInvalidToken
}
return id, tokenPayload, nil
} }
// NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development. // NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development.

View file

@ -1,6 +1,7 @@
package server package server
import ( import (
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
@ -373,7 +374,7 @@ func TestServerCodeToken(t *testing.T) {
// Have 'offline_access' in scope, should get non-empty refresh token. // Have 'offline_access' in scope, should get non-empty refresh token.
{ {
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
refreshToken: "0/refresh-1", refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
}, },
} }
@ -475,7 +476,7 @@ func TestServerTokenFail(t *testing.T) {
argCC: ccFixture, argCC: ccFixture,
argKey: keyFixture, argKey: keyFixture,
scope: []string{"openid", "offline_access"}, scope: []string{"openid", "offline_access"},
refreshToken: "0/refresh-1", refreshToken: fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
}, },
// no 'offline_access' in 'scope', should get empty refresh token // no 'offline_access' in 'scope', should get empty refresh token
@ -605,7 +606,7 @@ func TestServerRefreshToken(t *testing.T) {
}{ }{
// Everything is good. // Everything is good.
{ {
"0/refresh-1", fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
@ -621,7 +622,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid refresh token(invalid payload content). // Invalid refresh token(invalid payload content).
{ {
"0/refresh-2", fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-2"))),
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
@ -629,7 +630,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid refresh token(invalid ID content). // Invalid refresh token(invalid ID content).
{ {
"1/refresh-2", fmt.Sprintf("1/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
credXXX, credXXX,
signerFixture, signerFixture,
@ -637,7 +638,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(client is not associated with the token). // Invalid client(client is not associated with the token).
{ {
"0/refresh-1", fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
credYYY, credYYY,
signerFixture, signerFixture,
@ -645,7 +646,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(no client ID). // Invalid client(no client ID).
{ {
"0/refresh-1", fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
oidc.ClientCredentials{ID: "", Secret: "aaa"}, oidc.ClientCredentials{ID: "", Secret: "aaa"},
signerFixture, signerFixture,
@ -653,7 +654,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(no such client). // Invalid client(no such client).
{ {
"0/refresh-1", fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
oidc.ClientCredentials{ID: "AAA", Secret: "aaa"}, oidc.ClientCredentials{ID: "AAA", Secret: "aaa"},
signerFixture, signerFixture,
@ -661,7 +662,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(no secrets). // Invalid client(no secrets).
{ {
"0/refresh-1", fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
oidc.ClientCredentials{ID: "XXX"}, oidc.ClientCredentials{ID: "XXX"},
signerFixture, signerFixture,
@ -669,7 +670,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Invalid client(invalid secret). // Invalid client(invalid secret).
{ {
"0/refresh-1", fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"}, oidc.ClientCredentials{ID: "XXX", Secret: "bad-secret"},
signerFixture, signerFixture,
@ -677,7 +678,7 @@ func TestServerRefreshToken(t *testing.T) {
}, },
// Signing operation fails. // Signing operation fails.
{ {
"0/refresh-1", fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))),
"XXX", "XXX",
credXXX, credXXX,
&StaticSigner{sig: nil, err: errors.New("fail")}, &StaticSigner{sig: nil, err: errors.New("fail")},
@ -784,7 +785,7 @@ func TestServerRefreshToken(t *testing.T) {
} }
srv.UserRepo = userRepo srv.UserRepo = userRepo
_, err = srv.RefreshToken(credXXX, "0/refresh-1") _, err = srv.RefreshToken(credXXX, fmt.Sprintf("0/%s", base64.URLEncoding.EncodeToString([]byte("refresh-1"))))
if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) { if !reflect.DeepEqual(err, oauth2.NewError(oauth2.ErrorServerError)) {
t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err) t.Errorf("Expect: %v, got: %v", oauth2.NewError(oauth2.ErrorServerError), err)
} }