245 lines
6.8 KiB
Go
245 lines
6.8 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"path"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/stretchr/testify/require"
|
|
"gitlab.com/gitlab-org/labkit/log"
|
|
|
|
"gitlab.com/gitlab-org/gitlab/workhorse/internal/api"
|
|
)
|
|
|
|
var (
|
|
envTerminalPath = fmt.Sprintf("%s/-/environments/1/terminal.ws", testProject)
|
|
jobTerminalPath = fmt.Sprintf("%s/-/jobs/1/terminal.ws", testProject)
|
|
servicesProxyWSPath = fmt.Sprintf("%s/-/jobs/1/proxy.ws", testProject)
|
|
)
|
|
|
|
type connWithReq struct {
|
|
conn *websocket.Conn
|
|
req *http.Request
|
|
}
|
|
|
|
func TestChannelHappyPath(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
channelPath string
|
|
}{
|
|
{"environments", envTerminalPath},
|
|
{"jobs", jobTerminalPath},
|
|
{"services", servicesProxyWSPath},
|
|
}
|
|
for _, test := range tests {
|
|
t.Run(test.name, func(t *testing.T) {
|
|
serverConns, clientURL, close := wireupChannel(t, test.channelPath, nil, "channel.k8s.io")
|
|
defer close()
|
|
|
|
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
|
|
require.NoError(t, err)
|
|
|
|
server := (<-serverConns).conn
|
|
defer server.Close()
|
|
|
|
message := "test message"
|
|
|
|
// channel.k8s.io: server writes to channel 1, STDOUT
|
|
require.NoError(t, say(server, "\x01"+message))
|
|
requireReadMessage(t, client, websocket.BinaryMessage, message)
|
|
|
|
require.NoError(t, say(client, message))
|
|
|
|
// channel.k8s.io: client writes get put on channel 0, STDIN
|
|
requireReadMessage(t, server, websocket.BinaryMessage, "\x00"+message)
|
|
|
|
// Closing the client should send an EOT signal to the server's STDIN
|
|
client.Close()
|
|
requireReadMessage(t, server, websocket.BinaryMessage, "\x00\x04")
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestChannelBadTLS(t *testing.T) {
|
|
_, clientURL, close := wireupChannel(t, envTerminalPath, badCA, "channel.k8s.io")
|
|
defer close()
|
|
|
|
_, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
|
|
require.Equal(t, websocket.ErrBadHandshake, err, "unexpected error %v", err)
|
|
}
|
|
|
|
func TestChannelSessionTimeout(t *testing.T) {
|
|
serverConns, clientURL, close := wireupChannel(t, envTerminalPath, timeout, "channel.k8s.io")
|
|
defer close()
|
|
|
|
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
|
|
require.NoError(t, err)
|
|
|
|
sc := <-serverConns
|
|
defer sc.conn.Close()
|
|
|
|
client.SetReadDeadline(time.Now().Add(time.Duration(2) * time.Second))
|
|
_, _, err = client.ReadMessage()
|
|
|
|
require.True(t, websocket.IsCloseError(err, websocket.CloseAbnormalClosure), "Client connection was not closed, got %v", err)
|
|
}
|
|
|
|
func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) {
|
|
hdr := make(http.Header)
|
|
hdr.Set("Random-Header", "Value")
|
|
serverConns, clientURL, close := wireupChannel(t, envTerminalPath, setHeader(hdr), "channel.k8s.io")
|
|
defer close()
|
|
|
|
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
|
|
require.NoError(t, err)
|
|
defer client.Close()
|
|
|
|
sc := <-serverConns
|
|
defer sc.conn.Close()
|
|
require.Equal(t, "Value", sc.req.Header.Get("Random-Header"), "Header specified by upstream not sent to remote")
|
|
}
|
|
|
|
func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) {
|
|
serverConns, clientURL, close := wireupChannel(t, envTerminalPath, nil, "channel.k8s.io")
|
|
defer close()
|
|
|
|
hdr := make(http.Header)
|
|
hdr.Set("X-Forwarded-For", "127.0.0.2")
|
|
client, _, err := dialWebsocket(clientURL, hdr, "terminal.gitlab.com")
|
|
require.NoError(t, err)
|
|
defer client.Close()
|
|
|
|
clientIP, _, err := net.SplitHostPort(client.LocalAddr().String())
|
|
require.NoError(t, err)
|
|
|
|
sc := <-serverConns
|
|
defer sc.conn.Close()
|
|
|
|
require.Equal(t, "127.0.0.2, "+clientIP, sc.req.Header.Get("X-Forwarded-For"), "X-Forwarded-For from client not sent to remote")
|
|
}
|
|
|
|
func wireupChannel(t *testing.T, channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) {
|
|
serverConns, remote := startWebsocketServer(subprotocols...)
|
|
authResponse := channelOkBody(remote, nil, subprotocols...)
|
|
if modifier != nil {
|
|
modifier(authResponse)
|
|
}
|
|
upstream := testAuthServer(t, nil, nil, 200, authResponse)
|
|
workhorse := startWorkhorseServer(upstream.URL)
|
|
|
|
return serverConns, websocketURL(workhorse.URL, channelPath), func() {
|
|
workhorse.Close()
|
|
upstream.Close()
|
|
remote.Close()
|
|
}
|
|
}
|
|
|
|
func startWebsocketServer(subprotocols ...string) (chan connWithReq, *httptest.Server) {
|
|
upgrader := &websocket.Upgrader{Subprotocols: subprotocols}
|
|
|
|
connCh := make(chan connWithReq, 1)
|
|
server := httptest.NewTLSServer(webSocketHandler(upgrader, connCh))
|
|
|
|
return connCh, server
|
|
}
|
|
|
|
func webSocketHandler(upgrader *websocket.Upgrader, connCh chan connWithReq) http.HandlerFunc {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
logEntry := log.WithFields(log.Fields{
|
|
"method": r.Method,
|
|
"url": r.URL,
|
|
"headers": r.Header,
|
|
})
|
|
|
|
logEntry.Info("WEBSOCKET")
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
logEntry.WithError(err).Error("WEBSOCKET Upgrade failed")
|
|
return
|
|
}
|
|
connCh <- connWithReq{conn, r}
|
|
// The connection has been hijacked so it's OK to end here
|
|
})
|
|
}
|
|
|
|
func channelOkBody(remote *httptest.Server, header http.Header, subprotocols ...string) *api.Response {
|
|
out := &api.Response{
|
|
Channel: &api.ChannelSettings{
|
|
Url: websocketURL(remote.URL),
|
|
Header: header,
|
|
Subprotocols: subprotocols,
|
|
MaxSessionTime: 0,
|
|
},
|
|
}
|
|
|
|
if len(remote.TLS.Certificates) > 0 {
|
|
data := bytes.NewBuffer(nil)
|
|
pem.Encode(data, &pem.Block{Type: "CERTIFICATE", Bytes: remote.TLS.Certificates[0].Certificate[0]})
|
|
out.Channel.CAPem = data.String()
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
func badCA(authResponse *api.Response) {
|
|
authResponse.Channel.CAPem = "Bad CA"
|
|
}
|
|
|
|
func timeout(authResponse *api.Response) {
|
|
authResponse.Channel.MaxSessionTime = 1
|
|
}
|
|
|
|
func setHeader(hdr http.Header) func(*api.Response) {
|
|
return func(authResponse *api.Response) {
|
|
authResponse.Channel.Header = hdr
|
|
}
|
|
}
|
|
|
|
func dialWebsocket(url string, header http.Header, subprotocols ...string) (*websocket.Conn, *http.Response, error) {
|
|
dialer := &websocket.Dialer{
|
|
Subprotocols: subprotocols,
|
|
}
|
|
|
|
return dialer.Dial(url, header)
|
|
}
|
|
|
|
func websocketURL(httpURL string, suffix ...string) string {
|
|
url, err := url.Parse(httpURL)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
switch url.Scheme {
|
|
case "http":
|
|
url.Scheme = "ws"
|
|
case "https":
|
|
url.Scheme = "wss"
|
|
default:
|
|
panic("Unknown scheme: " + url.Scheme)
|
|
}
|
|
|
|
url.Path = path.Join(url.Path, strings.Join(suffix, "/"))
|
|
|
|
return url.String()
|
|
}
|
|
|
|
func say(conn *websocket.Conn, message string) error {
|
|
return conn.WriteMessage(websocket.TextMessage, []byte(message))
|
|
}
|
|
|
|
func requireReadMessage(t *testing.T, conn *websocket.Conn, expectedMessageType int, expectedData string) {
|
|
messageType, data, err := conn.ReadMessage()
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, expectedMessageType, messageType, "message type")
|
|
require.Equal(t, expectedData, string(data), "message data")
|
|
}
|