This repository has been archived on 2022-08-17. You can view files and clone it, but cannot push or open issues or pull requests.
dex/vendor/github.com/russellhaering/goxmldsig/validate.go
2017-01-09 18:30:58 -08:00

397 lines
11 KiB
Go

package dsig
import (
"bytes"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"regexp"
"github.com/beevik/etree"
)
var uriRegexp = regexp.MustCompile("^#[a-zA-Z_][\\w.-]*$")
type ValidationContext struct {
CertificateStore X509CertificateStore
IdAttribute string
Clock *Clock
}
func NewDefaultValidationContext(certificateStore X509CertificateStore) *ValidationContext {
return &ValidationContext{
CertificateStore: certificateStore,
IdAttribute: DefaultIdAttr,
}
}
// TODO(russell_h): More flexible namespace support. This might barely work.
func inNamespace(el *etree.Element, ns string) bool {
for _, attr := range el.Attr {
if attr.Value == ns {
if attr.Space == "" && attr.Key == "xmlns" {
return el.Space == ""
} else if attr.Space == "xmlns" {
return el.Space == attr.Key
}
}
}
return false
}
func childPath(space, tag string) string {
if space == "" {
return "./" + tag
} else {
return "./" + space + ":" + tag
}
}
// The RemoveElement method on etree.Element isn't recursive...
func recursivelyRemoveElement(tree, el *etree.Element) bool {
if tree.RemoveChild(el) != nil {
return true
}
for _, child := range tree.Child {
if childElement, ok := child.(*etree.Element); ok {
if recursivelyRemoveElement(childElement, el) {
return true
}
}
}
return false
}
// transform applies the passed set of transforms to the specified root element.
//
// The functionality of transform is currently very limited and purpose-specific.
//
// NOTE(russell_h): Ideally this wouldn't mutate the root passed to it, and would
// instead return a copy. Unfortunately copying the tree makes it difficult to
// correctly locate the signature. I'm opting, for now, to simply mutate the root
// parameter.
func (ctx *ValidationContext) transform(root, sig *etree.Element, transforms []*etree.Element) (*etree.Element, Canonicalizer, error) {
if len(transforms) != 2 {
return nil, nil, errors.New("Expected Enveloped and C14N transforms")
}
var canonicalizer Canonicalizer
for _, transform := range transforms {
algo := transform.SelectAttr(AlgorithmAttr)
if algo == nil {
return nil, nil, errors.New("Missing Algorithm attribute")
}
switch AlgorithmID(algo.Value) {
case EnvelopedSignatureAltorithmId:
if !recursivelyRemoveElement(root, sig) {
return nil, nil, errors.New("Error applying canonicalization transform: Signature not found")
}
case CanonicalXML10ExclusiveAlgorithmId:
var prefixList string
ins := transform.FindElement(childPath("", InclusiveNamespacesTag))
if ins != nil {
prefixListEl := ins.SelectAttr(PrefixListAttr)
if prefixListEl != nil {
prefixList = prefixListEl.Value
}
}
canonicalizer = MakeC14N10ExclusiveCanonicalizerWithPrefixList(prefixList)
case CanonicalXML11AlgorithmId:
canonicalizer = MakeC14N11Canonicalizer()
default:
return nil, nil, errors.New("Unknown Transform Algorithm: " + algo.Value)
}
}
if canonicalizer == nil {
return nil, nil, errors.New("Expected canonicalization transform")
}
return root, canonicalizer, nil
}
func (ctx *ValidationContext) digest(el *etree.Element, digestAlgorithmId string, canonicalizer Canonicalizer) ([]byte, error) {
data, err := canonicalizer.Canonicalize(el)
if err != nil {
return nil, err
}
digestAlgorithm, ok := digestAlgorithmsByIdentifier[digestAlgorithmId]
if !ok {
return nil, errors.New("Unknown digest algorithm: " + digestAlgorithmId)
}
hash := digestAlgorithm.New()
_, err = hash.Write(data)
if err != nil {
return nil, err
}
return hash.Sum(nil), nil
}
func (ctx *ValidationContext) verifySignedInfo(signatureElement *etree.Element, canonicalizer Canonicalizer, signatureMethodId string, cert *x509.Certificate, sig []byte) error {
signedInfo := signatureElement.FindElement(childPath(signatureElement.Space, SignedInfoTag))
if signedInfo == nil {
return errors.New("Missing SignedInfo")
}
// Any attributes from the 'Signature' element must be pushed down into the 'SignedInfo' element before it is canonicalized
for _, attr := range signatureElement.Attr {
signedInfo.CreateAttr(attr.Space+":"+attr.Key, attr.Value)
}
// Canonicalize the xml
canonical, err := canonicalizer.Canonicalize(signedInfo)
if err != nil {
return err
}
signatureAlgorithm, ok := signatureMethodsByIdentifier[signatureMethodId]
if !ok {
return errors.New("Unknown signature method: " + signatureMethodId)
}
hash := signatureAlgorithm.New()
_, err = hash.Write(canonical)
if err != nil {
return err
}
hashed := hash.Sum(nil)
pubKey, ok := cert.PublicKey.(*rsa.PublicKey)
if !ok {
return errors.New("Invalid public key")
}
// Verify that the private key matching the public key from the cert was what was used to sign the 'SignedInfo' and produce the 'SignatureValue'
err = rsa.VerifyPKCS1v15(pubKey, signatureAlgorithm, hashed[:], sig)
if err != nil {
return err
}
return nil
}
func (ctx *ValidationContext) validateSignature(el *etree.Element, cert *x509.Certificate) (*etree.Element, error) {
el = el.Copy()
// Verify the document minus the signedInfo against the 'DigestValue'
// Find the 'Signature' element
sig := el.FindElement(SignatureTag)
if sig == nil {
return nil, errors.New("Missing Signature")
}
if !inNamespace(sig, Namespace) {
return nil, errors.New("Signature element is in the wrong namespace")
}
// Get the 'SignedInfo' element
signedInfo := sig.FindElement(childPath(sig.Space, SignedInfoTag))
if signedInfo == nil {
return nil, errors.New("Missing SignedInfo")
}
reference := signedInfo.FindElement(childPath(sig.Space, ReferenceTag))
if reference == nil {
return nil, errors.New("Missing Reference")
}
transforms := reference.FindElement(childPath(sig.Space, TransformsTag))
if transforms == nil {
return nil, errors.New("Missing Transforms")
}
uri := reference.SelectAttr("URI")
if uri == nil {
// TODO(russell_h): It is permissible to leave this out. We should be
// able to fall back to finding the referenced element some other way.
return nil, errors.New("Reference is missing URI attribute")
}
if !uriRegexp.MatchString(uri.Value) {
return nil, errors.New("Invalid URI: " + uri.Value)
}
// Get the element referenced in the 'SignedInfo'
referencedElement := el.FindElement(fmt.Sprintf("//[@%s='%s']", ctx.IdAttribute, uri.Value[1:]))
if referencedElement == nil {
return nil, errors.New("Unable to find referenced element: " + uri.Value)
}
// Perform all transformations listed in the 'SignedInfo'
// Basically, this means removing the 'SignedInfo'
transformed, canonicalizer, err := ctx.transform(referencedElement, sig, transforms.ChildElements())
if err != nil {
return nil, err
}
digestMethod := reference.FindElement(childPath(sig.Space, DigestMethodTag))
if digestMethod == nil {
return nil, errors.New("Missing DigestMethod")
}
digestValue := reference.FindElement(childPath(sig.Space, DigestValueTag))
if digestValue == nil {
return nil, errors.New("Missing DigestValue")
}
digestAlgorithmAttr := digestMethod.SelectAttr(AlgorithmAttr)
if digestAlgorithmAttr == nil {
return nil, errors.New("Missing DigestMethod Algorithm attribute")
}
// Digest the transformed XML and compare it to the 'DigestValue' from the 'SignedInfo'
digest, err := ctx.digest(transformed, digestAlgorithmAttr.Value, canonicalizer)
if err != nil {
return nil, err
}
decodedDigestValue, err := base64.StdEncoding.DecodeString(digestValue.Text())
if err != nil {
return nil, err
}
if !bytes.Equal(digest, decodedDigestValue) {
return nil, errors.New("Signature could not be verified")
}
//Verify the signed info
signatureMethod := signedInfo.FindElement(childPath(sig.Space, SignatureMethodTag))
if signatureMethod == nil {
return nil, errors.New("Missing SignatureMethod")
}
signatureMethodAlgorithmAttr := signatureMethod.SelectAttr(AlgorithmAttr)
if digestAlgorithmAttr == nil {
return nil, errors.New("Missing SignatureMethod Algorithm attribute")
}
// Decode the 'SignatureValue' so we can compare against it
signatureValue := sig.FindElement(childPath(sig.Space, SignatureValueTag))
if signatureValue == nil {
return nil, errors.New("Missing SignatureValue")
}
decodedSignature, err := base64.StdEncoding.DecodeString(signatureValue.Text())
if err != nil {
return nil, errors.New("Could not decode signature")
}
// Actually verify the 'SignedInfo' was signed by a trusted source
err = ctx.verifySignedInfo(sig, canonicalizer, signatureMethodAlgorithmAttr.Value, cert, decodedSignature)
if err != nil {
return nil, err
}
return transformed, nil
}
func contains(roots []*x509.Certificate, cert *x509.Certificate) bool {
for _, root := range roots {
if root.Equal(cert) {
return true
}
}
return false
}
func (ctx *ValidationContext) verifyCertificate(el *etree.Element) (*x509.Certificate, error) {
now := ctx.Clock.Now()
el = el.Copy()
idAttr := el.SelectAttr(DefaultIdAttr)
if idAttr == nil || idAttr.Value == "" {
return nil, errors.New("Missing ID attribute")
}
signatureElements := el.FindElements("//" + SignatureTag)
var signatureElement *etree.Element
// Find the Signature element that references the whole Response element
for _, e := range signatureElements {
e2 := e.Copy()
signedInfo := e2.FindElement(childPath(e2.Space, SignedInfoTag))
if signedInfo == nil {
return nil, errors.New("Missing SignedInfo")
}
referenceElement := signedInfo.FindElement(childPath(e2.Space, ReferenceTag))
if referenceElement == nil {
return nil, errors.New("Missing Reference Element")
}
uriAttr := referenceElement.SelectAttr(URIAttr)
if uriAttr == nil || uriAttr.Value == "" {
return nil, errors.New("Missing URI attribute")
}
if uriAttr.Value[1:] == idAttr.Value {
signatureElement = e
break
}
}
if signatureElement == nil {
return nil, errors.New("Missing signature referencing the top-level element")
}
// Get the x509 element from the signature
x509Element := signatureElement.FindElement("//" + childPath(signatureElement.Space, X509CertificateTag))
if x509Element == nil {
return nil, errors.New("Missing x509 Element")
}
x509Text := "-----BEGIN CERTIFICATE-----\n" + x509Element.Text() + "\n-----END CERTIFICATE-----"
block, _ := pem.Decode([]byte(x509Text))
if block == nil {
return nil, errors.New("Failed to parse certificate PEM")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, err
}
roots, err := ctx.CertificateStore.Certificates()
if err != nil {
return nil, err
}
// Verify that the certificate is one we trust
if !contains(roots, cert) {
return nil, errors.New("Could not verify certificate against trusted certs")
}
if now.Before(cert.NotBefore) || now.After(cert.NotAfter) {
return nil, errors.New("Cert is not valid at this time")
}
return cert, nil
}
func (ctx *ValidationContext) Validate(el *etree.Element) (*etree.Element, error) {
cert, err := ctx.verifyCertificate(el)
if err != nil {
return nil, err
}
return ctx.validateSignature(el, cert)
}