From e531dd6be57ac12f35c254d4f5cbaa7cf5aadfc5 Mon Sep 17 00:00:00 2001 From: Frode Nordahl Date: Mon, 15 Feb 2016 12:53:19 +0100 Subject: [PATCH] Implement connection pooling for LDAP connections Fixes #309 --- Documentation/connectors-configuration.md | 2 + connector/connector_ldap.go | 197 ++++++++++++++++++---- 2 files changed, 169 insertions(+), 30 deletions(-) diff --git a/Documentation/connectors-configuration.md b/Documentation/connectors-configuration.md index 3637f0a6..d2c6f4be 100644 --- a/Documentation/connectors-configuration.md +++ b/Documentation/connectors-configuration.md @@ -167,6 +167,8 @@ In addition to `id` and `type`, the `ldap` connector takes the following additio * skipCertVerification: a `boolean`. Skip server certificate chain verification. +* maxIdleConn: a `integer`. Maximum number of idle LDAP Connections to keep in connection pool. Default: `5` + * baseDN: a `string`. Base DN from which Bind DN is built and searches are based. * nameAttribute: a `string`. Attribute to map to Name. Default: `cn` diff --git a/connector/connector_ldap.go b/connector/connector_ldap.go index 0e3d05aa..af2fb91f 100644 --- a/connector/connector_ldap.go +++ b/connector/connector_ldap.go @@ -12,6 +12,7 @@ import ( "net/url" "path" "strings" + "sync" "time" "github.com/coreos/dex/pkg/log" @@ -40,6 +41,7 @@ type LDAPConnectorConfig struct { KeyFile string `json:"keyFile"` CaFile string `json:"caFile"` SkipCertVerification bool `json:"skipCertVerification"` + MaxIdleConn int `json:"maxIdleConn"` BaseDN string `json:"baseDN"` NameAttribute string `json:"nameAttribute"` EmailAttribute string `json:"emailAttribute"` @@ -81,6 +83,8 @@ func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *t const defaultEmailAttribute = "mail" const defaultBindTemplate = "uid=%u,%b" const defaultSearchScope = ldap.ScopeWholeSubtree + const defaultMaxIdleConns = 5 + const defaultPoolCheckTimer = 7200 * time.Second if cfg.UseTLS && cfg.UseSSL { return nil, fmt.Errorf("Invalid configuration. useTLS and useSSL are mutual exclusive.") @@ -154,11 +158,22 @@ func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *t tlsConfig.Certificates = []tls.Certificate{cert} } + maxIdleConn := defaultMaxIdleConns + if cfg.MaxIdleConn > 0 { + maxIdleConn = cfg.MaxIdleConn + } + + ldapPool := &LDAPPool{ + MaxIdleConn: maxIdleConn, + PoolCheckTimer: defaultPoolCheckTimer, + ServerHost: cfg.ServerHost, + ServerPort: cfg.ServerPort, + UseTLS: cfg.UseTLS, + UseSSL: cfg.UseSSL, + TLSConfig: tlsConfig, + } + idp := &LDAPIdentityProvider{ - serverHost: cfg.ServerHost, - serverPort: cfg.ServerPort, - useTLS: cfg.UseTLS, - useSSL: cfg.UseSSL, baseDN: cfg.BaseDN, nameAttribute: nameAttribute, emailAttribute: emailAttribute, @@ -168,7 +183,7 @@ func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *t searchBindDN: cfg.SearchBindDN, searchBindPw: cfg.SearchBindPw, bindTemplate: bindTemplate, - tlsConfig: tlsConfig, + ldapPool: ldapPool, } idpc := &LDAPConnector{ @@ -188,9 +203,9 @@ func (c *LDAPConnector) ID() string { } func (c *LDAPConnector) Healthy() error { - ldapConn, err := c.idp.LDAPConnect() + ldapConn, err := c.idp.ldapPool.Acquire() if err == nil { - ldapConn.Close() + c.idp.ldapPool.Put(ldapConn) } return err } @@ -210,18 +225,145 @@ func (c *LDAPConnector) Register(mux *http.ServeMux, errorURL url.URL) { } func (c *LDAPConnector) Sync() chan struct{} { - return make(chan struct{}) + stop := make(chan struct{}) + + go func() { + for { + select { + case <-time.After(c.idp.ldapPool.PoolCheckTimer): + alive, killed := c.idp.ldapPool.CheckConnections() + if alive > 0 { + log.Infof("Connector ID=%v idle_conns=%v", c.id, alive) + } + if killed > 0 { + log.Warningf("Connector ID=%v closed %v dead connections.", c.id, killed) + } + case <-stop: + return + } + } + }() + return stop } func (c *LDAPConnector) TrustedEmailProvider() bool { return c.trustedEmailProvider } +// A LDAPPool is a Connection Pool for LDAP connections +// Initialize exported fields and use Acquire() to get a connection. +// Use Put() to put it back into the pool. +type LDAPPool struct { + m sync.Mutex + conns map[*ldap.Conn]struct{} + MaxIdleConn int + PoolCheckTimer time.Duration + ServerHost string + ServerPort uint16 + UseTLS bool + UseSSL bool + TLSConfig *tls.Config +} + +// Acquire removes and returns a random connection from the pool. A new connection is returned +// if there are no connections available in the pool. +func (p *LDAPPool) Acquire() (*ldap.Conn, error) { + conn := p.removeRandomConn() + if conn != nil { + return conn, nil + } + return p.ldapConnect() +} + +// Put makes a connection ready for re-use and puts it back into the pool. If the connection +// cannot be reused it is discarded. If there already are MaxIdleConn connections in the pool +// the connection is discarded. +func (p *LDAPPool) Put(c *ldap.Conn) { + p.m.Lock() + if p.conns == nil { + // First call to Put, initialize map + p.conns = make(map[*ldap.Conn]struct{}) + } + if len(p.conns)+1 > p.MaxIdleConn { + p.m.Unlock() + c.Close() + return + } + p.m.Unlock() + // drop to anonymous bind + err := c.Bind("", "") + if err != nil { + // unsupported or disallowed, throw away connection + log.Warningf("Unable to re-use LDAP Connection after failure to bind anonymously: %v", err) + c.Close() + return + } + p.m.Lock() + p.conns[c] = struct{}{} + p.m.Unlock() +} + +// removeConn attempts to remove the provided connection from the pool. If removeConn returns false +// another routine is using the connection and the caller should discard the pointer. +func (p *LDAPPool) removeConn(conn *ldap.Conn) bool { + p.m.Lock() + _, ok := p.conns[conn] + delete(p.conns, conn) + p.m.Unlock() + return ok +} + +// removeRandomConn attempts to remove a random connection from the pool. If removeRandomConn +// returns nil the pool is empty. +func (p *LDAPPool) removeRandomConn() *ldap.Conn { + p.m.Lock() + defer p.m.Unlock() + for conn := range p.conns { + delete(p.conns, conn) + return conn + } + return nil +} + +// CheckConnections attempts to iterate over all the connections in the pool and check wheter +// they are alive or not. Live connections are put back into the pool, dead ones are discarded. +func (p *LDAPPool) CheckConnections() (int, int) { + var conns []*ldap.Conn + var alive, killed int + + // Get snapshot of connection-map while holding Lock + p.m.Lock() + for conn := range p.conns { + conns = append(conns, conn) + } + p.m.Unlock() + + // Iterate over snapshot, Get and ping connections. + // Put live connections back into pool, Close dead ones. + for _, conn := range conns { + ok := p.removeConn(conn) + if ok { + err := ldapPing(conn) + if err == nil { + p.Put(conn) + alive++ + } else { + conn.Close() + killed++ + } + } + } + return alive, killed +} + +func ldapPing(conn *ldap.Conn) error { + // Query root DSE + s := ldap.NewSearchRequest("", ldap.ScopeBaseObject, ldap.NeverDerefAliases, 0, 0, false, "(objectClass=*)", []string{}, nil) + _, err := conn.Search(s) + return err +} + type LDAPIdentityProvider struct { - serverHost string - serverPort uint16 - useTLS bool - useSSL bool baseDN string nameAttribute string emailAttribute string @@ -231,26 +373,26 @@ type LDAPIdentityProvider struct { searchBindDN string searchBindPw string bindTemplate string - tlsConfig *tls.Config + ldapPool *LDAPPool } -func (m *LDAPIdentityProvider) LDAPConnect() (*ldap.Conn, error) { +func (p *LDAPPool) ldapConnect() (*ldap.Conn, error) { var err error var ldapConn *ldap.Conn log.Debugf("LDAPConnect()") - if m.useSSL { - ldapConn, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", m.serverHost, m.serverPort), m.tlsConfig) + if p.UseSSL { + ldapConn, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", p.ServerHost, p.ServerPort), p.TLSConfig) if err != nil { return nil, err } } else { - ldapConn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", m.serverHost, m.serverPort)) + ldapConn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", p.ServerHost, p.ServerPort)) if err != nil { return nil, err } - if m.useTLS { - err = ldapConn.StartTLS(m.tlsConfig) + if p.UseTLS { + err = ldapConn.StartTLS(p.TLSConfig) if err != nil { return nil, err } @@ -273,11 +415,11 @@ func (m *LDAPIdentityProvider) Identity(username, password string) (*oidc.Identi var bindDN, ldapUid, ldapName, ldapEmail string var ldapConn *ldap.Conn - ldapConn, err = m.LDAPConnect() + ldapConn, err = m.ldapPool.Acquire() if err != nil { return nil, err } - defer ldapConn.Close() + defer m.ldapPool.Put(ldapConn) if m.searchBeforeAuth { err = ldapConn.Bind(m.searchBindDN, m.searchBindPw) @@ -307,16 +449,11 @@ func (m *LDAPIdentityProvider) Identity(username, password string) (*oidc.Identi ldapName = sr.Entries[0].GetAttributeValue(m.nameAttribute) ldapEmail = sr.Entries[0].GetAttributeValue(m.emailAttribute) - // drop to anonymous bind, prepare for bind as user - err = ldapConn.Bind("", "") + // prepare LDAP connection for bind as user + m.ldapPool.Put(ldapConn) + ldapConn, err = m.ldapPool.Acquire() if err != nil { - // unsupported or disallowed, reconnect - log.Warningf("Re-connecting to LDAP Server after failure to bind anonymously: %v", err) - ldapConn.Close() - ldapConn, err = m.LDAPConnect() - if err != nil { - return nil, err - } + return nil, err } } else { bindDN = m.ParseString(m.bindTemplate, username)