272 lines
6.2 KiB
Go
272 lines
6.2 KiB
Go
// Copyright 2017 The etcd Authors
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package transport
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
// tlsListener overrides a TLS listener so it will reject client
|
|
// certificates with insufficient SAN credentials or CRL revoked
|
|
// certificates.
|
|
type tlsListener struct {
|
|
net.Listener
|
|
connc chan net.Conn
|
|
donec chan struct{}
|
|
err error
|
|
handshakeFailure func(*tls.Conn, error)
|
|
check tlsCheckFunc
|
|
}
|
|
|
|
type tlsCheckFunc func(context.Context, *tls.Conn) error
|
|
|
|
// NewTLSListener handshakes TLS connections and performs optional CRL checking.
|
|
func NewTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
|
|
check := func(context.Context, *tls.Conn) error { return nil }
|
|
return newTLSListener(l, tlsinfo, check)
|
|
}
|
|
|
|
func newTLSListener(l net.Listener, tlsinfo *TLSInfo, check tlsCheckFunc) (net.Listener, error) {
|
|
if tlsinfo == nil || tlsinfo.Empty() {
|
|
l.Close()
|
|
return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String())
|
|
}
|
|
tlscfg, err := tlsinfo.ServerConfig()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hf := tlsinfo.HandshakeFailure
|
|
if hf == nil {
|
|
hf = func(*tls.Conn, error) {}
|
|
}
|
|
|
|
if len(tlsinfo.CRLFile) > 0 {
|
|
prevCheck := check
|
|
check = func(ctx context.Context, tlsConn *tls.Conn) error {
|
|
if err := prevCheck(ctx, tlsConn); err != nil {
|
|
return err
|
|
}
|
|
st := tlsConn.ConnectionState()
|
|
if certs := st.PeerCertificates; len(certs) > 0 {
|
|
return checkCRL(tlsinfo.CRLFile, certs)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
tlsl := &tlsListener{
|
|
Listener: tls.NewListener(l, tlscfg),
|
|
connc: make(chan net.Conn),
|
|
donec: make(chan struct{}),
|
|
handshakeFailure: hf,
|
|
check: check,
|
|
}
|
|
go tlsl.acceptLoop()
|
|
return tlsl, nil
|
|
}
|
|
|
|
func (l *tlsListener) Accept() (net.Conn, error) {
|
|
select {
|
|
case conn := <-l.connc:
|
|
return conn, nil
|
|
case <-l.donec:
|
|
return nil, l.err
|
|
}
|
|
}
|
|
|
|
func checkSAN(ctx context.Context, tlsConn *tls.Conn) error {
|
|
st := tlsConn.ConnectionState()
|
|
if certs := st.PeerCertificates; len(certs) > 0 {
|
|
addr := tlsConn.RemoteAddr().String()
|
|
return checkCertSAN(ctx, certs[0], addr)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// acceptLoop launches each TLS handshake in a separate goroutine
|
|
// to prevent a hanging TLS connection from blocking other connections.
|
|
func (l *tlsListener) acceptLoop() {
|
|
var wg sync.WaitGroup
|
|
var pendingMu sync.Mutex
|
|
|
|
pending := make(map[net.Conn]struct{})
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer func() {
|
|
cancel()
|
|
pendingMu.Lock()
|
|
for c := range pending {
|
|
c.Close()
|
|
}
|
|
pendingMu.Unlock()
|
|
wg.Wait()
|
|
close(l.donec)
|
|
}()
|
|
|
|
for {
|
|
conn, err := l.Listener.Accept()
|
|
if err != nil {
|
|
l.err = err
|
|
return
|
|
}
|
|
|
|
pendingMu.Lock()
|
|
pending[conn] = struct{}{}
|
|
pendingMu.Unlock()
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer func() {
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
wg.Done()
|
|
}()
|
|
|
|
tlsConn := conn.(*tls.Conn)
|
|
herr := tlsConn.Handshake()
|
|
pendingMu.Lock()
|
|
delete(pending, conn)
|
|
pendingMu.Unlock()
|
|
|
|
if herr != nil {
|
|
l.handshakeFailure(tlsConn, herr)
|
|
return
|
|
}
|
|
if err := l.check(ctx, tlsConn); err != nil {
|
|
l.handshakeFailure(tlsConn, err)
|
|
return
|
|
}
|
|
|
|
select {
|
|
case l.connc <- tlsConn:
|
|
conn = nil
|
|
case <-ctx.Done():
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
|
|
func checkCRL(crlPath string, cert []*x509.Certificate) error {
|
|
// TODO: cache
|
|
crlBytes, err := ioutil.ReadFile(crlPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
certList, err := x509.ParseCRL(crlBytes)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
revokedSerials := make(map[string]struct{})
|
|
for _, rc := range certList.TBSCertList.RevokedCertificates {
|
|
revokedSerials[string(rc.SerialNumber.Bytes())] = struct{}{}
|
|
}
|
|
for _, c := range cert {
|
|
serial := string(c.SerialNumber.Bytes())
|
|
if _, ok := revokedSerials[serial]; ok {
|
|
return fmt.Errorf("transport: certificate serial %x revoked", serial)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func checkCertSAN(ctx context.Context, cert *x509.Certificate, remoteAddr string) error {
|
|
if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 {
|
|
return nil
|
|
}
|
|
h, _, herr := net.SplitHostPort(remoteAddr)
|
|
if herr != nil {
|
|
return herr
|
|
}
|
|
if len(cert.IPAddresses) > 0 {
|
|
cerr := cert.VerifyHostname(h)
|
|
if cerr == nil {
|
|
return nil
|
|
}
|
|
if len(cert.DNSNames) == 0 {
|
|
return cerr
|
|
}
|
|
}
|
|
if len(cert.DNSNames) > 0 {
|
|
ok, err := isHostInDNS(ctx, h, cert.DNSNames)
|
|
if ok {
|
|
return nil
|
|
}
|
|
errStr := ""
|
|
if err != nil {
|
|
errStr = " (" + err.Error() + ")"
|
|
}
|
|
return fmt.Errorf("tls: %q does not match any of DNSNames %q"+errStr, h, cert.DNSNames)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func isHostInDNS(ctx context.Context, host string, dnsNames []string) (ok bool, err error) {
|
|
// reverse lookup
|
|
wildcards, names := []string{}, []string{}
|
|
for _, dns := range dnsNames {
|
|
if strings.HasPrefix(dns, "*.") {
|
|
wildcards = append(wildcards, dns[1:])
|
|
} else {
|
|
names = append(names, dns)
|
|
}
|
|
}
|
|
lnames, lerr := net.DefaultResolver.LookupAddr(ctx, host)
|
|
for _, name := range lnames {
|
|
// strip trailing '.' from PTR record
|
|
if name[len(name)-1] == '.' {
|
|
name = name[:len(name)-1]
|
|
}
|
|
for _, wc := range wildcards {
|
|
if strings.HasSuffix(name, wc) {
|
|
return true, nil
|
|
}
|
|
}
|
|
for _, n := range names {
|
|
if n == name {
|
|
return true, nil
|
|
}
|
|
}
|
|
}
|
|
err = lerr
|
|
|
|
// forward lookup
|
|
for _, dns := range names {
|
|
addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns)
|
|
if lerr != nil {
|
|
err = lerr
|
|
continue
|
|
}
|
|
for _, addr := range addrs {
|
|
if addr == host {
|
|
return true, nil
|
|
}
|
|
}
|
|
}
|
|
return false, err
|
|
}
|
|
|
|
func (l *tlsListener) Close() error {
|
|
err := l.Listener.Close()
|
|
<-l.donec
|
|
return err
|
|
}
|