diff --git a/connector/connector_ldap.go b/connector/connector_ldap.go index 5acb33b6..8ba45aed 100644 --- a/connector/connector_ldap.go +++ b/connector/connector_ldap.go @@ -107,11 +107,12 @@ type LDAPConnector struct { nameAttribute string emailAttribute string - searchBeforeAuth bool - searchFilter string - searchScope int - searchBindDN string - searchBindPw string + searchBeforeAuth bool + searchFilter string + searchScope int + searchBindDN string + searchBindPw string + searchGroupFilter string bindTemplate string @@ -203,19 +204,20 @@ func (cfg *LDAPConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *t } idpc := &LDAPConnector{ - id: cfg.ID, - namespace: ns, - loginFunc: lf, - loginTpl: tpl, - baseDN: cfg.BaseDN, - nameAttribute: cfg.NameAttribute, - emailAttribute: cfg.EmailAttribute, - searchBeforeAuth: cfg.SearchBeforeAuth, - searchFilter: cfg.SearchFilter, - searchScope: searchScope, - searchBindDN: cfg.SearchBindDN, - searchBindPw: cfg.SearchBindPw, - bindTemplate: cfg.BindTemplate, + id: cfg.ID, + namespace: ns, + loginFunc: lf, + loginTpl: tpl, + baseDN: cfg.BaseDN, + nameAttribute: cfg.NameAttribute, + emailAttribute: cfg.EmailAttribute, + searchBeforeAuth: cfg.SearchBeforeAuth, + searchFilter: cfg.SearchFilter, + searchGroupFilter: cfg.SearchGroupFilter, + searchScope: searchScope, + searchBindDN: cfg.SearchBindDN, + searchBindPw: cfg.SearchBindPw, + bindTemplate: cfg.BindTemplate, ldapPool: &LDAPPool{ MaxIdleConn: cfg.MaxIdleConn, PoolCheckTimer: defaultPoolCheckTimer, @@ -433,12 +435,47 @@ func invalidBindCredentials(err error) bool { func (c *LDAPConnector) formatDN(template, username string) string { result := template - result = strings.Replace(result, "%u", username, -1) + result = strings.Replace(result, "%u", ldap.EscapeFilter(username), -1) result = strings.Replace(result, "%b", c.baseDN, -1) return result } +func (c *LDAPConnector) Groups(fullUserID string) ([]string, error) { + if !c.searchBeforeAuth { + return nil, fmt.Errorf("cannot search without service account") + } + if c.searchGroupFilter == "" { + return nil, fmt.Errorf("no group filter specified") + } + + var groups []string + err := c.ldapPool.Do(func(conn *ldap.Conn) error { + if err := conn.Bind(c.searchBindDN, c.searchBindPw); err != nil { + if !invalidBindCredentials(err) { + log.Errorf("failed to connect to LDAP for search bind: %v", err) + } + return fmt.Errorf("failed to bind: %v", err) + } + + req := &ldap.SearchRequest{ + BaseDN: c.baseDN, + Scope: c.searchScope, + Filter: c.formatDN(c.searchGroupFilter, fullUserID), + } + resp, err := conn.Search(req) + if err != nil { + return fmt.Errorf("search failed: %v", err) + } + groups = make([]string, len(resp.Entries)) + for i, entry := range resp.Entries { + groups[i] = entry.DN + } + return nil + }) + return groups, err +} + func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, error) { var ( identity *oidc.Identity @@ -447,8 +484,10 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err if c.searchBeforeAuth { err = c.ldapPool.Do(func(conn *ldap.Conn) error { if err := conn.Bind(c.searchBindDN, c.searchBindPw); err != nil { - // Don't wrap error as it may be a specific LDAP error. - return err + if !invalidBindCredentials(err) { + log.Errorf("failed to connect to LDAP for search bind: %v", err) + } + return fmt.Errorf("failed to bind: %v", err) } filter := c.formatDN(c.searchFilter, username) @@ -491,8 +530,10 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err err = c.ldapPool.Do(func(conn *ldap.Conn) error { userBindDN := c.formatDN(c.bindTemplate, username) if err := conn.Bind(userBindDN, password); err != nil { - // Don't wrap error as it may be a specific LDAP error. - return err + if !invalidBindCredentials(err) { + log.Errorf("failed to connect to LDAP for search bind: %v", err) + } + return fmt.Errorf("failed to bind: %v", err) } req := &ldap.SearchRequest{ @@ -522,11 +563,7 @@ func (c *LDAPConnector) Identity(username, password string) (*oidc.Identity, err return nil }) } - if err != nil { - if !invalidBindCredentials(err) { - log.Errorf("failed to connect to LDAP for search bind: %v", err) - } return nil, err } return identity, nil diff --git a/connector/interface.go b/connector/interface.go index 3216d8a6..6c79ffe1 100644 --- a/connector/interface.go +++ b/connector/interface.go @@ -60,6 +60,12 @@ type ConnectorConfig interface { Connector(ns url.URL, loginFunc oidc.LoginFunc, tpls *template.Template) (Connector, error) } +// GroupsConnector is a strategy for mapping a user to a set of groups. This is optionally +// implemented by some connectors. +type GroupsConnector interface { + Groups(fullUserID string) ([]string, error) +} + type ConnectorConfigRepo interface { All() ([]ConnectorConfig, error) GetConnectorByID(repo.Transaction, string) (ConnectorConfig, error)