2016-04-09 00:26:29 +05:30
|
|
|
package oidc
|
|
|
|
|
|
|
|
import (
|
|
|
|
"errors"
|
|
|
|
"net/http"
|
|
|
|
"reflect"
|
|
|
|
"testing"
|
|
|
|
|
|
|
|
"github.com/coreos/go-oidc/jose"
|
|
|
|
)
|
|
|
|
|
|
|
|
type staticTokenRefresher struct {
|
|
|
|
verify func(jose.JWT) error
|
|
|
|
refresh func() (jose.JWT, error)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *staticTokenRefresher) Verify(jwt jose.JWT) error {
|
|
|
|
return s.verify(jwt)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *staticTokenRefresher) Refresh() (jose.JWT, error) {
|
|
|
|
return s.refresh()
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestAuthenticatedTransportVerifiedJWT(t *testing.T) {
|
|
|
|
tests := []struct {
|
|
|
|
refresher TokenRefresher
|
|
|
|
startJWT jose.JWT
|
|
|
|
wantJWT jose.JWT
|
|
|
|
wantError error
|
|
|
|
}{
|
|
|
|
// verification succeeds, so refresh is not called
|
|
|
|
{
|
|
|
|
refresher: &staticTokenRefresher{
|
|
|
|
verify: func(jose.JWT) error { return nil },
|
|
|
|
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
|
|
|
|
},
|
|
|
|
startJWT: jose.JWT{RawPayload: "1"},
|
|
|
|
wantJWT: jose.JWT{RawPayload: "1"},
|
|
|
|
},
|
|
|
|
|
|
|
|
// verification fails, refresh succeeds so cached JWT changes
|
|
|
|
{
|
|
|
|
refresher: &staticTokenRefresher{
|
|
|
|
verify: func(jose.JWT) error { return errors.New("fail!") },
|
|
|
|
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
|
|
|
|
},
|
|
|
|
startJWT: jose.JWT{RawPayload: "1"},
|
|
|
|
wantJWT: jose.JWT{RawPayload: "2"},
|
|
|
|
},
|
|
|
|
|
|
|
|
// verification succeeds, so failing refresh isn't attempted
|
|
|
|
{
|
|
|
|
refresher: &staticTokenRefresher{
|
|
|
|
verify: func(jose.JWT) error { return nil },
|
|
|
|
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
|
|
|
|
},
|
|
|
|
startJWT: jose.JWT{RawPayload: "1"},
|
|
|
|
wantJWT: jose.JWT{RawPayload: "1"},
|
|
|
|
},
|
|
|
|
|
|
|
|
// verification fails, but refresh fails, too
|
|
|
|
{
|
|
|
|
refresher: &staticTokenRefresher{
|
|
|
|
verify: func(jose.JWT) error { return errors.New("fail!") },
|
|
|
|
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
|
|
|
|
},
|
|
|
|
startJWT: jose.JWT{RawPayload: "1"},
|
|
|
|
wantJWT: jose.JWT{},
|
|
|
|
wantError: errors.New("unable to acquire valid JWT: fail!"),
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
for i, tt := range tests {
|
|
|
|
at := &AuthenticatedTransport{
|
|
|
|
TokenRefresher: tt.refresher,
|
|
|
|
}
|
2016-05-10 02:59:14 +05:30
|
|
|
at.SetJWT(tt.startJWT)
|
2016-04-09 00:26:29 +05:30
|
|
|
|
|
|
|
gotJWT, err := at.verifiedJWT()
|
|
|
|
if !reflect.DeepEqual(tt.wantError, err) {
|
|
|
|
t.Errorf("#%d: unexpected error: want=%#v got=%#v", i, tt.wantError, err)
|
|
|
|
}
|
|
|
|
if !reflect.DeepEqual(tt.wantJWT, gotJWT) {
|
|
|
|
t.Errorf("#%d: incorrect JWT returned from verifiedJWT: want=%#v got=%#v", i, tt.wantJWT, gotJWT)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestAuthenticatedTransportJWTCaching(t *testing.T) {
|
|
|
|
at := &AuthenticatedTransport{
|
|
|
|
TokenRefresher: &staticTokenRefresher{
|
|
|
|
verify: func(jose.JWT) error { return errors.New("fail!") },
|
|
|
|
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
|
|
|
|
},
|
|
|
|
jwt: jose.JWT{RawPayload: "1"},
|
|
|
|
}
|
|
|
|
|
|
|
|
wantJWT := jose.JWT{RawPayload: "2"}
|
|
|
|
gotJWT, err := at.verifiedJWT()
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("got non-nil error: %#v", err)
|
|
|
|
}
|
|
|
|
if !reflect.DeepEqual(wantJWT, gotJWT) {
|
|
|
|
t.Fatalf("incorrect JWT returned from verifiedJWT: want=%#v got=%#v", wantJWT, gotJWT)
|
|
|
|
}
|
|
|
|
|
|
|
|
at.TokenRefresher = &staticTokenRefresher{
|
|
|
|
verify: func(jose.JWT) error { return nil },
|
|
|
|
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "3"}, nil },
|
|
|
|
}
|
|
|
|
|
|
|
|
// the previous JWT should still be cached on the AuthenticatedTransport since
|
|
|
|
// it is still valid, even though there's a new token ready to refresh
|
|
|
|
gotJWT, err = at.verifiedJWT()
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("got non-nil error: %#v", err)
|
|
|
|
}
|
|
|
|
if !reflect.DeepEqual(wantJWT, gotJWT) {
|
|
|
|
t.Fatalf("incorrect JWT returned from verifiedJWT: want=%#v got=%#v", wantJWT, gotJWT)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-05-10 02:59:14 +05:30
|
|
|
type fakeRoundTripper struct {
|
|
|
|
Request *http.Request
|
|
|
|
resp *http.Response
|
|
|
|
}
|
|
|
|
|
|
|
|
func (r *fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
|
|
r.Request = req
|
|
|
|
return r.resp, nil
|
|
|
|
}
|
|
|
|
|
2016-04-09 00:26:29 +05:30
|
|
|
func TestAuthenticatedTransportRoundTrip(t *testing.T) {
|
2016-05-10 02:59:14 +05:30
|
|
|
rr := &fakeRoundTripper{nil, &http.Response{StatusCode: http.StatusOK}}
|
2016-04-09 00:26:29 +05:30
|
|
|
at := &AuthenticatedTransport{
|
|
|
|
TokenRefresher: &staticTokenRefresher{
|
|
|
|
verify: func(jose.JWT) error { return nil },
|
|
|
|
},
|
|
|
|
RoundTripper: rr,
|
|
|
|
jwt: jose.JWT{RawPayload: "1"},
|
|
|
|
}
|
|
|
|
|
|
|
|
req := http.Request{}
|
|
|
|
_, err := at.RoundTrip(&req)
|
|
|
|
if err != nil {
|
|
|
|
t.Errorf("unexpected error: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if !reflect.DeepEqual(req, http.Request{}) {
|
|
|
|
t.Errorf("http.Request object was modified")
|
|
|
|
}
|
|
|
|
|
|
|
|
want := []string{"Bearer .1."}
|
|
|
|
got := rr.Request.Header["Authorization"]
|
|
|
|
if !reflect.DeepEqual(want, got) {
|
|
|
|
t.Errorf("incorrect Authorization header: want=%#v got=%#v", want, got)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestAuthenticatedTransportRoundTripRefreshFail(t *testing.T) {
|
2016-05-10 02:59:14 +05:30
|
|
|
rr := &fakeRoundTripper{nil, &http.Response{StatusCode: http.StatusOK}}
|
2016-04-09 00:26:29 +05:30
|
|
|
at := &AuthenticatedTransport{
|
|
|
|
TokenRefresher: &staticTokenRefresher{
|
|
|
|
verify: func(jose.JWT) error { return errors.New("fail!") },
|
|
|
|
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
|
|
|
|
},
|
|
|
|
RoundTripper: rr,
|
|
|
|
jwt: jose.JWT{RawPayload: "1"},
|
|
|
|
}
|
|
|
|
|
|
|
|
_, err := at.RoundTrip(&http.Request{})
|
|
|
|
if err == nil {
|
|
|
|
t.Errorf("expected non-nil error")
|
|
|
|
}
|
|
|
|
}
|