forked from mystiq/dex
Device token api endpoint (#1)
* Added /device/token handler with associated business logic and storage tests. * Use crypto rand for user code Signed-off-by: justin-slowik <justin.slowik@thermofisher.com>
This commit is contained in:
parent
6d343e059b
commit
0d1a0e4129
10 changed files with 163 additions and 49 deletions
|
@ -5,7 +5,6 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
|
@ -1438,9 +1437,8 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
|||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
message := "Could not parse Device Request body"
|
||||
s.logger.Errorf("%s : %v", message, err)
|
||||
respondWithError(w, message, err)
|
||||
s.logger.Errorf("Could not parse Device Request body: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -1454,7 +1452,11 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
|||
deviceCode := storage.NewDeviceCode()
|
||||
|
||||
//make user code
|
||||
userCode := storage.NewUserCode()
|
||||
userCode, err := storage.NewUserCode()
|
||||
if err != nil {
|
||||
s.logger.Errorf("Error generating user code: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
//make a pkce verification code
|
||||
pkceCode := storage.NewID()
|
||||
|
@ -1473,24 +1475,21 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
|
||||
message := fmt.Sprintf("Failed to store device request %v", err)
|
||||
s.logger.Errorf(message)
|
||||
respondWithError(w, message, err)
|
||||
s.logger.Errorf("Failed to store device request; %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
//Store the device token
|
||||
deviceToken := storage.DeviceToken{
|
||||
DeviceCode: deviceCode,
|
||||
Status: "pending",
|
||||
Token: "",
|
||||
Status: deviceTokenPending,
|
||||
Expiry: expireTime,
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
|
||||
message := fmt.Sprintf("Failed to store device token %v", err)
|
||||
s.logger.Errorf(message)
|
||||
respondWithError(w, message, err)
|
||||
s.logger.Errorf("Failed to store device token %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -1506,21 +1505,54 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
|||
enc.SetIndent("", " ")
|
||||
enc.Encode(code)
|
||||
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
message := "Could not parse Device Token Request body"
|
||||
s.logger.Warnf("%s : %v", message, err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
deviceCode := r.Form.Get("device_code")
|
||||
if deviceCode == "" {
|
||||
message := "No device code received"
|
||||
s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
grantType := r.PostFormValue("grant_type")
|
||||
if grantType != grantTypeDeviceCode {
|
||||
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
//Grab the device token from the db
|
||||
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
||||
if err != nil || s.now().After(deviceToken.Expiry) {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device code: %v", err)
|
||||
}
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch deviceToken.Status {
|
||||
case deviceTokenPending:
|
||||
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
||||
case deviceTokenComplete:
|
||||
w.Write([]byte(deviceToken.Token))
|
||||
}
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||
}
|
||||
}
|
||||
|
||||
func respondWithError(w io.Writer, errorMessage string, err error) {
|
||||
resp := struct {
|
||||
Error string `json:"error"`
|
||||
ErrorMessage string `json:"message"`
|
||||
}{
|
||||
Error: err.Error(),
|
||||
ErrorMessage: errorMessage,
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
enc.Encode(resp)
|
||||
}
|
||||
|
|
|
@ -122,6 +122,7 @@ const (
|
|||
grantTypeAuthorizationCode = "authorization_code"
|
||||
grantTypeRefreshToken = "refresh_token"
|
||||
grantTypePassword = "password"
|
||||
grantTypeDeviceCode = "device_code"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -130,6 +131,11 @@ const (
|
|||
responseTypeIDToken = "id_token" // ID Token in url fragment
|
||||
)
|
||||
|
||||
const (
|
||||
deviceTokenPending = "authorization_pending"
|
||||
deviceTokenComplete = "complete"
|
||||
)
|
||||
|
||||
func parseScopes(scopes []string) connector.Scopes {
|
||||
var s connector.Scopes
|
||||
for _, scope := range scopes {
|
||||
|
|
|
@ -303,6 +303,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
|||
handleFunc("/auth", s.handleAuthorization)
|
||||
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
||||
handleFunc("/device/code", s.handleDeviceCode)
|
||||
handleFunc("/device/token", s.handleDeviceToken)
|
||||
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
||||
// Strip the X-Remote-* headers to prevent security issues on
|
||||
// misconfigured authproxy connector setups.
|
||||
|
|
|
@ -837,8 +837,13 @@ func testGC(t *testing.T, s storage.Storage) {
|
|||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
|
||||
userCode, err := storage.NewUserCode()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected Error: %v", err)
|
||||
}
|
||||
|
||||
d := storage.DeviceRequest{
|
||||
UserCode: storage.NewUserCode(),
|
||||
UserCode: userCode,
|
||||
DeviceCode: storage.NewID(),
|
||||
ClientID: "client1",
|
||||
Scopes: []string{"openid", "email"},
|
||||
|
@ -896,9 +901,9 @@ func testGC(t *testing.T, s storage.Storage) {
|
|||
t.Errorf("expected no device token garbage collection results, got %#v", result)
|
||||
}
|
||||
}
|
||||
//if _, err := s.GetDeviceRequest(d.UserCode); err != nil {
|
||||
// t.Errorf("expected to be able to get auth request after GC: %v", err)
|
||||
//}
|
||||
if _, err := s.GetDeviceToken(dt.DeviceCode); err != nil {
|
||||
t.Errorf("expected to be able to get device token after GC: %v", err)
|
||||
}
|
||||
}
|
||||
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
|
||||
t.Errorf("garbage collection failed: %v", err)
|
||||
|
@ -906,12 +911,11 @@ func testGC(t *testing.T, s storage.Storage) {
|
|||
t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens)
|
||||
}
|
||||
|
||||
//TODO add this code back once Getters are written for device tokens
|
||||
//if _, err := s.GetDeviceRequest(d.UserCode); err == nil {
|
||||
// t.Errorf("expected device request to be GC'd")
|
||||
//} else if err != storage.ErrNotFound {
|
||||
// t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||
//}
|
||||
if _, err := s.GetDeviceToken(dt.DeviceCode); err == nil {
|
||||
t.Errorf("expected device token to be GC'd")
|
||||
} else if err != storage.ErrNotFound {
|
||||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// testTimezones tests that backends either fully support timezones or
|
||||
|
@ -961,8 +965,12 @@ func testTimezones(t *testing.T, s storage.Storage) {
|
|||
}
|
||||
|
||||
func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
|
||||
userCode, err := storage.NewUserCode()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
d1 := storage.DeviceRequest{
|
||||
UserCode: storage.NewUserCode(),
|
||||
UserCode: userCode,
|
||||
DeviceCode: storage.NewID(),
|
||||
ClientID: "client1",
|
||||
Scopes: []string{"openid", "email"},
|
||||
|
@ -975,7 +983,7 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
|
|||
}
|
||||
|
||||
// Attempt to create same DeviceRequest twice.
|
||||
err := s.CreateDeviceRequest(d1)
|
||||
err = s.CreateDeviceRequest(d1)
|
||||
mustBeErrAlreadyExists(t, "device request", err)
|
||||
|
||||
//No manual deletes for device requests, will be handled by garbage collection routines
|
||||
|
|
|
@ -591,6 +591,13 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
|
|||
return c.txnCreate(ctx, keyID(deviceRequestPrefix, t.DeviceCode), fromStorageDeviceToken(t))
|
||||
}
|
||||
|
||||
func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &t)
|
||||
return t, err
|
||||
}
|
||||
|
||||
func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) {
|
||||
res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
|
|
|
@ -641,3 +641,11 @@ func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
|
|||
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
|
||||
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
|
||||
}
|
||||
|
||||
func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||
var token DeviceToken
|
||||
if err := cli.get(resourceDeviceToken, deviceCode, &token); err != nil {
|
||||
return storage.DeviceToken{}, err
|
||||
}
|
||||
return toStorageDeviceToken(token), nil
|
||||
}
|
||||
|
|
|
@ -739,3 +739,12 @@ func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
|
|||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
|
||||
return storage.DeviceToken{
|
||||
DeviceCode: t.ObjectMeta.Name,
|
||||
Status: t.Status,
|
||||
Token: t.Token,
|
||||
Expiry: t.Expiry,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -503,3 +503,14 @@ func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
|
|||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
|
||||
s.tx(func() {
|
||||
var ok bool
|
||||
if t, ok = s.deviceTokens[deviceCode]; !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -922,3 +922,25 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||
return getDeviceToken(c, deviceCode)
|
||||
}
|
||||
|
||||
func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
|
||||
err = q.QueryRow(`
|
||||
select
|
||||
status, token, expiry
|
||||
from device_token where device_code = $1;
|
||||
`, deviceCode).Scan(
|
||||
&a.Status, &a.Token, &a.Expiry,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return a, storage.ErrNotFound
|
||||
}
|
||||
return a, fmt.Errorf("select device token: %v", err)
|
||||
}
|
||||
a.DeviceCode = deviceCode
|
||||
return a, nil
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"encoding/base32"
|
||||
"errors"
|
||||
"io"
|
||||
mrand "math/rand"
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -25,6 +25,9 @@ var (
|
|||
// TODO(ericchiang): refactor ID creation onto the storage.
|
||||
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
|
||||
|
||||
//Valid characters for user codes
|
||||
const validUserCharacters = "BCDFGHJKLMNPQRSTVWXZ"
|
||||
|
||||
// NewDeviceCode returns a 32 char alphanumeric cryptographically secure string
|
||||
func NewDeviceCode() string {
|
||||
return newSecureID(32)
|
||||
|
@ -79,6 +82,7 @@ type Storage interface {
|
|||
GetPassword(email string) (Password, error)
|
||||
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
|
||||
GetConnector(id string) (Connector, error)
|
||||
GetDeviceToken(deviceCode string) (DeviceToken, error)
|
||||
|
||||
ListClients() ([]Client, error)
|
||||
ListRefreshTokens() ([]RefreshToken, error)
|
||||
|
@ -357,18 +361,24 @@ type Keys struct {
|
|||
NextRotation time.Time
|
||||
}
|
||||
|
||||
func NewUserCode() string {
|
||||
mrand.Seed(time.Now().UnixNano())
|
||||
return randomString(4) + "-" + randomString(4)
|
||||
// NewUserCode returns a randomized 8 character user code for the device flow.
|
||||
// No vowels are included to prevent accidental generation of words
|
||||
func NewUserCode() (string, error) {
|
||||
code, err := randomString(8)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return code[:4] + "-" + code[4:], nil
|
||||
}
|
||||
|
||||
func randomString(n int) string {
|
||||
var letter = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
b := make([]rune, n)
|
||||
for i := range b {
|
||||
b[i] = letter[mrand.Intn(len(letter))]
|
||||
func randomString(n int) (string, error) {
|
||||
v := big.NewInt(int64(len(validUserCharacters)))
|
||||
bytes := make([]byte, n)
|
||||
for i := 0; i < n; i++ {
|
||||
c, _ := rand.Int(rand.Reader, v)
|
||||
bytes[i] = validUserCharacters[c.Int64()]
|
||||
}
|
||||
return string(b)
|
||||
return string(bytes), nil
|
||||
}
|
||||
|
||||
//DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user
|
||||
|
|
Loading…
Reference in a new issue