forked from mystiq/dex
Device token api endpoint (#1)
* Added /device/token handler with associated business logic and storage tests. * Use crypto rand for user code Signed-off-by: justin-slowik <justin.slowik@thermofisher.com>
This commit is contained in:
parent
6d343e059b
commit
0d1a0e4129
10 changed files with 163 additions and 49 deletions
|
@ -5,7 +5,6 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"path"
|
"path"
|
||||||
|
@ -1438,9 +1437,8 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||||
case http.MethodPost:
|
case http.MethodPost:
|
||||||
err := r.ParseForm()
|
err := r.ParseForm()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
message := "Could not parse Device Request body"
|
s.logger.Errorf("Could not parse Device Request body: %v", err)
|
||||||
s.logger.Errorf("%s : %v", message, err)
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
|
||||||
respondWithError(w, message, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1454,7 +1452,11 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||||
deviceCode := storage.NewDeviceCode()
|
deviceCode := storage.NewDeviceCode()
|
||||||
|
|
||||||
//make user code
|
//make user code
|
||||||
userCode := storage.NewUserCode()
|
userCode, err := storage.NewUserCode()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("Error generating user code: %v", err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
//make a pkce verification code
|
//make a pkce verification code
|
||||||
pkceCode := storage.NewID()
|
pkceCode := storage.NewID()
|
||||||
|
@ -1473,24 +1475,21 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
|
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
|
||||||
message := fmt.Sprintf("Failed to store device request %v", err)
|
s.logger.Errorf("Failed to store device request; %v", err)
|
||||||
s.logger.Errorf(message)
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||||
respondWithError(w, message, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//Store the device token
|
//Store the device token
|
||||||
deviceToken := storage.DeviceToken{
|
deviceToken := storage.DeviceToken{
|
||||||
DeviceCode: deviceCode,
|
DeviceCode: deviceCode,
|
||||||
Status: "pending",
|
Status: deviceTokenPending,
|
||||||
Token: "",
|
|
||||||
Expiry: expireTime,
|
Expiry: expireTime,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
|
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
|
||||||
message := fmt.Sprintf("Failed to store device token %v", err)
|
s.logger.Errorf("Failed to store device token %v", err)
|
||||||
s.logger.Errorf(message)
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||||
respondWithError(w, message, err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1506,21 +1505,54 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||||
enc.SetIndent("", " ")
|
enc.SetIndent("", " ")
|
||||||
enc.Encode(code)
|
enc.Encode(code)
|
||||||
|
|
||||||
|
default:
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodPost:
|
||||||
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
message := "Could not parse Device Token Request body"
|
||||||
|
s.logger.Warnf("%s : %v", message, err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceCode := r.Form.Get("device_code")
|
||||||
|
if deviceCode == "" {
|
||||||
|
message := "No device code received"
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
grantType := r.PostFormValue("grant_type")
|
||||||
|
if grantType != grantTypeDeviceCode {
|
||||||
|
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//Grab the device token from the db
|
||||||
|
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
||||||
|
if err != nil || s.now().After(deviceToken.Expiry) {
|
||||||
|
if err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get device code: %v", err)
|
||||||
|
}
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch deviceToken.Status {
|
||||||
|
case deviceTokenPending:
|
||||||
|
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
||||||
|
case deviceTokenComplete:
|
||||||
|
w.Write([]byte(deviceToken.Token))
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func respondWithError(w io.Writer, errorMessage string, err error) {
|
|
||||||
resp := struct {
|
|
||||||
Error string `json:"error"`
|
|
||||||
ErrorMessage string `json:"message"`
|
|
||||||
}{
|
|
||||||
Error: err.Error(),
|
|
||||||
ErrorMessage: errorMessage,
|
|
||||||
}
|
|
||||||
|
|
||||||
enc := json.NewEncoder(w)
|
|
||||||
enc.SetIndent("", " ")
|
|
||||||
enc.Encode(resp)
|
|
||||||
}
|
|
||||||
|
|
|
@ -122,6 +122,7 @@ const (
|
||||||
grantTypeAuthorizationCode = "authorization_code"
|
grantTypeAuthorizationCode = "authorization_code"
|
||||||
grantTypeRefreshToken = "refresh_token"
|
grantTypeRefreshToken = "refresh_token"
|
||||||
grantTypePassword = "password"
|
grantTypePassword = "password"
|
||||||
|
grantTypeDeviceCode = "device_code"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -130,6 +131,11 @@ const (
|
||||||
responseTypeIDToken = "id_token" // ID Token in url fragment
|
responseTypeIDToken = "id_token" // ID Token in url fragment
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
deviceTokenPending = "authorization_pending"
|
||||||
|
deviceTokenComplete = "complete"
|
||||||
|
)
|
||||||
|
|
||||||
func parseScopes(scopes []string) connector.Scopes {
|
func parseScopes(scopes []string) connector.Scopes {
|
||||||
var s connector.Scopes
|
var s connector.Scopes
|
||||||
for _, scope := range scopes {
|
for _, scope := range scopes {
|
||||||
|
|
|
@ -303,6 +303,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
||||||
handleFunc("/auth", s.handleAuthorization)
|
handleFunc("/auth", s.handleAuthorization)
|
||||||
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
||||||
handleFunc("/device/code", s.handleDeviceCode)
|
handleFunc("/device/code", s.handleDeviceCode)
|
||||||
|
handleFunc("/device/token", s.handleDeviceToken)
|
||||||
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Strip the X-Remote-* headers to prevent security issues on
|
// Strip the X-Remote-* headers to prevent security issues on
|
||||||
// misconfigured authproxy connector setups.
|
// misconfigured authproxy connector setups.
|
||||||
|
|
|
@ -837,8 +837,13 @@ func testGC(t *testing.T, s storage.Storage) {
|
||||||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userCode, err := storage.NewUserCode()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected Error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
d := storage.DeviceRequest{
|
d := storage.DeviceRequest{
|
||||||
UserCode: storage.NewUserCode(),
|
UserCode: userCode,
|
||||||
DeviceCode: storage.NewID(),
|
DeviceCode: storage.NewID(),
|
||||||
ClientID: "client1",
|
ClientID: "client1",
|
||||||
Scopes: []string{"openid", "email"},
|
Scopes: []string{"openid", "email"},
|
||||||
|
@ -896,9 +901,9 @@ func testGC(t *testing.T, s storage.Storage) {
|
||||||
t.Errorf("expected no device token garbage collection results, got %#v", result)
|
t.Errorf("expected no device token garbage collection results, got %#v", result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//if _, err := s.GetDeviceRequest(d.UserCode); err != nil {
|
if _, err := s.GetDeviceToken(dt.DeviceCode); err != nil {
|
||||||
// t.Errorf("expected to be able to get auth request after GC: %v", err)
|
t.Errorf("expected to be able to get device token after GC: %v", err)
|
||||||
//}
|
}
|
||||||
}
|
}
|
||||||
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
|
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
|
||||||
t.Errorf("garbage collection failed: %v", err)
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
|
@ -906,12 +911,11 @@ func testGC(t *testing.T, s storage.Storage) {
|
||||||
t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens)
|
t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO add this code back once Getters are written for device tokens
|
if _, err := s.GetDeviceToken(dt.DeviceCode); err == nil {
|
||||||
//if _, err := s.GetDeviceRequest(d.UserCode); err == nil {
|
t.Errorf("expected device token to be GC'd")
|
||||||
// t.Errorf("expected device request to be GC'd")
|
} else if err != storage.ErrNotFound {
|
||||||
//} else if err != storage.ErrNotFound {
|
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||||
// t.Errorf("expected storage.ErrNotFound, got %v", err)
|
}
|
||||||
//}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// testTimezones tests that backends either fully support timezones or
|
// testTimezones tests that backends either fully support timezones or
|
||||||
|
@ -961,8 +965,12 @@ func testTimezones(t *testing.T, s storage.Storage) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
|
func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
|
||||||
|
userCode, err := storage.NewUserCode()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
d1 := storage.DeviceRequest{
|
d1 := storage.DeviceRequest{
|
||||||
UserCode: storage.NewUserCode(),
|
UserCode: userCode,
|
||||||
DeviceCode: storage.NewID(),
|
DeviceCode: storage.NewID(),
|
||||||
ClientID: "client1",
|
ClientID: "client1",
|
||||||
Scopes: []string{"openid", "email"},
|
Scopes: []string{"openid", "email"},
|
||||||
|
@ -975,7 +983,7 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt to create same DeviceRequest twice.
|
// Attempt to create same DeviceRequest twice.
|
||||||
err := s.CreateDeviceRequest(d1)
|
err = s.CreateDeviceRequest(d1)
|
||||||
mustBeErrAlreadyExists(t, "device request", err)
|
mustBeErrAlreadyExists(t, "device request", err)
|
||||||
|
|
||||||
//No manual deletes for device requests, will be handled by garbage collection routines
|
//No manual deletes for device requests, will be handled by garbage collection routines
|
||||||
|
|
|
@ -591,6 +591,13 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
|
||||||
return c.txnCreate(ctx, keyID(deviceRequestPrefix, t.DeviceCode), fromStorageDeviceToken(t))
|
return c.txnCreate(ctx, keyID(deviceRequestPrefix, t.DeviceCode), fromStorageDeviceToken(t))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||||
|
defer cancel()
|
||||||
|
err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &t)
|
||||||
|
return t, err
|
||||||
|
}
|
||||||
|
|
||||||
func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) {
|
func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) {
|
||||||
res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix())
|
res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -641,3 +641,11 @@ func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||||
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
|
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
|
||||||
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
|
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||||
|
var token DeviceToken
|
||||||
|
if err := cli.get(resourceDeviceToken, deviceCode, &token); err != nil {
|
||||||
|
return storage.DeviceToken{}, err
|
||||||
|
}
|
||||||
|
return toStorageDeviceToken(token), nil
|
||||||
|
}
|
||||||
|
|
|
@ -739,3 +739,12 @@ func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
|
||||||
}
|
}
|
||||||
return req
|
return req
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
|
||||||
|
return storage.DeviceToken{
|
||||||
|
DeviceCode: t.ObjectMeta.Name,
|
||||||
|
Status: t.Status,
|
||||||
|
Token: t.Token,
|
||||||
|
Expiry: t.Expiry,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -503,3 +503,14 @@ func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
|
||||||
|
s.tx(func() {
|
||||||
|
var ok bool
|
||||||
|
if t, ok = s.deviceTokens[deviceCode]; !ok {
|
||||||
|
err = storage.ErrNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -922,3 +922,25 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||||
|
return getDeviceToken(c, deviceCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
|
||||||
|
err = q.QueryRow(`
|
||||||
|
select
|
||||||
|
status, token, expiry
|
||||||
|
from device_token where device_code = $1;
|
||||||
|
`, deviceCode).Scan(
|
||||||
|
&a.Status, &a.Token, &a.Expiry,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return a, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
return a, fmt.Errorf("select device token: %v", err)
|
||||||
|
}
|
||||||
|
a.DeviceCode = deviceCode
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"encoding/base32"
|
"encoding/base32"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
mrand "math/rand"
|
"math/big"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -25,6 +25,9 @@ var (
|
||||||
// TODO(ericchiang): refactor ID creation onto the storage.
|
// TODO(ericchiang): refactor ID creation onto the storage.
|
||||||
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
|
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
|
||||||
|
|
||||||
|
//Valid characters for user codes
|
||||||
|
const validUserCharacters = "BCDFGHJKLMNPQRSTVWXZ"
|
||||||
|
|
||||||
// NewDeviceCode returns a 32 char alphanumeric cryptographically secure string
|
// NewDeviceCode returns a 32 char alphanumeric cryptographically secure string
|
||||||
func NewDeviceCode() string {
|
func NewDeviceCode() string {
|
||||||
return newSecureID(32)
|
return newSecureID(32)
|
||||||
|
@ -79,6 +82,7 @@ type Storage interface {
|
||||||
GetPassword(email string) (Password, error)
|
GetPassword(email string) (Password, error)
|
||||||
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
|
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
|
||||||
GetConnector(id string) (Connector, error)
|
GetConnector(id string) (Connector, error)
|
||||||
|
GetDeviceToken(deviceCode string) (DeviceToken, error)
|
||||||
|
|
||||||
ListClients() ([]Client, error)
|
ListClients() ([]Client, error)
|
||||||
ListRefreshTokens() ([]RefreshToken, error)
|
ListRefreshTokens() ([]RefreshToken, error)
|
||||||
|
@ -357,18 +361,24 @@ type Keys struct {
|
||||||
NextRotation time.Time
|
NextRotation time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUserCode() string {
|
// NewUserCode returns a randomized 8 character user code for the device flow.
|
||||||
mrand.Seed(time.Now().UnixNano())
|
// No vowels are included to prevent accidental generation of words
|
||||||
return randomString(4) + "-" + randomString(4)
|
func NewUserCode() (string, error) {
|
||||||
|
code, err := randomString(8)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return code[:4] + "-" + code[4:], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func randomString(n int) string {
|
func randomString(n int) (string, error) {
|
||||||
var letter = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
v := big.NewInt(int64(len(validUserCharacters)))
|
||||||
b := make([]rune, n)
|
bytes := make([]byte, n)
|
||||||
for i := range b {
|
for i := 0; i < n; i++ {
|
||||||
b[i] = letter[mrand.Intn(len(letter))]
|
c, _ := rand.Int(rand.Reader, v)
|
||||||
|
bytes[i] = validUserCharacters[c.Int64()]
|
||||||
}
|
}
|
||||||
return string(b)
|
return string(bytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user
|
//DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user
|
||||||
|
|
Loading…
Reference in a new issue