Device token api endpoint (#1)

* Added /device/token handler with associated business logic and storage tests.

* Use crypto rand for user code

Signed-off-by: justin-slowik <justin.slowik@thermofisher.com>
This commit is contained in:
Justin Slowik 2020-01-27 10:35:37 -05:00 committed by justin-slowik
parent 6d343e059b
commit 0d1a0e4129
10 changed files with 163 additions and 49 deletions

View file

@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
@ -1438,9 +1437,8 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
case http.MethodPost: case http.MethodPost:
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
message := "Could not parse Device Request body" s.logger.Errorf("Could not parse Device Request body: %v", err)
s.logger.Errorf("%s : %v", message, err) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
respondWithError(w, message, err)
return return
} }
@ -1454,7 +1452,11 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
deviceCode := storage.NewDeviceCode() deviceCode := storage.NewDeviceCode()
//make user code //make user code
userCode := storage.NewUserCode() userCode, err := storage.NewUserCode()
if err != nil {
s.logger.Errorf("Error generating user code: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
}
//make a pkce verification code //make a pkce verification code
pkceCode := storage.NewID() pkceCode := storage.NewID()
@ -1473,24 +1475,21 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
} }
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil { if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
message := fmt.Sprintf("Failed to store device request %v", err) s.logger.Errorf("Failed to store device request; %v", err)
s.logger.Errorf(message) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
respondWithError(w, message, err)
return return
} }
//Store the device token //Store the device token
deviceToken := storage.DeviceToken{ deviceToken := storage.DeviceToken{
DeviceCode: deviceCode, DeviceCode: deviceCode,
Status: "pending", Status: deviceTokenPending,
Token: "",
Expiry: expireTime, Expiry: expireTime,
} }
if err := s.storage.CreateDeviceToken(deviceToken); err != nil { if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
message := fmt.Sprintf("Failed to store device token %v", err) s.logger.Errorf("Failed to store device token %v", err)
s.logger.Errorf(message) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
respondWithError(w, message, err)
return return
} }
@ -1506,21 +1505,54 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
enc.SetIndent("", " ") enc.SetIndent("", " ")
enc.Encode(code) enc.Encode(code)
default:
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
}
}
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
message := "Could not parse Device Token Request body"
s.logger.Warnf("%s : %v", message, err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
return
}
deviceCode := r.Form.Get("device_code")
if deviceCode == "" {
message := "No device code received"
s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest)
return
}
grantType := r.PostFormValue("grant_type")
if grantType != grantTypeDeviceCode {
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
return
}
//Grab the device token from the db
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
if err != nil || s.now().After(deviceToken.Expiry) {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get device code: %v", err)
}
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest)
return
}
switch deviceToken.Status {
case deviceTokenPending:
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
case deviceTokenComplete:
w.Write([]byte(deviceToken.Token))
}
default: default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
} }
} }
func respondWithError(w io.Writer, errorMessage string, err error) {
resp := struct {
Error string `json:"error"`
ErrorMessage string `json:"message"`
}{
Error: err.Error(),
ErrorMessage: errorMessage,
}
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
enc.Encode(resp)
}

View file

@ -122,6 +122,7 @@ const (
grantTypeAuthorizationCode = "authorization_code" grantTypeAuthorizationCode = "authorization_code"
grantTypeRefreshToken = "refresh_token" grantTypeRefreshToken = "refresh_token"
grantTypePassword = "password" grantTypePassword = "password"
grantTypeDeviceCode = "device_code"
) )
const ( const (
@ -130,6 +131,11 @@ const (
responseTypeIDToken = "id_token" // ID Token in url fragment responseTypeIDToken = "id_token" // ID Token in url fragment
) )
const (
deviceTokenPending = "authorization_pending"
deviceTokenComplete = "complete"
)
func parseScopes(scopes []string) connector.Scopes { func parseScopes(scopes []string) connector.Scopes {
var s connector.Scopes var s connector.Scopes
for _, scope := range scopes { for _, scope := range scopes {

View file

@ -303,6 +303,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
handleFunc("/auth", s.handleAuthorization) handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin) handleFunc("/auth/{connector}", s.handleConnectorLogin)
handleFunc("/device/code", s.handleDeviceCode) handleFunc("/device/code", s.handleDeviceCode)
handleFunc("/device/token", s.handleDeviceToken)
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
// Strip the X-Remote-* headers to prevent security issues on // Strip the X-Remote-* headers to prevent security issues on
// misconfigured authproxy connector setups. // misconfigured authproxy connector setups.

View file

@ -837,8 +837,13 @@ func testGC(t *testing.T, s storage.Storage) {
t.Errorf("expected storage.ErrNotFound, got %v", err) t.Errorf("expected storage.ErrNotFound, got %v", err)
} }
userCode, err := storage.NewUserCode()
if err != nil {
t.Errorf("Unexpected Error: %v", err)
}
d := storage.DeviceRequest{ d := storage.DeviceRequest{
UserCode: storage.NewUserCode(), UserCode: userCode,
DeviceCode: storage.NewID(), DeviceCode: storage.NewID(),
ClientID: "client1", ClientID: "client1",
Scopes: []string{"openid", "email"}, Scopes: []string{"openid", "email"},
@ -896,9 +901,9 @@ func testGC(t *testing.T, s storage.Storage) {
t.Errorf("expected no device token garbage collection results, got %#v", result) t.Errorf("expected no device token garbage collection results, got %#v", result)
} }
} }
//if _, err := s.GetDeviceRequest(d.UserCode); err != nil { if _, err := s.GetDeviceToken(dt.DeviceCode); err != nil {
// t.Errorf("expected to be able to get auth request after GC: %v", err) t.Errorf("expected to be able to get device token after GC: %v", err)
//} }
} }
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err) t.Errorf("garbage collection failed: %v", err)
@ -906,12 +911,11 @@ func testGC(t *testing.T, s storage.Storage) {
t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens) t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens)
} }
//TODO add this code back once Getters are written for device tokens if _, err := s.GetDeviceToken(dt.DeviceCode); err == nil {
//if _, err := s.GetDeviceRequest(d.UserCode); err == nil { t.Errorf("expected device token to be GC'd")
// t.Errorf("expected device request to be GC'd") } else if err != storage.ErrNotFound {
//} else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err)
// t.Errorf("expected storage.ErrNotFound, got %v", err) }
//}
} }
// testTimezones tests that backends either fully support timezones or // testTimezones tests that backends either fully support timezones or
@ -961,8 +965,12 @@ func testTimezones(t *testing.T, s storage.Storage) {
} }
func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
userCode, err := storage.NewUserCode()
if err != nil {
panic(err)
}
d1 := storage.DeviceRequest{ d1 := storage.DeviceRequest{
UserCode: storage.NewUserCode(), UserCode: userCode,
DeviceCode: storage.NewID(), DeviceCode: storage.NewID(),
ClientID: "client1", ClientID: "client1",
Scopes: []string{"openid", "email"}, Scopes: []string{"openid", "email"},
@ -975,7 +983,7 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
} }
// Attempt to create same DeviceRequest twice. // Attempt to create same DeviceRequest twice.
err := s.CreateDeviceRequest(d1) err = s.CreateDeviceRequest(d1)
mustBeErrAlreadyExists(t, "device request", err) mustBeErrAlreadyExists(t, "device request", err)
//No manual deletes for device requests, will be handled by garbage collection routines //No manual deletes for device requests, will be handled by garbage collection routines

