diff --git a/connector/bitbucketcloud/bitbucketcloud.go b/connector/bitbucketcloud/bitbucketcloud.go index 6c391be7..7d58221a 100644 --- a/connector/bitbucketcloud/bitbucketcloud.go +++ b/connector/bitbucketcloud/bitbucketcloud.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/dexidp/dex/pkg/log" "io/ioutil" "net/http" "sync" @@ -16,6 +15,8 @@ import ( "golang.org/x/oauth2/bitbucket" "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/pkg/groups" + "github.com/dexidp/dex/pkg/log" ) const ( @@ -350,7 +351,7 @@ func (b *bitbucketConnector) getGroups(ctx context.Context, client *http.Client, } if len(b.teams) > 0 { - filteredTeams := filterTeams(bitbucketTeams, b.teams) + filteredTeams := groups.Filter(bitbucketTeams, b.teams) if len(filteredTeams) == 0 { return nil, fmt.Errorf("bitbucket: user %q is not in any of the required teams", userLogin) } @@ -362,21 +363,6 @@ func (b *bitbucketConnector) getGroups(ctx context.Context, client *http.Client, return nil, nil } -// Filter the users' team memberships by 'teams' from config. -func filterTeams(userTeams, configTeams []string) []string { - teams := []string{} - teamFilter := make(map[string]struct{}) - for _, team := range configTeams { - teamFilter[team] = struct{}{} - } - for _, team := range userTeams { - if _, ok := teamFilter[team]; ok { - teams = append(teams, team) - } - } - return teams -} - type team struct { Name string `json:"username"` // The "username" from Bitbucket Cloud is actually the team name here } diff --git a/connector/github/github.go b/connector/github/github.go index 8658f8cd..9093c6ad 100644 --- a/connector/github/github.go +++ b/connector/github/github.go @@ -20,6 +20,7 @@ import ( "golang.org/x/oauth2/github" "github.com/dexidp/dex/connector" + groups_pkg "github.com/dexidp/dex/pkg/groups" "github.com/dexidp/dex/pkg/log" ) @@ -375,7 +376,7 @@ func (c *githubConnector) groupsForOrgs(ctx context.Context, client *http.Client // 'teams' list in config. if len(org.Teams) == 0 { inOrgNoTeams = true - } else if teams = filterTeams(teams, org.Teams); len(teams) == 0 { + } else if teams = groups_pkg.Filter(teams, org.Teams); len(teams) == 0 { c.logger.Infof("github: user %q in org %q but no teams", userName, org.Name) } @@ -466,22 +467,6 @@ func (c *githubConnector) userOrgTeams(ctx context.Context, client *http.Client) return groups, nil } -// Filter the users' team memberships by 'teams' from config. -func filterTeams(userTeams, configTeams []string) (teams []string) { - teamFilter := make(map[string]struct{}) - for _, team := range configTeams { - if _, ok := teamFilter[team]; !ok { - teamFilter[team] = struct{}{} - } - } - for _, team := range userTeams { - if _, ok := teamFilter[team]; ok { - teams = append(teams, team) - } - } - return -} - // get creates a "GET `apiURL`" request with context, sends the request using // the client, and decodes the resulting response body into v. A pagination URL // is returned if one exists. Any errors encountered when building requests, diff --git a/connector/gitlab/gitlab.go b/connector/gitlab/gitlab.go index 2c031712..fde519a0 100644 --- a/connector/gitlab/gitlab.go +++ b/connector/gitlab/gitlab.go @@ -13,6 +13,7 @@ import ( "golang.org/x/oauth2" "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/pkg/groups" "github.com/dexidp/dex/pkg/log" ) @@ -273,7 +274,7 @@ func (c *gitlabConnector) getGroups(ctx context.Context, client *http.Client, gr } if len(c.groups) > 0 { - filteredGroups := filterGroups(gitlabGroups, c.groups) + filteredGroups := groups.Filter(gitlabGroups, c.groups) if len(filteredGroups) == 0 { return nil, fmt.Errorf("gitlab: user %q is not in any of the required groups", userLogin) } @@ -284,18 +285,3 @@ func (c *gitlabConnector) getGroups(ctx context.Context, client *http.Client, gr return nil, nil } - -// Filter the users' group memberships by 'groups' from config. -func filterGroups(userGroups, configGroups []string) []string { - groups := []string{} - groupFilter := make(map[string]struct{}) - for _, group := range configGroups { - groupFilter[group] = struct{}{} - } - for _, group := range userGroups { - if _, ok := groupFilter[group]; ok { - groups = append(groups, group) - } - } - return groups -} diff --git a/connector/microsoft/microsoft.go b/connector/microsoft/microsoft.go index 730c77e3..b31cfa55 100644 --- a/connector/microsoft/microsoft.go +++ b/connector/microsoft/microsoft.go @@ -15,6 +15,7 @@ import ( "golang.org/x/oauth2" "github.com/dexidp/dex/connector" + groups_pkg "github.com/dexidp/dex/pkg/groups" "github.com/dexidp/dex/pkg/log" ) @@ -311,22 +312,9 @@ func (c *microsoftConnector) getGroups(ctx context.Context, client *http.Client, } // ensure that the user is in at least one required group - isInGroups := false - if len(c.groups) > 0 { - gs := make(map[string]struct{}) - for _, g := range c.groups { - gs[g] = struct{}{} - } - - for _, g := range groups { - if _, ok := gs[g]; ok { - isInGroups = true - break - } - } - } - if len(c.groups) > 0 && !isInGroups { - return nil, fmt.Errorf("microsoft: user %v not in required groups", userID) + filteredGroups := groups_pkg.Filter(groups, c.groups) + if len(c.groups) > 0 && len(filteredGroups) == 0 { + return nil, fmt.Errorf("microsoft: user %v not in any of the required groups", userID) } return diff --git a/pkg/groups/groups.go b/pkg/groups/groups.go new file mode 100644 index 00000000..5dde65ab --- /dev/null +++ b/pkg/groups/groups.go @@ -0,0 +1,18 @@ +// Package groups contains helper functions related to groups +package groups + +// Filter filters out any groups of given that are not in required. Thus it may +// happen that the resulting slice is empty. +func Filter(given, required []string) []string { + groups := []string{} + groupFilter := make(map[string]struct{}) + for _, group := range required { + groupFilter[group] = struct{}{} + } + for _, group := range given { + if _, ok := groupFilter[group]; ok { + groups = append(groups, group) + } + } + return groups +} diff --git a/pkg/groups/groups_test.go b/pkg/groups/groups_test.go new file mode 100644 index 00000000..0be62fb4 --- /dev/null +++ b/pkg/groups/groups_test.go @@ -0,0 +1,26 @@ +package groups_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/dexidp/dex/pkg/groups" +) + +func TestFilter(t *testing.T) { + cases := map[string]struct { + given, required, expected []string + }{ + "nothing given": {given: []string{}, required: []string{"ops"}, expected: []string{}}, + "exactly one match": {given: []string{"foo"}, required: []string{"foo"}, expected: []string{"foo"}}, + "no group of the required ones": {given: []string{"foo", "bar"}, required: []string{"baz"}, expected: []string{}}, + "subset matching": {given: []string{"foo", "bar", "baz"}, required: []string{"bar", "baz"}, expected: []string{"bar", "baz"}}, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + actual := groups.Filter(tc.given, tc.required) + assert.ElementsMatch(t, tc.expected, actual) + }) + } +}