forked from mystiq/dex
73 lines
1.2 KiB
Go
73 lines
1.2 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"crypto/rand"
|
||
|
"encoding/base64"
|
||
|
"errors"
|
||
|
"io"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"golang.org/x/net/context"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
gcInterval = time.Minute
|
||
|
expiresIn = time.Minute * 10
|
||
|
)
|
||
|
|
||
|
type memNonceSource struct {
|
||
|
mu sync.Mutex
|
||
|
nonces map[string]time.Time
|
||
|
}
|
||
|
|
||
|
func newNonceSource(ctx context.Context) *memNonceSource {
|
||
|
s := &memNonceSource{nonces: make(map[string]time.Time)}
|
||
|
go s.garbageCollect(ctx)
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
func (s *memNonceSource) Nonce() (string, error) {
|
||
|
buff := make([]byte, 32)
|
||
|
if _, err := io.ReadFull(rand.Reader, buff); err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
nonce := base64.RawURLEncoding.EncodeToString(buff)
|
||
|
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
s.nonces[nonce] = time.Now().Add(expiresIn)
|
||
|
|
||
|
return nonce, nil
|
||
|
}
|
||
|
|
||
|
func (s *memNonceSource) ClaimNonce(nonce string) error {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
|
||
|
if _, ok := s.nonces[nonce]; ok {
|
||
|
delete(s.nonces, nonce)
|
||
|
return nil
|
||
|
}
|
||
|
return errors.New("invalid nonce")
|
||
|
}
|
||
|
|
||
|
func (s *memNonceSource) garbageCollect(ctx context.Context) {
|
||
|
for {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
case <-time.After(gcInterval):
|
||
|
s.mu.Lock()
|
||
|
now := time.Now()
|
||
|
|
||
|
for nonce, exp := range s.nonces {
|
||
|
if now.After(exp) {
|
||
|
delete(s.nonces, nonce)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
s.mu.Unlock()
|
||
|
}
|
||
|
}
|
||
|
}
|