package api

import (
	"net/http"
	"testing"
)

func channel(url string, subprotocols ...string) *ChannelSettings {
	return &ChannelSettings{
		Url:            url,
		Subprotocols:   subprotocols,
		MaxSessionTime: 0,
	}
}

func ca(channel *ChannelSettings) *ChannelSettings {
	channel = channel.Clone()
	channel.CAPem = "Valid CA data"

	return channel
}

func timeout(channel *ChannelSettings) *ChannelSettings {
	channel = channel.Clone()
	channel.MaxSessionTime = 600

	return channel
}

func header(channel *ChannelSettings, values ...string) *ChannelSettings {
	if len(values) == 0 {
		values = []string{"Dummy Value"}
	}

	channel = channel.Clone()
	channel.Header = http.Header{
		"Header": values,
	}

	return channel
}

func TestClone(t *testing.T) {
	a := ca(header(channel("ws:", "", "")))
	b := a.Clone()

	if a == b {
		t.Fatalf("Address of cloned channel didn't change")
	}

	if &a.Subprotocols == &b.Subprotocols {
		t.Fatalf("Address of cloned subprotocols didn't change")
	}

	if &a.Header == &b.Header {
		t.Fatalf("Address of cloned header didn't change")
	}
}

func TestValidate(t *testing.T) {
	for i, tc := range []struct {
		channel *ChannelSettings
		valid   bool
		msg     string
	}{
		{nil, false, "nil channel"},
		{channel("", ""), false, "empty URL"},
		{channel("ws:"), false, "empty subprotocols"},
		{channel("ws:", "foo"), true, "any subprotocol"},
		{channel("ws:", "foo", "bar"), true, "multiple subprotocols"},
		{channel("ws:", ""), true, "websocket URL"},
		{channel("wss:", ""), true, "secure websocket URL"},
		{channel("http:", ""), false, "HTTP URL"},
		{channel("https:", ""), false, " HTTPS URL"},
		{ca(channel("ws:", "")), true, "any CA pem"},
		{header(channel("ws:", "")), true, "any headers"},
		{ca(header(channel("ws:", ""))), true, "PEM and headers"},
	} {
		if err := tc.channel.Validate(); (err != nil) == tc.valid {
			t.Fatalf("test case %d: "+tc.msg+": valid=%v: %s: %+v", i, tc.valid, err, tc.channel)
		}
	}
}

func TestDialer(t *testing.T) {
	channel := channel("ws:", "foo")
	dialer := channel.Dialer()

	if len(dialer.Subprotocols) != len(channel.Subprotocols) {
		t.Fatalf("Subprotocols don't match: %+v vs. %+v", channel.Subprotocols, dialer.Subprotocols)
	}

	for i, subprotocol := range channel.Subprotocols {
		if dialer.Subprotocols[i] != subprotocol {
			t.Fatalf("Subprotocols don't match: %+v vs. %+v", channel.Subprotocols, dialer.Subprotocols)
		}
	}

	if dialer.TLSClientConfig != nil {
		t.Fatalf("Unexpected TLSClientConfig: %+v", dialer)
	}

	channel = ca(channel)
	dialer = channel.Dialer()

	if dialer.TLSClientConfig == nil || dialer.TLSClientConfig.RootCAs == nil {
		t.Fatalf("Custom CA certificates not recognised!")
	}
}

func TestIsEqual(t *testing.T) {
	chann := channel("ws:", "foo")

	chann_header2 := header(chann, "extra")
	chann_header3 := header(chann)
	chann_header3.Header.Add("Extra", "extra")

	chann_ca2 := ca(chann)
	chann_ca2.CAPem = "other value"

	for i, tc := range []struct {
		channelA *ChannelSettings
		channelB *ChannelSettings
		expected bool
	}{
		{nil, nil, true},
		{chann, nil, false},
		{nil, chann, false},
		{chann, chann, true},
		{chann.Clone(), chann.Clone(), true},
		{chann, channel("foo:"), false},
		{chann, channel(chann.Url), false},
		{header(chann), header(chann), true},
		{chann_header2, chann_header2, true},
		{chann_header3, chann_header3, true},
		{header(chann), chann_header2, false},
		{header(chann), chann_header3, false},
		{header(chann), chann, false},
		{chann, header(chann), false},
		{ca(chann), ca(chann), true},
		{ca(chann), chann, false},
		{chann, ca(chann), false},
		{ca(header(chann)), ca(header(chann)), true},
		{chann_ca2, ca(chann), false},
		{chann, timeout(chann), false},
	} {
		if actual := tc.channelA.IsEqual(tc.channelB); tc.expected != actual {
			t.Fatalf(
				"test case %d: Comparison:\n-%+v\n+%+v\nexpected=%v: actual=%v",
				i, tc.channelA, tc.channelB, tc.expected, actual,
			)
		}
	}
}