diff --git a/connector/saml/saml.go b/connector/saml/saml.go index 8b55252a..a24df1c8 100644 --- a/connector/saml/saml.go +++ b/connector/saml/saml.go @@ -130,9 +130,7 @@ func (c *Config) Open(logger logrus.FieldLogger) (connector.Connector, error) { return c.openConnector(logger) } -func (c *Config) openConnector(logger logrus.FieldLogger) (interface { - connector.SAMLConnector -}, error) { +func (c *Config) openConnector(logger logrus.FieldLogger) (*provider, error) { requiredFields := []struct { name, val string }{ @@ -274,8 +272,11 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str if err != nil { return ident, fmt.Errorf("decode response: %v", err) } + + rootElementSigned := true if p.validator != nil { - if rawResp, err = verify(p.validator, rawResp); err != nil { + rawResp, rootElementSigned, err = verifyResponseSig(p.validator, rawResp) + if err != nil { return ident, fmt.Errorf("verify signature: %v", err) } } @@ -285,24 +286,26 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str return ident, fmt.Errorf("unmarshal response: %v", err) } - if p.issuer != "" && resp.Issuer != nil && resp.Issuer.Issuer != p.issuer { - return ident, fmt.Errorf("expected Issuer value %s, got %s", p.issuer, resp.Issuer.Issuer) - } + if rootElementSigned { + if p.issuer != "" && resp.Issuer != nil && resp.Issuer.Issuer != p.issuer { + return ident, fmt.Errorf("expected Issuer value %s, got %s", p.issuer, resp.Issuer.Issuer) + } - // Verify InResponseTo value matches the expected ID associated with - // the RelayState. - if resp.InResponseTo != inResponseTo { - return ident, fmt.Errorf("expected InResponseTo value %s, got %s", inResponseTo, resp.InResponseTo) - } + // Verify InResponseTo value matches the expected ID associated with + // the RelayState. + if resp.InResponseTo != inResponseTo { + return ident, fmt.Errorf("expected InResponseTo value %s, got %s", inResponseTo, resp.InResponseTo) + } - // Destination is optional. - if resp.Destination != "" && resp.Destination != p.redirectURI { - return ident, fmt.Errorf("expected destination %q got %q", p.redirectURI, resp.Destination) + // Destination is optional. + if resp.Destination != "" && resp.Destination != p.redirectURI { + return ident, fmt.Errorf("expected destination %q got %q", p.redirectURI, resp.Destination) - } + } - if err = p.validateStatus(&resp); err != nil { - return ident, err + if err = p.validateStatus(&resp); err != nil { + return ident, err + } } assertion := resp.Assertion @@ -481,41 +484,57 @@ func (p *provider) validateConditions(assertion *assertion) error { return nil } -// verify checks the signature info of a XML document and returns -// the signed elements. -// The Validate function of the goxmldsig library only looks for -// signatures on the root element level. But a saml Response is valid -// if the complete message is signed, or only the Assertion is signed, -// or but elements are signed. Therefore we first check a possible -// signature of the Response than of the Assertion. If one of these -// is successful the Response is considered as valid. -func verify(validator *dsig.ValidationContext, data []byte) (signed []byte, err error) { +// verifyResponseSig attempts to verify the signature of a SAML response or +// the assertion. +// +// If the root element is properly signed, this method returns it. +// +// The SAML spec requires supporting responses where the root element is +// unverified, but the sub elements are signed. In these cases, +// this method returns rootVerified=false to indicate that the +// elements should be trusted, but all other elements MUST be ignored. +// +// Note: we still don't support multiple tags. If there are +// multiple present this code will only process the first. +func verifyResponseSig(validator *dsig.ValidationContext, data []byte) (signed []byte, rootVerified bool, err error) { doc := etree.NewDocument() if err = doc.ReadFromBytes(data); err != nil { - return nil, fmt.Errorf("parse document: %v", err) + return nil, false, fmt.Errorf("parse document: %v", err) } - verified := false + response := doc.Root() transformedResponse, err := validator.Validate(response) if err == nil { - verified = true + // Root element is verified, return it. doc.SetRoot(transformedResponse) + signed, err = doc.WriteToBytes() + return signed, true, err } + // Ensures xmlns are copied down to the assertion element when they are defined in the root + // + // TODO: Only select from child elements of the root. assertion, err := etreeutils.NSSelectOne(response, "urn:oasis:names:tc:SAML:2.0:assertion", "Assertion") if err != nil { - return nil, fmt.Errorf("response does not contain an Assertion element") + return nil, false, fmt.Errorf("response does not contain an Assertion element") } transformedAssertion, err := validator.Validate(assertion) - if err == nil { - verified = true - response.RemoveChild(assertion) - response.AddChild(transformedAssertion) + if err != nil { + return nil, false, fmt.Errorf("response does not contain a valid signature element: %v", err) } - if verified != true { - return nil, fmt.Errorf("response does not contain a valid Signature element") + + // Verified an assertion but not the response. Can't trust any child elements, + // except the assertion. Remove them all. + for _, el := range response.ChildElements() { + response.RemoveChild(el) } - return doc.WriteToBytes() + + // We still return the full element, even though it's unverified + // because the element is not a valid XML document on its own. + // It still requires the root element to define things like namespaces. + response.AddChild(transformedAssertion) + signed, err = doc.WriteToBytes() + return signed, false, err } func uuidv4() string {