156 lines
4.2 KiB
Go
156 lines
4.2 KiB
Go
|
package channel
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"errors"
|
||
|
"net"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/gorilla/websocket"
|
||
|
)
|
||
|
|
||
|
type testcase struct {
|
||
|
input *fakeConn
|
||
|
expected *fakeConn
|
||
|
}
|
||
|
|
||
|
type fakeConn struct {
|
||
|
// WebSocket message type
|
||
|
mt int
|
||
|
data []byte
|
||
|
err error
|
||
|
}
|
||
|
|
||
|
func (f *fakeConn) ReadMessage() (int, []byte, error) {
|
||
|
return f.mt, f.data, f.err
|
||
|
}
|
||
|
|
||
|
func (f *fakeConn) WriteMessage(mt int, data []byte) error {
|
||
|
f.mt = mt
|
||
|
f.data = data
|
||
|
return f.err
|
||
|
}
|
||
|
|
||
|
func (f *fakeConn) WriteControl(mt int, data []byte, _ time.Time) error {
|
||
|
f.mt = mt
|
||
|
f.data = data
|
||
|
return f.err
|
||
|
}
|
||
|
|
||
|
func (f *fakeConn) UnderlyingConn() net.Conn {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func fake(mt int, data []byte, err error) *fakeConn {
|
||
|
return &fakeConn{mt: mt, data: []byte(data), err: err}
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
msg = []byte("foo bar")
|
||
|
msgBase64 = []byte("Zm9vIGJhcg==")
|
||
|
kubeMsg = append([]byte{0}, msg...)
|
||
|
kubeMsgBase64 = append([]byte{'0'}, msgBase64...)
|
||
|
|
||
|
errFake = errors.New("fake error")
|
||
|
|
||
|
text = websocket.TextMessage
|
||
|
binary = websocket.BinaryMessage
|
||
|
other = 999
|
||
|
|
||
|
fakeOther = fake(other, []byte("foo"), nil)
|
||
|
)
|
||
|
|
||
|
func requireEqualConn(t *testing.T, expected, actual *fakeConn, msg string, args ...interface{}) {
|
||
|
if expected.mt != actual.mt {
|
||
|
t.Logf("messageType expected to be %v but was %v", expected.mt, actual.mt)
|
||
|
t.Fatalf(msg, args...)
|
||
|
}
|
||
|
|
||
|
if !bytes.Equal(expected.data, actual.data) {
|
||
|
t.Logf("data expected to be %q but was %q: ", expected.data, actual.data)
|
||
|
t.Fatalf(msg, args...)
|
||
|
}
|
||
|
|
||
|
if expected.err != actual.err {
|
||
|
t.Logf("error expected to be %v but was %v", expected.err, actual.err)
|
||
|
t.Fatalf(msg, args...)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestReadMessage(t *testing.T) {
|
||
|
testCases := map[string][]testcase{
|
||
|
"channel.k8s.io": {
|
||
|
{fake(binary, kubeMsg, errFake), fake(binary, kubeMsg, errFake)},
|
||
|
{fake(binary, kubeMsg, nil), fake(binary, msg, nil)},
|
||
|
{fake(text, kubeMsg, nil), fake(binary, msg, nil)},
|
||
|
{fakeOther, fakeOther},
|
||
|
},
|
||
|
"base64.channel.k8s.io": {
|
||
|
{fake(text, kubeMsgBase64, errFake), fake(text, kubeMsgBase64, errFake)},
|
||
|
{fake(text, kubeMsgBase64, nil), fake(binary, msg, nil)},
|
||
|
{fake(binary, kubeMsgBase64, nil), fake(binary, msg, nil)},
|
||
|
{fakeOther, fakeOther},
|
||
|
},
|
||
|
"terminal.gitlab.com": {
|
||
|
{fake(binary, msg, errFake), fake(binary, msg, errFake)},
|
||
|
{fake(binary, msg, nil), fake(binary, msg, nil)},
|
||
|
{fake(text, msg, nil), fake(binary, msg, nil)},
|
||
|
{fakeOther, fakeOther},
|
||
|
},
|
||
|
"base64.terminal.gitlab.com": {
|
||
|
{fake(text, msgBase64, errFake), fake(text, msgBase64, errFake)},
|
||
|
{fake(text, msgBase64, nil), fake(binary, msg, nil)},
|
||
|
{fake(binary, msgBase64, nil), fake(binary, msg, nil)},
|
||
|
{fakeOther, fakeOther},
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for subprotocol, cases := range testCases {
|
||
|
for i, tc := range cases {
|
||
|
conn := Wrap(tc.input, subprotocol)
|
||
|
mt, data, err := conn.ReadMessage()
|
||
|
actual := fake(mt, data, err)
|
||
|
requireEqualConn(t, tc.expected, actual, "%s test case %v", subprotocol, i)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestWriteMessage(t *testing.T) {
|
||
|
testCases := map[string][]testcase{
|
||
|
"channel.k8s.io": {
|
||
|
{fake(binary, msg, errFake), fake(binary, kubeMsg, errFake)},
|
||
|
{fake(binary, msg, nil), fake(binary, kubeMsg, nil)},
|
||
|
{fake(text, msg, nil), fake(binary, kubeMsg, nil)},
|
||
|
{fakeOther, fakeOther},
|
||
|
},
|
||
|
"base64.channel.k8s.io": {
|
||
|
{fake(binary, msg, errFake), fake(text, kubeMsgBase64, errFake)},
|
||
|
{fake(binary, msg, nil), fake(text, kubeMsgBase64, nil)},
|
||
|
{fake(text, msg, nil), fake(text, kubeMsgBase64, nil)},
|
||
|
{fakeOther, fakeOther},
|
||
|
},
|
||
|
"terminal.gitlab.com": {
|
||
|
{fake(binary, msg, errFake), fake(binary, msg, errFake)},
|
||
|
{fake(binary, msg, nil), fake(binary, msg, nil)},
|
||
|
{fake(text, msg, nil), fake(binary, msg, nil)},
|
||
|
{fakeOther, fakeOther},
|
||
|
},
|
||
|
"base64.terminal.gitlab.com": {
|
||
|
{fake(binary, msg, errFake), fake(text, msgBase64, errFake)},
|
||
|
{fake(binary, msg, nil), fake(text, msgBase64, nil)},
|
||
|
{fake(text, msg, nil), fake(text, msgBase64, nil)},
|
||
|
{fakeOther, fakeOther},
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for subprotocol, cases := range testCases {
|
||
|
for i, tc := range cases {
|
||
|
actual := fake(0, nil, tc.input.err)
|
||
|
conn := Wrap(actual, subprotocol)
|
||
|
actual.err = conn.WriteMessage(tc.input.mt, tc.input.data)
|
||
|
requireEqualConn(t, tc.expected, actual, "%s test case %v", subprotocol, i)
|
||
|
}
|
||
|
}
|
||
|
}
|