package connector import ( "crypto/tls" "crypto/x509" "fmt" "html/template" "io/ioutil" "net/http" "net/url" "path" "strings" "sync" "time" "github.com/coreos/dex/pkg/log" "github.com/coreos/go-oidc/oidc" "gopkg.in/ldap.v2" ) const ( LDAPConnectorType = "ldap" LDAPLoginPageTemplateName = "ldap-login.html" ) func init() { RegisterConnectorConfigType(LDAPConnectorType, func() ConnectorConfig { return &LDAPConnectorConfig{} }) } type LDAPConnectorConfig struct { ID string `json:"id"` ServerHost string `json:"serverHost"` ServerPort uint16 `json:"serverPort"` Timeout time.Duration `json:"timeout"` UseTLS bool `json:"useTLS"` UseSSL bool `json:"useSSL"` CertFile string `json:"certFile"` KeyFile string `json:"keyFile"` CaFile string `json:"caFile"` SkipCertVerification bool `json:"skipCertVerification"` MaxIdleConn int `json:"maxIdleConn"` BaseDN string `json:"baseDN"` NameAttribute string `json:"nameAttribute"` EmailAttribute string `json:"emailAttribute"` SearchBeforeAuth bool `json:"searchBeforeAuth"` SearchFilter string `json:"searchFilter"` SearchScope string `json:"searchScope"` SearchBindDN string `json:"searchBindDN"` SearchBindPw string `json:"searchBindPw"` BindTemplate string `json:"bindTemplate"` TrustedEmailProvider bool `json:"trustedEmailProvider"` } func (cfg *LDAPConnectorConfig) ConnectorID() string { return cfg.ID } func (cfg *LDAPConnectorConfig) ConnectorType() string { return LDAPConnectorType } type LDAPConnector struct { id string idp *LDAPIdentityProvider namespace url.URL trustedEmailProvider bool loginFunc oidc.LoginFunc loginTpl *template.Template } func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *template.Template) (Connector, error) { ns.Path = path.Join(ns.Path, httpPathCallback) tpl := tpls.Lookup(LDAPLoginPageTemplateName) if tpl == nil { return nil, fmt.Errorf("unable to find necessary HTML template") } // defaults const defaultNameAttribute = "cn" const defaultEmailAttribute = "mail" const defaultBindTemplate = "uid=%u,%b" const defaultSearchScope = ldap.ScopeWholeSubtree const defaultMaxIdleConns = 5 const defaultPoolCheckTimer = 7200 * time.Second if cfg.UseTLS && cfg.UseSSL { return nil, fmt.Errorf("Invalid configuration. useTLS and useSSL are mutual exclusive.") } if len(cfg.CertFile) > 0 && len(cfg.KeyFile) == 0 { return nil, fmt.Errorf("Invalid configuration. Both certFile and keyFile must be specified.") } nameAttribute := defaultNameAttribute if len(cfg.NameAttribute) > 0 { nameAttribute = cfg.NameAttribute } emailAttribute := defaultEmailAttribute if len(cfg.EmailAttribute) > 0 { emailAttribute = cfg.EmailAttribute } bindTemplate := defaultBindTemplate if len(cfg.BindTemplate) > 0 { if cfg.SearchBeforeAuth { log.Warningf("bindTemplate not used when searchBeforeAuth specified.") } bindTemplate = cfg.BindTemplate } searchScope := defaultSearchScope if len(cfg.SearchScope) > 0 { switch { case strings.EqualFold(cfg.SearchScope, "BASE"): searchScope = ldap.ScopeBaseObject case strings.EqualFold(cfg.SearchScope, "ONE"): searchScope = ldap.ScopeSingleLevel case strings.EqualFold(cfg.SearchScope, "SUB"): searchScope = ldap.ScopeWholeSubtree default: return nil, fmt.Errorf("Invalid value for searchScope: '%v'. Must be one of 'base', 'one' or 'sub'.", cfg.SearchScope) } } if cfg.Timeout != 0 { ldap.DefaultTimeout = cfg.Timeout * time.Millisecond } tlsConfig := &tls.Config{ ServerName: cfg.ServerHost, InsecureSkipVerify: cfg.SkipCertVerification, } if (cfg.UseTLS || cfg.UseSSL) && len(cfg.CaFile) > 0 { buf, err := ioutil.ReadFile(cfg.CaFile) if err != nil { return nil, err } rootCertPool := x509.NewCertPool() ok := rootCertPool.AppendCertsFromPEM(buf) if ok { tlsConfig.RootCAs = rootCertPool } else { return nil, fmt.Errorf("%v: Unable to parse certificate data.", cfg.CaFile) } } if (cfg.UseTLS || cfg.UseSSL) && len(cfg.CertFile) > 0 && len(cfg.KeyFile) > 0 { cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile) if err != nil { return nil, err } tlsConfig.Certificates = []tls.Certificate{cert} } maxIdleConn := defaultMaxIdleConns if cfg.MaxIdleConn > 0 { maxIdleConn = cfg.MaxIdleConn } ldapPool := &LDAPPool{ MaxIdleConn: maxIdleConn, PoolCheckTimer: defaultPoolCheckTimer, ServerHost: cfg.ServerHost, ServerPort: cfg.ServerPort, UseTLS: cfg.UseTLS, UseSSL: cfg.UseSSL, TLSConfig: tlsConfig, } idp := &LDAPIdentityProvider{ baseDN: cfg.BaseDN, nameAttribute: nameAttribute, emailAttribute: emailAttribute, searchBeforeAuth: cfg.SearchBeforeAuth, searchFilter: cfg.SearchFilter, searchScope: searchScope, searchBindDN: cfg.SearchBindDN, searchBindPw: cfg.SearchBindPw, bindTemplate: bindTemplate, ldapPool: ldapPool, } idpc := &LDAPConnector{ id: cfg.ID, idp: idp, namespace: ns, trustedEmailProvider: cfg.TrustedEmailProvider, loginFunc: lf, loginTpl: tpl, } return idpc, nil } func (c *LDAPConnector) ID() string { return c.id } func (c *LDAPConnector) Healthy() error { ldapConn, err := c.idp.ldapPool.Acquire() if err == nil { c.idp.ldapPool.Put(ldapConn) } return err } func (c *LDAPConnector) LoginURL(sessionKey, prompt string) (string, error) { q := url.Values{} q.Set("session_key", sessionKey) q.Set("prompt", prompt) enc := q.Encode() return path.Join(c.namespace.Path, "login") + "?" + enc, nil } func (c *LDAPConnector) Register(mux *http.ServeMux, errorURL url.URL) { route := path.Join(c.namespace.Path, "login") mux.Handle(route, handleLoginFunc(c.loginFunc, c.loginTpl, c.idp, route, errorURL)) } func (c *LDAPConnector) Sync() chan struct{} { stop := make(chan struct{}) go func() { for { select { case <-time.After(c.idp.ldapPool.PoolCheckTimer): alive, killed := c.idp.ldapPool.CheckConnections() if alive > 0 { log.Infof("Connector ID=%v idle_conns=%v", c.id, alive) } if killed > 0 { log.Warningf("Connector ID=%v closed %v dead connections.", c.id, killed) } case <-stop: return } } }() return stop } func (c *LDAPConnector) TrustedEmailProvider() bool { return c.trustedEmailProvider } // A LDAPPool is a Connection Pool for LDAP connections // Initialize exported fields and use Acquire() to get a connection. // Use Put() to put it back into the pool. type LDAPPool struct { m sync.Mutex conns map[*ldap.Conn]struct{} MaxIdleConn int PoolCheckTimer time.Duration ServerHost string ServerPort uint16 UseTLS bool UseSSL bool TLSConfig *tls.Config } // Acquire removes and returns a random connection from the pool. A new connection is returned // if there are no connections available in the pool. func (p *LDAPPool) Acquire() (*ldap.Conn, error) { conn := p.removeRandomConn() if conn != nil { return conn, nil } return p.ldapConnect() } // Put makes a connection ready for re-use and puts it back into the pool. If the connection // cannot be reused it is discarded. If there already are MaxIdleConn connections in the pool // the connection is discarded. func (p *LDAPPool) Put(c *ldap.Conn) { p.m.Lock() if p.conns == nil { // First call to Put, initialize map p.conns = make(map[*ldap.Conn]struct{}) } if len(p.conns)+1 > p.MaxIdleConn { p.m.Unlock() c.Close() return } p.m.Unlock() // drop to anonymous bind err := c.Bind("", "") if err != nil { // unsupported or disallowed, throw away connection log.Warningf("Unable to re-use LDAP Connection after failure to bind anonymously: %v", err) c.Close() return } p.m.Lock() p.conns[c] = struct{}{} p.m.Unlock() } // removeConn attempts to remove the provided connection from the pool. If removeConn returns false // another routine is using the connection and the caller should discard the pointer. func (p *LDAPPool) removeConn(conn *ldap.Conn) bool { p.m.Lock() _, ok := p.conns[conn] delete(p.conns, conn) p.m.Unlock() return ok } // removeRandomConn attempts to remove a random connection from the pool. If removeRandomConn // returns nil the pool is empty. func (p *LDAPPool) removeRandomConn() *ldap.Conn { p.m.Lock() defer p.m.Unlock() for conn := range p.conns { delete(p.conns, conn) return conn } return nil } // CheckConnections attempts to iterate over all the connections in the pool and check wheter // they are alive or not. Live connections are put back into the pool, dead ones are discarded. func (p *LDAPPool) CheckConnections() (int, int) { var conns []*ldap.Conn var alive, killed int // Get snapshot of connection-map while holding Lock p.m.Lock() for conn := range p.conns { conns = append(conns, conn) } p.m.Unlock() // Iterate over snapshot, Get and ping connections. // Put live connections back into pool, Close dead ones. for _, conn := range conns { ok := p.removeConn(conn) if ok { err := ldapPing(conn) if err == nil { p.Put(conn) alive++ } else { conn.Close() killed++ } } } return alive, killed } func ldapPing(conn *ldap.Conn) error { // Query root DSE s := ldap.NewSearchRequest("", ldap.ScopeBaseObject, ldap.NeverDerefAliases, 0, 0, false, "(objectClass=*)", []string{}, nil) _, err := conn.Search(s) return err } type LDAPIdentityProvider struct { baseDN string nameAttribute string emailAttribute string searchBeforeAuth bool searchFilter string searchScope int searchBindDN string searchBindPw string bindTemplate string ldapPool *LDAPPool } func (p *LDAPPool) ldapConnect() (*ldap.Conn, error) { var err error var ldapConn *ldap.Conn log.Debugf("LDAPConnect()") if p.UseSSL { ldapConn, err = ldap.DialTLS("tcp", fmt.Sprintf("%s:%d", p.ServerHost, p.ServerPort), p.TLSConfig) if err != nil { return nil, err } } else { ldapConn, err = ldap.Dial("tcp", fmt.Sprintf("%s:%d", p.ServerHost, p.ServerPort)) if err != nil { return nil, err } if p.UseTLS { err = ldapConn.StartTLS(p.TLSConfig) if err != nil { return nil, err } } } return ldapConn, err } func (m *LDAPIdentityProvider) ParseString(template, username string) string { result := template result = strings.Replace(result, "%u", username, -1) result = strings.Replace(result, "%b", m.baseDN, -1) return result } func (m *LDAPIdentityProvider) Identity(username, password string) (*oidc.Identity, error) { var err error var bindDN, ldapUid, ldapName, ldapEmail string var ldapConn *ldap.Conn ldapConn, err = m.ldapPool.Acquire() if err != nil { return nil, err } defer m.ldapPool.Put(ldapConn) if m.searchBeforeAuth { err = ldapConn.Bind(m.searchBindDN, m.searchBindPw) if err != nil { return nil, err } filter := m.ParseString(m.searchFilter, username) attributes := []string{ m.nameAttribute, m.emailAttribute, } s := ldap.NewSearchRequest(m.baseDN, m.searchScope, ldap.NeverDerefAliases, 0, 0, false, filter, attributes, nil) sr, err := ldapConn.Search(s) if err != nil { return nil, err } if len(sr.Entries) == 0 { err = fmt.Errorf("Search returned no match. filter='%v' base='%v'", filter, m.baseDN) return nil, err } bindDN = sr.Entries[0].DN ldapName = sr.Entries[0].GetAttributeValue(m.nameAttribute) ldapEmail = sr.Entries[0].GetAttributeValue(m.emailAttribute) // prepare LDAP connection for bind as user m.ldapPool.Put(ldapConn) ldapConn, err = m.ldapPool.Acquire() if err != nil { return nil, err } } else { bindDN = m.ParseString(m.bindTemplate, username) } // authenticate user err = ldapConn.Bind(bindDN, password) if err != nil { return nil, err } ldapUid = bindDN return &oidc.Identity{ ID: ldapUid, Name: ldapName, Email: ldapEmail, }, nil }