dex/server/auth_middleware_test.go
2015-08-18 11:26:57 -07:00

200 lines
4.6 KiB
Go

package server
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/coreos/dex/client"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
"github.com/coreos/go-oidc/oidc"
)
type staticHandler struct{}
func (h staticHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func TestClientToken(t *testing.T) {
now := time.Now()
tomorrow := now.Add(24 * time.Hour)
validClientID := "valid-client"
ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{
ID: validClientID,
},
}
repo := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci})
privKey, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("Failed to generate private key, error=%v", err)
}
signer := privKey.Signer()
pubKey := *key.NewPublicKey(privKey.JWK())
validIss := "https://example.com"
makeToken := func(iss, sub, aud string, iat, exp time.Time) string {
claims := oidc.NewClaims(iss, sub, aud, iat, exp)
jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil {
t.Fatalf("Failed to generate JWT, error=%v", err)
}
return jwt.Encode()
}
validJWT := makeToken(validIss, validClientID, validClientID, now, tomorrow)
invalidJWT := makeToken("", "", "", now, tomorrow)
tests := []struct {
keys []key.PublicKey
repo client.ClientIdentityRepo
header string
wantCode int
}{
// valid token
{
keys: []key.PublicKey{pubKey},
repo: repo,
header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusOK,
},
// invalid token
{
keys: []key.PublicKey{pubKey},
repo: repo,
header: fmt.Sprintf("BEARER %s", invalidJWT),
wantCode: http.StatusUnauthorized,
},
// empty header
{
keys: []key.PublicKey{pubKey},
repo: repo,
header: "",
wantCode: http.StatusUnauthorized,
},
// unparsable token
{
keys: []key.PublicKey{pubKey},
repo: repo,
header: "BEARER xxx",
wantCode: http.StatusUnauthorized,
},
// no verification keys
{
keys: []key.PublicKey{},
repo: repo,
header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusUnauthorized,
},
// nil repo
{
keys: []key.PublicKey{pubKey},
repo: nil,
header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusUnauthorized,
},
// empty repo
{
keys: []key.PublicKey{pubKey},
repo: client.NewClientIdentityRepo(nil),
header: fmt.Sprintf("BEARER %s", validJWT),
wantCode: http.StatusUnauthorized,
},
// client not in repo
{
keys: []key.PublicKey{pubKey},
repo: repo,
header: fmt.Sprintf("BEARER %s", makeToken(validIss, "DOESNT-EXIST", "DOESNT-EXIST", now, tomorrow)),
wantCode: http.StatusUnauthorized,
},
}
for i, tt := range tests {
w := httptest.NewRecorder()
mw := &clientTokenMiddleware{
issuerURL: validIss,
ciRepo: tt.repo,
keysFunc: func() ([]key.PublicKey, error) {
return tt.keys, nil
},
next: staticHandler{},
}
req := &http.Request{
Header: http.Header{
"Authorization": []string{tt.header},
},
}
mw.ServeHTTP(w, req)
if tt.wantCode != w.Code {
t.Errorf("case %d: invalid response code, want=%d, got=%d", i, tt.wantCode, w.Code)
}
}
}
func TestGetClientIDFromAuthorizedRequest(t *testing.T) {
now := time.Now()
tomorrow := now.Add(24 * time.Hour)
privKey, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("Failed to generate private key, error=%v", err)
}
signer := privKey.Signer()
makeToken := func(iss, sub, aud string, iat, exp time.Time) string {
claims := oidc.NewClaims(iss, sub, aud, iat, exp)
jwt, err := jose.NewSignedJWT(claims, signer)
if err != nil {
t.Fatalf("Failed to generate JWT, error=%v", err)
}
return jwt.Encode()
}
tests := []struct {
header string
wantClient string
wantErr bool
}{
{
header: fmt.Sprintf("BEARER %s", makeToken("iss", "CLIENT_ID", "", now, tomorrow)),
wantClient: "CLIENT_ID",
wantErr: false,
},
{
header: fmt.Sprintf("BEARER %s", makeToken("iss", "", "", now, tomorrow)),
wantErr: true,
},
}
for i, tt := range tests {
req := &http.Request{
Header: http.Header{
"Authorization": []string{tt.header},
},
}
gotClient, err := getClientIDFromAuthorizedRequest(req)
if tt.wantErr {
if err == nil {
t.Errorf("case %d: want non-nil err", i)
}
continue
}
if err != nil {
t.Errorf("case %d: got err: %q", i, err)
continue
}
if gotClient != tt.wantClient {
t.Errorf("case %d: want=%v, got=%v", i, tt.wantClient, gotClient)
}
}
}