package integration import ( "encoding/base64" "fmt" "html/template" "net/http" "net/url" "testing" "time" "github.com/coreos/dex/client" clientmanager "github.com/coreos/dex/client/manager" "github.com/coreos/dex/connector" "github.com/coreos/dex/db" phttp "github.com/coreos/dex/pkg/http" "github.com/coreos/dex/refresh/refreshtest" "github.com/coreos/dex/server" "github.com/coreos/dex/session/manager" "github.com/coreos/dex/user" "github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/key" "github.com/coreos/go-oidc/oauth2" "github.com/coreos/go-oidc/oidc" ) func mockServer(cis []client.Client) (*server.Server, error) { dbMap := db.NewMemDB() k, err := key.GeneratePrivateKey() if err != nil { return nil, fmt.Errorf("Unable to generate private key: %v", err) } km := key.NewPrivateKeyManager() err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(time.Minute))) if err != nil { return nil, err } clientIDGenerator := func(hostport string) (string, error) { return hostport, nil } secGen := func() ([]byte, error) { return []byte("secret"), nil } clientRepo := db.NewClientRepo(dbMap) clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), cis, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen}) if err != nil { return nil, err } sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap)) srv := &server.Server{ IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"}, KeyManager: km, ClientRepo: clientRepo, ClientManager: clientManager, SessionManager: sm, } return srv, nil } func mockClient(srv *server.Server, ci client.Client) (*oidc.Client, error) { hdlr := srv.HTTPHandler() sClient := &phttp.HandlerClient{Handler: hdlr} cfg, err := oidc.FetchProviderConfig(sClient, srv.IssuerURL.String()) if err != nil { return nil, fmt.Errorf("failed to fetch provider config: %v", err) } jwks, err := srv.KeyManager.JWKs() if err != nil { return nil, fmt.Errorf("failed to generate JWKs: %v", err) } ks := key.NewPublicKeySet(jwks, time.Now().Add(1*time.Hour)) ccfg := oidc.ClientConfig{ HTTPClient: sClient, ProviderConfig: cfg, Credentials: ci.Credentials, KeySet: *ks, } return oidc.NewClient(ccfg) } func verifyUserClaims(claims jose.Claims, ci *client.Client, user *user.User, issuerURL url.URL) error { expectedSub, expectedName := ci.Credentials.ID, ci.Credentials.ID if user != nil { expectedSub, expectedName = user.ID, user.DisplayName } if aud, ok := claims["aud"].(string); !ok { return fmt.Errorf("unexpected claim value for aud, got=nil, want=%v", ci.Credentials.ID) } else if aud != ci.Credentials.ID { return fmt.Errorf("unexpected claim value for aud, got=%v, want=%v", aud, ci.Credentials.ID) } if sub, ok := claims["sub"].(string); !ok { return fmt.Errorf("unexpected claim value for sub, got=nil, want=%v", expectedSub) } else if sub != expectedSub { return fmt.Errorf("unexpected claim value for sub, got=%v, want=%v", sub, expectedSub) } if name, ok := claims["name"].(string); !ok { return fmt.Errorf("unexpected claim value for aud, got=nil, want=%v", expectedName) } else if name != expectedName { return fmt.Errorf("unexpected claim value for name, got=%v, want=%v", name, expectedName) } wantIss := issuerURL.String() if iss := claims["iss"].(string); iss != wantIss { return fmt.Errorf("unexpected claim value for iss, got=%v, want=%v", iss, wantIss) } return nil } func TestHTTPExchangeTokenRefreshToken(t *testing.T) { password, err := user.NewPasswordFromPlaintext("woof") if err != nil { t.Fatalf("unexpectd error: %q", err) } passwordInfo := user.PasswordInfo{ UserID: "elroy77", Password: password, } cfg := &connector.LocalConnectorConfig{ ID: "local", } validRedirURL := url.URL{ Scheme: "http", Host: "client.example.com", Path: "/callback", } ci := client.Client{ Credentials: oidc.ClientCredentials{ ID: validRedirURL.Host, Secret: base64.URLEncoding.EncodeToString([]byte("secret")), }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ validRedirURL, }, }, } clientIDGenerator := func(hostport string) (string, error) { return hostport, nil } secGen := func() ([]byte, error) { return []byte("secret"), nil } dbMap := db.NewMemDB() clientRepo := db.NewClientRepo(dbMap) clientManager, err := clientmanager.NewClientManagerFromClients(clientRepo, db.TransactionFactory(dbMap), []client.Client{ci}, clientmanager.ManagerOptions{ClientIDGenerator: clientIDGenerator, SecretGenerator: secGen}) if err != nil { t.Fatalf("Failed to create client identity manager: " + err.Error()) } passwordInfoRepo, err := db.NewPasswordInfoRepoFromPasswordInfos(db.NewMemDB(), []user.PasswordInfo{passwordInfo}) if err != nil { t.Fatalf("Failed to create password info repo: %v", err) } issuerURL := url.URL{Scheme: "http", Host: "server.example.com"} sm := manager.NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap)) k, err := key.GeneratePrivateKey() if err != nil { t.Fatalf("Unable to generate RSA key: %v", err) } km := key.NewPrivateKeyManager() err = km.Set(key.NewPrivateKeySet([]*key.PrivateKey{k}, time.Now().Add(time.Minute))) if err != nil { t.Fatalf("Unexpected error: %v", err) } usr := user.User{ ID: "ID-test", Email: "testemail@example.com", DisplayName: "displayname", } userRepo := db.NewUserRepo(db.NewMemDB()) if err := userRepo.Create(nil, usr); err != nil { t.Fatalf("Unexpected error: %v", err) } refreshTokenRepo := refreshtest.NewTestRefreshTokenRepo() srv := &server.Server{ IssuerURL: issuerURL, KeyManager: km, SessionManager: sm, ClientRepo: clientRepo, ClientManager: clientManager, Templates: template.New(connector.LoginPageTemplateName), Connectors: []connector.Connector{}, UserRepo: userRepo, PasswordInfoRepo: passwordInfoRepo, RefreshTokenRepo: refreshTokenRepo, } if err = srv.AddConnector(cfg); err != nil { t.Fatalf("Unexpected error: %v", err) } sClient := &phttp.HandlerClient{Handler: srv.HTTPHandler()} pcfg, err := oidc.FetchProviderConfig(sClient, issuerURL.String()) if err != nil { t.Fatalf("Failed to fetch provider config: %v", err) } ks := key.NewPublicKeySet([]jose.JWK{k.JWK()}, time.Now().Add(1*time.Hour)) ccfg := oidc.ClientConfig{ HTTPClient: sClient, ProviderConfig: pcfg, Credentials: ci.Credentials, RedirectURL: validRedirURL.String(), KeySet: *ks, } cl, err := oidc.NewClient(ccfg) if err != nil { t.Fatalf("Failed creating oidc.Client: %v", err) } m := http.NewServeMux() var claims jose.Claims var refresh string m.HandleFunc("/callback", handleCallbackFunc(cl, &claims, &refresh)) cClient := &phttp.HandlerClient{Handler: m} // this will actually happen due to some interaction between the // end-user and a remote identity provider sessionID, err := sm.NewSession("bogus_idpc", ci.Credentials.ID, "bogus", url.URL{}, "", false, []string{"openid", "offline_access"}) if err != nil { t.Fatalf("Unexpected error: %v", err) } if _, err = sm.AttachRemoteIdentity(sessionID, passwordInfo.Identity()); err != nil { t.Fatalf("Unexpected error: %v", err) } if _, err = sm.AttachUser(sessionID, usr.ID); err != nil { t.Fatalf("Unexpected error: %v", err) } key, err := sm.NewSessionKey(sessionID) if err != nil { t.Fatalf("Unexpected error: %v", err) } req, err := http.NewRequest("GET", fmt.Sprintf("http://client.example.com/callback?code=%s", key), nil) if err != nil { t.Fatalf("Failed creating HTTP request: %v", err) } resp, err := cClient.Do(req) if err != nil { t.Fatalf("Failed resolving HTTP requests against /callback: %v", err) } if err := verifyUserClaims(claims, &ci, &usr, issuerURL); err != nil { t.Fatalf("Failed to verify claims: %v", err) } if resp.StatusCode != http.StatusOK { t.Fatalf("Received status code %d, want %d", resp.StatusCode, http.StatusOK) } if refresh == "" { t.Fatalf("No refresh token") } // Use refresh token to get a new ID token. token, err := cl.RefreshToken(refresh) if err != nil { t.Fatalf("Unexpected error: %v", err) } claims, err = token.Claims() if err != nil { t.Fatalf("Failed parsing claims from client token: %v", err) } if err := verifyUserClaims(claims, &ci, &usr, issuerURL); err != nil { t.Fatalf("Failed to verify claims: %v", err) } } func TestHTTPClientCredsToken(t *testing.T) { validRedirURL := url.URL{ Scheme: "http", Host: "client.example.com", Path: "/callback", } ci := client.Client{ Credentials: oidc.ClientCredentials{ ID: validRedirURL.Host, Secret: base64.URLEncoding.EncodeToString([]byte("secret")), }, Metadata: oidc.ClientMetadata{ RedirectURIs: []url.URL{ validRedirURL, }, }, } cis := []client.Client{ci} srv, err := mockServer(cis) if err != nil { t.Fatalf("Unexpected error setting up server: %v", err) } cl, err := mockClient(srv, ci) if err != nil { t.Fatalf("Unexpected error setting up OIDC client: %v", err) } tok, err := cl.ClientCredsToken([]string{"openid"}) if err != nil { t.Fatalf("Failed getting client token: %v", err) } claims, err := tok.Claims() if err != nil { t.Fatalf("Failed parsing claims from client token: %v", err) } if err := verifyUserClaims(claims, &ci, nil, srv.IssuerURL); err != nil { t.Fatalf("Failed to verify claims: %v", err) } } func handleCallbackFunc(c *oidc.Client, claims *jose.Claims, refresh *string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") if code == "" { phttp.WriteError(w, http.StatusBadRequest, "code query param must be set") return } oac, err := c.OAuthClient() if err != nil { phttp.WriteError(w, http.StatusInternalServerError, fmt.Sprintf("unable to create oauth client: %v", err)) return } t, err := oac.RequestToken(oauth2.GrantTypeAuthCode, code) if err != nil { phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to verify auth code with issuer: %v", err)) return } // Get id token and claims. tok, err := jose.ParseJWT(t.IDToken) if err != nil { phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to parse id_token: %v", err)) return } if err := c.VerifyJWT(tok); err != nil { phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to verify the JWT: %v", err)) return } if *claims, err = tok.Claims(); err != nil { phttp.WriteError(w, http.StatusBadRequest, fmt.Sprintf("unable to construct claims: %v", err)) return } // Get refresh token. *refresh = t.RefreshToken w.WriteHeader(http.StatusOK) } }