This repository has been archived on 2022-08-17. You can view files and clone it, but cannot push or open issues or pull requests.
dex/integration/oidc_test.go
Eric Chiang ac73d3cdf2 *: load password infos from users file in no-db mode not connectors
In --no-db mode, load passwords from the users file instead of the
connectors file. This allows us to remove the password infos field
from the local connector and stop loading them during connector
registration, a case that was causing panics when using a real
database (see #286).

Fixes #286
Closes #340
2016-04-06 11:15:05 -07:00

340 lines
9.2 KiB
Go

package integration
import (
"encoding/base64"
"fmt"
"html/template"
"net/http"
"net/url"
"testing"
"time"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
phttp "github.com/coreos/dex/pkg/http"
"github.com/coreos/dex/refresh/refreshtest"
"github.com/coreos/dex/server"
"github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
)
func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) {
dbMap := db.NewMemDB()
k, err := key.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("Unable to generate private key: %v", err)
}
km := key.NewPrivateKeyManager()
err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(time.Minute)))
if err != nil {
return nil, err
}
clientIdentityRepo, err := db.NewClientIdentityRepoFromClients(dbMap, cis)
if err != nil {
return nil, err
}
sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
srv := &server.Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km,
ClientIdentityRepo: clientIdentityRepo,
SessionManager: sm,
}
return srv, nil
}
func mockClient(srv *server.Server, ci oidc.ClientIdentity) (*oidc.Client, error) {
hdlr := srv.HTTPHandler()
sClient := &phttp.HandlerClient{Handler: hdlr}
cfg, err := oidc.FetchProviderConfig(sClient, srv.IssuerURL.String())
if err != nil {
return nil, fmt.Errorf("failed to fetch provider config: %v", err)
}
jwks, err := srv.KeyManager.JWKs()
if err != nil {
return nil, fmt.Errorf("failed to generate JWKs: %v", err)
}
ks := key.NewPublicKeySet(jwks, time.Now().Add(1*time.Hour))
ccfg := oidc.ClientConfig{
HTTPClient: sClient,
ProviderConfig: cfg,
Credentials: ci.Credentials,
KeySet: *ks,
}
return oidc.NewClient(ccfg)
}
func verifyUserClaims(claims jose.Claims, ci *oidc.ClientIdentity, user *user.User, issuerURL url.URL) error {
expectedSub, expectedName := ci.Credentials.ID, ci.Credentials.ID
if user != nil {
expectedSub, expectedName = user.ID, user.DisplayName
}
if aud := claims["aud"].(string); aud != ci.Credentials.ID {
return fmt.Errorf("unexpected claim value for aud, got=%v, want=%v", aud, ci.Credentials.ID)
}
if sub := claims["sub"].(string); sub != expectedSub {
return fmt.Errorf("unexpected claim value for sub, got=%v, want=%v", sub, expectedSub)
}
if name := claims["name"].(string); name != expectedName {
return fmt.Errorf("unexpected claim value for name, got=%v, want=%v", name, expectedName)
}
wantIss := issuerURL.String()
if iss := claims["iss"].(string); iss != wantIss {
return fmt.Errorf("unexpected claim value for iss, got=%v, want=%v", iss, wantIss)
}
return nil
}
func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
password, err := user.NewPasswordFromPlaintext("woof")
if err != nil {
t.Fatalf("unexpectd error: %q", err)
}
passwordInfo := user.PasswordInfo{
UserID: "elroy77",
Password: password,
}
cfg := &connector.LocalConnectorConfig{
ID: "local",
}
ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{
ID: "72de74a9",
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
},
}
dbMap := db.NewMemDB()
cir, err := db.NewClientIdentityRepoFromClients(dbMap, []oidc.ClientIdentity{ci})
if err != nil {
t.Fatalf("Failed to create client identity repo: " + err.Error())
}
passwordInfoRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(db.NewMemDB(), []user.PasswordInfo{passwordInfo})
if err != nil {
t.Fatalf("Failed to create password info repo: %v", err)
}
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
k, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("Unable to generate RSA key: %v", err)
}
km := key.NewPrivateKeyManager()
err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(time.Minute)))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
usr := user.User{
ID: "ID-test",
Email: "testemail@example.com",
DisplayName: "displayname",
}
userRepo := db.NewUserRepo(db.NewMemDB())
if err := userRepo.Create(nil, usr); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo()
srv := &server.Server{
IssuerURL: issuerURL,
KeyManager: km,
SessionManager: sm,
ClientIdentityRepo: cir,
Templates: template.New(connector.LoginPageTemplateName),
Connectors: []connector.Connector{},
UserRepo: userRepo,
PasswordInfoRepo: passwordInfoRepo,
RefreshTokenRepo: refreshTokenRepo,
}
if err = srv.AddConnector(cfg); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
sClient := &phttp.HandlerClient{Handler: srv.HTTPHandler()}
pcfg, err := oidc.FetchProviderConfig(sClient, issuerURL.String())
if err != nil {
t.Fatalf("Failed to fetch provider config: %v", err)
}
ks := key.NewPublicKeySet([]jose.JWK{k.JWK()}, time.Now().Add(1*time.Hour))
ccfg := oidc.ClientConfig{
HTTPClient: sClient,
ProviderConfig: pcfg,
Credentials: ci.Credentials,
RedirectURL: "http://client.example.com",
KeySet: *ks,
}
cl, err := oidc.NewClient(ccfg)
if err != nil {
t.Fatalf("Failed creating oidc.Client: %v", err)
}
m := http.NewServeMux()
var claims jose.Claims
var refresh string
m.HandleFunc("/callback", handleCallbackFunc(cl, &claims, &refresh))
cClient := &phttp.HandlerClient{Handler: m}
// this will actually happen due to some interaction between the
// end-user and a remote identity provider
sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if _, err = sm.AttachRemoteIdentity(sessionID, passwordInfo.Identity()); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if _, err = sm.AttachUser(sessionID, usr.ID); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
key, err := sm.NewSessionKey(sessionID)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
req, err := http.NewRequest("GET", fmt.Sprintf("http://client.example.com/callback?code=%s", key), nil)
if err != nil {
t.Fatalf("Failed creating HTTP request: %v", err)
}
resp, err := cClient.Do(req)
if err != nil {
t.Fatalf("Failed resolving HTTP requests against /callback: %v", err)
}
if err := verifyUserClaims(claims, &ci, &usr, issuerURL); err != nil {
t.Fatalf("Failed to verify claims: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("Received status code %d, want %d", resp.StatusCode, http.StatusOK)
}
if refresh == "" {
t.Fatalf("No refresh token")
}
// Use refresh token to get a new ID token.
token, err := cl.RefreshToken(refresh)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
claims, err = token.Claims()
if err != nil {
t.Fatalf("Failed parsing claims from client token: %v", err)
}
if err := verifyUserClaims(claims, &ci, &usr, issuerURL); err != nil {
t.Fatalf("Failed to verify claims: %v", err)
}
}
func TestHTTPClientCredsToken(t *testing.T) {
ci := oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{
ID: "72de74a9",
Secret: base64.URLEncoding.EncodeToString([]byte("XXX")),
},
}
cis := []oidc.ClientIdentity{ci}
srv, err := mockServer(cis)
if err != nil {
t.Fatalf("Unexpected error setting up server: %v", err)
}
cl, err := mockClient(srv, ci)
if err != nil {
t.Fatalf("Unexpected error setting up OIDC client: %v", err)
}
tok, err := cl.ClientCredsToken([]string{"openid"})
if err != nil {
t.Fatalf("Failed getting client token: %v", err)
}
claims, err := tok.Claims()
if err != nil {
t.Fatalf("Failed parsing claims from client token: %v", err)
}
if err := verifyUserClaims(claims, &ci, nil, srv.IssuerURL); err != nil {
t.Fatalf("Failed to verify claims: %v", err)
}
}
func handleCallbackFunc(c *oidc.Client, claims *jose.Claims, refresh *string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
if code == "" {
phttp.WriteError(w, http.StatusBadRequest, "code query param must be set")
return
}
oac, err := c.OAuthClient()
if err != nil {
phttp.WriteError(w, http.StatusInternalServerError, fmt.Sprintf("unable to create oauth client: %v", err))
return
}
t, err := oac.RequestToken(oauth2.GrantTypeAuthCode, code)
if err != nil {
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to verify auth code with issuer: %v", err))
return
}
// Get id token and claims.
tok, err := jose.ParseJWT(t.IDToken)
if err != nil {
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to parse id_token: %v", err))
return
}
if err := c.VerifyJWT(tok); err != nil {
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to verify the JWT: %v", err))
return
}
if *claims, err = tok.Claims(); err != nil {
phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to construct claims: %v", err))
return
}
// Get refresh token.
*refresh = t.RefreshToken
w.WriteHeader(http.StatusOK)
}
}