diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 3d07f2ff..255e1d95 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -279,6 +279,9 @@ type Expiry struct { // AuthRequests defines the duration of time for which the AuthRequests will be valid. AuthRequests string `json:"authRequests"` + + // DeviceRequests defines the duration of time for which the DeviceRequests will be valid. + DeviceRequests string `json:"deviceRequests"` } // Logger holds configuration required to customize logging for dex. diff --git a/cmd/dex/config_test.go b/cmd/dex/config_test.go index 12c7b218..bced93b5 100644 --- a/cmd/dex/config_test.go +++ b/cmd/dex/config_test.go @@ -119,6 +119,7 @@ expiry: signingKeys: "7h" idTokens: "25h" authRequests: "25h" + deviceRequests: "10m" logger: level: "debug" @@ -197,9 +198,10 @@ logger: }, }, Expiry: Expiry{ - SigningKeys: "7h", - IDTokens: "25h", - AuthRequests: "25h", + SigningKeys: "7h", + IDTokens: "25h", + AuthRequests: "25h", + DeviceRequests: "10m", }, Logger: Logger{ Level: "debug", diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index d0e8f9ac..ca740593 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -269,7 +269,14 @@ func serve(cmd *cobra.Command, args []string) error { logger.Infof("config auth requests valid for: %v", authRequests) serverConfig.AuthRequestsValidFor = authRequests } - + if c.Expiry.DeviceRequests != "" { + deviceRequests, err := time.ParseDuration(c.Expiry.DeviceRequests) + if err != nil { + return fmt.Errorf("invalid config value %q for device request expiry: %v", c.Expiry.AuthRequests, err) + } + logger.Infof("config device requests valid for: %v", deviceRequests) + serverConfig.DeviceRequestsValidFor = deviceRequests + } serv, err := server.NewServer(context.Background(), serverConfig) if err != nil { return fmt.Errorf("failed to initialize server: %v", err) diff --git a/examples/config-dev.yaml b/examples/config-dev.yaml index f7b011ba..75126692 100644 --- a/examples/config-dev.yaml +++ b/examples/config-dev.yaml @@ -64,6 +64,7 @@ telemetry: # Uncomment this block to enable configuration for the expiration time durations. # expiry: +# deviceRequests: "5m" # signingKeys: "6h" # idTokens: "24h" @@ -95,7 +96,11 @@ staticClients: - 'http://127.0.0.1:5555/callback' name: 'Example App' secret: ZXhhbXBsZS1hcHAtc2VjcmV0 - +# - id: example-device-client +# redirectURIs: +# - /device/callback +# name: 'Static Client for Device Flow' +# public: true connectors: - type: mockCallback id: mock diff --git a/scripts/manifests/crds/devicerequests.yaml b/scripts/manifests/crds/devicerequests.yaml new file mode 100644 index 00000000..9b5b4200 --- /dev/null +++ b/scripts/manifests/crds/devicerequests.yaml @@ -0,0 +1,12 @@ +apiVersion: apiextensions.k8s.io/v1beta1 +kind: CustomResourceDefinition +metadata: + name: devicerequests.dex.coreos.com +spec: + group: dex.coreos.com + names: + kind: DeviceRequest + listKind: DeviceRequestList + plural: devicerequests + singular: devicerequest + version: v1 diff --git a/scripts/manifests/crds/devicetokens.yaml b/scripts/manifests/crds/devicetokens.yaml new file mode 100644 index 00000000..b6ce78dc --- /dev/null +++ b/scripts/manifests/crds/devicetokens.yaml @@ -0,0 +1,12 @@ +apiVersion: apiextensions.k8s.io/v1beta1 +kind: CustomResourceDefinition +metadata: + name: devicetokens.dex.coreos.com +spec: + group: dex.coreos.com + names: + kind: DeviceToken + listKind: DeviceTokenList + plural: devicetokens + singular: devicetoken + version: v1 diff --git a/server/deviceflowhandlers.go b/server/deviceflowhandlers.go new file mode 100644 index 00000000..4a8b382d --- /dev/null +++ b/server/deviceflowhandlers.go @@ -0,0 +1,390 @@ +package server + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "path" + "strconv" + "strings" + "time" + + "github.com/dexidp/dex/storage" +) + +type deviceCodeResponse struct { + //The unique device code for device authentication + DeviceCode string `json:"device_code"` + //The code the user will exchange via a browser and log in + UserCode string `json:"user_code"` + //The url to verify the user code. + VerificationURI string `json:"verification_uri"` + //The verification uri with the user code appended for pre-filling form + VerificationURIComplete string `json:"verification_uri_complete"` + //The lifetime of the device code + ExpireTime int `json:"expires_in"` + //How often the device is allowed to poll to verify that the user login occurred + PollInterval int `json:"interval"` +} + +func (s *Server) getDeviceVerificationURI() string { + return path.Join(s.issuerURL.Path, "/device/auth/verify_code") +} + +func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + // Grab the parameter(s) from the query. + // If "user_code" is set, pre-populate the user code text field. + // If "invalid" is set, set the invalidAttempt boolean, which will display a message to the user that they + // attempted to redeem an invalid or expired user code. + userCode := r.URL.Query().Get("user_code") + invalidAttempt, err := strconv.ParseBool(r.URL.Query().Get("invalid")) + if err != nil { + invalidAttempt = false + } + if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, invalidAttempt); err != nil { + s.logger.Errorf("Server template error: %v", err) + s.renderError(r, w, http.StatusNotFound, "Page not found") + } + default: + s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") + } +} + +func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) { + pollIntervalSeconds := 5 + + switch r.Method { + case http.MethodPost: + err := r.ParseForm() + if err != nil { + s.logger.Errorf("Could not parse Device Request body: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound) + return + } + + //Get the client id and scopes from the post + clientID := r.Form.Get("client_id") + clientSecret := r.Form.Get("client_secret") + scopes := strings.Fields(r.Form.Get("scope")) + + s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes) + + //Make device code + deviceCode := storage.NewDeviceCode() + + //make user code + userCode, err := storage.NewUserCode() + if err != nil { + s.logger.Errorf("Error generating user code: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) + } + + //Generate the expire time + expireTime := time.Now().Add(s.deviceRequestsValidFor) + + //Store the Device Request + deviceReq := storage.DeviceRequest{ + UserCode: userCode, + DeviceCode: deviceCode, + ClientID: clientID, + ClientSecret: clientSecret, + Scopes: scopes, + Expiry: expireTime, + } + + if err := s.storage.CreateDeviceRequest(deviceReq); err != nil { + s.logger.Errorf("Failed to store device request; %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) + return + } + + //Store the device token + deviceToken := storage.DeviceToken{ + DeviceCode: deviceCode, + Status: deviceTokenPending, + Expiry: expireTime, + LastRequestTime: s.now(), + PollIntervalSeconds: 0, + } + + if err := s.storage.CreateDeviceToken(deviceToken); err != nil { + s.logger.Errorf("Failed to store device token %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) + return + } + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + s.logger.Errorf("Could not parse issuer URL %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) + return + } + u.Path = path.Join(u.Path, "device") + vURI := u.String() + + q := u.Query() + q.Set("user_code", userCode) + u.RawQuery = q.Encode() + vURIComplete := u.String() + + code := deviceCodeResponse{ + DeviceCode: deviceCode, + UserCode: userCode, + VerificationURI: vURI, + VerificationURIComplete: vURIComplete, + ExpireTime: int(s.deviceRequestsValidFor.Seconds()), + PollInterval: pollIntervalSeconds, + } + + enc := json.NewEncoder(w) + enc.SetEscapeHTML(false) + enc.SetIndent("", " ") + enc.Encode(code) + + default: + s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type") + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) + } +} + +func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.Method { + case http.MethodPost: + err := r.ParseForm() + if err != nil { + s.logger.Warnf("Could not parse Device Token Request body: %v", err) + s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) + return + } + + deviceCode := r.Form.Get("device_code") + if deviceCode == "" { + s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest) + return + } + + grantType := r.PostFormValue("grant_type") + if grantType != grantTypeDeviceCode { + s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest) + return + } + + now := s.now() + + //Grab the device token, check validity + deviceToken, err := s.storage.GetDeviceToken(deviceCode) + if err != nil { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get device code: %v", err) + } + s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest) + return + } else if now.After(deviceToken.Expiry) { + s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest) + return + } + + //Rate Limiting check + slowDown := false + pollInterval := deviceToken.PollIntervalSeconds + minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval)) + if now.Before(minRequestTime) { + slowDown = true + //Continually increase the poll interval until the user waits the proper time + pollInterval += 5 + } else { + pollInterval = 5 + } + + switch deviceToken.Status { + case deviceTokenPending: + updater := func(old storage.DeviceToken) (storage.DeviceToken, error) { + old.PollIntervalSeconds = pollInterval + old.LastRequestTime = now + return old, nil + } + // Update device token last request time in storage + if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil { + s.logger.Errorf("failed to update device token: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "") + return + } + if slowDown { + s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest) + } else { + s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized) + } + case deviceTokenComplete: + w.Write([]byte(deviceToken.Token)) + } + default: + s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") + } +} + +func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + userCode := r.FormValue("state") + code := r.FormValue("code") + + if userCode == "" || code == "" { + s.renderError(r, w, http.StatusBadRequest, "Request was missing parameters") + return + } + + // Authorization redirect callback from OAuth2 auth flow. + if errMsg := r.FormValue("error"); errMsg != "" { + http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest) + return + } + + authCode, err := s.storage.GetAuthCode(code) + if err != nil || s.now().After(authCode.Expiry) { + errCode := http.StatusBadRequest + if err != nil && err != storage.ErrNotFound { + s.logger.Errorf("failed to get auth code: %v", err) + errCode = http.StatusInternalServerError + } + s.renderError(r, w, errCode, "Invalid or expired auth code.") + return + } + + //Grab the device request from storage + deviceReq, err := s.storage.GetDeviceRequest(userCode) + if err != nil || s.now().After(deviceReq.Expiry) { + errCode := http.StatusBadRequest + if err != nil && err != storage.ErrNotFound { + s.logger.Errorf("failed to get device code: %v", err) + errCode = http.StatusInternalServerError + } + s.renderError(r, w, errCode, "Invalid or expired user code.") + return + } + + client, err := s.storage.GetClient(deviceReq.ClientID) + if err != nil { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get client: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + } else { + s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) + } + return + } + if client.Secret != deviceReq.ClientSecret { + s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) + return + } + + resp, err := s.exchangeAuthCode(w, authCode, client) + if err != nil { + s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err) + s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.") + return + } + + //Grab the device token from storage + old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode) + if err != nil || s.now().After(old.Expiry) { + errCode := http.StatusBadRequest + if err != nil && err != storage.ErrNotFound { + s.logger.Errorf("failed to get device token: %v", err) + errCode = http.StatusInternalServerError + } + s.renderError(r, w, errCode, "Invalid or expired device code.") + return + } + + updater := func(old storage.DeviceToken) (storage.DeviceToken, error) { + if old.Status == deviceTokenComplete { + return old, errors.New("device token already complete") + } + respStr, err := json.MarshalIndent(resp, "", " ") + if err != nil { + s.logger.Errorf("failed to marshal device token response: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "") + return old, err + } + + old.Token = string(respStr) + old.Status = deviceTokenComplete + return old, nil + } + + // Update refresh token in the storage, store the token and mark as complete + if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil { + s.logger.Errorf("failed to update device token: %v", err) + s.renderError(r, w, http.StatusBadRequest, "") + return + } + + if err := s.templates.deviceSuccess(r, w, client.Name); err != nil { + s.logger.Errorf("Server template error: %v", err) + s.renderError(r, w, http.StatusNotFound, "Page not found") + } + + default: + http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest) + return + } +} + +func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + err := r.ParseForm() + if err != nil { + s.logger.Warnf("Could not parse user code verification request body : %v", err) + s.renderError(r, w, http.StatusBadRequest, "") + return + } + + userCode := r.Form.Get("user_code") + if userCode == "" { + s.renderError(r, w, http.StatusBadRequest, "No user code received") + return + } + + userCode = strings.ToUpper(userCode) + + //Find the user code in the available requests + deviceRequest, err := s.storage.GetDeviceRequest(userCode) + if err != nil || s.now().After(deviceRequest.Expiry) { + if err != nil && err != storage.ErrNotFound { + s.logger.Errorf("failed to get device request: %v", err) + } + if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, true); err != nil { + s.logger.Errorf("Server template error: %v", err) + s.renderError(r, w, http.StatusNotFound, "Page not found") + } + return + } + + //Redirect to Dex Auth Endpoint + authURL := path.Join(s.issuerURL.Path, "/auth") + u, err := url.Parse(authURL) + if err != nil { + s.renderError(r, w, http.StatusInternalServerError, "Invalid auth URI.") + return + } + q := u.Query() + q.Set("client_id", deviceRequest.ClientID) + q.Set("client_secret", deviceRequest.ClientSecret) + q.Set("state", deviceRequest.UserCode) + q.Set("response_type", "code") + q.Set("redirect_uri", "/device/callback") + q.Set("scope", strings.Join(deviceRequest.Scopes, " ")) + u.RawQuery = q.Encode() + + http.Redirect(w, r, u.String(), http.StatusFound) + + default: + s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") + } +} diff --git a/server/deviceflowhandlers_test.go b/server/deviceflowhandlers_test.go new file mode 100644 index 00000000..5ab3ddb6 --- /dev/null +++ b/server/deviceflowhandlers_test.go @@ -0,0 +1,678 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "path" + "strings" + "testing" + "time" + + "github.com/dexidp/dex/storage" +) + +func TestDeviceVerificationURI(t *testing.T) { + t0 := time.Now() + + now := func() time.Time { return t0 } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer = c.Issuer + "/non-root-path" + c.Now = now + }) + defer httpServer.Close() + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Fatalf("Could not parse issuer URL %v", err) + } + u.Path = path.Join(u.Path, "/device/auth/verify_code") + + uri := s.getDeviceVerificationURI() + if uri != u.Path { + t.Errorf("Invalid verification URI. Expected %v got %v", u.Path, uri) + } +} + +func TestHandleDeviceCode(t *testing.T) { + t0 := time.Now() + + now := func() time.Time { return t0 } + + tests := []struct { + testName string + clientID string + requestType string + scopes []string + expectedResponseCode int + expectedServerResponse string + }{ + { + testName: "New Code", + clientID: "test", + requestType: "POST", + scopes: []string{"openid", "profile", "email"}, + expectedResponseCode: http.StatusOK, + }, + { + testName: "Invalid request Type (GET)", + clientID: "test", + requestType: "GET", + scopes: []string{"openid", "profile", "email"}, + expectedResponseCode: http.StatusBadRequest, + }, + } + for _, tc := range tests { + t.Run(tc.testName, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer = c.Issuer + "/non-root-path" + c.Now = now + }) + defer httpServer.Close() + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Fatalf("Could not parse issuer URL %v", err) + } + u.Path = path.Join(u.Path, "device/code") + + data := url.Values{} + data.Set("client_id", tc.clientID) + for _, scope := range tc.scopes { + data.Add("scope", scope) + } + req, _ := http.NewRequest(tc.requestType, u.String(), bytes.NewBufferString(data.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + if rr.Code != tc.expectedResponseCode { + t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code) + } + + body, err := ioutil.ReadAll(rr.Body) + if err != nil { + t.Errorf("Could read token response %v", err) + } + if tc.expectedResponseCode == http.StatusOK { + var resp deviceCodeResponse + if err := json.Unmarshal(body, &resp); err != nil { + t.Errorf("Unexpected Device Code Response Format %v", string(body)) + } + } + }) + } +} + +func TestDeviceCallback(t *testing.T) { + t0 := time.Now() + + now := func() time.Time { return t0 } + + type formValues struct { + state string + code string + error string + } + + // Base "Control" test values + baseFormValues := formValues{ + state: "XXXX-XXXX", + code: "somecode", + } + baseAuthCode := storage.AuthCode{ + ID: "somecode", + ClientID: "testclient", + RedirectURI: deviceCallbackURI, + Nonce: "", + Scopes: []string{"openid", "profile", "email"}, + ConnectorID: "mock", + ConnectorData: nil, + Claims: storage.Claims{}, + Expiry: now().Add(5 * time.Minute), + } + baseDeviceRequest := storage.DeviceRequest{ + UserCode: "XXXX-XXXX", + DeviceCode: "devicecode", + ClientID: "testclient", + ClientSecret: "", + Scopes: []string{"openid", "profile", "email"}, + Expiry: now().Add(5 * time.Minute), + } + baseDeviceToken := storage.DeviceToken{ + DeviceCode: "devicecode", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + } + + tests := []struct { + testName string + expectedResponseCode int + values formValues + testAuthCode storage.AuthCode + testDeviceRequest storage.DeviceRequest + testDeviceToken storage.DeviceToken + }{ + { + testName: "Missing State", + values: formValues{ + state: "", + code: "somecode", + error: "", + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Missing Code", + values: formValues{ + state: "XXXX-XXXX", + code: "", + error: "", + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Error During Authorization", + values: formValues{ + state: "XXXX-XXXX", + code: "somecode", + error: "Error Condition", + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Expired Auth Code", + values: baseFormValues, + testAuthCode: storage.AuthCode{ + ID: "somecode", + ClientID: "testclient", + RedirectURI: deviceCallbackURI, + Nonce: "", + Scopes: []string{"openid", "profile", "email"}, + ConnectorID: "pic", + ConnectorData: nil, + Claims: storage.Claims{}, + Expiry: now().Add(-5 * time.Minute), + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Invalid Auth Code", + values: baseFormValues, + testAuthCode: storage.AuthCode{ + ID: "somecode", + ClientID: "testclient", + RedirectURI: deviceCallbackURI, + Nonce: "", + Scopes: []string{"openid", "profile", "email"}, + ConnectorID: "pic", + ConnectorData: nil, + Claims: storage.Claims{}, + Expiry: now().Add(5 * time.Minute), + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Expired Device Request", + values: baseFormValues, + testAuthCode: baseAuthCode, + testDeviceRequest: storage.DeviceRequest{ + UserCode: "XXXX-XXXX", + DeviceCode: "devicecode", + ClientID: "testclient", + Scopes: []string{"openid", "profile", "email"}, + Expiry: now().Add(-5 * time.Minute), + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Non-Existent User Code", + values: baseFormValues, + testAuthCode: baseAuthCode, + testDeviceRequest: storage.DeviceRequest{ + UserCode: "ZZZZ-ZZZZ", + DeviceCode: "devicecode", + Scopes: []string{"openid", "profile", "email"}, + Expiry: now().Add(5 * time.Minute), + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Bad Device Request Client", + values: baseFormValues, + testAuthCode: baseAuthCode, + testDeviceRequest: storage.DeviceRequest{ + UserCode: "XXXX-XXXX", + DeviceCode: "devicecode", + Scopes: []string{"openid", "profile", "email"}, + Expiry: now().Add(5 * time.Minute), + }, + expectedResponseCode: http.StatusUnauthorized, + }, + { + testName: "Bad Device Request Secret", + values: baseFormValues, + testAuthCode: baseAuthCode, + testDeviceRequest: storage.DeviceRequest{ + UserCode: "XXXX-XXXX", + DeviceCode: "devicecode", + ClientSecret: "foobar", + Scopes: []string{"openid", "profile", "email"}, + Expiry: now().Add(5 * time.Minute), + }, + expectedResponseCode: http.StatusUnauthorized, + }, + { + testName: "Expired Device Token", + values: baseFormValues, + testAuthCode: baseAuthCode, + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "devicecode", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(-5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Device Code Already Redeemed", + values: baseFormValues, + testAuthCode: baseAuthCode, + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "devicecode", + Status: deviceTokenComplete, + Token: "", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Successful Exchange", + values: baseFormValues, + testAuthCode: baseAuthCode, + testDeviceRequest: baseDeviceRequest, + testDeviceToken: baseDeviceToken, + expectedResponseCode: http.StatusOK, + }, + } + for _, tc := range tests { + t.Run(tc.testName, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + //c.Issuer = c.Issuer + "/non-root-path" + c.Now = now + }) + defer httpServer.Close() + + if err := s.storage.CreateAuthCode(tc.testAuthCode); err != nil { + t.Fatalf("failed to create auth code: %v", err) + } + + if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil { + t.Fatalf("failed to create device request: %v", err) + } + + if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil { + t.Fatalf("failed to create device token: %v", err) + } + + client := storage.Client{ + ID: "testclient", + Secret: "", + RedirectURIs: []string{deviceCallbackURI}, + } + if err := s.storage.CreateClient(client); err != nil { + t.Fatalf("failed to create client: %v", err) + } + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Fatalf("Could not parse issuer URL %v", err) + } + u.Path = path.Join(u.Path, "device/callback") + q := u.Query() + q.Set("state", tc.values.state) + q.Set("code", tc.values.code) + q.Set("error", tc.values.error) + u.RawQuery = q.Encode() + req, _ := http.NewRequest("GET", u.String(), nil) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + if rr.Code != tc.expectedResponseCode { + t.Errorf("%s: Unexpected Response Type. Expected %v got %v", tc.testName, tc.expectedResponseCode, rr.Code) + } + }) + } +} + +func TestDeviceTokenResponse(t *testing.T) { + t0 := time.Now() + + now := func() time.Time { return t0 } + + baseDeviceRequest := storage.DeviceRequest{ + UserCode: "ABCD-WXYZ", + DeviceCode: "foo", + ClientID: "testclient", + Scopes: []string{"openid", "profile", "offline_access"}, + Expiry: now().Add(5 * time.Minute), + } + + tests := []struct { + testName string + testDeviceRequest storage.DeviceRequest + testDeviceToken storage.DeviceToken + testGrantType string + testDeviceCode string + expectedServerResponse string + expectedResponseCode int + }{ + { + testName: "Valid but pending token", + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "f00bar", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + testDeviceCode: "f00bar", + expectedServerResponse: deviceTokenPending, + expectedResponseCode: http.StatusUnauthorized, + }, + { + testName: "Invalid Grant Type", + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "f00bar", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + testDeviceCode: "f00bar", + testGrantType: grantTypeAuthorizationCode, + expectedServerResponse: errInvalidGrant, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Test Slow Down State", + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "f00bar", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: now(), + PollIntervalSeconds: 10, + }, + testDeviceCode: "f00bar", + expectedServerResponse: deviceTokenSlowDown, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Test Expired Device Token", + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "f00bar", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(-5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + testDeviceCode: "f00bar", + expectedServerResponse: deviceTokenExpired, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Test Non-existent Device Code", + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "foo", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(-5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + testDeviceCode: "bar", + expectedServerResponse: errInvalidRequest, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Empty Device Code in Request", + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "bar", + Status: deviceTokenPending, + Token: "", + Expiry: now().Add(-5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + testDeviceCode: "", + expectedServerResponse: errInvalidRequest, + expectedResponseCode: http.StatusBadRequest, + }, + { + testName: "Claim validated token from Device Code", + testDeviceRequest: baseDeviceRequest, + testDeviceToken: storage.DeviceToken{ + DeviceCode: "foo", + Status: deviceTokenComplete, + Token: "{\"access_token\": \"foobar\"}", + Expiry: now().Add(5 * time.Minute), + LastRequestTime: time.Time{}, + PollIntervalSeconds: 0, + }, + testDeviceCode: "foo", + expectedServerResponse: "{\"access_token\": \"foobar\"}", + expectedResponseCode: http.StatusOK, + }, + } + for _, tc := range tests { + t.Run(tc.testName, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer = c.Issuer + "/non-root-path" + c.Now = now + }) + defer httpServer.Close() + + if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil { + t.Fatalf("Failed to store device token %v", err) + } + + if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil { + t.Fatalf("Failed to store device token %v", err) + } + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Fatalf("Could not parse issuer URL %v", err) + } + u.Path = path.Join(u.Path, "device/token") + + data := url.Values{} + grantType := grantTypeDeviceCode + if tc.testGrantType != "" { + grantType = tc.testGrantType + } + data.Set("grant_type", grantType) + data.Set("device_code", tc.testDeviceCode) + req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + if rr.Code != tc.expectedResponseCode { + t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code) + } + + body, err := ioutil.ReadAll(rr.Body) + if err != nil { + t.Errorf("Could read token response %v", err) + } + if tc.expectedResponseCode == http.StatusBadRequest || tc.expectedResponseCode == http.StatusUnauthorized { + expectJsonErrorResponse(tc.testName, body, tc.expectedServerResponse, t) + } else if string(body) != tc.expectedServerResponse { + t.Errorf("Unexpected Server Response. Expected %v got %v", tc.expectedServerResponse, string(body)) + } + }) + } +} + +func expectJsonErrorResponse(testCase string, body []byte, expectedError string, t *testing.T) { + jsonMap := make(map[string]interface{}) + err := json.Unmarshal(body, &jsonMap) + if err != nil { + t.Errorf("Unexpected error unmarshalling response: %v", err) + } + if jsonMap["error"] != expectedError { + t.Errorf("Test Case %s expected error %v, received %v", testCase, expectedError, jsonMap["error"]) + } +} + +func TestVerifyCodeResponse(t *testing.T) { + t0 := time.Now() + + now := func() time.Time { return t0 } + + tests := []struct { + testName string + testDeviceRequest storage.DeviceRequest + userCode string + expectedResponseCode int + expectedRedirectPath string + }{ + { + testName: "Unknown user code", + testDeviceRequest: storage.DeviceRequest{ + UserCode: "ABCD-WXYZ", + DeviceCode: "f00bar", + ClientID: "testclient", + Scopes: []string{"openid", "profile", "offline_access"}, + Expiry: now().Add(5 * time.Minute), + }, + userCode: "CODE-TEST", + expectedResponseCode: http.StatusBadRequest, + expectedRedirectPath: "", + }, + { + testName: "Expired user code", + testDeviceRequest: storage.DeviceRequest{ + UserCode: "ABCD-WXYZ", + DeviceCode: "f00bar", + ClientID: "testclient", + Scopes: []string{"openid", "profile", "offline_access"}, + Expiry: now().Add(-5 * time.Minute), + }, + userCode: "ABCD-WXYZ", + expectedResponseCode: http.StatusBadRequest, + expectedRedirectPath: "", + }, + { + testName: "No user code", + testDeviceRequest: storage.DeviceRequest{ + UserCode: "ABCD-WXYZ", + DeviceCode: "f00bar", + ClientID: "testclient", + Scopes: []string{"openid", "profile", "offline_access"}, + Expiry: now().Add(-5 * time.Minute), + }, + userCode: "", + expectedResponseCode: http.StatusBadRequest, + expectedRedirectPath: "", + }, + { + testName: "Valid user code, expect redirect to auth endpoint", + testDeviceRequest: storage.DeviceRequest{ + UserCode: "ABCD-WXYZ", + DeviceCode: "f00bar", + ClientID: "testclient", + Scopes: []string{"openid", "profile", "offline_access"}, + Expiry: now().Add(5 * time.Minute), + }, + userCode: "ABCD-WXYZ", + expectedResponseCode: http.StatusFound, + expectedRedirectPath: "/auth", + }, + } + for _, tc := range tests { + t.Run(tc.testName, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer = c.Issuer + "/non-root-path" + c.Now = now + }) + defer httpServer.Close() + + if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil { + t.Fatalf("Failed to store device token %v", err) + } + + u, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Fatalf("Could not parse issuer URL %v", err) + } + + u.Path = path.Join(u.Path, "device/auth/verify_code") + data := url.Values{} + data.Set("user_code", tc.userCode) + req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + if rr.Code != tc.expectedResponseCode { + t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code) + } + + u, err = url.Parse(s.issuerURL.String()) + if err != nil { + t.Errorf("Could not parse issuer URL %v", err) + } + u.Path = path.Join(u.Path, tc.expectedRedirectPath) + + location := rr.Header().Get("Location") + if rr.Code == http.StatusFound && !strings.HasPrefix(location, u.Path) { + t.Errorf("Invalid Redirect. Expected %v got %v", u.Path, location) + } + }) + } +} diff --git a/server/handlers.go b/server/handlers.go index 5512d87f..babd5417 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -148,30 +148,34 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) { } type discovery struct { - Issuer string `json:"issuer"` - Auth string `json:"authorization_endpoint"` - Token string `json:"token_endpoint"` - Keys string `json:"jwks_uri"` - UserInfo string `json:"userinfo_endpoint"` - ResponseTypes []string `json:"response_types_supported"` - Subjects []string `json:"subject_types_supported"` - IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` - Scopes []string `json:"scopes_supported"` - AuthMethods []string `json:"token_endpoint_auth_methods_supported"` - Claims []string `json:"claims_supported"` + Issuer string `json:"issuer"` + Auth string `json:"authorization_endpoint"` + Token string `json:"token_endpoint"` + Keys string `json:"jwks_uri"` + UserInfo string `json:"userinfo_endpoint"` + DeviceEndpoint string `json:"device_authorization_endpoint"` + GrantTypes []string `json:"grant_types_supported"` + ResponseTypes []string `json:"response_types_supported"` + Subjects []string `json:"subject_types_supported"` + IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` + Scopes []string `json:"scopes_supported"` + AuthMethods []string `json:"token_endpoint_auth_methods_supported"` + Claims []string `json:"claims_supported"` } func (s *Server) discoveryHandler() (http.HandlerFunc, error) { d := discovery{ - Issuer: s.issuerURL.String(), - Auth: s.absURL("/auth"), - Token: s.absURL("/token"), - Keys: s.absURL("/keys"), - UserInfo: s.absURL("/userinfo"), - Subjects: []string{"public"}, - IDTokenAlgs: []string{string(jose.RS256)}, - Scopes: []string{"openid", "email", "groups", "profile", "offline_access"}, - AuthMethods: []string{"client_secret_basic"}, + Issuer: s.issuerURL.String(), + Auth: s.absURL("/auth"), + Token: s.absURL("/token"), + Keys: s.absURL("/keys"), + UserInfo: s.absURL("/userinfo"), + DeviceEndpoint: s.absURL("/device/code"), + Subjects: []string{"public"}, + GrantTypes: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode}, + IDTokenAlgs: []string{string(jose.RS256)}, + Scopes: []string{"openid", "email", "groups", "profile", "offline_access"}, + AuthMethods: []string{"client_secret_basic"}, Claims: []string{ "aud", "email", "email_verified", "exp", "iat", "iss", "locale", "name", "sub", @@ -784,24 +788,33 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s return } + tokenResponse, err := s.exchangeAuthCode(w, authCode, client) + if err != nil { + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + s.writeAccessToken(w, tokenResponse) +} + +func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenReponse, error) { accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID) if err != nil { s.logger.Errorf("failed to create new access token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return + return nil, err } idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return + return nil, err } - if err := s.storage.DeleteAuthCode(code); err != nil { + if err := s.storage.DeleteAuthCode(authCode.ID); err != nil { s.logger.Errorf("failed to delete auth code: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return + return nil, err } reqRefresh := func() bool { @@ -848,13 +861,13 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s if refreshToken, err = internal.Marshal(token); err != nil { s.logger.Errorf("failed to marshal refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return + return nil, err } if err := s.storage.CreateRefresh(refresh); err != nil { s.logger.Errorf("failed to create refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return + return nil, err } // deleteToken determines if we need to delete the newly created refresh token @@ -885,7 +898,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s s.logger.Errorf("failed to get offline session: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true - return + return nil, err } offlineSessions := storage.OfflineSessions{ UserID: refresh.Claims.UserID, @@ -900,7 +913,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s s.logger.Errorf("failed to create offline session: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true - return + return nil, err } } else { if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { @@ -909,7 +922,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s s.logger.Errorf("failed to delete refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true - return + return nil, err } } @@ -921,11 +934,11 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s s.logger.Errorf("failed to update offline session: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) deleteToken = true - return + return nil, err } } } - s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry) + return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil } // handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6 @@ -1121,7 +1134,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry) + resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry) + s.writeAccessToken(w, resp) } func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { @@ -1368,23 +1382,29 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli } } - s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry) + resp := s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry) + s.writeAccessToken(w, resp) } -func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) { - resp := struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` - RefreshToken string `json:"refresh_token,omitempty"` - IDToken string `json:"id_token"` - }{ +type accessTokenReponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + IDToken string `json:"id_token"` +} + +func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string, expiry time.Time) *accessTokenReponse { + return &accessTokenReponse{ accessToken, "bearer", int(expiry.Sub(s.now()).Seconds()), refreshToken, idToken, } +} + +func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenReponse) { data, err := json.Marshal(resp) if err != nil { s.logger.Errorf("failed to marshal access token response: %v", err) diff --git a/server/oauth2.go b/server/oauth2.go index 05dd25d2..2596fd4e 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -114,6 +114,10 @@ const ( scopeCrossClientPrefix = "audience:server:client_id:" ) +const ( + deviceCallbackURI = "/device/callback" +) + const ( redirectURIOOB = "urn:ietf:wg:oauth:2.0:oob" ) @@ -122,6 +126,7 @@ const ( grantTypeAuthorizationCode = "authorization_code" grantTypeRefreshToken = "refresh_token" grantTypePassword = "password" + grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code" ) const ( @@ -130,6 +135,13 @@ const ( responseTypeIDToken = "id_token" // ID Token in url fragment ) +const ( + deviceTokenPending = "authorization_pending" + deviceTokenComplete = "complete" + deviceTokenSlowDown = "slow_down" + deviceTokenExpired = "expired_token" +) + func parseScopes(scopes []string) connector.Scopes { var s connector.Scopes for _, scope := range scopes { @@ -425,6 +437,9 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI) return nil, &authErr{"", "", errInvalidRequest, description} } + if redirectURI == deviceCallbackURI && client.Public { + redirectURI = s.issuerURL.Path + deviceCallbackURI + } // From here on out, we want to redirect back to the client with an error. newErr := func(typ, format string, a ...interface{}) *authErr { @@ -566,7 +581,7 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool { return false } - if redirectURI == redirectURIOOB { + if redirectURI == redirectURIOOB || redirectURI == deviceCallbackURI { return true } diff --git a/server/server.go b/server/server.go index a0a075fb..f4d139d1 100644 --- a/server/server.go +++ b/server/server.go @@ -75,9 +75,10 @@ type Config struct { // If enabled, the connectors selection page will always be shown even if there's only one AlwaysShowLoginScreen bool - RotateKeysAfter time.Duration // Defaults to 6 hours. - IDTokensValidFor time.Duration // Defaults to 24 hours - AuthRequestsValidFor time.Duration // Defaults to 24 hours + RotateKeysAfter time.Duration // Defaults to 6 hours. + IDTokensValidFor time.Duration // Defaults to 24 hours + AuthRequestsValidFor time.Duration // Defaults to 24 hours + DeviceRequestsValidFor time.Duration // Defaults to 5 minutes // If set, the server will use this connector to handle password grants PasswordConnector string @@ -156,8 +157,9 @@ type Server struct { now func() time.Time - idTokensValidFor time.Duration - authRequestsValidFor time.Duration + idTokensValidFor time.Duration + authRequestsValidFor time.Duration + deviceRequestsValidFor time.Duration logger log.Logger } @@ -219,6 +221,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) supportedResponseTypes: supported, idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour), + deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute), skipApproval: c.SkipApprovalScreen, alwaysShowLogin: c.AlwaysShowLoginScreen, now: now, @@ -302,6 +305,11 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleWithCORS("/userinfo", s.handleUserInfo) handleFunc("/auth", s.handleAuthorization) handleFunc("/auth/{connector}", s.handleConnectorLogin) + handleFunc("/device", s.handleDeviceExchange) + handleFunc("/device/auth/verify_code", s.verifyUserCode) + handleFunc("/device/code", s.handleDeviceCode) + handleFunc("/device/token", s.handleDeviceToken) + handleFunc(deviceCallbackURI, s.handleDeviceCallback) r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { // Strip the X-Remote-* headers to prevent security issues on // misconfigured authproxy connector setups. @@ -450,7 +458,8 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura if r, err := s.storage.GarbageCollect(now()); err != nil { s.logger.Errorf("garbage collection failed: %v", err) } else if r.AuthRequests > 0 || r.AuthCodes > 0 { - s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes) + s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d, device requests =%d, device tokens=%d", + r.AuthRequests, r.AuthCodes, r.DeviceRequests, r.DeviceTokens) } } } diff --git a/server/server_test.go b/server/server_test.go index 8fe84c9a..114d1fc8 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -8,11 +8,13 @@ import ( "encoding/pem" "errors" "fmt" + "io/ioutil" "net/http" "net/http/httptest" "net/http/httputil" "net/url" "os" + "path" "reflect" "sort" "strings" @@ -203,6 +205,274 @@ func TestDiscovery(t *testing.T) { } } +type oauth2Tests struct { + clientID string + tests []test +} + +type test struct { + name string + // If specified these set of scopes will be used during the test case. + scopes []string + // handleToken provides the OAuth2 token response for the integration test. + handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token, *mock.Callback) error +} + +func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time) oauth2Tests { + requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"} + + // Used later when configuring test servers to set how long id_tokens will be valid for. + // + // The actual value of 30s is completely arbitrary. We just need to set a value + // so tests can compute the expected "expires_in" field. + idTokensValidFor := time.Second * 30 + + oidcConfig := &oidc.Config{SkipClientIDCheck: true} + + return oauth2Tests{ + clientID: clientID, + tests: []test{ + { + name: "verify ID Token", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + idToken, ok := token.Extra("id_token").(string) + if !ok { + return fmt.Errorf("no id token found") + } + if _, err := p.Verifier(oidcConfig).Verify(ctx, idToken); err != nil { + return fmt.Errorf("failed to verify id token: %v", err) + } + return nil + }, + }, + { + name: "fetch userinfo", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + ui, err := p.UserInfo(ctx, config.TokenSource(ctx, token)) + if err != nil { + return fmt.Errorf("failed to fetch userinfo: %v", err) + } + if conn.Identity.Email != ui.Email { + return fmt.Errorf("expected email to be %v, got %v", conn.Identity.Email, ui.Email) + } + return nil + }, + }, + { + name: "verify id token and oauth2 token expiry", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + expectedExpiry := now().Add(idTokensValidFor) + + timeEq := func(t1, t2 time.Time, within time.Duration) bool { + return t1.Sub(t2) < within + } + + if !timeEq(token.Expiry, expectedExpiry, time.Second) { + return fmt.Errorf("expected expired_in to be %s, got %s", expectedExpiry, token.Expiry) + } + + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + return fmt.Errorf("no id token found") + } + idToken, err := p.Verifier(oidcConfig).Verify(ctx, rawIDToken) + if err != nil { + return fmt.Errorf("failed to verify id token: %v", err) + } + if !timeEq(idToken.Expiry, expectedExpiry, time.Second) { + return fmt.Errorf("expected id token expiry to be %s, got %s", expectedExpiry, token.Expiry) + } + return nil + }, + }, + { + name: "verify at_hash", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + return fmt.Errorf("no id token found") + } + idToken, err := p.Verifier(oidcConfig).Verify(ctx, rawIDToken) + if err != nil { + return fmt.Errorf("failed to verify id token: %v", err) + } + + var claims struct { + AtHash string `json:"at_hash"` + } + if err := idToken.Claims(&claims); err != nil { + return fmt.Errorf("failed to decode raw claims: %v", err) + } + if claims.AtHash == "" { + return errors.New("no at_hash value in id_token") + } + wantAtHash, err := accessTokenHash(jose.RS256, token.AccessToken) + if err != nil { + return fmt.Errorf("computed expected at hash: %v", err) + } + if wantAtHash != claims.AtHash { + return fmt.Errorf("expected at_hash=%q got=%q", wantAtHash, claims.AtHash) + } + + return nil + }, + }, + { + name: "refresh token", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + // have to use time.Now because the OAuth2 package uses it. + token.Expiry = time.Now().Add(time.Second * -10) + if token.Valid() { + return errors.New("token shouldn't be valid") + } + + newToken, err := config.TokenSource(ctx, token).Token() + if err != nil { + return fmt.Errorf("failed to refresh token: %v", err) + } + if token.RefreshToken == newToken.RefreshToken { + return fmt.Errorf("old refresh token was the same as the new token %q", token.RefreshToken) + } + + if _, err := config.TokenSource(ctx, token).Token(); err == nil { + return errors.New("was able to redeem the same refresh token twice") + } + return nil + }, + }, + { + name: "refresh with explicit scopes", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + v := url.Values{} + v.Add("client_id", clientID) + v.Add("client_secret", clientSecret) + v.Add("grant_type", "refresh_token") + v.Add("refresh_token", token.RefreshToken) + v.Add("scope", strings.Join(requestedScopes, " ")) + resp, err := http.PostForm(p.Endpoint().TokenURL, v) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + dump, err := httputil.DumpResponse(resp, true) + if err != nil { + panic(err) + } + return fmt.Errorf("unexpected response: %s", dump) + } + return nil + }, + }, + { + name: "refresh with extra spaces", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + v := url.Values{} + v.Add("client_id", clientID) + v.Add("client_secret", clientSecret) + v.Add("grant_type", "refresh_token") + v.Add("refresh_token", token.RefreshToken) + + // go-oidc adds an additional space before scopes when refreshing. + // Since we support that client we choose to be more relaxed about + // scope parsing, disregarding extra whitespace. + v.Add("scope", " "+strings.Join(requestedScopes, " ")) + resp, err := http.PostForm(p.Endpoint().TokenURL, v) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + dump, err := httputil.DumpResponse(resp, true) + if err != nil { + panic(err) + } + return fmt.Errorf("unexpected response: %s", dump) + } + return nil + }, + }, + { + name: "refresh with unauthorized scopes", + scopes: []string{"openid", "email"}, + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + v := url.Values{} + v.Add("client_id", clientID) + v.Add("client_secret", clientSecret) + v.Add("grant_type", "refresh_token") + v.Add("refresh_token", token.RefreshToken) + // Request a scope that wasn't requestd initially. + v.Add("scope", "oidc email profile") + resp, err := http.PostForm(p.Endpoint().TokenURL, v) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + dump, err := httputil.DumpResponse(resp, true) + if err != nil { + panic(err) + } + return fmt.Errorf("unexpected response: %s", dump) + } + return nil + }, + }, + { + // This test ensures that the connector.RefreshConnector interface is being + // used when clients request a refresh token. + name: "refresh with identity changes", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error { + // have to use time.Now because the OAuth2 package uses it. + token.Expiry = time.Now().Add(time.Second * -10) + if token.Valid() { + return errors.New("token shouldn't be valid") + } + + ident := connector.Identity{ + UserID: "fooid", + Username: "foo", + Email: "foo@bar.com", + EmailVerified: true, + Groups: []string{"foo", "bar"}, + } + conn.Identity = ident + + type claims struct { + Username string `json:"name"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Groups []string `json:"groups"` + } + want := claims{ident.Username, ident.Email, ident.EmailVerified, ident.Groups} + + newToken, err := config.TokenSource(ctx, token).Token() + if err != nil { + return fmt.Errorf("failed to refresh token: %v", err) + } + rawIDToken, ok := newToken.Extra("id_token").(string) + if !ok { + return fmt.Errorf("no id_token in refreshed token") + } + idToken, err := p.Verifier(oidcConfig).Verify(ctx, rawIDToken) + if err != nil { + return fmt.Errorf("failed to verify id token: %v", err) + } + var got claims + if err := idToken.Claims(&got); err != nil { + return fmt.Errorf("failed to unmarshal claims: %v", err) + } + + if diff := pretty.Compare(want, got); diff != "" { + return fmt.Errorf("got identity != want identity: %s", diff) + } + return nil + }, + }, + }, + } +} + // TestOAuth2CodeFlow runs integration tests against a test server. The tests stand up a server // which requires no interaction to login, logs in through a test client, then passes the client // and returned token to the test. @@ -226,255 +496,8 @@ func TestOAuth2CodeFlow(t *testing.T) { // Connector used by the tests. var conn *mock.Callback - oidcConfig := &oidc.Config{SkipClientIDCheck: true} - - tests := []struct { - name string - // If specified these set of scopes will be used during the test case. - scopes []string - // handleToken provides the OAuth2 token response for the integration test. - handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token) error - }{ - { - name: "verify ID Token", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - idToken, ok := token.Extra("id_token").(string) - if !ok { - return fmt.Errorf("no id token found") - } - if _, err := p.Verifier(oidcConfig).Verify(ctx, idToken); err != nil { - return fmt.Errorf("failed to verify id token: %v", err) - } - return nil - }, - }, - { - name: "fetch userinfo", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - ui, err := p.UserInfo(ctx, config.TokenSource(ctx, token)) - if err != nil { - return fmt.Errorf("failed to fetch userinfo: %v", err) - } - if conn.Identity.Email != ui.Email { - return fmt.Errorf("expected email to be %v, got %v", conn.Identity.Email, ui.Email) - } - return nil - }, - }, - { - name: "verify id token and oauth2 token expiry", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - expectedExpiry := now().Add(idTokensValidFor) - - timeEq := func(t1, t2 time.Time, within time.Duration) bool { - return t1.Sub(t2) < within - } - - if !timeEq(token.Expiry, expectedExpiry, time.Second) { - return fmt.Errorf("expected expired_in to be %s, got %s", expectedExpiry, token.Expiry) - } - - rawIDToken, ok := token.Extra("id_token").(string) - if !ok { - return fmt.Errorf("no id token found") - } - idToken, err := p.Verifier(oidcConfig).Verify(ctx, rawIDToken) - if err != nil { - return fmt.Errorf("failed to verify id token: %v", err) - } - if !timeEq(idToken.Expiry, expectedExpiry, time.Second) { - return fmt.Errorf("expected id token expiry to be %s, got %s", expectedExpiry, token.Expiry) - } - return nil - }, - }, - { - name: "verify at_hash", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - rawIDToken, ok := token.Extra("id_token").(string) - if !ok { - return fmt.Errorf("no id token found") - } - idToken, err := p.Verifier(oidcConfig).Verify(ctx, rawIDToken) - if err != nil { - return fmt.Errorf("failed to verify id token: %v", err) - } - - var claims struct { - AtHash string `json:"at_hash"` - } - if err := idToken.Claims(&claims); err != nil { - return fmt.Errorf("failed to decode raw claims: %v", err) - } - if claims.AtHash == "" { - return errors.New("no at_hash value in id_token") - } - wantAtHash, err := accessTokenHash(jose.RS256, token.AccessToken) - if err != nil { - return fmt.Errorf("computed expected at hash: %v", err) - } - if wantAtHash != claims.AtHash { - return fmt.Errorf("expected at_hash=%q got=%q", wantAtHash, claims.AtHash) - } - - return nil - }, - }, - { - name: "refresh token", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - // have to use time.Now because the OAuth2 package uses it. - token.Expiry = time.Now().Add(time.Second * -10) - if token.Valid() { - return errors.New("token shouldn't be valid") - } - - newToken, err := config.TokenSource(ctx, token).Token() - if err != nil { - return fmt.Errorf("failed to refresh token: %v", err) - } - if token.RefreshToken == newToken.RefreshToken { - return fmt.Errorf("old refresh token was the same as the new token %q", token.RefreshToken) - } - - if _, err := config.TokenSource(ctx, token).Token(); err == nil { - return errors.New("was able to redeem the same refresh token twice") - } - return nil - }, - }, - { - name: "refresh with explicit scopes", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - v := url.Values{} - v.Add("client_id", clientID) - v.Add("client_secret", clientSecret) - v.Add("grant_type", "refresh_token") - v.Add("refresh_token", token.RefreshToken) - v.Add("scope", strings.Join(requestedScopes, " ")) - resp, err := http.PostForm(p.Endpoint().TokenURL, v) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - dump, err := httputil.DumpResponse(resp, true) - if err != nil { - panic(err) - } - return fmt.Errorf("unexpected response: %s", dump) - } - return nil - }, - }, - { - name: "refresh with extra spaces", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - v := url.Values{} - v.Add("client_id", clientID) - v.Add("client_secret", clientSecret) - v.Add("grant_type", "refresh_token") - v.Add("refresh_token", token.RefreshToken) - - // go-oidc adds an additional space before scopes when refreshing. - // Since we support that client we choose to be more relaxed about - // scope parsing, disregarding extra whitespace. - v.Add("scope", " "+strings.Join(requestedScopes, " ")) - resp, err := http.PostForm(p.Endpoint().TokenURL, v) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - dump, err := httputil.DumpResponse(resp, true) - if err != nil { - panic(err) - } - return fmt.Errorf("unexpected response: %s", dump) - } - return nil - }, - }, - { - name: "refresh with unauthorized scopes", - scopes: []string{"openid", "email"}, - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - v := url.Values{} - v.Add("client_id", clientID) - v.Add("client_secret", clientSecret) - v.Add("grant_type", "refresh_token") - v.Add("refresh_token", token.RefreshToken) - // Request a scope that wasn't requestd initially. - v.Add("scope", "oidc email profile") - resp, err := http.PostForm(p.Endpoint().TokenURL, v) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode == http.StatusOK { - dump, err := httputil.DumpResponse(resp, true) - if err != nil { - panic(err) - } - return fmt.Errorf("unexpected response: %s", dump) - } - return nil - }, - }, - { - // This test ensures that the connector.RefreshConnector interface is being - // used when clients request a refresh token. - name: "refresh with identity changes", - handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { - // have to use time.Now because the OAuth2 package uses it. - token.Expiry = time.Now().Add(time.Second * -10) - if token.Valid() { - return errors.New("token shouldn't be valid") - } - - ident := connector.Identity{ - UserID: "fooid", - Username: "foo", - Email: "foo@bar.com", - EmailVerified: true, - Groups: []string{"foo", "bar"}, - } - conn.Identity = ident - - type claims struct { - Username string `json:"name"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - Groups []string `json:"groups"` - } - want := claims{ident.Username, ident.Email, ident.EmailVerified, ident.Groups} - - newToken, err := config.TokenSource(ctx, token).Token() - if err != nil { - return fmt.Errorf("failed to refresh token: %v", err) - } - rawIDToken, ok := newToken.Extra("id_token").(string) - if !ok { - return fmt.Errorf("no id_token in refreshed token") - } - idToken, err := p.Verifier(oidcConfig).Verify(ctx, rawIDToken) - if err != nil { - return fmt.Errorf("failed to verify id token: %v", err) - } - var got claims - if err := idToken.Claims(&got); err != nil { - return fmt.Errorf("failed to unmarshal claims: %v", err) - } - - if diff := pretty.Compare(want, got); diff != "" { - return fmt.Errorf("got identity != want identity: %s", diff) - } - return nil - }, - }, - } - - for _, tc := range tests { + tests := makeOAuth2Tests(clientID, clientSecret, now) + for _, tc := range tests.tests { func() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -540,7 +563,7 @@ func TestOAuth2CodeFlow(t *testing.T) { t.Errorf("failed to exchange code for token: %v", err) return } - err = tc.handleToken(ctx, p, oauth2Config, token) + err = tc.handleToken(ctx, p, oauth2Config, token, conn) if err != nil { t.Errorf("%s: %v", tc.name, err) } @@ -1253,3 +1276,157 @@ func TestRefreshTokenFlow(t *testing.T) { t.Errorf("Token refreshed with invalid refresh token, error expected.") } } + +// TestOAuth2DeviceFlow runs device flow integration tests against a test server +func TestOAuth2DeviceFlow(t *testing.T) { + clientID := "testclient" + clientSecret := "" + requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"} + + t0 := time.Now() + + // Always have the time function used by the server return the same time so + // we can predict expected values of "expires_in" fields exactly. + now := func() time.Time { return t0 } + + // Connector used by the tests. + var conn *mock.Callback + idTokensValidFor := time.Second * 30 + + for _, tc := range makeOAuth2Tests(clientID, clientSecret, now).tests { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Issuer = c.Issuer + "/non-root-path" + c.Now = now + c.IDTokensValidFor = idTokensValidFor + }) + defer httpServer.Close() + + mockConn := s.connectors["mock"] + conn = mockConn.Connector.(*mock.Callback) + + p, err := oidc.NewProvider(ctx, httpServer.URL) + if err != nil { + t.Fatalf("failed to get provider: %v", err) + } + + //Add the Clients to the test server + client := storage.Client{ + ID: clientID, + RedirectURIs: []string{deviceCallbackURI}, + Public: true, + } + if err := s.storage.CreateClient(client); err != nil { + t.Fatalf("failed to create client: %v", err) + } + + //Grab the issuer that we'll reuse for the different endpoints to hit + issuer, err := url.Parse(s.issuerURL.String()) + if err != nil { + t.Errorf("Could not parse issuer URL %v", err) + } + + //Send a new Device Request + codeURL, _ := url.Parse(issuer.String()) + codeURL.Path = path.Join(codeURL.Path, "device/code") + + data := url.Values{} + data.Set("client_id", clientID) + data.Add("scope", strings.Join(requestedScopes, " ")) + resp, err := http.PostForm(codeURL.String(), data) + if err != nil { + t.Errorf("Could not request device code: %v", err) + } + defer resp.Body.Close() + responseBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("Could read device code response %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) + } + + //Parse the code response + var deviceCode deviceCodeResponse + if err := json.Unmarshal(responseBody, &deviceCode); err != nil { + t.Errorf("Unexpected Device Code Response Format %v", string(responseBody)) + } + + //Mock the user hitting the verification URI and posting the form + verifyURL, _ := url.Parse(issuer.String()) + verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code") + urlData := url.Values{} + urlData.Set("user_code", deviceCode.UserCode) + resp, err = http.PostForm(verifyURL.String(), urlData) + if err != nil { + t.Errorf("Error Posting Form: %v", err) + } + defer resp.Body.Close() + responseBody, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("Could read verification response %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) + } + + //Hit the Token Endpoint, and try and get an access token + tokenURL, _ := url.Parse(issuer.String()) + tokenURL.Path = path.Join(tokenURL.Path, "/device/token") + v := url.Values{} + v.Add("grant_type", grantTypeDeviceCode) + v.Add("device_code", deviceCode.DeviceCode) + resp, err = http.PostForm(tokenURL.String(), v) + if err != nil { + t.Errorf("Could not request device token: %v", err) + } + defer resp.Body.Close() + responseBody, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("Could read device token response %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody)) + } + + //Parse the response + var tokenRes accessTokenReponse + if err := json.Unmarshal(responseBody, &tokenRes); err != nil { + t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody)) + } + + token := &oauth2.Token{ + AccessToken: tokenRes.AccessToken, + TokenType: tokenRes.TokenType, + RefreshToken: tokenRes.RefreshToken, + } + raw := make(map[string]interface{}) + json.Unmarshal(responseBody, &raw) // no error checks for optional fields + token = token.WithExtra(raw) + if secs := tokenRes.ExpiresIn; secs > 0 { + token.Expiry = time.Now().Add(time.Duration(secs) * time.Second) + } + + //Run token tests to validate info is correct + // Create the OAuth2 config. + oauth2Config := &oauth2.Config{ + ClientID: client.ID, + ClientSecret: client.Secret, + Endpoint: p.Endpoint(), + Scopes: requestedScopes, + RedirectURL: deviceCallbackURI, + } + if len(tc.scopes) != 0 { + oauth2Config.Scopes = tc.scopes + } + err = tc.handleToken(ctx, p, oauth2Config, token, conn) + if err != nil { + t.Errorf("%s: %v", tc.name, err) + } + }() + } +} diff --git a/server/templates.go b/server/templates.go index 4947a102..6681b851 100644 --- a/server/templates.go +++ b/server/templates.go @@ -15,11 +15,13 @@ import ( ) const ( - tmplApproval = "approval.html" - tmplLogin = "login.html" - tmplPassword = "password.html" - tmplOOB = "oob.html" - tmplError = "error.html" + tmplApproval = "approval.html" + tmplLogin = "login.html" + tmplPassword = "password.html" + tmplOOB = "oob.html" + tmplError = "error.html" + tmplDevice = "device.html" + tmplDeviceSuccess = "device_success.html" ) var requiredTmpls = []string{ @@ -28,14 +30,18 @@ var requiredTmpls = []string{ tmplPassword, tmplOOB, tmplError, + tmplDevice, + tmplDeviceSuccess, } type templates struct { - loginTmpl *template.Template - approvalTmpl *template.Template - passwordTmpl *template.Template - oobTmpl *template.Template - errorTmpl *template.Template + loginTmpl *template.Template + approvalTmpl *template.Template + passwordTmpl *template.Template + oobTmpl *template.Template + errorTmpl *template.Template + deviceTmpl *template.Template + deviceSuccessTmpl *template.Template } type webConfig struct { @@ -152,11 +158,13 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) { return nil, fmt.Errorf("missing template(s): %s", missingTmpls) } return &templates{ - loginTmpl: tmpls.Lookup(tmplLogin), - approvalTmpl: tmpls.Lookup(tmplApproval), - passwordTmpl: tmpls.Lookup(tmplPassword), - oobTmpl: tmpls.Lookup(tmplOOB), - errorTmpl: tmpls.Lookup(tmplError), + loginTmpl: tmpls.Lookup(tmplLogin), + approvalTmpl: tmpls.Lookup(tmplApproval), + passwordTmpl: tmpls.Lookup(tmplPassword), + oobTmpl: tmpls.Lookup(tmplOOB), + errorTmpl: tmpls.Lookup(tmplError), + deviceTmpl: tmpls.Lookup(tmplDevice), + deviceSuccessTmpl: tmpls.Lookup(tmplDeviceSuccess), }, nil } @@ -242,6 +250,27 @@ func (n byName) Len() int { return len(n) } func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name } func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] } +func (t *templates) device(r *http.Request, w http.ResponseWriter, postURL string, userCode string, lastWasInvalid bool) error { + if lastWasInvalid { + w.WriteHeader(http.StatusBadRequest) + } + data := struct { + PostURL string + UserCode string + Invalid bool + ReqPath string + }{postURL, userCode, lastWasInvalid, r.URL.Path} + return renderTemplate(w, t.deviceTmpl, data) +} + +func (t *templates) deviceSuccess(r *http.Request, w http.ResponseWriter, clientName string) error { + data := struct { + ClientName string + ReqPath string + }{clientName, r.URL.Path} + return renderTemplate(w, t.deviceSuccessTmpl, data) +} + func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error { sort.Sort(byName(connectors)) data := struct { diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 1ac51fc8..a550a530 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -49,6 +49,8 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) { {"ConnectorCRUD", testConnectorCRUD}, {"GarbageCollection", testGC}, {"TimezoneSupport", testTimezones}, + {"DeviceRequestCRUD", testDeviceRequestCRUD}, + {"DeviceTokenCRUD", testDeviceTokenCRUD}, }) } @@ -834,6 +836,87 @@ func testGC(t *testing.T, s storage.Storage) { } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) } + + userCode, err := storage.NewUserCode() + if err != nil { + t.Errorf("Unexpected Error: %v", err) + } + + d := storage.DeviceRequest{ + UserCode: userCode, + DeviceCode: storage.NewID(), + ClientID: "client1", + ClientSecret: "secret1", + Scopes: []string{"openid", "email"}, + Expiry: expiry, + } + + if err := s.CreateDeviceRequest(d); err != nil { + t.Fatalf("failed creating device request: %v", err) + } + + for _, tz := range []*time.Location{time.UTC, est, pst} { + result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) + if err != nil { + t.Errorf("garbage collection failed: %v", err) + } else { + if result.DeviceRequests != 0 { + t.Errorf("expected no device garbage collection results, got %#v", result) + } + } + if _, err := s.GetDeviceRequest(d.UserCode); err != nil { + t.Errorf("expected to be able to get auth request after GC: %v", err) + } + } + if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { + t.Errorf("garbage collection failed: %v", err) + } else if r.DeviceRequests != 1 { + t.Errorf("expected to garbage collect 1 device request, got %d", r.DeviceRequests) + } + + if _, err := s.GetDeviceRequest(d.UserCode); err == nil { + t.Errorf("expected device request to be GC'd") + } else if err != storage.ErrNotFound { + t.Errorf("expected storage.ErrNotFound, got %v", err) + } + + dt := storage.DeviceToken{ + DeviceCode: storage.NewID(), + Status: "pending", + Token: "foo", + Expiry: expiry, + LastRequestTime: time.Now(), + PollIntervalSeconds: 0, + } + + if err := s.CreateDeviceToken(dt); err != nil { + t.Fatalf("failed creating device token: %v", err) + } + + for _, tz := range []*time.Location{time.UTC, est, pst} { + result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) + if err != nil { + t.Errorf("garbage collection failed: %v", err) + } else { + if result.DeviceTokens != 0 { + t.Errorf("expected no device token garbage collection results, got %#v", result) + } + } + if _, err := s.GetDeviceToken(dt.DeviceCode); err != nil { + t.Errorf("expected to be able to get device token after GC: %v", err) + } + } + if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { + t.Errorf("garbage collection failed: %v", err) + } else if r.DeviceTokens != 1 { + t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens) + } + + if _, err := s.GetDeviceToken(dt.DeviceCode); err == nil { + t.Errorf("expected device token to be GC'd") + } else if err != storage.ErrNotFound { + t.Errorf("expected storage.ErrNotFound, got %v", err) + } } // testTimezones tests that backends either fully support timezones or @@ -881,3 +964,72 @@ func testTimezones(t *testing.T, s storage.Storage) { t.Fatalf("expected expiry %v got %v", wantTime, gotTime) } } + +func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { + userCode, err := storage.NewUserCode() + if err != nil { + panic(err) + } + d1 := storage.DeviceRequest{ + UserCode: userCode, + DeviceCode: storage.NewID(), + ClientID: "client1", + ClientSecret: "secret1", + Scopes: []string{"openid", "email"}, + Expiry: neverExpire, + } + + if err := s.CreateDeviceRequest(d1); err != nil { + t.Fatalf("failed creating device request: %v", err) + } + + // Attempt to create same DeviceRequest twice. + err = s.CreateDeviceRequest(d1) + mustBeErrAlreadyExists(t, "device request", err) + + //No manual deletes for device requests, will be handled by garbage collection routines + //see testGC +} + +func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { + //Create a Token + d1 := storage.DeviceToken{ + DeviceCode: storage.NewID(), + Status: "pending", + Token: storage.NewID(), + Expiry: neverExpire, + LastRequestTime: time.Now(), + PollIntervalSeconds: 0, + } + + if err := s.CreateDeviceToken(d1); err != nil { + t.Fatalf("failed creating device token: %v", err) + } + + // Attempt to create same Device Token twice. + err := s.CreateDeviceToken(d1) + mustBeErrAlreadyExists(t, "device token", err) + + //Update the device token, simulate a redemption + if err := s.UpdateDeviceToken(d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) { + old.Token = "token data" + old.Status = "complete" + return old, nil + }); err != nil { + t.Fatalf("failed to update device token: %v", err) + } + + //Retrieve the device token + got, err := s.GetDeviceToken(d1.DeviceCode) + if err != nil { + t.Fatalf("failed to get device token: %v", err) + } + + //Validate expected result set + if got.Status != "complete" { + t.Fatalf("update failed, wanted token status=%v got %v", "complete", got.Status) + } + if got.Token != "token data" { + t.Fatalf("update failed, wanted token %v got %v", "token data", got.Token) + } +} diff --git a/storage/etcd/etcd.go b/storage/etcd/etcd.go index e26ce760..e8abe3d0 100644 --- a/storage/etcd/etcd.go +++ b/storage/etcd/etcd.go @@ -22,6 +22,8 @@ const ( offlineSessionPrefix = "offline_session/" connectorPrefix = "connector/" keysName = "openid-connect-keys" + deviceRequestPrefix = "device_req/" + deviceTokenPrefix = "device_token/" // defaultStorageTimeout will be applied to all storage's operations. defaultStorageTimeout = 5 * time.Second @@ -72,6 +74,36 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error result.AuthCodes++ } } + + deviceRequests, err := c.listDeviceRequests(ctx) + if err != nil { + return result, err + } + + for _, deviceRequest := range deviceRequests { + if now.After(deviceRequest.Expiry) { + if err := c.deleteKey(ctx, keyID(deviceRequestPrefix, deviceRequest.UserCode)); err != nil { + c.logger.Errorf("failed to delete device request %v", err) + delErr = fmt.Errorf("failed to delete device request: %v", err) + } + result.DeviceRequests++ + } + } + + deviceTokens, err := c.listDeviceTokens(ctx) + if err != nil { + return result, err + } + + for _, deviceToken := range deviceTokens { + if now.After(deviceToken.Expiry) { + if err := c.deleteKey(ctx, keyID(deviceTokenPrefix, deviceToken.DeviceCode)); err != nil { + c.logger.Errorf("failed to delete device token %v", err) + delErr = fmt.Errorf("failed to delete device token: %v", err) + } + result.DeviceTokens++ + } + } return result, delErr } @@ -531,3 +563,77 @@ func keyEmail(prefix, email string) string { return prefix + strings.ToLower(ema func keySession(prefix, userID, connID string) string { return prefix + strings.ToLower(userID+"|"+connID) } + +func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d)) +} + +func (c *conn) GetDeviceRequest(userCode string) (r storage.DeviceRequest, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + err = c.getKey(ctx, keyID(deviceRequestPrefix, userCode), &r) + return r, err +} + +func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) { + res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix()) + if err != nil { + return requests, err + } + for _, v := range res.Kvs { + var r DeviceRequest + if err = json.Unmarshal(v.Value, &r); err != nil { + return requests, err + } + requests = append(requests, r) + } + return requests, nil +} + +func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnCreate(ctx, keyID(deviceTokenPrefix, t.DeviceCode), fromStorageDeviceToken(t)) +} + +func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &t) + return t, err +} + +func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) { + res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix()) + if err != nil { + return deviceTokens, err + } + for _, v := range res.Kvs { + var dt DeviceToken + if err = json.Unmarshal(v.Value, &dt); err != nil { + return deviceTokens, err + } + deviceTokens = append(deviceTokens, dt) + } + return deviceTokens, nil +} + +func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnUpdate(ctx, keyID(deviceTokenPrefix, deviceCode), func(currentValue []byte) ([]byte, error) { + var current DeviceToken + if len(currentValue) > 0 { + if err := json.Unmarshal(currentValue, ¤t); err != nil { + return nil, err + } + } + updated, err := updater(toStorageDeviceToken(current)) + if err != nil { + return nil, err + } + return json.Marshal(fromStorageDeviceToken(updated)) + }) +} diff --git a/storage/etcd/etcd_test.go b/storage/etcd/etcd_test.go index 4c17fdf1..122d7dae 100644 --- a/storage/etcd/etcd_test.go +++ b/storage/etcd/etcd_test.go @@ -44,6 +44,8 @@ func cleanDB(c *conn) error { passwordPrefix, offlineSessionPrefix, connectorPrefix, + deviceRequestPrefix, + deviceTokenPrefix, } { _, err := c.db.Delete(ctx, prefix, clientv3.WithPrefix()) if err != nil { diff --git a/storage/etcd/types.go b/storage/etcd/types.go index a16eae8e..def95b55 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -216,3 +216,56 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { } return s } + +// DeviceRequest is a mirrored struct from storage with JSON struct tags +type DeviceRequest struct { + UserCode string `json:"user_code"` + DeviceCode string `json:"device_code"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + Scopes []string `json:"scopes"` + Expiry time.Time `json:"expiry"` +} + +func fromStorageDeviceRequest(d storage.DeviceRequest) DeviceRequest { + return DeviceRequest{ + UserCode: d.UserCode, + DeviceCode: d.DeviceCode, + ClientID: d.ClientID, + ClientSecret: d.ClientSecret, + Scopes: d.Scopes, + Expiry: d.Expiry, + } +} + +// DeviceToken is a mirrored struct from storage with JSON struct tags +type DeviceToken struct { + DeviceCode string `json:"device_code"` + Status string `json:"status"` + Token string `json:"token"` + Expiry time.Time `json:"expiry"` + LastRequestTime time.Time `json:"last_request"` + PollIntervalSeconds int `json:"poll_interval"` +} + +func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken { + return DeviceToken{ + DeviceCode: t.DeviceCode, + Status: t.Status, + Token: t.Token, + Expiry: t.Expiry, + LastRequestTime: t.LastRequestTime, + PollIntervalSeconds: t.PollIntervalSeconds, + } +} + +func toStorageDeviceToken(t DeviceToken) storage.DeviceToken { + return storage.DeviceToken{ + DeviceCode: t.DeviceCode, + Status: t.Status, + Token: t.Token, + Expiry: t.Expiry, + LastRequestTime: t.LastRequestTime, + PollIntervalSeconds: t.PollIntervalSeconds, + } +} diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index 4bdf3dd6..baf1d567 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -21,6 +21,8 @@ const ( kindPassword = "Password" kindOfflineSessions = "OfflineSessions" kindConnector = "Connector" + kindDeviceRequest = "DeviceRequest" + kindDeviceToken = "DeviceToken" ) const ( @@ -32,6 +34,8 @@ const ( resourcePassword = "passwords" resourceOfflineSessions = "offlinesessionses" // Again attempts to pluralize. resourceConnector = "connectors" + resourceDeviceRequest = "devicerequests" + resourceDeviceToken = "devicetokens" ) // Config values for the Kubernetes storage type. @@ -593,5 +597,84 @@ func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err e result.AuthCodes++ } } + + var deviceRequests DeviceRequestList + if err := cli.list(resourceDeviceRequest, &deviceRequests); err != nil { + return result, fmt.Errorf("failed to list device requests: %v", err) + } + + for _, deviceRequest := range deviceRequests.DeviceRequests { + if now.After(deviceRequest.Expiry) { + if err := cli.delete(resourceDeviceRequest, deviceRequest.ObjectMeta.Name); err != nil { + cli.logger.Errorf("failed to delete device request: %v", err) + delErr = fmt.Errorf("failed to delete device request: %v", err) + } + result.DeviceRequests++ + } + } + + var deviceTokens DeviceTokenList + if err := cli.list(resourceDeviceToken, &deviceTokens); err != nil { + return result, fmt.Errorf("failed to list device tokens: %v", err) + } + + for _, deviceToken := range deviceTokens.DeviceTokens { + if now.After(deviceToken.Expiry) { + if err := cli.delete(resourceDeviceToken, deviceToken.ObjectMeta.Name); err != nil { + cli.logger.Errorf("failed to delete device token: %v", err) + delErr = fmt.Errorf("failed to delete device token: %v", err) + } + result.DeviceTokens++ + } + } + + if delErr != nil { + return result, delErr + } return result, delErr } + +func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error { + return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d)) +} + +func (cli *client) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) { + var req DeviceRequest + if err := cli.get(resourceDeviceRequest, strings.ToLower(userCode), &req); err != nil { + return storage.DeviceRequest{}, err + } + return toStorageDeviceRequest(req), nil +} + +func (cli *client) CreateDeviceToken(t storage.DeviceToken) error { + return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t)) +} + +func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) { + var token DeviceToken + if err := cli.get(resourceDeviceToken, deviceCode, &token); err != nil { + return storage.DeviceToken{}, err + } + return toStorageDeviceToken(token), nil +} + +func (cli *client) getDeviceToken(deviceCode string) (t DeviceToken, err error) { + err = cli.get(resourceDeviceToken, deviceCode, &t) + return +} + +func (cli *client) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { + r, err := cli.getDeviceToken(deviceCode) + if err != nil { + return err + } + updated, err := updater(toStorageDeviceToken(r)) + if err != nil { + return err + } + updated.DeviceCode = deviceCode + + newToken := cli.fromStorageDeviceToken(updated) + newToken.ObjectMeta = r.ObjectMeta + return cli.put(resourceDeviceToken, r.ObjectMeta.Name, newToken) +} diff --git a/storage/kubernetes/storage_test.go b/storage/kubernetes/storage_test.go index ea471427..2c9deeb2 100644 --- a/storage/kubernetes/storage_test.go +++ b/storage/kubernetes/storage_test.go @@ -85,6 +85,8 @@ func (s *StorageTestSuite) TestStorage() { for _, resource := range []string{ resourceAuthCode, resourceAuthRequest, + resourceDeviceRequest, + resourceDeviceToken, resourceClient, resourceRefreshToken, resourceKeys, diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 0fbb2907..f856a731 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -143,6 +143,36 @@ var customResourceDefinitions = []k8sapi.CustomResourceDefinition{ }, }, }, + { + ObjectMeta: k8sapi.ObjectMeta{ + Name: "devicerequests.dex.coreos.com", + }, + TypeMeta: crdMeta, + Spec: k8sapi.CustomResourceDefinitionSpec{ + Group: apiGroup, + Version: "v1", + Names: k8sapi.CustomResourceDefinitionNames{ + Plural: "devicerequests", + Singular: "devicerequest", + Kind: "DeviceRequest", + }, + }, + }, + { + ObjectMeta: k8sapi.ObjectMeta{ + Name: "devicetokens.dex.coreos.com", + }, + TypeMeta: crdMeta, + Spec: k8sapi.CustomResourceDefinitionSpec{ + Group: apiGroup, + Version: "v1", + Names: k8sapi.CustomResourceDefinitionNames{ + Plural: "devicetokens", + Singular: "devicetoken", + Kind: "DeviceToken", + }, + }, + }, } // There will only ever be a single keys resource. Maintain this by setting a @@ -635,3 +665,103 @@ type ConnectorList struct { k8sapi.ListMeta `json:"metadata,omitempty"` Connectors []Connector `json:"items"` } + +// DeviceRequest is a mirrored struct from storage with JSON struct tags and +// Kubernetes type metadata. +type DeviceRequest struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ObjectMeta `json:"metadata,omitempty"` + + DeviceCode string `json:"device_code,omitempty"` + ClientID string `json:"client_id,omitempty"` + ClientSecret string `json:"client_secret,omitempty"` + Scopes []string `json:"scopes,omitempty"` + Expiry time.Time `json:"expiry"` +} + +// AuthRequestList is a list of AuthRequests. +type DeviceRequestList struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ListMeta `json:"metadata,omitempty"` + DeviceRequests []DeviceRequest `json:"items"` +} + +func (cli *client) fromStorageDeviceRequest(a storage.DeviceRequest) DeviceRequest { + req := DeviceRequest{ + TypeMeta: k8sapi.TypeMeta{ + Kind: kindDeviceRequest, + APIVersion: cli.apiVersion, + }, + ObjectMeta: k8sapi.ObjectMeta{ + Name: strings.ToLower(a.UserCode), + Namespace: cli.namespace, + }, + DeviceCode: a.DeviceCode, + ClientID: a.ClientID, + ClientSecret: a.ClientSecret, + Scopes: a.Scopes, + Expiry: a.Expiry, + } + return req +} + +func toStorageDeviceRequest(req DeviceRequest) storage.DeviceRequest { + return storage.DeviceRequest{ + UserCode: strings.ToUpper(req.ObjectMeta.Name), + DeviceCode: req.DeviceCode, + ClientID: req.ClientID, + ClientSecret: req.ClientSecret, + Scopes: req.Scopes, + Expiry: req.Expiry, + } +} + +// DeviceToken is a mirrored struct from storage with JSON struct tags and +// Kubernetes type metadata. +type DeviceToken struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ObjectMeta `json:"metadata,omitempty"` + + Status string `json:"status,omitempty"` + Token string `json:"token,omitempty"` + Expiry time.Time `json:"expiry"` + LastRequestTime time.Time `json:"last_request"` + PollIntervalSeconds int `json:"poll_interval"` +} + +// DeviceTokenList is a list of DeviceTokens. +type DeviceTokenList struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ListMeta `json:"metadata,omitempty"` + DeviceTokens []DeviceToken `json:"items"` +} + +func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken { + req := DeviceToken{ + TypeMeta: k8sapi.TypeMeta{ + Kind: kindDeviceToken, + APIVersion: cli.apiVersion, + }, + ObjectMeta: k8sapi.ObjectMeta{ + Name: t.DeviceCode, + Namespace: cli.namespace, + }, + Status: t.Status, + Token: t.Token, + Expiry: t.Expiry, + LastRequestTime: t.LastRequestTime, + PollIntervalSeconds: t.PollIntervalSeconds, + } + return req +} + +func toStorageDeviceToken(t DeviceToken) storage.DeviceToken { + return storage.DeviceToken{ + DeviceCode: t.ObjectMeta.Name, + Status: t.Status, + Token: t.Token, + Expiry: t.Expiry, + LastRequestTime: t.LastRequestTime, + PollIntervalSeconds: t.PollIntervalSeconds, + } +} diff --git a/storage/memory/memory.go b/storage/memory/memory.go index 681d204e..82264205 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -20,6 +20,8 @@ func New(logger log.Logger) storage.Storage { passwords: make(map[string]storage.Password), offlineSessions: make(map[offlineSessionID]storage.OfflineSessions), connectors: make(map[string]storage.Connector), + deviceRequests: make(map[string]storage.DeviceRequest), + deviceTokens: make(map[string]storage.DeviceToken), logger: logger, } } @@ -46,6 +48,8 @@ type memStorage struct { passwords map[string]storage.Password offlineSessions map[offlineSessionID]storage.OfflineSessions connectors map[string]storage.Connector + deviceRequests map[string]storage.DeviceRequest + deviceTokens map[string]storage.DeviceToken keys storage.Keys @@ -79,6 +83,18 @@ func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err result.AuthRequests++ } } + for id, a := range s.deviceRequests { + if now.After(a.Expiry) { + delete(s.deviceRequests, id) + result.DeviceRequests++ + } + } + for id, a := range s.deviceTokens { + if now.After(a.Expiry) { + delete(s.deviceTokens, id) + result.DeviceTokens++ + } + } }) return result, nil } @@ -465,3 +481,61 @@ func (s *memStorage) UpdateConnector(id string, updater func(c storage.Connector }) return } + +func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) { + s.tx(func() { + if _, ok := s.deviceRequests[d.UserCode]; ok { + err = storage.ErrAlreadyExists + } else { + s.deviceRequests[d.UserCode] = d + } + }) + return +} + +func (s *memStorage) GetDeviceRequest(userCode string) (req storage.DeviceRequest, err error) { + s.tx(func() { + var ok bool + if req, ok = s.deviceRequests[userCode]; !ok { + err = storage.ErrNotFound + return + } + }) + return +} + +func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) { + s.tx(func() { + if _, ok := s.deviceTokens[t.DeviceCode]; ok { + err = storage.ErrAlreadyExists + } else { + s.deviceTokens[t.DeviceCode] = t + } + }) + return +} + +func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) { + s.tx(func() { + var ok bool + if t, ok = s.deviceTokens[deviceCode]; !ok { + err = storage.ErrNotFound + return + } + }) + return +} + +func (s *memStorage) UpdateDeviceToken(deviceCode string, updater func(p storage.DeviceToken) (storage.DeviceToken, error)) (err error) { + s.tx(func() { + r, ok := s.deviceTokens[deviceCode] + if !ok { + err = storage.ErrNotFound + return + } + if r, err = updater(r); err == nil { + s.deviceTokens[deviceCode] = r + } + }) + return +} diff --git a/storage/sql/crud.go b/storage/sql/crud.go index e87dc56a..b74b76e1 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -100,6 +100,23 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error if n, err := r.RowsAffected(); err == nil { result.AuthCodes = n } + + r, err = c.Exec(`delete from device_request where expiry < $1`, now) + if err != nil { + return result, fmt.Errorf("gc device_request: %v", err) + } + if n, err := r.RowsAffected(); err == nil { + result.DeviceRequests = n + } + + r, err = c.Exec(`delete from device_token where expiry < $1`, now) + if err != nil { + return result, fmt.Errorf("gc device_token: %v", err) + } + if n, err := r.RowsAffected(); err == nil { + result.DeviceTokens = n + } + return } @@ -867,3 +884,113 @@ func (c *conn) delete(table, field, id string) error { } return nil } + +func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error { + _, err := c.Exec(` + insert into device_request ( + user_code, device_code, client_id, client_secret, scopes, expiry + ) + values ( + $1, $2, $3, $4, $5, $6 + );`, + d.UserCode, d.DeviceCode, d.ClientID, d.ClientSecret, encoder(d.Scopes), d.Expiry, + ) + if err != nil { + if c.alreadyExistsCheck(err) { + return storage.ErrAlreadyExists + } + return fmt.Errorf("insert device request: %v", err) + } + return nil +} + +func (c *conn) CreateDeviceToken(t storage.DeviceToken) error { + _, err := c.Exec(` + insert into device_token ( + device_code, status, token, expiry, last_request, poll_interval + ) + values ( + $1, $2, $3, $4, $5, $6 + );`, + t.DeviceCode, t.Status, t.Token, t.Expiry, t.LastRequestTime, t.PollIntervalSeconds, + ) + if err != nil { + if c.alreadyExistsCheck(err) { + return storage.ErrAlreadyExists + } + return fmt.Errorf("insert device token: %v", err) + } + return nil +} + +func (c *conn) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) { + return getDeviceRequest(c, userCode) +} + +func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) { + err = q.QueryRow(` + select + device_code, client_id, client_secret, scopes, expiry + from device_request where user_code = $1; + `, userCode).Scan( + &d.DeviceCode, &d.ClientID, &d.ClientSecret, decoder(&d.Scopes), &d.Expiry, + ) + if err != nil { + if err == sql.ErrNoRows { + return d, storage.ErrNotFound + } + return d, fmt.Errorf("select device token: %v", err) + } + d.UserCode = userCode + return d, nil +} + +func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) { + return getDeviceToken(c, deviceCode) +} + +func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) { + err = q.QueryRow(` + select + status, token, expiry, last_request, poll_interval + from device_token where device_code = $1; + `, deviceCode).Scan( + &a.Status, &a.Token, &a.Expiry, &a.LastRequestTime, &a.PollIntervalSeconds, + ) + if err != nil { + if err == sql.ErrNoRows { + return a, storage.ErrNotFound + } + return a, fmt.Errorf("select device token: %v", err) + } + a.DeviceCode = deviceCode + return a, nil +} + +func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { + return c.ExecTx(func(tx *trans) error { + r, err := getDeviceToken(tx, deviceCode) + if err != nil { + return err + } + if r, err = updater(r); err != nil { + return err + } + _, err = tx.Exec(` + update device_token + set + status = $1, + token = $2, + last_request = $3, + poll_interval = $4 + where + device_code = $5 + `, + r.Status, r.Token, r.LastRequestTime, r.PollIntervalSeconds, r.DeviceCode, + ) + if err != nil { + return fmt.Errorf("update device token: %v", err) + } + return nil + }) +} diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index dc727535..73934b1b 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -229,4 +229,25 @@ var migrations = []migration{ }, flavor: &flavorMySQL, }, + { + stmts: []string{` + create table device_request ( + user_code text not null primary key, + device_code text not null, + client_id text not null, + client_secret text , + scopes bytea not null, -- JSON array of strings + expiry timestamptz not null + );`, + ` + create table device_token ( + device_code text not null primary key, + status text not null, + token bytea, + expiry timestamptz not null, + last_request timestamptz not null, + poll_interval integer not null + );`, + }, + }, } diff --git a/storage/storage.go b/storage/storage.go index 5bbb2b3f..d11305e2 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -5,6 +5,7 @@ import ( "encoding/base32" "errors" "io" + "math/big" "strings" "time" @@ -24,9 +25,21 @@ var ( // TODO(ericchiang): refactor ID creation onto the storage. var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567") +//Valid characters for user codes +const validUserCharacters = "BCDFGHJKLMNPQRSTVWXZ" + +// NewDeviceCode returns a 32 char alphanumeric cryptographically secure string +func NewDeviceCode() string { + return newSecureID(32) +} + // NewID returns a random string which can be used as an ID for objects. func NewID() string { - buff := make([]byte, 16) // 128 bit random ID. + return newSecureID(16) +} + +func newSecureID(len int) string { + buff := make([]byte, len) // random ID. if _, err := io.ReadFull(rand.Reader, buff); err != nil { panic(err) } @@ -36,8 +49,10 @@ func NewID() string { // GCResult returns the number of objects deleted by garbage collection. type GCResult struct { - AuthRequests int64 - AuthCodes int64 + AuthRequests int64 + AuthCodes int64 + DeviceRequests int64 + DeviceTokens int64 } // Storage is the storage interface used by the server. Implementations are @@ -54,6 +69,8 @@ type Storage interface { CreatePassword(p Password) error CreateOfflineSessions(s OfflineSessions) error CreateConnector(c Connector) error + CreateDeviceRequest(d DeviceRequest) error + CreateDeviceToken(d DeviceToken) error // TODO(ericchiang): return (T, bool, error) so we can indicate not found // requests that way instead of using ErrNotFound. @@ -65,6 +82,8 @@ type Storage interface { GetPassword(email string) (Password, error) GetOfflineSessions(userID string, connID string) (OfflineSessions, error) GetConnector(id string) (Connector, error) + GetDeviceRequest(userCode string) (DeviceRequest, error) + GetDeviceToken(deviceCode string) (DeviceToken, error) ListClients() ([]Client, error) ListRefreshTokens() ([]RefreshToken, error) @@ -101,8 +120,10 @@ type Storage interface { UpdatePassword(email string, updater func(p Password) (Password, error)) error UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error UpdateConnector(id string, updater func(c Connector) (Connector, error)) error + UpdateDeviceToken(deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error - // GarbageCollect deletes all expired AuthCodes and AuthRequests. + // GarbageCollect deletes all expired AuthCodes, + // AuthRequests, DeviceRequests, and DeviceTokens. GarbageCollect(now time.Time) (GCResult, error) } @@ -342,3 +363,49 @@ type Keys struct { // For caching purposes, implementations MUST NOT update keys before this time. NextRotation time.Time } + +// NewUserCode returns a randomized 8 character user code for the device flow. +// No vowels are included to prevent accidental generation of words +func NewUserCode() (string, error) { + code, err := randomString(8) + if err != nil { + return "", err + } + return code[:4] + "-" + code[4:], nil +} + +func randomString(n int) (string, error) { + v := big.NewInt(int64(len(validUserCharacters))) + bytes := make([]byte, n) + for i := 0; i < n; i++ { + c, _ := rand.Int(rand.Reader, v) + bytes[i] = validUserCharacters[c.Int64()] + } + return string(bytes), nil +} + +//DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user +//authenticates using their user code or the expiry time passes. +type DeviceRequest struct { + //The code the user will enter in a browser + UserCode string + //The unique device code for device authentication + DeviceCode string + //The client ID the code is for + ClientID string + //The Client Secret + ClientSecret string + //The scopes the device requests + Scopes []string + //The expire time + Expiry time.Time +} + +type DeviceToken struct { + DeviceCode string + Status string + Token string + Expiry time.Time + LastRequestTime time.Time + PollIntervalSeconds int +} diff --git a/web/templates/device.html b/web/templates/device.html new file mode 100644 index 00000000..674cbdc3 --- /dev/null +++ b/web/templates/device.html @@ -0,0 +1,23 @@ +{{ template "header.html" . }} + +
Return to your device to continue
+