Add domainHint parameter to Microsoft Connector (#2586)

Signed-off-by: Joe Knight <josephtknight@users.noreply.github.com>
This commit is contained in:
Joe Knight 2022-07-25 13:12:55 -06:00 committed by GitHub
parent 367487d7c5
commit 27c25d00be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 1 deletions

View file

@ -57,6 +57,7 @@ type Config struct {
// PromptType is used for the prompt query parameter. // PromptType is used for the prompt query parameter.
// For valid values, see https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow#request-an-authorization-code. // For valid values, see https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow#request-an-authorization-code.
PromptType string `json:"promptType"` PromptType string `json:"promptType"`
DomainHint string `json:"domainHint"`
} }
// Open returns a strategy for logging in through Microsoft. // Open returns a strategy for logging in through Microsoft.
@ -75,6 +76,7 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
logger: logger, logger: logger,
emailToLowercase: c.EmailToLowercase, emailToLowercase: c.EmailToLowercase,
promptType: c.PromptType, promptType: c.PromptType,
domainHint: c.DomainHint,
} }
// By default allow logins from both personal and business/school // By default allow logins from both personal and business/school
// accounts. // accounts.
@ -119,6 +121,7 @@ type microsoftConnector struct {
logger log.Logger logger log.Logger
emailToLowercase bool emailToLowercase bool
promptType string promptType string
domainHint string
} }
func (c *microsoftConnector) isOrgTenant() bool { func (c *microsoftConnector) isOrgTenant() bool {
@ -160,6 +163,9 @@ func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, stat
if c.promptType != "" { if c.promptType != "" {
options = append(options, oauth2.SetAuthURLParam("prompt", c.promptType)) options = append(options, oauth2.SetAuthURLParam("prompt", c.promptType))
} }
if c.domainHint != "" {
options = append(options, oauth2.SetAuthURLParam("domain_hint", c.domainHint))
}
return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil
} }

View file

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"os" "os"
"reflect" "reflect"
"testing" "testing"
@ -16,13 +17,68 @@ type testResponse struct {
data interface{} data interface{}
} }
const tenant = "9b1c3439-a67e-4e92-bb0d-0571d44ca965" const (
tenant = "9b1c3439-a67e-4e92-bb0d-0571d44ca965"
clientID = "a115ebf3-6020-4384-8eb1-c0c42e667b6f"
)
var dummyToken = testResponse{data: map[string]interface{}{ var dummyToken = testResponse{data: map[string]interface{}{
"access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9", "access_token": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9",
"expires_in": "30", "expires_in": "30",
}} }}
func TestLoginURL(t *testing.T) {
testURL := "https://test.com"
testState := "some-state"
conn := microsoftConnector{
apiURL: testURL,
graphURL: testURL,
redirectURI: testURL,
clientID: clientID,
tenant: tenant,
}
loginURL, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, testState)
parsedLoginURL, _ := url.Parse(loginURL)
queryParams := parsedLoginURL.Query()
expectEquals(t, parsedLoginURL.Path, "/"+tenant+"/oauth2/v2.0/authorize")
expectEquals(t, queryParams.Get("client_id"), clientID)
expectEquals(t, queryParams.Get("redirect_uri"), testURL)
expectEquals(t, queryParams.Get("response_type"), "code")
expectEquals(t, queryParams.Get("scope"), "user.read")
expectEquals(t, queryParams.Get("state"), testState)
expectEquals(t, queryParams.Get("prompt"), "")
expectEquals(t, queryParams.Get("domain_hint"), "")
}
func TestLoginURLWithOptions(t *testing.T) {
testURL := "https://test.com"
promptType := "consent"
domainHint := "domain.hint"
conn := microsoftConnector{
apiURL: testURL,
graphURL: testURL,
redirectURI: testURL,
clientID: clientID,
tenant: tenant,
promptType: promptType,
domainHint: domainHint,
}
loginURL, _ := conn.LoginURL(connector.Scopes{}, conn.redirectURI, "some-state")
parsedLoginURL, _ := url.Parse(loginURL)
queryParams := parsedLoginURL.Query()
expectEquals(t, queryParams.Get("prompt"), promptType)
expectEquals(t, queryParams.Get("domain_hint"), domainHint)
}
func TestUserIdentityFromGraphAPI(t *testing.T) { func TestUserIdentityFromGraphAPI(t *testing.T) {
s := newTestServer(map[string]testResponse{ s := newTestServer(map[string]testResponse{
"/v1.0/me?$select=id,displayName,userPrincipalName": { "/v1.0/me?$select=id,displayName,userPrincipalName": {