262 lines
6.7 KiB
Go
262 lines
6.7 KiB
Go
|
package server
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"net/http"
|
||
|
"net/http/httptest"
|
||
|
"net/url"
|
||
|
"reflect"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/coreos/go-oidc/oauth2"
|
||
|
)
|
||
|
|
||
|
func TestWriteAPIError(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
err error
|
||
|
code int
|
||
|
wantCode int
|
||
|
wantBody string
|
||
|
}{
|
||
|
// standard
|
||
|
{
|
||
|
err: newAPIError(errorInvalidRequest, "foo"),
|
||
|
code: http.StatusBadRequest,
|
||
|
wantCode: http.StatusBadRequest,
|
||
|
wantBody: `{"error":"invalid_request","error_description":"foo"}`,
|
||
|
},
|
||
|
// no description
|
||
|
{
|
||
|
err: newAPIError(errorInvalidRequest, ""),
|
||
|
code: http.StatusBadRequest,
|
||
|
wantCode: http.StatusBadRequest,
|
||
|
wantBody: `{"error":"invalid_request"}`,
|
||
|
},
|
||
|
// no type
|
||
|
{
|
||
|
err: newAPIError("", ""),
|
||
|
code: http.StatusBadRequest,
|
||
|
wantCode: http.StatusBadRequest,
|
||
|
wantBody: `{"error":"server_error"}`,
|
||
|
},
|
||
|
// generic error
|
||
|
{
|
||
|
err: errors.New("generic failure"),
|
||
|
code: http.StatusTeapot,
|
||
|
wantCode: http.StatusTeapot,
|
||
|
wantBody: `{"error":"server_error"}`,
|
||
|
},
|
||
|
// nil error
|
||
|
{
|
||
|
err: nil,
|
||
|
code: http.StatusTeapot,
|
||
|
wantCode: http.StatusTeapot,
|
||
|
wantBody: `{"error":"server_error"}`,
|
||
|
},
|
||
|
// empty code
|
||
|
{
|
||
|
err: nil,
|
||
|
code: 0,
|
||
|
wantCode: http.StatusInternalServerError,
|
||
|
wantBody: `{"error":"server_error"}`,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for i, tt := range tests {
|
||
|
w := httptest.NewRecorder()
|
||
|
writeAPIError(w, tt.code, tt.err)
|
||
|
|
||
|
if tt.wantCode != w.Code {
|
||
|
t.Errorf("case %d: incorrect HTTP status: want=%d got=%d", i, tt.wantCode, w.Code)
|
||
|
}
|
||
|
|
||
|
gotBody := w.Body.String()
|
||
|
if tt.wantBody != gotBody {
|
||
|
t.Errorf("case %d: incorrect HTTP body: want=%q got=%q", i, tt.wantBody, gotBody)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestWriteTokenError(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
err error
|
||
|
state string
|
||
|
wantCode int
|
||
|
wantHeader http.Header
|
||
|
wantBody string
|
||
|
}{
|
||
|
{
|
||
|
err: oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||
|
state: "bazinga",
|
||
|
wantCode: http.StatusBadRequest,
|
||
|
wantHeader: http.Header{
|
||
|
"Content-Type": []string{"application/json"},
|
||
|
},
|
||
|
wantBody: `{"error":"invalid_request","state":"bazinga"}`,
|
||
|
},
|
||
|
{
|
||
|
err: oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||
|
wantCode: http.StatusBadRequest,
|
||
|
wantHeader: http.Header{
|
||
|
"Content-Type": []string{"application/json"},
|
||
|
},
|
||
|
wantBody: `{"error":"invalid_request"}`,
|
||
|
},
|
||
|
{
|
||
|
err: oauth2.NewError(oauth2.ErrorInvalidGrant),
|
||
|
wantCode: http.StatusBadRequest,
|
||
|
wantHeader: http.Header{
|
||
|
"Content-Type": []string{"application/json"},
|
||
|
},
|
||
|
wantBody: `{"error":"invalid_grant"}`,
|
||
|
},
|
||
|
{
|
||
|
err: oauth2.NewError(oauth2.ErrorInvalidClient),
|
||
|
wantCode: http.StatusUnauthorized,
|
||
|
wantHeader: http.Header{
|
||
|
"Content-Type": []string{"application/json"},
|
||
|
"Www-Authenticate": []string{"Basic"},
|
||
|
},
|
||
|
wantBody: `{"error":"invalid_client"}`,
|
||
|
},
|
||
|
{
|
||
|
err: oauth2.NewError(oauth2.ErrorServerError),
|
||
|
wantCode: http.StatusBadRequest,
|
||
|
wantHeader: http.Header{
|
||
|
"Content-Type": []string{"application/json"},
|
||
|
},
|
||
|
wantBody: `{"error":"server_error"}`,
|
||
|
},
|
||
|
{
|
||
|
err: oauth2.NewError(oauth2.ErrorUnsupportedGrantType),
|
||
|
wantCode: http.StatusBadRequest,
|
||
|
wantHeader: http.Header{
|
||
|
"Content-Type": []string{"application/json"},
|
||
|
},
|
||
|
wantBody: `{"error":"unsupported_grant_type"}`,
|
||
|
},
|
||
|
{
|
||
|
err: errors.New("generic failure"),
|
||
|
wantCode: http.StatusBadRequest,
|
||
|
wantHeader: http.Header{
|
||
|
"Content-Type": []string{"application/json"},
|
||
|
},
|
||
|
wantBody: `{"error":"server_error"}`,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for i, tt := range tests {
|
||
|
w := httptest.NewRecorder()
|
||
|
writeTokenError(w, tt.err, tt.state)
|
||
|
|
||
|
if tt.wantCode != w.Code {
|
||
|
t.Errorf("case %d: incorrect HTTP status: want=%d got=%d", i, tt.wantCode, w.Code)
|
||
|
}
|
||
|
|
||
|
gotHeader := w.Header()
|
||
|
if !reflect.DeepEqual(tt.wantHeader, gotHeader) {
|
||
|
t.Errorf("case %d: incorrect HTTP headers: want=%#v got=%#v", i, tt.wantHeader, gotHeader)
|
||
|
}
|
||
|
|
||
|
gotBody := w.Body.String()
|
||
|
if tt.wantBody != gotBody {
|
||
|
t.Errorf("case %d: incorrect HTTP body: want=%q got=%q", i, tt.wantBody, gotBody)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestWriteAuthError(t *testing.T) {
|
||
|
wantCode := http.StatusBadRequest
|
||
|
wantHeader := http.Header{"Content-Type": []string{"application/json"}}
|
||
|
tests := []struct {
|
||
|
err error
|
||
|
state string
|
||
|
wantBody string
|
||
|
}{
|
||
|
{
|
||
|
err: errors.New("foobar"),
|
||
|
state: "bazinga",
|
||
|
wantBody: `{"error":"server_error","state":"bazinga"}`,
|
||
|
},
|
||
|
{
|
||
|
err: oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||
|
state: "foo",
|
||
|
wantBody: `{"error":"invalid_request","state":"foo"}`,
|
||
|
},
|
||
|
{
|
||
|
err: oauth2.NewError(oauth2.ErrorUnsupportedResponseType),
|
||
|
state: "bar",
|
||
|
wantBody: `{"error":"unsupported_response_type","state":"bar"}`,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for i, tt := range tests {
|
||
|
w := httptest.NewRecorder()
|
||
|
writeAuthError(w, tt.err, tt.state)
|
||
|
|
||
|
if wantCode != w.Code {
|
||
|
t.Errorf("case %d: incorrect HTTP status: want=%d got=%d", i, wantCode, w.Code)
|
||
|
}
|
||
|
|
||
|
gotHeader := w.Header()
|
||
|
if !reflect.DeepEqual(wantHeader, gotHeader) {
|
||
|
t.Errorf("case %d: incorrect HTTP headers: want=%#v got=%#v", i, wantHeader, gotHeader)
|
||
|
}
|
||
|
|
||
|
gotBody := w.Body.String()
|
||
|
if tt.wantBody != gotBody {
|
||
|
t.Errorf("case %d: incorrect HTTP body: want=%q got=%q", i, tt.wantBody, gotBody)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestRedirectAuthError(t *testing.T) {
|
||
|
wantCode := http.StatusTemporaryRedirect
|
||
|
|
||
|
tests := []struct {
|
||
|
err error
|
||
|
state string
|
||
|
redirectURL url.URL
|
||
|
wantLoc string
|
||
|
}{
|
||
|
{
|
||
|
err: errors.New("foobar"),
|
||
|
state: "bazinga",
|
||
|
redirectURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||
|
wantLoc: "http://server.example.com?error=server_error&state=bazinga",
|
||
|
},
|
||
|
{
|
||
|
err: oauth2.NewError(oauth2.ErrorInvalidRequest),
|
||
|
state: "foo",
|
||
|
redirectURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||
|
wantLoc: "http://server.example.com?error=invalid_request&state=foo",
|
||
|
},
|
||
|
{
|
||
|
err: oauth2.NewError(oauth2.ErrorUnsupportedResponseType),
|
||
|
state: "bar",
|
||
|
redirectURL: url.URL{Scheme: "http", Host: "server.example.com"},
|
||
|
wantLoc: "http://server.example.com?error=unsupported_response_type&state=bar",
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for i, tt := range tests {
|
||
|
w := httptest.NewRecorder()
|
||
|
redirectAuthError(w, tt.err, tt.state, tt.redirectURL)
|
||
|
|
||
|
if wantCode != w.Code {
|
||
|
t.Errorf("case %d: incorrect HTTP status: want=%d got=%d", i, wantCode, w.Code)
|
||
|
}
|
||
|
|
||
|
wantHeader := http.Header{"Location": []string{tt.wantLoc}}
|
||
|
gotHeader := w.Header()
|
||
|
if !reflect.DeepEqual(wantHeader, gotHeader) {
|
||
|
t.Errorf("case %d: incorrect HTTP headers: want=%#v got=%#v", i, wantHeader, gotHeader)
|
||
|
}
|
||
|
|
||
|
gotBody := w.Body.String()
|
||
|
if gotBody != "" {
|
||
|
t.Errorf("case %d: incorrect empty HTTP body, got=%q", i, gotBody)
|
||
|
}
|
||
|
}
|
||
|
}
|