View file

@ -591,6 +591,13 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
return c.txnCreate(ctx, keyID(deviceRequestPrefix, t.DeviceCode), fromStorageDeviceToken(t)) return c.txnCreate(ctx, keyID(deviceRequestPrefix, t.DeviceCode), fromStorageDeviceToken(t))
} }
func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &t)
return t, err
}
func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) { func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) {
res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix()) res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix())
if err != nil { if err != nil {

View file

@ -641,3 +641,11 @@ func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error { func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t)) return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
} }
func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
var token DeviceToken
if err := cli.get(resourceDeviceToken, deviceCode, &token); err != nil {
return storage.DeviceToken{}, err
}
return toStorageDeviceToken(token), nil
}

View file

@ -739,3 +739,12 @@ func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
} }
return req return req
} }
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
return storage.DeviceToken{
DeviceCode: t.ObjectMeta.Name,
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
}
}

View file

@ -503,3 +503,14 @@ func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
}) })
return return
} }
func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
s.tx(func() {
var ok bool
if t, ok = s.deviceTokens[deviceCode]; !ok {
err = storage.ErrNotFound
return
}
})
return
}

View file

@ -922,3 +922,25 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
} }
return nil return nil
} }
func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
return getDeviceToken(c, deviceCode)
}
func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
err = q.QueryRow(`
select
status, token, expiry
from device_token where device_code = $1;
`, deviceCode).Scan(
&a.Status, &a.Token, &a.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
return a, storage.ErrNotFound
}
return a, fmt.Errorf("select device token: %v", err)
}
a.DeviceCode = deviceCode
return a, nil
}

View file

@ -5,7 +5,7 @@ import (
"encoding/base32" "encoding/base32"
"errors" "errors"
"io" "io"
mrand "math/rand" "math/big"
"strings" "strings"
"time" "time"
@ -25,6 +25,9 @@ var (
// TODO(ericchiang): refactor ID creation onto the storage. // TODO(ericchiang): refactor ID creation onto the storage.
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567") var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
//Valid characters for user codes
const validUserCharacters = "BCDFGHJKLMNPQRSTVWXZ"
// NewDeviceCode returns a 32 char alphanumeric cryptographically secure string // NewDeviceCode returns a 32 char alphanumeric cryptographically secure string
func NewDeviceCode() string { func NewDeviceCode() string {
return newSecureID(32) return newSecureID(32)
@ -79,6 +82,7 @@ type Storage interface {
GetPassword(email string) (Password, error) GetPassword(email string) (Password, error)
GetOfflineSessions(userID string, connID string) (OfflineSessions, error) GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
GetConnector(id string) (Connector, error) GetConnector(id string) (Connector, error)
GetDeviceToken(deviceCode string) (DeviceToken, error)
ListClients() ([]Client, error) ListClients() ([]Client, error)
ListRefreshTokens() ([]RefreshToken, error) ListRefreshTokens() ([]RefreshToken, error)
@ -357,18 +361,24 @@ type Keys struct {
NextRotation time.Time NextRotation time.Time
} }
func NewUserCode() string { // NewUserCode returns a randomized 8 character user code for the device flow.
mrand.Seed(time.Now().UnixNano()) // No vowels are included to prevent accidental generation of words
return randomString(4) + "-" + randomString(4) func NewUserCode() (string, error) {
code, err := randomString(8)
if err != nil {
return "", err
}
return code[:4] + "-" + code[4:], nil
} }
func randomString(n int) string { func randomString(n int) (string, error) {
var letter = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ") v := big.NewInt(int64(len(validUserCharacters)))
b := make([]rune, n) bytes := make([]byte, n)
for i := range b { for i := 0; i < n; i++ {
b[i] = letter[mrand.Intn(len(letter))] c, _ := rand.Int(rand.Reader, v)
bytes[i] = validUserCharacters[c.Int64()]
} }
return string(b) return string(bytes), nil
} }
//DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user //DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user