Implement connection pooling for LDAP connections

Fixes #309
This commit is contained in:
Frode Nordahl 2016-02-15 12:53:19 +01:00
parent f976fa1d3b
commit e531dd6be5
2 changed files with 169 additions and 30 deletions

View file

@ -167,6 +167,8 @@ In addition to `id` and `type`, the `ldap` connector takes the following additio
* skipCertVerification: a `boolean`. Skip server certificate chain verification. * skipCertVerification: a `boolean`. Skip server certificate chain verification.
* maxIdleConn: a `integer`. Maximum number of idle LDAP Connections to keep in connection pool. Default: `5`
* baseDN: a `string`. Base DN from which Bind DN is built and searches are based. * baseDN: a `string`. Base DN from which Bind DN is built and searches are based.
* nameAttribute: a `string`. Attribute to map to Name. Default: `cn` * nameAttribute: a `string`. Attribute to map to Name. Default: `cn`

View file

@ -12,6 +12,7 @@ import (
"net/url" "net/url"
"path" "path"
"strings" "strings"
"sync"
"time" "time"
"github.com/coreos/dex/pkg/log" "github.com/coreos/dex/pkg/log"
@ -40,6 +41,7 @@ type LDAPConnectorConfig struct {
KeyFile string `json:"keyFile"` KeyFile string `json:"keyFile"`
CaFile string `json:"caFile"` CaFile string `json:"caFile"`
SkipCertVerification bool `json:"skipCertVerification"` SkipCertVerification bool `json:"skipCertVerification"`
MaxIdleConn int `json:"maxIdleConn"`
BaseDN string `json:"baseDN"` BaseDN string `json:"baseDN"`
NameAttribute string `json:"nameAttribute"` NameAttribute string `json:"nameAttribute"`
EmailAttribute string `json:"emailAttribute"` EmailAttribute string `json:"emailAttribute"`
@ -81,6 +83,8 @@ func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *t
const defaultEmailAttribute = "mail" const defaultEmailAttribute = "mail"
const defaultBindTemplate = "uid=%u,%b" const defaultBindTemplate = "uid=%u,%b"
const defaultSearchScope = ldap.ScopeWholeSubtree const defaultSearchScope = ldap.ScopeWholeSubtree
const defaultMaxIdleConns = 5
const defaultPoolCheckTimer = 7200 * time.Second
if cfg.UseTLS && cfg.UseSSL { if cfg.UseTLS && cfg.UseSSL {
return nil, fmt.Errorf("Invalid configuration. useTLS and useSSL are mutual exclusive.") return nil, fmt.Errorf("Invalid configuration. useTLS and useSSL are mutual exclusive.")
@ -154,11 +158,22 @@ func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *t
tlsConfig.Certificates = []tls.Certificate{cert} tlsConfig.Certificates = []tls.Certificate{cert}
} }
maxIdleConn := defaultMaxIdleConns
if cfg.MaxIdleConn > 0 {
maxIdleConn = cfg.MaxIdleConn
}
ldapPool := &LDAPPool{
MaxIdleConn: maxIdleConn,
PoolCheckTimer: defaultPoolCheckTimer,
ServerHost: cfg.ServerHost,
ServerPort: cfg.ServerPort,
UseTLS: cfg.UseTLS,
UseSSL: cfg.UseSSL,
TLSConfig: tlsConfig,
}
idp := &LDAPIdentityProvider{ idp := &LDAPIdentityProvider{
serverHost: cfg.ServerHost,
serverPort: cfg.ServerPort,
useTLS: cfg.UseTLS,
useSSL: cfg.UseSSL,
baseDN: cfg.BaseDN, baseDN: cfg.BaseDN,
nameAttribute: nameAttribute, nameAttribute: nameAttribute,
emailAttribute: emailAttribute, emailAttribute: emailAttribute,
@ -168,7 +183,7 @@ func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *t
searchBindDN: cfg.SearchBindDN, searchBindDN: cfg.SearchBindDN,
searchBindPw: cfg.SearchBindPw, searchBindPw: cfg.SearchBindPw,
bindTemplate: bindTemplate, bindTemplate: bindTemplate,
tlsConfig: tlsConfig, ldapPool: ldapPool,
} }
idpc := &LDAPConnector{ idpc := &LDAPConnector{
@ -188,9 +203,9 @@ func (c *LDAPConnector) ID() string {
} }
func (c *LDAPConnector) Healthy() error { func (c *LDAPConnector) Healthy() error {
ldapConn, err := c.idp.LDAPConnect() ldapConn, err := c.idp.ldapPool.Acquire()
if err == nil { if err == nil {
ldapConn.Close() c.idp.ldapPool.Put(ldapConn)
} }
return err return err
} }
@ -210,18 +225,145 @@ func (c *LDAPConnector) Register(mux *http.ServeMux, errorURL url.URL) {
} }
func (c *LDAPConnector) Sync() chan struct{} { func (c *LDAPConnector) Sync() chan struct{} {
return make(chan struct{}) stop := make(chan struct{})
go func() {
for {
select {
case <-time.After(c.idp.ldapPool.PoolCheckTimer):
alive, killed := c.idp.ldapPool.CheckConnections()
if alive > 0 {
log.Infof("Connector ID=%v idle_conns=%v", c.id, alive)
}
if killed > 0 {
log.Warningf("Connector ID=%v closed %v dead connections.", c.id, killed)
}
case <-stop:
return
}
}
}()
return stop
} }
func (c *LDAPConnector) TrustedEmailProvider() bool { func (c *LDAPConnector) TrustedEmailProvider() bool {
return c.trustedEmailProvider return c.trustedEmailProvider
} }
// A LDAPPool is a Connection Pool for LDAP connections
// Initialize exported fields and use Acquire() to get a connection.
// Use Put() to put it back into the pool.
type LDAPPool struct {
m sync.Mutex
conns map[*ldap.Conn]struct{}
MaxIdleConn int
PoolCheckTimer time.Duration
ServerHost string
ServerPort uint16
UseTLS bool
UseSSL bool
TLSConfig *tls.Config
}
// Acquire removes and returns a random connection from the pool. A new connection is returned
// if there are no connections available in the pool.
func (p *LDAPPool) Acquire() (*ldap.Conn, error) {
conn := p.removeRandomConn()
if conn != nil {
return conn, nil
}
return p.ldapConnect()
}
// Put makes a connection ready for re-use and puts it back into the pool. If the connection
// cannot be reused it is discarded. If there already are MaxIdleConn connections in the pool
// the connection is discarded.
func (p *LDAPPool) Put(c *ldap.Conn) {
p.m.Lock()
if p.conns == nil {
// First call to Put, initialize map
p.conns = make(map[*ldap.Conn]struct{})
}
if len(p.conns)+1 > p.MaxIdleConn {
p.m.Unlock()
c.Close()
return
}
p.m.Unlock()
// drop to anonymous bind
err := c.Bind("", "")
if err != nil {
// unsupported or disallowed, throw away connection
log.Warningf("Unable to re-use LDAP Connection after failure to bind anonymously: %v", err)
c.Close()
return
}
p.m.Lock()
p.conns[c] = struct{}{}
p.m.Unlock()
}
// removeConn attempts to remove the provided connection from the pool. If removeConn returns false
// another routine is using the connection and the caller should discard the pointer.
func (p *LDAPPool) removeConn(conn *ldap.Conn) bool {
p.m.Lock()
_, ok := p.conns[conn]
delete(p.conns, conn)
p.m.Unlock()
return ok
}
// removeRandomConn attempts to remove a random connection from the pool. If removeRandomConn
// returns nil the pool is empty.
func (p *LDAPPool) removeRandomConn() *ldap.Conn {
p.m.Lock()
defer p.m.Unlock()
for conn := range p.conns {
delete(p.conns, conn)
return conn
}
return nil
}
// CheckConnections attempts to iterate over all the connections in the pool and check wheter
// they are alive or not. Live connections are put back into the pool, dead ones are discarded.
func (p *LDAPPool) CheckConnections() (int, int) {
var conns []*ldap.Conn
var alive, killed int
// Get snapshot of connection-map while holding Lock
p.m.Lock()
for conn := range p.conns {
conns = append(conns, conn)
}
p.m.Unlock()
// Iterate over snapshot, Get and ping connections.
// Put live connections back into pool, Close dead ones.
for _, conn := range conns {
ok := p.removeConn(conn)
if ok {
err := ldapPing(conn)
if err == nil {
p.Put(conn)
alive++
} else {
conn.Close()
killed++
}
}
}
return alive, killed
}
func ldapPing(conn *ldap.Conn) error {
// Query root DSE
s := ldap.NewSearchRequest("", ldap.ScopeBaseObject, ldap.NeverDerefAliases, 0, 0, false, "(objectClass=*)", []string{}, nil)
_, err := conn.Search(s)
return err
}
type LDAPIdentityProvider struct { type LDAPIdentityProvider struct {
serverHost string
serverPort uint16
useTLS bool
useSSL bool
baseDN string baseDN string
nameAttribute string nameAttribute string
emailAttribute string emailAttribute string
@ -231,26 +373,26 @@ type LDAPIdentityProvider struct {
searchBindDN string searchBindDN string
searchBindPw string searchBindPw string
bindTemplate string bindTemplate string
tlsConfig *tls.Config ldapPool *LDAPPool
} }
func (m *LDAPIdentityProvider) LDAPConnect() (*ldap.Conn, error) { func (p *LDAPPool) ldapConnect() (*ldap.Conn, error) {
var err error var err error
var ldapConn *ldap.Conn var ldapConn *ldap.Conn
log.Debugf("LDAPConnect()") log.Debugf("LDAPConnect()")
if m.useSSL { if p.UseSSL {
ldapConn, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", m.serverHost, m.serverPort), m.tlsConfig) ldapConn, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", p.ServerHost, p.ServerPort), p.TLSConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else { } else {
ldapConn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", m.serverHost, m.serverPort)) ldapConn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", p.ServerHost, p.ServerPort))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if m.useTLS { if p.UseTLS {
err = ldapConn.StartTLS(m.tlsConfig) err = ldapConn.StartTLS(p.TLSConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -273,11 +415,11 @@ func (m *LDAPIdentityProvider) Identity(username, password string) (*oidc.Identi
var bindDN, ldapUid, ldapName, ldapEmail string var bindDN, ldapUid, ldapName, ldapEmail string
var ldapConn *ldap.Conn var ldapConn *ldap.Conn
ldapConn, err = m.LDAPConnect() ldapConn, err = m.ldapPool.Acquire()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer ldapConn.Close() defer m.ldapPool.Put(ldapConn)
if m.searchBeforeAuth { if m.searchBeforeAuth {
err = ldapConn.Bind(m.searchBindDN, m.searchBindPw) err = ldapConn.Bind(m.searchBindDN, m.searchBindPw)
@ -307,16 +449,11 @@ func (m *LDAPIdentityProvider) Identity(username, password string) (*oidc.Identi
ldapName = sr.Entries[0].GetAttributeValue(m.nameAttribute) ldapName = sr.Entries[0].GetAttributeValue(m.nameAttribute)
ldapEmail = sr.Entries[0].GetAttributeValue(m.emailAttribute) ldapEmail = sr.Entries[0].GetAttributeValue(m.emailAttribute)
// drop to anonymous bind, prepare for bind as user // prepare LDAP connection for bind as user
err = ldapConn.Bind("", "") m.ldapPool.Put(ldapConn)
ldapConn, err = m.ldapPool.Acquire()
if err != nil { if err != nil {
// unsupported or disallowed, reconnect return nil, err
log.Warningf("Re-connecting to LDAP Server after failure to bind anonymously: %v", err)
ldapConn.Close()
ldapConn, err = m.LDAPConnect()
if err != nil {
return nil, err
}
} }
} else { } else {
bindDN = m.ParseString(m.bindTemplate, username) bindDN = m.ParseString(m.bindTemplate, username)