connectors: refactor filter code into a helper package
I hope I didn't miss any :D Signed-off-by: Stephan Renatus <srenatus@chef.io>
This commit is contained in:
parent
39dc5dcfb7
commit
51f50fcad8
6 changed files with 55 additions and 66 deletions
|
@ -6,7 +6,6 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/dexidp/dex/pkg/log"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -16,6 +15,8 @@ import (
|
||||||
"golang.org/x/oauth2/bitbucket"
|
"golang.org/x/oauth2/bitbucket"
|
||||||
|
|
||||||
"github.com/dexidp/dex/connector"
|
"github.com/dexidp/dex/connector"
|
||||||
|
"github.com/dexidp/dex/pkg/groups"
|
||||||
|
"github.com/dexidp/dex/pkg/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -350,7 +351,7 @@ func (b *bitbucketConnector) getGroups(ctx context.Context, client *http.Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(b.teams) > 0 {
|
if len(b.teams) > 0 {
|
||||||
filteredTeams := filterTeams(bitbucketTeams, b.teams)
|
filteredTeams := groups.Filter(bitbucketTeams, b.teams)
|
||||||
if len(filteredTeams) == 0 {
|
if len(filteredTeams) == 0 {
|
||||||
return nil, fmt.Errorf("bitbucket: user %q is not in any of the required teams", userLogin)
|
return nil, fmt.Errorf("bitbucket: user %q is not in any of the required teams", userLogin)
|
||||||
}
|
}
|
||||||
|
@ -362,21 +363,6 @@ func (b *bitbucketConnector) getGroups(ctx context.Context, client *http.Client,
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter the users' team memberships by 'teams' from config.
|
|
||||||
func filterTeams(userTeams, configTeams []string) []string {
|
|
||||||
teams := []string{}
|
|
||||||
teamFilter := make(map[string]struct{})
|
|
||||||
for _, team := range configTeams {
|
|
||||||
teamFilter[team] = struct{}{}
|
|
||||||
}
|
|
||||||
for _, team := range userTeams {
|
|
||||||
if _, ok := teamFilter[team]; ok {
|
|
||||||
teams = append(teams, team)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return teams
|
|
||||||
}
|
|
||||||
|
|
||||||
type team struct {
|
type team struct {
|
||||||
Name string `json:"username"` // The "username" from Bitbucket Cloud is actually the team name here
|
Name string `json:"username"` // The "username" from Bitbucket Cloud is actually the team name here
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"golang.org/x/oauth2/github"
|
"golang.org/x/oauth2/github"
|
||||||
|
|
||||||
"github.com/dexidp/dex/connector"
|
"github.com/dexidp/dex/connector"
|
||||||
|
groups_pkg "github.com/dexidp/dex/pkg/groups"
|
||||||
"github.com/dexidp/dex/pkg/log"
|
"github.com/dexidp/dex/pkg/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -375,7 +376,7 @@ func (c *githubConnector) groupsForOrgs(ctx context.Context, client *http.Client
|
||||||
// 'teams' list in config.
|
// 'teams' list in config.
|
||||||
if len(org.Teams) == 0 {
|
if len(org.Teams) == 0 {
|
||||||
inOrgNoTeams = true
|
inOrgNoTeams = true
|
||||||
} else if teams = filterTeams(teams, org.Teams); len(teams) == 0 {
|
} else if teams = groups_pkg.Filter(teams, org.Teams); len(teams) == 0 {
|
||||||
c.logger.Infof("github: user %q in org %q but no teams", userName, org.Name)
|
c.logger.Infof("github: user %q in org %q but no teams", userName, org.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -466,22 +467,6 @@ func (c *githubConnector) userOrgTeams(ctx context.Context, client *http.Client)
|
||||||
return groups, nil
|
return groups, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter the users' team memberships by 'teams' from config.
|
|
||||||
func filterTeams(userTeams, configTeams []string) (teams []string) {
|
|
||||||
teamFilter := make(map[string]struct{})
|
|
||||||
for _, team := range configTeams {
|
|
||||||
if _, ok := teamFilter[team]; !ok {
|
|
||||||
teamFilter[team] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, team := range userTeams {
|
|
||||||
if _, ok := teamFilter[team]; ok {
|
|
||||||
teams = append(teams, team)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// get creates a "GET `apiURL`" request with context, sends the request using
|
// get creates a "GET `apiURL`" request with context, sends the request using
|
||||||
// the client, and decodes the resulting response body into v. A pagination URL
|
// the client, and decodes the resulting response body into v. A pagination URL
|
||||||
// is returned if one exists. Any errors encountered when building requests,
|
// is returned if one exists. Any errors encountered when building requests,
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/dexidp/dex/connector"
|
"github.com/dexidp/dex/connector"
|
||||||
|
"github.com/dexidp/dex/pkg/groups"
|
||||||
"github.com/dexidp/dex/pkg/log"
|
"github.com/dexidp/dex/pkg/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -273,7 +274,7 @@ func (c *gitlabConnector) getGroups(ctx context.Context, client *http.Client, gr
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.groups) > 0 {
|
if len(c.groups) > 0 {
|
||||||
filteredGroups := filterGroups(gitlabGroups, c.groups)
|
filteredGroups := groups.Filter(gitlabGroups, c.groups)
|
||||||
if len(filteredGroups) == 0 {
|
if len(filteredGroups) == 0 {
|
||||||
return nil, fmt.Errorf("gitlab: user %q is not in any of the required groups", userLogin)
|
return nil, fmt.Errorf("gitlab: user %q is not in any of the required groups", userLogin)
|
||||||
}
|
}
|
||||||
|
@ -284,18 +285,3 @@ func (c *gitlabConnector) getGroups(ctx context.Context, client *http.Client, gr
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter the users' group memberships by 'groups' from config.
|
|
||||||
func filterGroups(userGroups, configGroups []string) []string {
|
|
||||||
groups := []string{}
|
|
||||||
groupFilter := make(map[string]struct{})
|
|
||||||
for _, group := range configGroups {
|
|
||||||
groupFilter[group] = struct{}{}
|
|
||||||
}
|
|
||||||
for _, group := range userGroups {
|
|
||||||
if _, ok := groupFilter[group]; ok {
|
|
||||||
groups = append(groups, group)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return groups
|
|
||||||
}
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/dexidp/dex/connector"
|
"github.com/dexidp/dex/connector"
|
||||||
|
groups_pkg "github.com/dexidp/dex/pkg/groups"
|
||||||
"github.com/dexidp/dex/pkg/log"
|
"github.com/dexidp/dex/pkg/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -311,22 +312,9 @@ func (c *microsoftConnector) getGroups(ctx context.Context, client *http.Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure that the user is in at least one required group
|
// ensure that the user is in at least one required group
|
||||||
isInGroups := false
|
filteredGroups := groups_pkg.Filter(groups, c.groups)
|
||||||
if len(c.groups) > 0 {
|
if len(c.groups) > 0 && len(filteredGroups) == 0 {
|
||||||
gs := make(map[string]struct{})
|
return nil, fmt.Errorf("microsoft: user %v not in any of the required groups", userID)
|
||||||
for _, g := range c.groups {
|
|
||||||
gs[g] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, g := range groups {
|
|
||||||
if _, ok := gs[g]; ok {
|
|
||||||
isInGroups = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(c.groups) > 0 && !isInGroups {
|
|
||||||
return nil, fmt.Errorf("microsoft: user %v not in required groups", userID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
18
pkg/groups/groups.go
Normal file
18
pkg/groups/groups.go
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
// Package groups contains helper functions related to groups
|
||||||
|
package groups
|
||||||
|
|
||||||
|
// Filter filters out any groups of given that are not in required. Thus it may
|
||||||
|
// happen that the resulting slice is empty.
|
||||||
|
func Filter(given, required []string) []string {
|
||||||
|
groups := []string{}
|
||||||
|
groupFilter := make(map[string]struct{})
|
||||||
|
for _, group := range required {
|
||||||
|
groupFilter[group] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, group := range given {
|
||||||
|
if _, ok := groupFilter[group]; ok {
|
||||||
|
groups = append(groups, group)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
}
|
26
pkg/groups/groups_test.go
Normal file
26
pkg/groups/groups_test.go
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
package groups_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/pkg/groups"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFilter(t *testing.T) {
|
||||||
|
cases := map[string]struct {
|
||||||
|
given, required, expected []string
|
||||||
|
}{
|
||||||
|
"nothing given": {given: []string{}, required: []string{"ops"}, expected: []string{}},
|
||||||
|
"exactly one match": {given: []string{"foo"}, required: []string{"foo"}, expected: []string{"foo"}},
|
||||||
|
"no group of the required ones": {given: []string{"foo", "bar"}, required: []string{"baz"}, expected: []string{}},
|
||||||
|
"subset matching": {given: []string{"foo", "bar", "baz"}, required: []string{"bar", "baz"}, expected: []string{"bar", "baz"}},
|
||||||
|
}
|
||||||
|
for name, tc := range cases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
actual := groups.Filter(tc.given, tc.required)
|
||||||
|
assert.ElementsMatch(t, tc.expected, actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Reference in a new issue