93 lines
2.3 KiB
Go
93 lines
2.3 KiB
Go
package roundtripper
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strconv"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestMustParseAddress(t *testing.T) {
|
|
successExamples := []struct{ address, scheme, expected string }{
|
|
{"1.2.3.4:56", "http", "1.2.3.4:56"},
|
|
{"[::1]:23", "http", "::1:23"},
|
|
{"4.5.6.7", "http", "4.5.6.7:http"},
|
|
{"4.5.6.7", "https", "4.5.6.7:https"},
|
|
}
|
|
for i, example := range successExamples {
|
|
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
|
require.Equal(t, example.expected, mustParseAddress(example.address, example.scheme))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMustParseAddressPanic(t *testing.T) {
|
|
panicExamples := []struct{ address, scheme string }{
|
|
{"1.2.3.4", ""},
|
|
}
|
|
|
|
for i, panicExample := range panicExamples {
|
|
t.Run(strconv.Itoa(i), func(t *testing.T) {
|
|
defer func() {
|
|
if r := recover(); r == nil {
|
|
t.Fatal("expected panic")
|
|
}
|
|
}()
|
|
mustParseAddress(panicExample.address, panicExample.scheme)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSupportsHTTPBackend(t *testing.T) {
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
fmt.Fprint(w, "successful response")
|
|
}))
|
|
defer ts.Close()
|
|
|
|
testNewBackendRoundTripper(t, ts, nil, "successful response")
|
|
}
|
|
|
|
func TestSupportsHTTPSBackend(t *testing.T) {
|
|
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
fmt.Fprint(w, "successful response")
|
|
}))
|
|
defer ts.Close()
|
|
|
|
certpool := x509.NewCertPool()
|
|
certpool.AddCert(ts.Certificate())
|
|
tlsClientConfig := &tls.Config{
|
|
RootCAs: certpool,
|
|
}
|
|
|
|
testNewBackendRoundTripper(t, ts, tlsClientConfig, "successful response")
|
|
}
|
|
|
|
func testNewBackendRoundTripper(t *testing.T, ts *httptest.Server, tlsClientConfig *tls.Config, expectedResponseBody string) {
|
|
t.Helper()
|
|
|
|
backend, err := url.Parse(ts.URL)
|
|
require.NoError(t, err, "parse url")
|
|
|
|
rt := newBackendRoundTripper(backend, "", 0, true, tlsClientConfig)
|
|
|
|
req, err := http.NewRequest("GET", ts.URL+"/", nil)
|
|
require.NoError(t, err, "build request")
|
|
|
|
response, err := rt.RoundTrip(req)
|
|
require.NoError(t, err, "perform roundtrip")
|
|
defer response.Body.Close()
|
|
|
|
body, err := io.ReadAll(response.Body)
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, expectedResponseBody, string(body))
|
|
}
|