305 lines
7.4 KiB
Go
Vendored
305 lines
7.4 KiB
Go
Vendored
package winio
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"syscall"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"github.com/Microsoft/go-winio/pkg/guid"
|
|
)
|
|
|
|
//sys bind(s syscall.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
|
|
|
|
const (
|
|
afHvSock = 34 // AF_HYPERV
|
|
|
|
socketError = ^uintptr(0)
|
|
)
|
|
|
|
// An HvsockAddr is an address for a AF_HYPERV socket.
|
|
type HvsockAddr struct {
|
|
VMID guid.GUID
|
|
ServiceID guid.GUID
|
|
}
|
|
|
|
type rawHvsockAddr struct {
|
|
Family uint16
|
|
_ uint16
|
|
VMID guid.GUID
|
|
ServiceID guid.GUID
|
|
}
|
|
|
|
// Network returns the address's network name, "hvsock".
|
|
func (addr *HvsockAddr) Network() string {
|
|
return "hvsock"
|
|
}
|
|
|
|
func (addr *HvsockAddr) String() string {
|
|
return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID)
|
|
}
|
|
|
|
// VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
|
|
func VsockServiceID(port uint32) guid.GUID {
|
|
g, _ := guid.FromString("00000000-facb-11e6-bd58-64006a7986d3")
|
|
g.Data1 = port
|
|
return g
|
|
}
|
|
|
|
func (addr *HvsockAddr) raw() rawHvsockAddr {
|
|
return rawHvsockAddr{
|
|
Family: afHvSock,
|
|
VMID: addr.VMID,
|
|
ServiceID: addr.ServiceID,
|
|
}
|
|
}
|
|
|
|
func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
|
|
addr.VMID = raw.VMID
|
|
addr.ServiceID = raw.ServiceID
|
|
}
|
|
|
|
// HvsockListener is a socket listener for the AF_HYPERV address family.
|
|
type HvsockListener struct {
|
|
sock *win32File
|
|
addr HvsockAddr
|
|
}
|
|
|
|
// HvsockConn is a connected socket of the AF_HYPERV address family.
|
|
type HvsockConn struct {
|
|
sock *win32File
|
|
local, remote HvsockAddr
|
|
}
|
|
|
|
func newHvSocket() (*win32File, error) {
|
|
fd, err := syscall.Socket(afHvSock, syscall.SOCK_STREAM, 1)
|
|
if err != nil {
|
|
return nil, os.NewSyscallError("socket", err)
|
|
}
|
|
f, err := makeWin32File(fd)
|
|
if err != nil {
|
|
syscall.Close(fd)
|
|
return nil, err
|
|
}
|
|
f.socket = true
|
|
return f, nil
|
|
}
|
|
|
|
// ListenHvsock listens for connections on the specified hvsock address.
|
|
func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
|
|
l := &HvsockListener{addr: *addr}
|
|
sock, err := newHvSocket()
|
|
if err != nil {
|
|
return nil, l.opErr("listen", err)
|
|
}
|
|
sa := addr.raw()
|
|
err = bind(sock.handle, unsafe.Pointer(&sa), int32(unsafe.Sizeof(sa)))
|
|
if err != nil {
|
|
return nil, l.opErr("listen", os.NewSyscallError("socket", err))
|
|
}
|
|
err = syscall.Listen(sock.handle, 16)
|
|
if err != nil {
|
|
return nil, l.opErr("listen", os.NewSyscallError("listen", err))
|
|
}
|
|
return &HvsockListener{sock: sock, addr: *addr}, nil
|
|
}
|
|
|
|
func (l *HvsockListener) opErr(op string, err error) error {
|
|
return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err}
|
|
}
|
|
|
|
// Addr returns the listener's network address.
|
|
func (l *HvsockListener) Addr() net.Addr {
|
|
return &l.addr
|
|
}
|
|
|
|
// Accept waits for the next connection and returns it.
|
|
func (l *HvsockListener) Accept() (_ net.Conn, err error) {
|
|
sock, err := newHvSocket()
|
|
if err != nil {
|
|
return nil, l.opErr("accept", err)
|
|
}
|
|
defer func() {
|
|
if sock != nil {
|
|
sock.Close()
|
|
}
|
|
}()
|
|
c, err := l.sock.prepareIo()
|
|
if err != nil {
|
|
return nil, l.opErr("accept", err)
|
|
}
|
|
defer l.sock.wg.Done()
|
|
|
|
// AcceptEx, per documentation, requires an extra 16 bytes per address.
|
|
const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
|
|
var addrbuf [addrlen * 2]byte
|
|
|
|
var bytes uint32
|
|
err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0, addrlen, addrlen, &bytes, &c.o)
|
|
_, err = l.sock.asyncIo(c, nil, bytes, err)
|
|
if err != nil {
|
|
return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
|
|
}
|
|
conn := &HvsockConn{
|
|
sock: sock,
|
|
}
|
|
conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
|
|
conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
|
|
sock = nil
|
|
return conn, nil
|
|
}
|
|
|
|
// Close closes the listener, causing any pending Accept calls to fail.
|
|
func (l *HvsockListener) Close() error {
|
|
return l.sock.Close()
|
|
}
|
|
|
|
/* Need to finish ConnectEx handling
|
|
func DialHvsock(ctx context.Context, addr *HvsockAddr) (*HvsockConn, error) {
|
|
sock, err := newHvSocket()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer func() {
|
|
if sock != nil {
|
|
sock.Close()
|
|
}
|
|
}()
|
|
c, err := sock.prepareIo()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer sock.wg.Done()
|
|
var bytes uint32
|
|
err = windows.ConnectEx(windows.Handle(sock.handle), sa, nil, 0, &bytes, &c.o)
|
|
_, err = sock.asyncIo(ctx, c, nil, bytes, err)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
conn := &HvsockConn{
|
|
sock: sock,
|
|
remote: *addr,
|
|
}
|
|
sock = nil
|
|
return conn, nil
|
|
}
|
|
*/
|
|
|
|
func (conn *HvsockConn) opErr(op string, err error) error {
|
|
return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
|
|
}
|
|
|
|
func (conn *HvsockConn) Read(b []byte) (int, error) {
|
|
c, err := conn.sock.prepareIo()
|
|
if err != nil {
|
|
return 0, conn.opErr("read", err)
|
|
}
|
|
defer conn.sock.wg.Done()
|
|
buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
|
|
var flags, bytes uint32
|
|
err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
|
|
n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err)
|
|
if err != nil {
|
|
if _, ok := err.(syscall.Errno); ok {
|
|
err = os.NewSyscallError("wsarecv", err)
|
|
}
|
|
return 0, conn.opErr("read", err)
|
|
} else if n == 0 {
|
|
err = io.EOF
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func (conn *HvsockConn) Write(b []byte) (int, error) {
|
|
t := 0
|
|
for len(b) != 0 {
|
|
n, err := conn.write(b)
|
|
if err != nil {
|
|
return t + n, err
|
|
}
|
|
t += n
|
|
b = b[n:]
|
|
}
|
|
return t, nil
|
|
}
|
|
|
|
func (conn *HvsockConn) write(b []byte) (int, error) {
|
|
c, err := conn.sock.prepareIo()
|
|
if err != nil {
|
|
return 0, conn.opErr("write", err)
|
|
}
|
|
defer conn.sock.wg.Done()
|
|
buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
|
|
var bytes uint32
|
|
err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
|
|
n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err)
|
|
if err != nil {
|
|
if _, ok := err.(syscall.Errno); ok {
|
|
err = os.NewSyscallError("wsasend", err)
|
|
}
|
|
return 0, conn.opErr("write", err)
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
// Close closes the socket connection, failing any pending read or write calls.
|
|
func (conn *HvsockConn) Close() error {
|
|
return conn.sock.Close()
|
|
}
|
|
|
|
func (conn *HvsockConn) shutdown(how int) error {
|
|
err := syscall.Shutdown(conn.sock.handle, syscall.SHUT_RD)
|
|
if err != nil {
|
|
return os.NewSyscallError("shutdown", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CloseRead shuts down the read end of the socket.
|
|
func (conn *HvsockConn) CloseRead() error {
|
|
err := conn.shutdown(syscall.SHUT_RD)
|
|
if err != nil {
|
|
return conn.opErr("close", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CloseWrite shuts down the write end of the socket, notifying the other endpoint that
|
|
// no more data will be written.
|
|
func (conn *HvsockConn) CloseWrite() error {
|
|
err := conn.shutdown(syscall.SHUT_WR)
|
|
if err != nil {
|
|
return conn.opErr("close", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// LocalAddr returns the local address of the connection.
|
|
func (conn *HvsockConn) LocalAddr() net.Addr {
|
|
return &conn.local
|
|
}
|
|
|
|
// RemoteAddr returns the remote address of the connection.
|
|
func (conn *HvsockConn) RemoteAddr() net.Addr {
|
|
return &conn.remote
|
|
}
|
|
|
|
// SetDeadline implements the net.Conn SetDeadline method.
|
|
func (conn *HvsockConn) SetDeadline(t time.Time) error {
|
|
conn.SetReadDeadline(t)
|
|
conn.SetWriteDeadline(t)
|
|
return nil
|
|
}
|
|
|
|
// SetReadDeadline implements the net.Conn SetReadDeadline method.
|
|
func (conn *HvsockConn) SetReadDeadline(t time.Time) error {
|
|
return conn.sock.SetReadDeadline(t)
|
|
}
|
|
|
|
// SetWriteDeadline implements the net.Conn SetWriteDeadline method.
|
|
func (conn *HvsockConn) SetWriteDeadline(t time.Time) error {
|
|
return conn.sock.SetWriteDeadline(t)
|
|
}
|