Merge pull request #197 from ericchiang/oauth2_connector

connector: github and bitbucket oauth2 connectors added
This commit is contained in:
Eric Chiang 2015-12-10 08:56:09 -08:00
commit f63ec158a5
39 changed files with 1038 additions and 3710 deletions

View file

@ -75,6 +75,65 @@ Here's what a `oidc` connector looks like configured for authenticating with Goo
} }
``` ```
### `github` connector
This connector config lets users authenticate through [GitHub](https://github.com/). In addition to `id` and `type`, the `github` connector takes the following additional fields:
* clientID: a `string`. The GitHub OAuth application client ID.
* clientSecret: a `string`. The GitHub OAuth application client secret.
To begin, register an OAuth application with GitHub through your, or your organization's [account settings](ttps://github.com/settings/applications/new). To register dex as a client of your GitHub application, enter dex's redirect URL under 'Authorization callback URL':
```
https://$DEX_HOST:$DEX_PORT/auth/$CONNECTOR_ID/callback
```
`$DEX_HOST` and `$DEX_PORT` are the host and port of your dex installation. `$CONNECTOR_ID` is the `id` field of the connector.
Here's an example of a `github` connector; the clientID and clientSecret should be replaced by values provided by GitHub.
```
{
"type": "github",
"id": "github",
"clientID": "$DEX_GITHUB_CLIENT_ID",
"clientSecret": "$DEX_GITHUB_CLIENT_SECRET"
}
```
The `github` connector requests read only access to user's email through the [`user:email` scope](https://developer.github.com/v3/oauth/#scopes).
### `bitbucket` connector
This connector config lets users authenticate through [Bitbucket](https://bitbucket.org/). In addition to `id` and `type`, the `bitbucket` connector takes the following additional fields:
* clientID: a `string`. The Bitbucket OAuth consumer client ID.
* clientSecret: a `string`. The Bitbucket OAuth consumer client secret.
To begin, register an OAuth consumer with Bitbucket through your, or your teams's management page. Follow the documentation at their [developer site](https://confluence.atlassian.com/bitbucket/oauth-on-bitbucket-cloud-238027431.html).
__NOTE:__ When configuring a consumer through Bitbucket you _must_ configure read email permissions.
To register dex as a client of your Bitbucket consumer, enter dex's redirect URL under 'Callback URL':
```
https://$DEX_HOST:$DEX_PORT/auth/$CONNECTOR_ID/callback
```
`$DEX_HOST` and `$DEX_PORT` are the host and port of your dex installation. `$CONNECTOR_ID` is the `id` field of the connector.
Here's an example of a `bitbucket` connector; the clientID and clientSecret should be replaced by values provided by Bitbucket.
```
{
"type": "bitbucket",
"id": "bitbucket",
"clientID": "$DEX_BITBUCKET_CLIENT_ID",
"clientSecret": "$DEX_BITBUCKET_CLIENT_SECRET"
}
```
## Setting the Configuration ## Setting the Configuration
To set a connectors configuration in dex, put it in some temporary file, then use the dexctl command to upload it to dex: To set a connectors configuration in dex, put it in some temporary file, then use the dexctl command to upload it to dex:

10
Godeps/Godeps.json generated
View file

@ -21,23 +21,23 @@
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/http", "ImportPath": "github.com/coreos/go-oidc/http",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "faf70c34f9c411f234eb96d23c518c087cd96d79"
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/jose", "ImportPath": "github.com/coreos/go-oidc/jose",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "faf70c34f9c411f234eb96d23c518c087cd96d79"
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/key", "ImportPath": "github.com/coreos/go-oidc/key",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "faf70c34f9c411f234eb96d23c518c087cd96d79"
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/oauth2", "ImportPath": "github.com/coreos/go-oidc/oauth2",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "faf70c34f9c411f234eb96d23c518c087cd96d79"
}, },
{ {
"ImportPath": "github.com/coreos/go-oidc/oidc", "ImportPath": "github.com/coreos/go-oidc/oidc",
"Rev": "ee7cb1fb480df22f7d8c4c90199e438e454ca3b6" "Rev": "faf70c34f9c411f234eb96d23c518c087cd96d79"
}, },
{ {
"ImportPath": "github.com/coreos/pkg/capnslog", "ImportPath": "github.com/coreos/pkg/capnslog",

202
Godeps/_workspace/src/github.com/coreos/go-oidc/LICENSE generated vendored Normal file
View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "{}"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright {yyyy} {name of copyright owner}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -0,0 +1,5 @@
CoreOS Project
Copyright 2014 CoreOS, Inc
This product includes software developed at CoreOS, Inc.
(http://www.coreos.com/).

View file

@ -1,380 +0,0 @@
package http
import (
"net/http"
"net/url"
"reflect"
"strings"
"testing"
"time"
)
func TestCacheControlMaxAgeSuccess(t *testing.T) {
tests := []struct {
hdr string
wantAge time.Duration
wantOK bool
}{
{"max-age=12", 12 * time.Second, true},
{"max-age=-12", 0, false},
{"max-age=0", 0, false},
{"public, max-age=12", 12 * time.Second, true},
{"public, max-age=40192, must-revalidate", 40192 * time.Second, true},
{"public, not-max-age=12, must-revalidate", time.Duration(0), false},
}
for i, tt := range tests {
maxAge, ok, err := cacheControlMaxAge(tt.hdr)
if err != nil {
t.Errorf("case %d: err=%v", i, err)
}
if tt.wantAge != maxAge {
t.Errorf("case %d: want=%d got=%d", i, tt.wantAge, maxAge)
}
if tt.wantOK != ok {
t.Errorf("case %d: incorrect ok value: want=%t got=%t", i, tt.wantOK, ok)
}
}
}
func TestCacheControlMaxAgeFail(t *testing.T) {
tests := []string{
"max-age=aasdf",
"max-age=",
"max-age",
}
for i, tt := range tests {
_, ok, err := cacheControlMaxAge(tt)
if ok {
t.Errorf("case %d: want ok=false, got true", i)
}
if err == nil {
t.Errorf("case %d: want non-nil err", i)
}
}
}
func TestMergeQuery(t *testing.T) {
tests := []struct {
u string
q url.Values
w string
}{
// No values
{
u: "http://example.com",
q: nil,
w: "http://example.com",
},
// No additional values
{
u: "http://example.com?foo=bar",
q: nil,
w: "http://example.com?foo=bar",
},
// Simple addition
{
u: "http://example.com",
q: url.Values{
"foo": []string{"bar"},
},
w: "http://example.com?foo=bar",
},
// Addition with existing values
{
u: "http://example.com?dog=boo",
q: url.Values{
"foo": []string{"bar"},
},
w: "http://example.com?dog=boo&foo=bar",
},
// Merge
{
u: "http://example.com?dog=boo",
q: url.Values{
"dog": []string{"elroy"},
},
w: "http://example.com?dog=boo&dog=elroy",
},
// Add and merge
{
u: "http://example.com?dog=boo",
q: url.Values{
"dog": []string{"elroy"},
"foo": []string{"bar"},
},
w: "http://example.com?dog=boo&dog=elroy&foo=bar",
},
// Multivalue merge
{
u: "http://example.com?dog=boo",
q: url.Values{
"dog": []string{"elroy", "penny"},
},
w: "http://example.com?dog=boo&dog=elroy&dog=penny",
},
}
for i, tt := range tests {
ur, err := url.Parse(tt.u)
if err != nil {
t.Errorf("case %d: failed parsing test url: %v, error: %v", i, tt.u, err)
}
got := MergeQuery(*ur, tt.q)
want, err := url.Parse(tt.w)
if err != nil {
t.Errorf("case %d: failed parsing want url: %v, error: %v", i, tt.w, err)
}
if !reflect.DeepEqual(*want, got) {
t.Errorf("case %d: want: %v, got: %v", i, *want, got)
}
}
}
func TestExpiresPass(t *testing.T) {
tests := []struct {
date string
exp string
wantTTL time.Duration
wantOK bool
}{
// Expires and Date properly set
{
date: "Thu, 01 Dec 1983 22:00:00 GMT",
exp: "Fri, 02 Dec 1983 01:00:00 GMT",
wantTTL: 10800 * time.Second,
wantOK: true,
},
// empty headers
{
date: "",
exp: "",
wantOK: false,
},
// lack of Expirs short-ciruits Date parsing
{
date: "foo",
exp: "",
wantOK: false,
},
// lack of Date short-ciruits Expires parsing
{
date: "",
exp: "foo",
wantOK: false,
},
// no Date
{
exp: "Thu, 01 Dec 1983 22:00:00 GMT",
wantTTL: 0,
wantOK: false,
},
// no Expires
{
date: "Thu, 01 Dec 1983 22:00:00 GMT",
wantTTL: 0,
wantOK: false,
},
// Expires < Date
{
date: "Fri, 02 Dec 1983 01:00:00 GMT",
exp: "Thu, 01 Dec 1983 22:00:00 GMT",
wantTTL: 0,
wantOK: false,
},
}
for i, tt := range tests {
ttl, ok, err := expires(tt.date, tt.exp)
if err != nil {
t.Errorf("case %d: err=%v", i, err)
}
if tt.wantTTL != ttl {
t.Errorf("case %d: want=%d got=%d", i, tt.wantTTL, ttl)
}
if tt.wantOK != ok {
t.Errorf("case %d: incorrect ok value: want=%t got=%t", i, tt.wantOK, ok)
}
}
}
func TestExpiresFail(t *testing.T) {
tests := []struct {
date string
exp string
}{
// malformed Date header
{
date: "foo",
exp: "Fri, 02 Dec 1983 01:00:00 GMT",
},
// malformed exp header
{
date: "Fri, 02 Dec 1983 01:00:00 GMT",
exp: "bar",
},
}
for i, tt := range tests {
_, _, err := expires(tt.date, tt.exp)
if err == nil {
t.Errorf("case %d: expected non-nil error", i)
}
}
}
func TestCacheablePass(t *testing.T) {
tests := []struct {
headers http.Header
wantTTL time.Duration
wantOK bool
}{
// valid Cache-Control
{
headers: http.Header{
"Cache-Control": []string{"max-age=100"},
},
wantTTL: 100 * time.Second,
wantOK: true,
},
// valid Date/Expires
{
headers: http.Header{
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
"Expires": []string{"Fri, 02 Dec 1983 01:00:00 GMT"},
},
wantTTL: 10800 * time.Second,
wantOK: true,
},
// Cache-Control supersedes Date/Expires
{
headers: http.Header{
"Cache-Control": []string{"max-age=100"},
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
"Expires": []string{"Fri, 02 Dec 1983 01:00:00 GMT"},
},
wantTTL: 100 * time.Second,
wantOK: true,
},
// no caching headers
{
headers: http.Header{},
wantOK: false,
},
}
for i, tt := range tests {
ttl, ok, err := Cacheable(tt.headers)
if err != nil {
t.Errorf("case %d: err=%v", i, err)
continue
}
if tt.wantTTL != ttl {
t.Errorf("case %d: want=%d got=%d", i, tt.wantTTL, ttl)
}
if tt.wantOK != ok {
t.Errorf("case %d: incorrect ok value: want=%t got=%t", i, tt.wantOK, ok)
}
}
}
func TestCacheableFail(t *testing.T) {
tests := []http.Header{
// invalid Cache-Control short-circuits
http.Header{
"Cache-Control": []string{"max-age"},
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
"Expires": []string{"Fri, 02 Dec 1983 01:00:00 GMT"},
},
// no Cache-Control, invalid Expires
http.Header{
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
"Expires": []string{"boo"},
},
}
for i, tt := range tests {
_, _, err := Cacheable(tt)
if err == nil {
t.Errorf("case %d: want non-nil err", i)
}
}
}
func TestNewResourceLocation(t *testing.T) {
tests := []struct {
ru *url.URL
id string
want string
}{
{
ru: &url.URL{
Scheme: "http",
Host: "example.com",
},
id: "foo",
want: "http://example.com/foo",
},
// https
{
ru: &url.URL{
Scheme: "https",
Host: "example.com",
},
id: "foo",
want: "https://example.com/foo",
},
// with path
{
ru: &url.URL{
Scheme: "http",
Host: "example.com",
Path: "one/two/three",
},
id: "foo",
want: "http://example.com/one/two/three/foo",
},
// with fragment
{
ru: &url.URL{
Scheme: "http",
Host: "example.com",
Fragment: "frag",
},
id: "foo",
want: "http://example.com/foo",
},
// with query
{
ru: &url.URL{
Scheme: "http",
Host: "example.com",
RawQuery: "dog=elroy",
},
id: "foo",
want: "http://example.com/foo",
},
}
for i, tt := range tests {
got := NewResourceLocation(tt.ru, tt.id)
if tt.want != got {
t.Errorf("case %d: want=%s, got=%s", i, tt.want, got)
}
}
}
func TestCopyRequest(t *testing.T) {
r1, err := http.NewRequest("GET", "http://example.com", strings.NewReader("foo"))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
r2 := CopyRequest(r1)
if !reflect.DeepEqual(r1, r2) {
t.Fatalf("Result of CopyRequest incorrect: %#v != %#v", r1, r2)
}
}

View file

@ -1,49 +0,0 @@
package http
import (
"net/url"
"testing"
)
func TestParseNonEmptyURL(t *testing.T) {
tests := []struct {
u string
ok bool
}{
{"", false},
{"http://", false},
{"example.com", false},
{"example", false},
{"http://example", true},
{"http://example:1234", true},
{"http://example.com", true},
{"http://example.com:1234", true},
}
for i, tt := range tests {
u, err := ParseNonEmptyURL(tt.u)
if err != nil {
t.Logf("err: %v", err)
if tt.ok {
t.Errorf("case %d: unexpected error: %v", i, err)
} else {
continue
}
}
if !tt.ok {
t.Errorf("case %d: expected error but got none", i)
continue
}
uu, err := url.Parse(tt.u)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
if uu.String() != u.String() {
t.Errorf("case %d: incorrect url value, want: %q, got: %q", i, uu.String(), u.String())
}
}
}

View file

@ -26,6 +26,32 @@ func (c Claims) StringClaim(name string) (string, bool, error) {
return v, true, nil return v, true, nil
} }
func (c Claims) StringsClaim(name string) ([]string, bool, error) {
cl, ok := c[name]
if !ok {
return nil, false, nil
}
if v, ok := cl.([]string); ok {
return v, true, nil
}
// When unmarshaled, []string will become []interface{}.
if v, ok := cl.([]interface{}); ok {
var ret []string
for _, vv := range v {
str, ok := vv.(string)
if !ok {
return nil, false, fmt.Errorf("unable to parse claim as string array: %v", name)
}
ret = append(ret, str)
}
return ret, true, nil
}
return nil, false, fmt.Errorf("unable to parse claim as string array: %v", name)
}
func (c Claims) Int64Claim(name string) (int64, bool, error) { func (c Claims) Int64Claim(name string) (int64, bool, error) {
cl, ok := c[name] cl, ok := c[name]
if !ok { if !ok {

View file

@ -1,240 +0,0 @@
package jose
import (
"testing"
"time"
)
func TestString(t *testing.T) {
tests := []struct {
cl Claims
key string
ok bool
err bool
val string
}{
// ok, no err, claim exists
{
cl: Claims{
"foo": "bar",
},
key: "foo",
val: "bar",
ok: true,
err: false,
},
// no claims
{
cl: Claims{},
key: "foo",
val: "",
ok: false,
err: false,
},
// missing claim
{
cl: Claims{
"foo": "bar",
},
key: "xxx",
val: "",
ok: false,
err: false,
},
// unparsable: type
{
cl: Claims{
"foo": struct{}{},
},
key: "foo",
val: "",
ok: false,
err: true,
},
// unparsable: nil value
{
cl: Claims{
"foo": nil,
},
key: "foo",
val: "",
ok: false,
err: true,
},
}
for i, tt := range tests {
val, ok, err := tt.cl.StringClaim(tt.key)
if tt.err && err == nil {
t.Errorf("case %d: want err=non-nil, got err=nil", i)
} else if !tt.err && err != nil {
t.Errorf("case %d: want err=nil, got err=%v", i, err)
}
if tt.ok != ok {
t.Errorf("case %d: want ok=%v, got ok=%v", i, tt.ok, ok)
}
if tt.val != val {
t.Errorf("case %d: want val=%v, got val=%v", i, tt.val, val)
}
}
}
func TestInt64(t *testing.T) {
tests := []struct {
cl Claims
key string
ok bool
err bool
val int64
}{
// ok, no err, claim exists
{
cl: Claims{
"foo": int64(100),
},
key: "foo",
val: int64(100),
ok: true,
err: false,
},
// no claims
{
cl: Claims{},
key: "foo",
val: 0,
ok: false,
err: false,
},
// missing claim
{
cl: Claims{
"foo": "bar",
},
key: "xxx",
val: 0,
ok: false,
err: false,
},
// unparsable: type
{
cl: Claims{
"foo": struct{}{},
},
key: "foo",
val: 0,
ok: false,
err: true,
},
// unparsable: nil value
{
cl: Claims{
"foo": nil,
},
key: "foo",
val: 0,
ok: false,
err: true,
},
}
for i, tt := range tests {
val, ok, err := tt.cl.Int64Claim(tt.key)
if tt.err && err == nil {
t.Errorf("case %d: want err=non-nil, got err=nil", i)
} else if !tt.err && err != nil {
t.Errorf("case %d: want err=nil, got err=%v", i, err)
}
if tt.ok != ok {
t.Errorf("case %d: want ok=%v, got ok=%v", i, tt.ok, ok)
}
if tt.val != val {
t.Errorf("case %d: want val=%v, got val=%v", i, tt.val, val)
}
}
}
func TestTime(t *testing.T) {
now := time.Now().UTC()
unixNow := now.Unix()
tests := []struct {
cl Claims
key string
ok bool
err bool
val time.Time
}{
// ok, no err, claim exists
{
cl: Claims{
"foo": unixNow,
},
key: "foo",
val: time.Unix(now.Unix(), 0).UTC(),
ok: true,
err: false,
},
// no claims
{
cl: Claims{},
key: "foo",
val: time.Time{},
ok: false,
err: false,
},
// missing claim
{
cl: Claims{
"foo": "bar",
},
key: "xxx",
val: time.Time{},
ok: false,
err: false,
},
// unparsable: type
{
cl: Claims{
"foo": struct{}{},
},
key: "foo",
val: time.Time{},
ok: false,
err: true,
},
// unparsable: nil value
{
cl: Claims{
"foo": nil,
},
key: "foo",
val: time.Time{},
ok: false,
err: true,
},
}
for i, tt := range tests {
val, ok, err := tt.cl.TimeClaim(tt.key)
if tt.err && err == nil {
t.Errorf("case %d: want err=non-nil, got err=nil", i)
} else if !tt.err && err != nil {
t.Errorf("case %d: want err=nil, got err=%v", i, err)
}
if tt.ok != ok {
t.Errorf("case %d: want ok=%v, got ok=%v", i, tt.ok, ok)
}
if tt.val != val {
t.Errorf("case %d: want val=%v, got val=%v", i, tt.val, val)
}
}
}

View file

@ -1,64 +0,0 @@
package jose
import (
"testing"
)
func TestDecodeBase64URLPaddingOptional(t *testing.T) {
tests := []struct {
encoded string
decoded string
err bool
}{
{
// With padding
encoded: "VGVjdG9uaWM=",
decoded: "Tectonic",
},
{
// Without padding
encoded: "VGVjdG9uaWM",
decoded: "Tectonic",
},
{
// Even More padding
encoded: "VGVjdG9uaQ==",
decoded: "Tectoni",
},
{
// And take it away!
encoded: "VGVjdG9uaQ",
decoded: "Tectoni",
},
{
// Too much padding.
encoded: "VGVjdG9uaWNh=",
decoded: "",
err: true,
},
{
// Too much padding.
encoded: "VGVjdG9uaWNh=",
decoded: "",
err: true,
},
}
for i, tt := range tests {
got, err := decodeBase64URLPaddingOptional(tt.encoded)
if tt.err {
if err == nil {
t.Errorf("case %d: expected non-nil err", i)
}
continue
}
if err != nil {
t.Errorf("case %d: want nil err, got: %v", i, err)
}
if string(got) != tt.decoded {
t.Errorf("case %d: want=%q, got=%q", i, tt.decoded, got)
}
}
}

View file

@ -1,74 +0,0 @@
package jose
import (
"strings"
"testing"
)
type testCase struct{ t string }
var validInput []testCase
var invalidInput []testCase
func init() {
validInput = []testCase{
{
"eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
},
}
invalidInput = []testCase{
// empty
{
"",
},
// undecodeable
{
"aaa.bbb.ccc",
},
// missing parts
{
"aaa",
},
// missing parts
{
"aaa.bbb",
},
// too many parts
{
"aaa.bbb.ccc.ddd",
},
// invalid header
// EncodeHeader(map[string]string{"foo": "bar"})
{
"eyJmb28iOiJiYXIifQ.bbb.ccc",
},
}
}
func TestParseJWS(t *testing.T) {
for i, tt := range validInput {
jws, err := ParseJWS(tt.t)
if err != nil {
t.Errorf("test: %d. expected: valid, actual: invalid", i)
}
expectedHeader := strings.Split(tt.t, ".")[0]
if jws.RawHeader != expectedHeader {
t.Errorf("test: %d. expected: %s, actual: %s", i, expectedHeader, jws.RawHeader)
}
expectedPayload := strings.Split(tt.t, ".")[1]
if jws.RawPayload != expectedPayload {
t.Errorf("test: %d. expected: %s, actual: %s", i, expectedPayload, jws.RawPayload)
}
}
for i, tt := range invalidInput {
_, err := ParseJWS(tt.t)
if err == nil {
t.Errorf("test: %d. expected: invalid, actual: valid", i)
}
}
}

View file

@ -1,8 +1,6 @@
package jose package jose
import ( import "strings"
"strings"
)
type JWT JWS type JWT JWS
@ -63,13 +61,13 @@ func (j *JWT) Encode() string {
return strings.Join([]string{d, s}, ".") return strings.Join([]string{d, s}, ".")
} }
func NewSignedJWT(claims map[string]interface{}, s Signer) (*JWT, error) { func NewSignedJWT(claims Claims, s Signer) (*JWT, error) {
header := JOSEHeader{ header := JOSEHeader{
HeaderKeyAlgorithm: s.Alg(), HeaderKeyAlgorithm: s.Alg(),
HeaderKeyID: s.ID(), HeaderKeyID: s.ID(),
} }
jwt, err := NewJWT(header, Claims(claims)) jwt, err := NewJWT(header, claims)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,94 +0,0 @@
package jose
import (
"reflect"
"testing"
)
func TestParseJWT(t *testing.T) {
tests := []struct {
r string
h JOSEHeader
c Claims
}{
{
// Example from JWT spec:
// http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html#ExampleJWT
"eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
JOSEHeader{
HeaderMediaType: "JWT",
HeaderKeyAlgorithm: "HS256",
},
Claims{
"iss": "joe",
// NOTE: test numbers must be floats for equality checks to work since values are converted form interface{} to float64 by default.
"exp": 1300819380.0,
"http://example.com/is_root": true,
},
},
}
for i, tt := range tests {
jwt, err := ParseJWT(tt.r)
if err != nil {
t.Errorf("raw token should parse. test: %d. expected: valid, actual: invalid. err=%v", i, err)
}
if !reflect.DeepEqual(tt.h, jwt.Header) {
t.Errorf("JOSE headers should match. test: %d. expected: %v, actual: %v", i, tt.h, jwt.Header)
}
claims, err := jwt.Claims()
if err != nil {
t.Errorf("test: %d. expected: valid claim parsing. err=%v", i, err)
}
if !reflect.DeepEqual(tt.c, claims) {
t.Errorf("claims should match. test: %d. expected: %v, actual: %v", i, tt.c, claims)
}
enc := jwt.Encode()
if enc != tt.r {
t.Errorf("encoded jwt should match raw jwt. test: %d. expected: %v, actual: %v", i, tt.r, enc)
}
}
}
func TestNewJWTHeaderTyp(t *testing.T) {
jwt, err := NewJWT(JOSEHeader{}, Claims{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
want := "JWT"
got := jwt.Header[HeaderMediaType]
if want != got {
t.Fatalf("Header %q incorrect: want=%s got=%s", HeaderMediaType, want, got)
}
}
func TestNewJWTHeaderKeyID(t *testing.T) {
jwt, err := NewJWT(JOSEHeader{HeaderKeyID: "foo"}, Claims{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
want := "foo"
got, ok := jwt.KeyID()
if !ok {
t.Fatalf("KeyID not set")
} else if want != got {
t.Fatalf("KeyID incorrect: want=%s got=%s", want, got)
}
}
func TestNewJWTHeaderKeyIDNotSet(t *testing.T) {
jwt, err := NewJWT(JOSEHeader{}, Claims{})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if _, ok := jwt.KeyID(); ok {
t.Fatalf("KeyID set, but should not be")
}
}

View file

@ -1,85 +0,0 @@
package jose
import (
"bytes"
"encoding/base64"
"testing"
)
var hmacTestCases = []struct {
data string
sig string
jwk JWK
valid bool
desc string
}{
{
"test",
"Aymga2LNFrM-tnkr6MYLFY2Jou46h2_Omogeu0iMCRQ=",
JWK{
ID: "fake-key",
Alg: "HS256",
Secret: []byte("secret"),
},
true,
"valid case",
},
{
"test",
"Aymga2LNFrM-tnkr6MYLFY2Jou46h2_Omogeu0iMCRQ=",
JWK{
ID: "different-key",
Alg: "HS256",
Secret: []byte("secret"),
},
true,
"invalid: different key, should not match",
},
{
"test sig and non-matching data",
"Aymga2LNFrM-tnkr6MYLFY2Jou46h2_Omogeu0iMCRQ=",
JWK{
ID: "fake-key",
Alg: "HS256",
Secret: []byte("secret"),
},
false,
"invalid: sig and data should not match",
},
}
func TestVerify(t *testing.T) {
for _, tt := range hmacTestCases {
v, err := NewVerifierHMAC(tt.jwk)
if err != nil {
t.Errorf("should construct hmac verifier. test: %s. err=%v", tt.desc, err)
}
decSig, _ := base64.URLEncoding.DecodeString(tt.sig)
err = v.Verify(decSig, []byte(tt.data))
if err == nil && !tt.valid {
t.Errorf("verify failure. test: %s. expected: invalid, actual: valid.", tt.desc)
}
if err != nil && tt.valid {
t.Errorf("verify failure. test: %s. expected: valid, actual: invalid. err=%v", tt.desc, err)
}
}
}
func TestSign(t *testing.T) {
for _, tt := range hmacTestCases {
s := NewSignerHMAC("test", tt.jwk.Secret)
sig, err := s.Sign([]byte(tt.data))
if err != nil {
t.Errorf("sign failure. test: %s. err=%v", tt.desc, err)
}
expSig, _ := base64.URLEncoding.DecodeString(tt.sig)
if tt.valid && !bytes.Equal(sig, expSig) {
t.Errorf("sign failure. test: %s. expected: %s, actual: %s.", tt.desc, tt.sig, base64.URLEncoding.EncodeToString(sig))
}
if !tt.valid && bytes.Equal(sig, expSig) {
t.Errorf("sign failure. test: %s. expected: invalid signature.", tt.desc)
}
}
}

View file

@ -4,6 +4,7 @@ import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"encoding/base64" "encoding/base64"
"encoding/json"
"math/big" "math/big"
"time" "time"
@ -18,6 +19,19 @@ type PublicKey struct {
jwk jose.JWK jwk jose.JWK
} }
func (k *PublicKey) MarshalJSON() ([]byte, error) {
return json.Marshal(k.jwk)
}
func (k *PublicKey) UnmarshalJSON(data []byte) error {
var jwk jose.JWK
if err := json.Unmarshal(data, &jwk); err != nil {
return err
}
k.jwk = jwk
return nil
}
func (k *PublicKey) ID() string { func (k *PublicKey) ID() string {
return k.jwk.ID return k.jwk.ID
} }

View file

@ -1,68 +0,0 @@
package key
import (
"crypto/rsa"
"math/big"
"reflect"
"testing"
"time"
"github.com/coreos/go-oidc/jose"
)
func TestPrivateRSAKeyJWK(t *testing.T) {
n := big.NewInt(int64(17))
if n == nil {
panic("NewInt returned nil")
}
k := &PrivateKey{
KeyID: "foo",
PrivateKey: &rsa.PrivateKey{
PublicKey: rsa.PublicKey{N: n, E: 65537},
},
}
want := jose.JWK{
ID: "foo",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: n,
Exponent: 65537,
}
got := k.JWK()
if !reflect.DeepEqual(want, got) {
t.Fatalf("JWK mismatch: want=%#v got=%#v", want, got)
}
}
func TestPublicKeySetKey(t *testing.T) {
n := big.NewInt(int64(17))
if n == nil {
panic("NewInt returned nil")
}
k := jose.JWK{
ID: "foo",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: n,
Exponent: 65537,
}
now := time.Now().UTC()
ks := NewPublicKeySet([]jose.JWK{k}, now)
want := &PublicKey{jwk: k}
got := ks.Key("foo")
if !reflect.DeepEqual(want, got) {
t.Errorf("Unexpected response from PublicKeySet.Key: want=%#v got=%#v", want, got)
}
got = ks.Key("bar")
if got != nil {
t.Errorf("Expected nil response from PublicKeySet.Key, got %#v", got)
}
}

View file

@ -1,225 +0,0 @@
package key
import (
"crypto/rsa"
"math/big"
"reflect"
"strconv"
"testing"
"time"
"github.com/jonboulle/clockwork"
"github.com/coreos/go-oidc/jose"
)
var (
jwk1 jose.JWK
jwk2 jose.JWK
jwk3 jose.JWK
)
func init() {
jwk1 = jose.JWK{
ID: "1",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: big.NewInt(1),
Exponent: 65537,
}
jwk2 = jose.JWK{
ID: "2",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: big.NewInt(2),
Exponent: 65537,
}
jwk3 = jose.JWK{
ID: "3",
Type: "RSA",
Alg: "RS256",
Use: "sig",
Modulus: big.NewInt(3),
Exponent: 65537,
}
}
func generatePrivateKeyStatic(t *testing.T, idAndN int) *PrivateKey {
n := big.NewInt(int64(idAndN))
if n == nil {
t.Fatalf("Call to NewInt(%d) failed", idAndN)
}
pk := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{N: n, E: 65537},
}
return &PrivateKey{
KeyID: strconv.Itoa(idAndN),
PrivateKey: pk,
}
}
func TestPrivateKeyManagerJWKsRotate(t *testing.T) {
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
km := NewPrivateKeyManager()
err := km.Set(&PrivateKeySet{
keys: []*PrivateKey{k1, k2, k3},
ActiveKeyID: k1.KeyID,
expiresAt: time.Now().Add(time.Minute),
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
want := []jose.JWK{jwk1, jwk2, jwk3}
got, err := km.JWKs()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if !reflect.DeepEqual(want, got) {
t.Fatalf("JWK mismatch: want=%#v got=%#v", want, got)
}
}
func TestPrivateKeyManagerSigner(t *testing.T) {
k := generatePrivateKeyStatic(t, 13)
km := NewPrivateKeyManager()
err := km.Set(&PrivateKeySet{
keys: []*PrivateKey{k},
ActiveKeyID: k.KeyID,
expiresAt: time.Now().Add(time.Minute),
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
signer, err := km.Signer()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
wantID := "13"
gotID := signer.ID()
if wantID != gotID {
t.Fatalf("Signer has incorrect ID: want=%s got=%s", wantID, gotID)
}
}
func TestPrivateKeyManagerHealthyFail(t *testing.T) {
keyFixture := generatePrivateKeyStatic(t, 1)
tests := []*privateKeyManager{
// keySet nil
&privateKeyManager{
keySet: nil,
clock: clockwork.NewRealClock(),
},
// zero keys
&privateKeyManager{
keySet: &PrivateKeySet{
keys: []*PrivateKey{},
expiresAt: time.Now().Add(time.Minute),
},
clock: clockwork.NewRealClock(),
},
// key set expired
&privateKeyManager{
keySet: &PrivateKeySet{
keys: []*PrivateKey{keyFixture},
expiresAt: time.Now().Add(-1 * time.Minute),
},
clock: clockwork.NewRealClock(),
},
}
for i, tt := range tests {
if err := tt.Healthy(); err == nil {
t.Errorf("case %d: nil error", i)
}
}
}
func TestPrivateKeyManagerHealthyFailsOtherMethods(t *testing.T) {
km := NewPrivateKeyManager()
if _, err := km.JWKs(); err == nil {
t.Fatalf("Expected non-nil error")
}
if _, err := km.Signer(); err == nil {
t.Fatalf("Expected non-nil error")
}
}
func TestPrivateKeyManagerExpiresAt(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
k := generatePrivateKeyStatic(t, 17)
km := &privateKeyManager{
clock: fc,
}
want := fc.Now().UTC()
got := km.ExpiresAt()
if want != got {
t.Fatalf("Incorrect expiration time: want=%v got=%v", want, got)
}
err := km.Set(&PrivateKeySet{
keys: []*PrivateKey{k},
ActiveKeyID: k.KeyID,
expiresAt: now.Add(2 * time.Minute),
})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
want = fc.Now().UTC().Add(2 * time.Minute)
got = km.ExpiresAt()
if want != got {
t.Fatalf("Incorrect expiration time: want=%v got=%v", want, got)
}
}
func TestPublicKeys(t *testing.T) {
km := NewPrivateKeyManager()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
tests := [][]*PrivateKey{
[]*PrivateKey{k1},
[]*PrivateKey{k1, k2},
[]*PrivateKey{k1, k2, k3},
}
for i, tt := range tests {
ks := &PrivateKeySet{
keys: tt,
expiresAt: time.Now().Add(time.Hour),
}
km.Set(ks)
jwks, err := km.JWKs()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
pks := NewPublicKeySet(jwks, time.Now().Add(time.Hour))
want := pks.Keys()
got, err := km.PublicKeys()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if !reflect.DeepEqual(want, got) {
t.Errorf("case %d: Invalid public keys: want=%v got=%v", i, want, got)
}
}
}

View file

@ -1,311 +0,0 @@
package key
import (
"reflect"
"testing"
"time"
"github.com/jonboulle/clockwork"
)
func generatePrivateKeySerialFunc(t *testing.T) GeneratePrivateKeyFunc {
var n int
return func() (*PrivateKey, error) {
n++
return generatePrivateKeyStatic(t, n), nil
}
}
func TestRotate(t *testing.T) {
now := time.Now()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
tests := []struct {
start *PrivateKeySet
key *PrivateKey
keep int
exp time.Time
want *PrivateKeySet
}{
// start with nil keys
{
start: nil,
key: k1,
keep: 2,
exp: now.Add(time.Second),
want: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(time.Second),
},
},
// start with zero keys
{
start: &PrivateKeySet{},
key: k1,
keep: 2,
exp: now.Add(time.Second),
want: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(time.Second),
},
},
// add second key
{
start: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now,
},
key: k2,
keep: 2,
exp: now.Add(time.Second),
want: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(time.Second),
},
},
// rotate in third key
{
start: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now,
},
key: k3,
keep: 2,
exp: now.Add(time.Second),
want: &PrivateKeySet{
keys: []*PrivateKey{k3, k2},
ActiveKeyID: k3.KeyID,
expiresAt: now.Add(time.Second),
},
},
}
for i, tt := range tests {
repo := NewPrivateKeySetRepo()
if tt.start != nil {
err := repo.Set(tt.start)
if err != nil {
log.Fatalf("case %d: unexpected error: %v", i, err)
}
}
rotatePrivateKeys(repo, tt.key, tt.keep, tt.exp)
got, err := repo.Get()
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
if !reflect.DeepEqual(tt.want, got) {
t.Errorf("case %d: unexpected result: want=%#v got=%#v", i, tt.want, got)
}
}
}
func TestPrivateKeyRotatorRun(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
k4 := generatePrivateKeyStatic(t, 4)
kRepo := NewPrivateKeySetRepo()
krot := NewPrivateKeyRotator(kRepo, 4*time.Second)
krot.clock = fc
krot.generateKey = generatePrivateKeySerialFunc(t)
steps := []*PrivateKeySet{
&PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(4 * time.Second),
},
&PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(6 * time.Second),
},
&PrivateKeySet{
keys: []*PrivateKey{k3, k2},
ActiveKeyID: k3.KeyID,
expiresAt: now.Add(8 * time.Second),
},
&PrivateKeySet{
keys: []*PrivateKey{k4, k3},
ActiveKeyID: k4.KeyID,
expiresAt: now.Add(10 * time.Second),
},
}
stop := krot.Run()
defer close(stop)
for i, st := range steps {
// wait for the rotater to get sleepy
fc.BlockUntil(1)
got, err := kRepo.Get()
if err != nil {
t.Fatalf("step %d: unexpected error: %v", i, err)
}
if !reflect.DeepEqual(st, got) {
t.Fatalf("step %d: unexpected state: want=%#v got=%#v", i, st, got)
}
fc.Advance(2 * time.Second)
}
}
func TestPrivateKeyRotatorExpiresAt(t *testing.T) {
fc := clockwork.NewFakeClock()
krot := &PrivateKeyRotator{
clock: fc,
ttl: time.Minute,
}
got := krot.expiresAt()
want := fc.Now().UTC().Add(time.Minute)
if !reflect.DeepEqual(want, got) {
t.Errorf("Incorrect expiration time: want=%v got=%v", want, got)
}
}
func TestNextRotation(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
tests := []struct {
expiresAt time.Time
ttl time.Duration
numKeys int
expected time.Duration
}{
{
// closest to prod
expiresAt: now.Add(time.Hour * 24),
ttl: time.Hour * 24,
numKeys: 2,
expected: time.Hour * 12,
},
{
expiresAt: now.Add(time.Hour * 2),
ttl: time.Hour * 4,
numKeys: 2,
expected: 0,
},
{
// No keys.
expiresAt: now.Add(time.Hour * 2),
ttl: time.Hour * 4,
numKeys: 0,
expected: 0,
},
{
// Nil keyset.
expiresAt: now.Add(time.Hour * 2),
ttl: time.Hour * 4,
numKeys: -1,
expected: 0,
},
{
// KeySet expired.
expiresAt: now.Add(time.Hour * -2),
ttl: time.Hour * 4,
numKeys: 2,
expected: 0,
},
{
// Expiry past now + TTL
expiresAt: now.Add(time.Hour * 5),
ttl: time.Hour * 4,
numKeys: 2,
expected: 3 * time.Hour,
},
}
for i, tt := range tests {
kRepo := NewPrivateKeySetRepo()
krot := NewPrivateKeyRotator(kRepo, tt.ttl)
krot.clock = fc
pks := &PrivateKeySet{
expiresAt: tt.expiresAt,
}
if tt.numKeys != -1 {
for n := 0; n < tt.numKeys; n++ {
pks.keys = append(pks.keys, generatePrivateKeyStatic(t, n))
}
err := kRepo.Set(pks)
if err != nil {
log.Fatalf("case %d: unexpected error: %v", i, err)
}
}
actual, err := krot.nextRotation()
if err != nil {
t.Errorf("case %d: error calling shouldRotate(): %v", i, err)
}
if actual != tt.expected {
t.Errorf("case %d: actual == %v, want %v", i, actual, tt.expected)
}
}
}
func TestHealthy(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
tests := []struct {
expiresAt time.Time
numKeys int
expected error
}{
{
expiresAt: now.Add(time.Hour),
numKeys: 2,
expected: nil,
},
{
expiresAt: now.Add(time.Hour),
numKeys: -1,
expected: ErrorNoKeys,
},
{
expiresAt: now.Add(time.Hour),
numKeys: 0,
expected: ErrorNoKeys,
},
{
expiresAt: now.Add(-time.Hour),
numKeys: 2,
expected: ErrorPrivateKeysExpired,
},
}
for i, tt := range tests {
kRepo := NewPrivateKeySetRepo()
krot := NewPrivateKeyRotator(kRepo, time.Hour)
krot.clock = fc
pks := &PrivateKeySet{
expiresAt: tt.expiresAt,
}
if tt.numKeys != -1 {
for n := 0; n < tt.numKeys; n++ {
pks.keys = append(pks.keys, generatePrivateKeyStatic(t, n))
}
err := kRepo.Set(pks)
if err != nil {
log.Fatalf("case %d: unexpected error: %v", i, err)
}
}
if err := krot.Healthy(); err != tt.expected {
t.Errorf("case %d: got==%q, want==%q", i, err, tt.expected)
}
}
}

View file

@ -1,205 +0,0 @@
package key
import (
"errors"
"reflect"
"testing"
"time"
"github.com/jonboulle/clockwork"
)
type staticReadableKeySetRepo struct {
ks KeySet
err error
}
func (r *staticReadableKeySetRepo) Get() (KeySet, error) {
return r.ks, r.err
}
func TestKeySyncerSync(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
steps := []struct {
fromKS KeySet
fromErr error
advance time.Duration
want *PrivateKeySet
}{
// on startup, first sync should trigger within a second
{
fromKS: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(10 * time.Second),
},
advance: time.Second,
want: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(10 * time.Second),
},
},
// advance halfway into TTL, triggering sync
{
fromKS: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(15 * time.Second),
},
advance: 5 * time.Second,
want: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(15 * time.Second),
},
},
// advance halfway into TTL, triggering sync that fails
{
fromErr: errors.New("fail!"),
advance: 10 * time.Second,
want: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(15 * time.Second),
},
},
// sync retries quickly, and succeeds with fixed data
{
fromKS: &PrivateKeySet{
keys: []*PrivateKey{k3, k2, k1},
ActiveKeyID: k3.KeyID,
expiresAt: now.Add(25 * time.Second),
},
advance: 3 * time.Second,
want: &PrivateKeySet{
keys: []*PrivateKey{k3, k2, k1},
ActiveKeyID: k3.KeyID,
expiresAt: now.Add(25 * time.Second),
},
},
}
from := &staticReadableKeySetRepo{}
to := NewPrivateKeySetRepo()
syncer := NewKeySetSyncer(from, to)
syncer.clock = fc
stop := syncer.Run()
defer close(stop)
for i, st := range steps {
from.ks = st.fromKS
from.err = st.fromErr
fc.Advance(st.advance)
fc.BlockUntil(1)
ks, err := to.Get()
if err != nil {
t.Fatalf("step %d: unable to get keys: %v", i, err)
}
if !reflect.DeepEqual(st.want, ks) {
t.Fatalf("step %d: incorrect state: want=%#v got=%#v", i, st.want, ks)
}
}
}
func TestSync(t *testing.T) {
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
k1 := generatePrivateKeyStatic(t, 1)
k2 := generatePrivateKeyStatic(t, 2)
k3 := generatePrivateKeyStatic(t, 3)
tests := []struct {
keySet *PrivateKeySet
want time.Duration
}{
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(time.Minute),
},
want: time.Minute,
},
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(time.Minute),
},
want: time.Minute,
},
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k3, k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(time.Minute),
},
want: time.Minute,
},
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k2, k1},
ActiveKeyID: k2.KeyID,
expiresAt: now.Add(time.Hour),
},
want: time.Hour,
},
{
keySet: &PrivateKeySet{
keys: []*PrivateKey{k1},
ActiveKeyID: k1.KeyID,
expiresAt: now.Add(-time.Hour),
},
want: 0,
},
}
for i, tt := range tests {
from := NewPrivateKeySetRepo()
to := NewPrivateKeySetRepo()
err := from.Set(tt.keySet)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
exp, err := sync(from, to, fc)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
if tt.want != exp {
t.Errorf("case %d: want=%v got=%v", i, tt.want, exp)
}
}
}
func TestSyncFail(t *testing.T) {
tests := []error{
nil,
errors.New("fail!"),
}
for i, tt := range tests {
from := &staticReadableKeySetRepo{ks: nil, err: tt}
to := NewPrivateKeySetRepo()
if _, err := sync(from, to, clockwork.NewFakeClock()); err == nil {
t.Errorf("case %d: expected non-nil error", i)
}
}
}

View file

@ -1,10 +1,5 @@
package oauth2 package oauth2
import (
"encoding/json"
"fmt"
)
const ( const (
ErrorAccessDenied = "access_denied" ErrorAccessDenied = "access_denied"
ErrorInvalidClient = "invalid_client" ErrorInvalidClient = "invalid_client"
@ -17,23 +12,18 @@ const (
) )
type Error struct { type Error struct {
Type string `json:"error"` Type string `json:"error"`
State string `json:"state,omitempty"` Description string `json:"error_description,omitempty"`
State string `json:"state,omitempty"`
} }
func (e *Error) Error() string { func (e *Error) Error() string {
if e.Description != "" {
return e.Type + ": " + e.Description
}
return e.Type return e.Type
} }
func NewError(typ string) *Error { func NewError(typ string) *Error {
return &Error{Type: typ} return &Error{Type: typ}
} }
func unmarshalError(b []byte) error {
var oerr Error
err := json.Unmarshal(b, &oerr)
if err != nil {
return fmt.Errorf("unrecognized error: %s", string(b))
}
return &oerr
}

View file

@ -1,79 +0,0 @@
package oauth2
import (
"fmt"
"reflect"
"testing"
)
func TestUnmarshalError(t *testing.T) {
tests := []struct {
b []byte
e *Error
o bool
}{
{
b: []byte("{ \"error\": \"invalid_client\", \"state\": \"foo\" }"),
e: &Error{Type: ErrorInvalidClient, State: "foo"},
o: true,
},
{
b: []byte("{ \"error\": \"invalid_grant\", \"state\": \"bar\" }"),
e: &Error{Type: ErrorInvalidGrant, State: "bar"},
o: true,
},
{
b: []byte("{ \"error\": \"invalid_request\", \"state\": \"\" }"),
e: &Error{Type: ErrorInvalidRequest, State: ""},
o: true,
},
{
b: []byte("{ \"error\": \"server_error\", \"state\": \"elroy\" }"),
e: &Error{Type: ErrorServerError, State: "elroy"},
o: true,
},
{
b: []byte("{ \"error\": \"unsupported_grant_type\", \"state\": \"\" }"),
e: &Error{Type: ErrorUnsupportedGrantType, State: ""},
o: true,
},
{
b: []byte("{ \"error\": \"unsupported_response_type\", \"state\": \"\" }"),
e: &Error{Type: ErrorUnsupportedResponseType, State: ""},
o: true,
},
// Should fail json unmarshal
{
b: nil,
e: nil,
o: false,
},
{
b: []byte("random string"),
e: nil,
o: false,
},
}
for i, tt := range tests {
err := unmarshalError(tt.b)
oerr, ok := err.(*Error)
if ok != tt.o {
t.Errorf("%v != %v, %v", ok, tt.o, oerr)
t.Errorf("case %d: want=%+v, got=%+v", i, tt.e, oerr)
}
if ok && !reflect.DeepEqual(tt.e, oerr) {
t.Errorf("case %d: want=%+v, got=%+v", i, tt.e, oerr)
}
if !ok && tt.e != nil {
want := fmt.Sprintf("unrecognized error: %s", string(tt.b))
got := tt.e.Error()
if want != got {
t.Errorf("case %d: want=%+v, got=%+v", i, want, got)
}
}
}
}

View file

@ -220,11 +220,7 @@ func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) {
if err != nil { if err != nil {
return return
} }
badStatusCode := resp.StatusCode < 200 || resp.StatusCode > 299
if resp.StatusCode < 200 || resp.StatusCode > 299 {
err = unmarshalError(body)
return
}
contentType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) contentType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err != nil { if err != nil {
@ -235,42 +231,69 @@ func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) {
RawBody: body, RawBody: body,
} }
newError := func(typ, desc, state string) error {
if typ == "" {
return fmt.Errorf("unrecognized error %s", body)
}
return &Error{typ, desc, state}
}
if contentType == "application/x-www-form-urlencoded" || contentType == "text/plain" { if contentType == "application/x-www-form-urlencoded" || contentType == "text/plain" {
var vals url.Values var vals url.Values
vals, err = url.ParseQuery(string(body)) vals, err = url.ParseQuery(string(body))
if err != nil { if err != nil {
return return
} }
if error := vals.Get("error"); error != "" || badStatusCode {
err = newError(error, vals.Get("error_description"), vals.Get("state"))
return
}
e := vals.Get("expires_in")
if e == "" {
e = vals.Get("expires")
}
if e != "" {
result.Expires, err = strconv.Atoi(e)
if err != nil {
return
}
}
result.AccessToken = vals.Get("access_token") result.AccessToken = vals.Get("access_token")
result.TokenType = vals.Get("token_type") result.TokenType = vals.Get("token_type")
result.IDToken = vals.Get("id_token") result.IDToken = vals.Get("id_token")
result.RefreshToken = vals.Get("refresh_token") result.RefreshToken = vals.Get("refresh_token")
result.Scope = vals.Get("scope") result.Scope = vals.Get("scope")
e := vals.Get("expires_in")
if e == "" {
e = vals.Get("expires")
}
result.Expires, err = strconv.Atoi(e)
if err != nil {
return
}
} else { } else {
b := make(map[string]interface{}) var r struct {
if err = json.Unmarshal(body, &b); err != nil { AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
State string `json:"state"`
ExpiresIn int `json:"expires_in"`
Expires int `json:"expires"`
Error string `json:"error"`
Desc string `json:"error_description"`
}
if err = json.Unmarshal(body, &r); err != nil {
return return
} }
result.AccessToken, _ = b["access_token"].(string) if r.Error != "" || badStatusCode {
result.TokenType, _ = b["token_type"].(string) err = newError(r.Error, r.Desc, r.State)
result.IDToken, _ = b["id_token"].(string) return
result.RefreshToken, _ = b["refresh_token"].(string) }
result.Scope, _ = b["scope"].(string) result.AccessToken = r.AccessToken
e, ok := b["expires_in"].(int) result.TokenType = r.TokenType
if !ok { result.IDToken = r.IDToken
e, _ = b["expires"].(int) result.RefreshToken = r.RefreshToken
result.Scope = r.Scope
if r.ExpiresIn == 0 {
result.Expires = r.Expires
} else {
result.Expires = r.ExpiresIn
} }
result.Expires = e
} }
return return
} }

View file

@ -1,262 +0,0 @@
package oauth2
import (
"errors"
"net/url"
"reflect"
"strings"
"testing"
phttp "github.com/coreos/go-oidc/http"
)
func TestParseAuthCodeRequest(t *testing.T) {
tests := []struct {
query url.Values
wantACR AuthCodeRequest
wantErr error
}{
// no redirect_uri
{
query: url.Values{
"response_type": []string{"code"},
"scope": []string{"foo bar baz"},
"client_id": []string{"XXX"},
"state": []string{"pants"},
},
wantACR: AuthCodeRequest{
ResponseType: "code",
ClientID: "XXX",
Scope: []string{"foo", "bar", "baz"},
State: "pants",
RedirectURL: nil,
},
},
// with redirect_uri
{
query: url.Values{
"response_type": []string{"code"},
"redirect_uri": []string{"https://127.0.0.1:5555/callback?foo=bar"},
"scope": []string{"foo bar baz"},
"client_id": []string{"XXX"},
"state": []string{"pants"},
},
wantACR: AuthCodeRequest{
ResponseType: "code",
ClientID: "XXX",
Scope: []string{"foo", "bar", "baz"},
State: "pants",
RedirectURL: &url.URL{
Scheme: "https",
Host: "127.0.0.1:5555",
Path: "/callback",
RawQuery: "foo=bar",
},
},
},
// unsupported response_type doesn't trigger error
{
query: url.Values{
"response_type": []string{"token"},
"redirect_uri": []string{"https://127.0.0.1:5555/callback?foo=bar"},
"scope": []string{"foo bar baz"},
"client_id": []string{"XXX"},
"state": []string{"pants"},
},
wantACR: AuthCodeRequest{
ResponseType: "token",
ClientID: "XXX",
Scope: []string{"foo", "bar", "baz"},
State: "pants",
RedirectURL: &url.URL{
Scheme: "https",
Host: "127.0.0.1:5555",
Path: "/callback",
RawQuery: "foo=bar",
},
},
},
// unparseable redirect_uri
{
query: url.Values{
"response_type": []string{"code"},
"redirect_uri": []string{":"},
"scope": []string{"foo bar baz"},
"client_id": []string{"XXX"},
"state": []string{"pants"},
},
wantACR: AuthCodeRequest{
ResponseType: "code",
ClientID: "XXX",
Scope: []string{"foo", "bar", "baz"},
State: "pants",
},
wantErr: NewError(ErrorInvalidRequest),
},
// no client_id, redirect_uri not parsed
{
query: url.Values{
"response_type": []string{"code"},
"redirect_uri": []string{"https://127.0.0.1:5555/callback?foo=bar"},
"scope": []string{"foo bar baz"},
"client_id": []string{},
"state": []string{"pants"},
},
wantACR: AuthCodeRequest{
ResponseType: "code",
ClientID: "",
Scope: []string{"foo", "bar", "baz"},
State: "pants",
RedirectURL: nil,
},
wantErr: NewError(ErrorInvalidRequest),
},
}
for i, tt := range tests {
got, err := ParseAuthCodeRequest(tt.query)
if !reflect.DeepEqual(tt.wantErr, err) {
t.Errorf("case %d: incorrect error value: want=%q got=%q", i, tt.wantErr, err)
}
if !reflect.DeepEqual(tt.wantACR, got) {
t.Errorf("case %d: incorrect AuthCodeRequest value: want=%#v got=%#v", i, tt.wantACR, got)
}
}
}
func TestClientCredsToken(t *testing.T) {
hc := &phttp.RequestRecorder{Error: errors.New("error")}
cfg := Config{
Credentials: ClientCredentials{ID: "cid", Secret: "csecret"},
Scope: []string{"foo-scope", "bar-scope"},
TokenURL: "http://example.com/token",
AuthMethod: AuthMethodClientSecretBasic,
RedirectURL: "http://example.com/redirect",
AuthURL: "http://example.com/auth",
}
c, err := NewClient(hc, cfg)
if err != nil {
t.Errorf("unexpected error %v", err)
}
scope := []string{"openid"}
c.ClientCredsToken(scope)
if hc.Request == nil {
t.Error("request is empty")
}
tu := hc.Request.URL.String()
if cfg.TokenURL != tu {
t.Errorf("wrong token url, want=%v, got=%v", cfg.TokenURL, tu)
}
ct := hc.Request.Header.Get("Content-Type")
if ct != "application/x-www-form-urlencoded" {
t.Errorf("wrong content-type, want=application/x-www-form-urlencoded, got=%v", ct)
}
cid, secret, ok := phttp.BasicAuth(hc.Request)
if !ok {
t.Error("unexpected error parsing basic auth")
}
if cfg.Credentials.ID != cid {
t.Errorf("wrong client ID, want=%v, got=%v", cfg.Credentials.ID, cid)
}
if cfg.Credentials.Secret != secret {
t.Errorf("wrong client secret, want=%v, got=%v", cfg.Credentials.Secret, secret)
}
err = hc.Request.ParseForm()
if err != nil {
t.Error("unexpected error parsing form")
}
gt := hc.Request.PostForm.Get("grant_type")
if gt != GrantTypeClientCreds {
t.Errorf("wrong grant_type, want=client_credentials, got=%v", gt)
}
sc := strings.Split(hc.Request.PostForm.Get("scope"), " ")
if !reflect.DeepEqual(scope, sc) {
t.Errorf("wrong scope, want=%v, got=%v", scope, sc)
}
}
func TestNewAuthenticatedRequest(t *testing.T) {
tests := []struct {
authMethod string
url string
values url.Values
}{
{
authMethod: AuthMethodClientSecretBasic,
url: "http://example.com/token",
values: url.Values{},
},
{
authMethod: AuthMethodClientSecretPost,
url: "http://example.com/token",
values: url.Values{},
},
}
for i, tt := range tests {
hc := &phttp.HandlerClient{}
cfg := Config{
Credentials: ClientCredentials{ID: "cid", Secret: "csecret"},
Scope: []string{"foo-scope", "bar-scope"},
TokenURL: "http://example.com/token",
AuthURL: "http://example.com/auth",
RedirectURL: "http://example.com/redirect",
AuthMethod: tt.authMethod,
}
c, err := NewClient(hc, cfg)
req, err := c.newAuthenticatedRequest(tt.url, tt.values)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
err = req.ParseForm()
if err != nil {
t.Errorf("case %d: want nil err, got %v", i, err)
}
if tt.authMethod == AuthMethodClientSecretBasic {
cid, secret, ok := phttp.BasicAuth(req)
if !ok {
t.Errorf("case %d: !ok parsing Basic Auth headers", i)
continue
}
if cid != cfg.Credentials.ID {
t.Errorf("case %d: want CID == %q, got CID == %q", i, cfg.Credentials.ID, cid)
}
if secret != cfg.Credentials.Secret {
t.Errorf("case %d: want secret == %q, got secret == %q", i, cfg.Credentials.Secret, secret)
}
} else if tt.authMethod == AuthMethodClientSecretPost {
if req.PostFormValue("client_secret") != cfg.Credentials.Secret {
t.Errorf("case %d: want client_secret == %q, got client_secret == %q",
i, cfg.Credentials.Secret, req.PostFormValue("client_secret"))
}
}
for k, v := range tt.values {
if !reflect.DeepEqual(v, req.PostForm[k]) {
t.Errorf("case %d: key:%q want==%q, got==%q", i, k, v, req.PostForm[k])
}
}
if req.URL.String() != tt.url {
t.Errorf("case %d: want URL==%q, got URL==%q", i, tt.url, req.URL.String())
}
}
}

View file

@ -1,367 +0,0 @@
package oidc
import (
"net/url"
"reflect"
"testing"
"time"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
"github.com/coreos/go-oidc/oauth2"
)
func TestNewClientScopeDefault(t *testing.T) {
tests := []struct {
c ClientConfig
e []string
}{
{
// No scope
c: ClientConfig{RedirectURL: "http://example.com/redirect"},
e: DefaultScope,
},
{
// Nil scope
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: nil},
e: DefaultScope,
},
{
// Empty scope
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{}},
e: []string{},
},
{
// Custom scope equal to default
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{"openid", "email", "profile"}},
e: DefaultScope,
},
{
// Custom scope not including defaults
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{"foo", "bar"}},
e: []string{"foo", "bar"},
},
{
// Custom scopes overlapping with defaults
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{"openid", "foo"}},
e: []string{"openid", "foo"},
},
}
for i, tt := range tests {
c, err := NewClient(tt.c)
if err != nil {
t.Errorf("case %d: unexpected error from NewClient: %v", i, err)
continue
}
if !reflect.DeepEqual(tt.e, c.scope) {
t.Errorf("case %d: want: %v, got: %v", i, tt.e, c.scope)
}
}
}
func TestHealthy(t *testing.T) {
now := time.Now().UTC()
tests := []struct {
c *Client
h bool
}{
// all ok
{
c: &Client{
providerConfig: ProviderConfig{
Issuer: "http://example.com",
ExpiresAt: now.Add(time.Hour),
},
},
h: true,
},
// zero-value ProviderConfig.ExpiresAt
{
c: &Client{
providerConfig: ProviderConfig{
Issuer: "http://example.com",
},
},
h: true,
},
// expired ProviderConfig
{
c: &Client{
providerConfig: ProviderConfig{
Issuer: "http://example.com",
ExpiresAt: now.Add(time.Hour * -1),
},
},
h: false,
},
// empty ProviderConfig
{
c: &Client{},
h: false,
},
}
for i, tt := range tests {
err := tt.c.Healthy()
want := tt.h
got := (err == nil)
if want != got {
t.Errorf("case %d: want: healthy=%v, got: healhty=%v, err: %v", i, want, got, err)
}
}
}
func TestClientKeysFuncAll(t *testing.T) {
priv1, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("failed to generate private key, error=%v", err)
}
priv2, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("failed to generate private key, error=%v", err)
}
now := time.Now()
future := now.Add(time.Hour)
past := now.Add(-1 * time.Hour)
tests := []struct {
keySet *key.PublicKeySet
want []key.PublicKey
}{
// two keys, non-expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, future),
want: []key.PublicKey{*key.NewPublicKey(priv2.JWK()), *key.NewPublicKey(priv1.JWK())},
},
// no keys, non-expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{}, future),
want: []key.PublicKey{},
},
// two keys, expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, past),
want: []key.PublicKey{},
},
// no keys, expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{}, past),
want: []key.PublicKey{},
},
}
for i, tt := range tests {
var c Client
c.keySet = *tt.keySet
keysFunc := c.keysFuncAll()
got := keysFunc()
if !reflect.DeepEqual(tt.want, got) {
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, got)
}
}
}
func TestClientKeysFuncWithID(t *testing.T) {
priv1, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("failed to generate private key, error=%v", err)
}
priv2, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("failed to generate private key, error=%v", err)
}
now := time.Now()
future := now.Add(time.Hour)
past := now.Add(-1 * time.Hour)
tests := []struct {
keySet *key.PublicKeySet
argID string
want []key.PublicKey
}{
// two keys, match, non-expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, future),
argID: priv2.ID(),
want: []key.PublicKey{*key.NewPublicKey(priv2.JWK())},
},
// two keys, no match, non-expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, future),
argID: "XXX",
want: []key.PublicKey{},
},
// no keys, no match, non-expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{}, future),
argID: priv2.ID(),
want: []key.PublicKey{},
},
// two keys, match, expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, past),
argID: priv2.ID(),
want: []key.PublicKey{},
},
// no keys, no match, expired set
{
keySet: key.NewPublicKeySet([]jose.JWK{}, past),
argID: priv2.ID(),
want: []key.PublicKey{},
},
}
for i, tt := range tests {
var c Client
c.keySet = *tt.keySet
keysFunc := c.keysFuncWithID(tt.argID)
got := keysFunc()
if !reflect.DeepEqual(tt.want, got) {
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, got)
}
}
}
func TestClientMetadataValid(t *testing.T) {
tests := []ClientMetadata{
// one RedirectURL
ClientMetadata{
RedirectURLs: []url.URL{url.URL{Scheme: "http", Host: "example.com"}},
},
// one RedirectURL w/ nonempty path
ClientMetadata{
RedirectURLs: []url.URL{url.URL{Scheme: "http", Host: "example.com", Path: "/foo"}},
},
// two RedirectURLs
ClientMetadata{
RedirectURLs: []url.URL{
url.URL{Scheme: "http", Host: "foo.example.com"},
url.URL{Scheme: "http", Host: "bar.example.com"},
},
},
}
for i, tt := range tests {
if err := tt.Valid(); err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
}
}
}
func TestClientMetadataInvalid(t *testing.T) {
tests := []ClientMetadata{
// nil RedirectURls slice
ClientMetadata{
RedirectURLs: nil,
},
// empty RedirectURLs slice
ClientMetadata{
RedirectURLs: []url.URL{},
},
// empty url.URL
ClientMetadata{
RedirectURLs: []url.URL{url.URL{}},
},
// empty url.URL following OK item
ClientMetadata{
RedirectURLs: []url.URL{url.URL{Scheme: "http", Host: "example.com"}, url.URL{}},
},
// url.URL with empty Host
ClientMetadata{
RedirectURLs: []url.URL{url.URL{Scheme: "http", Host: ""}},
},
// url.URL with empty Scheme
ClientMetadata{
RedirectURLs: []url.URL{url.URL{Scheme: "", Host: "example.com"}},
},
// url.URL with non-HTTP(S) Scheme
ClientMetadata{
RedirectURLs: []url.URL{url.URL{Scheme: "tcp", Host: "127.0.0.1"}},
},
}
for i, tt := range tests {
if err := tt.Valid(); err == nil {
t.Errorf("case %d: expected non-nil error", i)
}
}
}
func TestChooseAuthMethod(t *testing.T) {
tests := []struct {
supported []string
chosen string
err bool
}{
{
supported: []string{},
chosen: oauth2.AuthMethodClientSecretBasic,
},
{
supported: []string{oauth2.AuthMethodClientSecretBasic},
chosen: oauth2.AuthMethodClientSecretBasic,
},
{
supported: []string{oauth2.AuthMethodClientSecretPost},
chosen: oauth2.AuthMethodClientSecretPost,
},
{
supported: []string{oauth2.AuthMethodClientSecretPost, oauth2.AuthMethodClientSecretBasic},
chosen: oauth2.AuthMethodClientSecretPost,
},
{
supported: []string{oauth2.AuthMethodClientSecretBasic, oauth2.AuthMethodClientSecretPost},
chosen: oauth2.AuthMethodClientSecretBasic,
},
{
supported: []string{oauth2.AuthMethodClientSecretJWT, oauth2.AuthMethodClientSecretPost},
chosen: oauth2.AuthMethodClientSecretPost,
},
{
supported: []string{oauth2.AuthMethodClientSecretJWT},
chosen: "",
err: true,
},
}
for i, tt := range tests {
client := Client{
providerConfig: ProviderConfig{
TokenEndpointAuthMethodsSupported: tt.supported,
},
}
got, err := client.chooseAuthMethod()
if tt.err {
if err == nil {
t.Errorf("case %d: expected non-nil err", i)
}
continue
}
if got != tt.chosen {
t.Errorf("case %d: want=%q, got=%q", i, tt.chosen, got)
}
}
}

View file

@ -1,113 +0,0 @@
package oidc
import (
"reflect"
"testing"
"time"
"github.com/coreos/go-oidc/jose"
)
func TestIdentityFromClaims(t *testing.T) {
tests := []struct {
claims jose.Claims
want Identity
}{
{
claims: jose.Claims{
"sub": "123850281",
"name": "Elroy",
"email": "elroy@example.com",
"exp": float64(1.416935146e+09),
},
want: Identity{
ID: "123850281",
Name: "",
Email: "elroy@example.com",
ExpiresAt: time.Date(2014, time.November, 25, 17, 05, 46, 0, time.UTC),
},
},
{
claims: jose.Claims{
"sub": "123850281",
"name": "Elroy",
"exp": float64(1.416935146e+09),
},
want: Identity{
ID: "123850281",
Name: "",
Email: "",
ExpiresAt: time.Date(2014, time.November, 25, 17, 05, 46, 0, time.UTC),
},
},
{
claims: jose.Claims{
"sub": "123850281",
"name": "Elroy",
"email": "elroy@example.com",
"exp": int64(1416935146),
},
want: Identity{
ID: "123850281",
Name: "",
Email: "elroy@example.com",
ExpiresAt: time.Date(2014, time.November, 25, 17, 05, 46, 0, time.UTC),
},
},
{
claims: jose.Claims{
"sub": "123850281",
"name": "Elroy",
"email": "elroy@example.com",
},
want: Identity{
ID: "123850281",
Name: "",
Email: "elroy@example.com",
ExpiresAt: time.Time{},
},
},
}
for i, tt := range tests {
got, err := IdentityFromClaims(tt.claims)
if err != nil {
t.Errorf("case %d: unexpected error: %v", i, err)
continue
}
if !reflect.DeepEqual(tt.want, *got) {
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, *got)
}
}
}
func TestIdentityFromClaimsFail(t *testing.T) {
tests := []jose.Claims{
// sub incorrect type
jose.Claims{
"sub": 123,
"name": "foo",
"email": "elroy@example.com",
},
// email incorrect type
jose.Claims{
"sub": "123850281",
"name": "Elroy",
"email": false,
},
// exp incorrect type
jose.Claims{
"sub": "123850281",
"name": "Elroy",
"email": "elroy@example.com",
"exp": "2014-11-25 18:05:46 +0000 UTC",
},
}
for i, tt := range tests {
_, err := IdentityFromClaims(tt)
if err == nil {
t.Errorf("case %d: expected non-nil error", i)
}
}
}

View file

@ -1,466 +0,0 @@
package oidc
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"reflect"
"testing"
"time"
"github.com/jonboulle/clockwork"
phttp "github.com/coreos/go-oidc/http"
"github.com/coreos/go-oidc/oauth2"
)
type fakeProviderConfigGetterSetter struct {
cfg *ProviderConfig
getCount int
setCount int
}
func (g *fakeProviderConfigGetterSetter) Get() (ProviderConfig, error) {
g.getCount++
return *g.cfg, nil
}
func (g *fakeProviderConfigGetterSetter) Set(cfg ProviderConfig) error {
g.cfg = &cfg
g.setCount++
return nil
}
type fakeProviderConfigHandler struct {
cfg ProviderConfig
maxAge time.Duration
}
func (s *fakeProviderConfigHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
b, _ := json.Marshal(s.cfg)
if s.maxAge.Seconds() >= 0 {
w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(s.maxAge.Seconds())))
}
w.Header().Set("Content-Type", "application/json")
w.Write(b)
}
func TestHTTPProviderConfigGetter(t *testing.T) {
svr := &fakeProviderConfigHandler{}
hc := &phttp.HandlerClient{Handler: svr}
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
tests := []struct {
dsc string
age time.Duration
cfg ProviderConfig
ok bool
}{
// everything is good
{
dsc: "https://example.com",
age: time.Minute,
cfg: ProviderConfig{
Issuer: "https://example.com",
ExpiresAt: now.Add(time.Minute),
},
ok: true,
},
// iss and disco url differ by scheme only (how google works)
{
dsc: "https://example.com",
age: time.Minute,
cfg: ProviderConfig{
Issuer: "example.com",
ExpiresAt: now.Add(time.Minute),
},
ok: true,
},
// issuer and discovery URL mismatch
{
dsc: "https://foo.com",
age: time.Minute,
cfg: ProviderConfig{
Issuer: "https://example.com",
ExpiresAt: now.Add(time.Minute),
},
ok: false,
},
// missing cache header results in zero ExpiresAt
{
dsc: "https://example.com",
age: -1,
cfg: ProviderConfig{
Issuer: "https://example.com",
},
ok: true,
},
}
for i, tt := range tests {
svr.cfg = tt.cfg
svr.maxAge = tt.age
getter := NewHTTPProviderConfigGetter(hc, tt.dsc)
getter.clock = fc
got, err := getter.Get()
if err != nil {
if tt.ok {
t.Fatalf("test %d: unexpected error: %v", i, err)
}
continue
}
if !tt.ok {
t.Fatalf("test %d: expected error", i)
continue
}
if !reflect.DeepEqual(tt.cfg, got) {
t.Fatalf("test %d: want: %#v, got: %#v", i, tt.cfg, got)
}
}
}
func TestProviderConfigSyncerRun(t *testing.T) {
c1 := &ProviderConfig{
Issuer: "http://first.example.com",
}
c2 := &ProviderConfig{
Issuer: "http://second.example.com",
}
tests := []struct {
first *ProviderConfig
advance time.Duration
second *ProviderConfig
firstExp time.Duration
secondExp time.Duration
count int
}{
// exp is 10m, should have same config after 1s
{
first: c1,
firstExp: time.Duration(10 * time.Minute),
advance: time.Minute,
second: c1,
secondExp: time.Duration(10 * time.Minute),
count: 1,
},
// exp is 10m, should have new config after 10/2 = 5m
{
first: c1,
firstExp: time.Duration(10 * time.Minute),
advance: time.Duration(5 * time.Minute),
second: c2,
secondExp: time.Duration(10 * time.Minute),
count: 2,
},
// exp is 20m, should have new config after 20/2 = 10m
{
first: c1,
firstExp: time.Duration(20 * time.Minute),
advance: time.Duration(10 * time.Minute),
second: c2,
secondExp: time.Duration(30 * time.Minute),
count: 2,
},
}
assertCfg := func(i int, to *fakeProviderConfigGetterSetter, want ProviderConfig) {
got, err := to.Get()
if err != nil {
t.Fatalf("test %d: unable to get config: %v", i, err)
}
if !reflect.DeepEqual(want, got) {
t.Fatalf("test %d: incorrect state:\nwant=%#v\ngot=%#v", i, want, got)
}
}
for i, tt := range tests {
from := &fakeProviderConfigGetterSetter{}
to := &fakeProviderConfigGetterSetter{}
fc := clockwork.NewFakeClock()
now := fc.Now().UTC()
syncer := NewProviderConfigSyncer(from, to)
syncer.clock = fc
tt.first.ExpiresAt = now.Add(tt.firstExp)
tt.second.ExpiresAt = now.Add(tt.secondExp)
if err := from.Set(*tt.first); err != nil {
t.Fatalf("test %d: unexpected error: %v", i, err)
}
stop := syncer.Run()
defer close(stop)
fc.BlockUntil(1)
// first sync
assertCfg(i, to, *tt.first)
if err := from.Set(*tt.second); err != nil {
t.Fatalf("test %d: unexpected error: %v", i, err)
}
fc.Advance(tt.advance)
fc.BlockUntil(1)
// second sync
assertCfg(i, to, *tt.second)
if tt.count != from.getCount {
t.Fatalf("test %d: want: %v, got: %v", i, tt.count, from.getCount)
}
}
}
type staticProviderConfigGetter struct {
cfg ProviderConfig
err error
}
func (g *staticProviderConfigGetter) Get() (ProviderConfig, error) {
return g.cfg, g.err
}
type staticProviderConfigSetter struct {
cfg *ProviderConfig
err error
}
func (s *staticProviderConfigSetter) Set(cfg ProviderConfig) error {
s.cfg = &cfg
return s.err
}
func TestProviderConfigSyncerSyncFailure(t *testing.T) {
fc := clockwork.NewFakeClock()
tests := []struct {
from *staticProviderConfigGetter
to *staticProviderConfigSetter
// want indicates what ProviderConfig should be passed to Set.
// If nil, the Set should not be called.
want *ProviderConfig
}{
// generic Get failure
{
from: &staticProviderConfigGetter{err: errors.New("fail")},
to: &staticProviderConfigSetter{},
want: nil,
},
// generic Set failure
{
from: &staticProviderConfigGetter{cfg: ProviderConfig{ExpiresAt: fc.Now().Add(time.Minute)}},
to: &staticProviderConfigSetter{err: errors.New("fail")},
want: &ProviderConfig{ExpiresAt: fc.Now().Add(time.Minute)},
},
}
for i, tt := range tests {
pcs := &ProviderConfigSyncer{
from: tt.from,
to: tt.to,
clock: fc,
}
_, err := pcs.sync()
if err == nil {
t.Errorf("case %d: expected non-nil error", i)
}
if !reflect.DeepEqual(tt.want, tt.to.cfg) {
t.Errorf("case %d: Set mismatch: want=%#v got=%#v", i, tt.want, tt.to.cfg)
}
}
}
func TestNextSyncAfter(t *testing.T) {
fc := clockwork.NewFakeClock()
tests := []struct {
exp time.Time
want time.Duration
}{
{
exp: fc.Now().Add(time.Hour),
want: 30 * time.Minute,
},
// override large values with the maximum
{
exp: fc.Now().Add(168 * time.Hour), // one week
want: 24 * time.Hour,
},
// override "now" values with the minimum
{
exp: fc.Now(),
want: time.Minute,
},
// override negative values with the minimum
{
exp: fc.Now().Add(-1 * time.Minute),
want: time.Minute,
},
// zero-value Time results in maximum sync interval
{
exp: time.Time{},
want: 24 * time.Hour,
},
}
for i, tt := range tests {
got := nextSyncAfter(tt.exp, fc)
if tt.want != got {
t.Errorf("case %d: want=%v got=%v", i, tt.want, got)
}
}
}
func TestProviderConfigEmpty(t *testing.T) {
cfg := ProviderConfig{}
if !cfg.Empty() {
t.Fatalf("Empty provider config reports non-empty")
}
cfg = ProviderConfig{Issuer: "http://example.com"}
if cfg.Empty() {
t.Fatalf("Non-empty provider config reports empty")
}
}
func TestPCSStepAfter(t *testing.T) {
pass := func() (time.Duration, error) { return 7 * time.Second, nil }
fail := func() (time.Duration, error) { return 0, errors.New("fail") }
tests := []struct {
stepper pcsStepper
stepFunc pcsStepFunc
want pcsStepper
}{
// good step results in retry at TTL
{
stepper: &pcsStepNext{},
stepFunc: pass,
want: &pcsStepNext{aft: 7 * time.Second},
},
// good step after failed step results results in retry at TTL
{
stepper: &pcsStepRetry{aft: 2 * time.Second},
stepFunc: pass,
want: &pcsStepNext{aft: 7 * time.Second},
},
// failed step results in a retry in 1s
{
stepper: &pcsStepNext{},
stepFunc: fail,
want: &pcsStepRetry{aft: time.Second},
},
// failed retry backs off by a factor of 2
{
stepper: &pcsStepRetry{aft: time.Second},
stepFunc: fail,
want: &pcsStepRetry{aft: 2 * time.Second},
},
// failed retry backs off by a factor of 2, up to 1m
{
stepper: &pcsStepRetry{aft: 32 * time.Second},
stepFunc: fail,
want: &pcsStepRetry{aft: 60 * time.Second},
},
}
for i, tt := range tests {
got := tt.stepper.step(tt.stepFunc)
if !reflect.DeepEqual(tt.want, got) {
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, got)
}
}
}
func TestProviderConfigSupportsGrantType(t *testing.T) {
tests := []struct {
types []string
typ string
want bool
}{
// explicitly supported
{
types: []string{"foo_type"},
typ: "foo_type",
want: true,
},
// explicitly unsupported
{
types: []string{"bar_type"},
typ: "foo_type",
want: false,
},
// default type explicitly unsupported
{
types: []string{oauth2.GrantTypeImplicit},
typ: oauth2.GrantTypeAuthCode,
want: false,
},
// type not found in default set
{
types: []string{},
typ: "foo_type",
want: false,
},
// type found in default set
{
types: []string{},
typ: oauth2.GrantTypeAuthCode,
want: true,
},
}
for i, tt := range tests {
cfg := ProviderConfig{
GrantTypesSupported: tt.types,
}
got := cfg.SupportsGrantType(tt.typ)
if tt.want != got {
t.Errorf("case %d: assert %v supports %v: want=%t got=%t", i, tt.types, tt.typ, tt.want, got)
}
}
}
func TestWaitForProviderConfigImmediateSuccess(t *testing.T) {
cfg := ProviderConfig{Issuer: "http://example.com"}
b, err := json.Marshal(cfg)
if err != nil {
t.Fatalf("Failed marshaling provider config")
}
resp := http.Response{Body: ioutil.NopCloser(bytes.NewBuffer(b))}
hc := &phttp.RequestRecorder{Response: &resp}
fc := clockwork.NewFakeClock()
reschan := make(chan ProviderConfig)
go func() {
reschan <- waitForProviderConfig(hc, cfg.Issuer, fc)
}()
var got ProviderConfig
select {
case got = <-reschan:
case <-time.After(time.Second):
t.Fatalf("Did not receive result within 1s")
}
if !reflect.DeepEqual(cfg, got) {
t.Fatalf("Received incorrect provider config: want=%#v got=%#v", cfg, got)
}
}

View file

@ -1,167 +0,0 @@
package oidc
import (
"errors"
"net/http"
"reflect"
"testing"
phttp "github.com/coreos/go-oidc/http"
"github.com/coreos/go-oidc/jose"
)
type staticTokenRefresher struct {
verify func(jose.JWT) error
refresh func() (jose.JWT, error)
}
func (s *staticTokenRefresher) Verify(jwt jose.JWT) error {
return s.verify(jwt)
}
func (s *staticTokenRefresher) Refresh() (jose.JWT, error) {
return s.refresh()
}
func TestAuthenticatedTransportVerifiedJWT(t *testing.T) {
tests := []struct {
refresher TokenRefresher
startJWT jose.JWT
wantJWT jose.JWT
wantError error
}{
// verification succeeds, so refresh is not called
{
refresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return nil },
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
},
startJWT: jose.JWT{RawPayload: "1"},
wantJWT: jose.JWT{RawPayload: "1"},
},
// verification fails, refresh succeeds so cached JWT changes
{
refresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return errors.New("fail!") },
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
},
startJWT: jose.JWT{RawPayload: "1"},
wantJWT: jose.JWT{RawPayload: "2"},
},
// verification succeeds, so failing refresh isn't attempted
{
refresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return nil },
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
},
startJWT: jose.JWT{RawPayload: "1"},
wantJWT: jose.JWT{RawPayload: "1"},
},
// verification fails, but refresh fails, too
{
refresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return errors.New("fail!") },
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
},
startJWT: jose.JWT{RawPayload: "1"},
wantJWT: jose.JWT{},
wantError: errors.New("unable to acquire valid JWT: fail!"),
},
}
for i, tt := range tests {
at := &AuthenticatedTransport{
TokenRefresher: tt.refresher,
jwt: tt.startJWT,
}
gotJWT, err := at.verifiedJWT()
if !reflect.DeepEqual(tt.wantError, err) {
t.Errorf("#%d: unexpected error: want=%#v got=%#v", i, tt.wantError, err)
}
if !reflect.DeepEqual(tt.wantJWT, gotJWT) {
t.Errorf("#%d: incorrect JWT returned from verifiedJWT: want=%#v got=%#v", i, tt.wantJWT, gotJWT)
}
}
}
func TestAuthenticatedTransportJWTCaching(t *testing.T) {
at := &AuthenticatedTransport{
TokenRefresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return errors.New("fail!") },
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
},
jwt: jose.JWT{RawPayload: "1"},
}
wantJWT := jose.JWT{RawPayload: "2"}
gotJWT, err := at.verifiedJWT()
if err != nil {
t.Fatalf("got non-nil error: %#v", err)
}
if !reflect.DeepEqual(wantJWT, gotJWT) {
t.Fatalf("incorrect JWT returned from verifiedJWT: want=%#v got=%#v", wantJWT, gotJWT)
}
at.TokenRefresher = &staticTokenRefresher{
verify: func(jose.JWT) error { return nil },
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "3"}, nil },
}
// the previous JWT should still be cached on the AuthenticatedTransport since
// it is still valid, even though there's a new token ready to refresh
gotJWT, err = at.verifiedJWT()
if err != nil {
t.Fatalf("got non-nil error: %#v", err)
}
if !reflect.DeepEqual(wantJWT, gotJWT) {
t.Fatalf("incorrect JWT returned from verifiedJWT: want=%#v got=%#v", wantJWT, gotJWT)
}
}
func TestAuthenticatedTransportRoundTrip(t *testing.T) {
rr := &phttp.RequestRecorder{Response: &http.Response{StatusCode: http.StatusOK}}
at := &AuthenticatedTransport{
TokenRefresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return nil },
},
RoundTripper: rr,
jwt: jose.JWT{RawPayload: "1"},
}
req := http.Request{}
_, err := at.RoundTrip(&req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if !reflect.DeepEqual(req, http.Request{}) {
t.Errorf("http.Request object was modified")
}
want := []string{"Bearer .1."}
got := rr.Request.Header["Authorization"]
if !reflect.DeepEqual(want, got) {
t.Errorf("incorrect Authorization header: want=%#v got=%#v", want, got)
}
}
func TestAuthenticatedTransportRoundTripRefreshFail(t *testing.T) {
rr := &phttp.RequestRecorder{Response: &http.Response{StatusCode: http.StatusOK}}
at := &AuthenticatedTransport{
TokenRefresher: &staticTokenRefresher{
verify: func(jose.JWT) error { return errors.New("fail!") },
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
},
RoundTripper: rr,
jwt: jose.JWT{RawPayload: "1"},
}
_, err := at.RoundTrip(&http.Request{})
if err == nil {
t.Errorf("expected non-nil error")
}
}

View file

@ -53,7 +53,7 @@ func CookieTokenExtractor(cookieName string) RequestTokenExtractor {
} }
} }
func NewClaims(iss, sub, aud string, iat, exp time.Time) jose.Claims { func NewClaims(iss, sub string, aud interface{}, iat, exp time.Time) jose.Claims {
return jose.Claims{ return jose.Claims{
// required // required
"iss": iss, "iss": iss,

View file

@ -1,95 +0,0 @@
package oidc
import (
"fmt"
"net/http"
"reflect"
"testing"
"time"
"github.com/coreos/go-oidc/jose"
)
func TestCookieTokenExtractorInvalid(t *testing.T) {
ckName := "tokenCookie"
tests := []*http.Cookie{
&http.Cookie{},
&http.Cookie{Name: ckName},
&http.Cookie{Name: ckName, Value: ""},
}
for i, tt := range tests {
r, _ := http.NewRequest("", "", nil)
r.AddCookie(tt)
_, err := CookieTokenExtractor(ckName)(r)
if err == nil {
t.Errorf("case %d: want: error for invalid cookie token, got: no error.", i)
}
}
}
func TestCookieTokenExtractorValid(t *testing.T) {
validToken := "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
ckName := "tokenCookie"
tests := []*http.Cookie{
&http.Cookie{Name: ckName, Value: "some non-empty value"},
&http.Cookie{Name: ckName, Value: validToken},
}
for i, tt := range tests {
r, _ := http.NewRequest("", "", nil)
r.AddCookie(tt)
_, err := CookieTokenExtractor(ckName)(r)
if err != nil {
t.Errorf("case %d: want: valid cookie with no error, got: %v", i, err)
}
}
}
func TestExtractBearerTokenInvalid(t *testing.T) {
tests := []string{"", "x", "Bearer", "xxxxxxx", "Bearer "}
for i, tt := range tests {
r, _ := http.NewRequest("", "", nil)
r.Header.Add("Authorization", tt)
_, err := ExtractBearerToken(r)
if err == nil {
t.Errorf("case %d: want: invalid Authorization header, got: valid Authorization header.", i)
}
}
}
func TestExtractBearerTokenValid(t *testing.T) {
validToken := "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
tests := []string{
fmt.Sprintf("Bearer %s", validToken),
}
for i, tt := range tests {
r, _ := http.NewRequest("", "", nil)
r.Header.Add("Authorization", tt)
_, err := ExtractBearerToken(r)
if err != nil {
t.Errorf("case %d: want: valid Authorization header, got: invalid Authorization header: %v.", i, err)
}
}
}
func TestNewClaims(t *testing.T) {
issAt := time.Date(2, time.January, 1, 0, 0, 0, 0, time.UTC)
expAt := time.Date(2, time.January, 1, 1, 0, 0, 0, time.UTC)
want := jose.Claims{
"iss": "https://example.com",
"sub": "user-123",
"aud": "client-abc",
"iat": float64(issAt.Unix()),
"exp": float64(expAt.Unix()),
}
got := NewClaims("https://example.com", "user-123", "client-abc", issAt, expAt)
if !reflect.DeepEqual(want, got) {
t.Fatalf("want=%#v got=%#v", want, got)
}
}

View file

@ -25,6 +25,17 @@ func VerifySignature(jwt jose.JWT, keys []key.PublicKey) (bool, error) {
return false, nil return false, nil
} }
// containsString returns true if the given string(needle) is found
// in the string array(haystack).
func containsString(needle string, haystack []string) bool {
for _, v := range haystack {
if v == needle {
return true
}
}
return false
}
// Verify claims in accordance with OIDC spec // Verify claims in accordance with OIDC spec
// http://openid.net/specs/openid-connect-basic-1_0.html#IDTokenValidation // http://openid.net/specs/openid-connect-basic-1_0.html#IDTokenValidation
func VerifyClaims(jwt jose.JWT, issuer, clientID string) error { func VerifyClaims(jwt jose.JWT, issuer, clientID string) error {
@ -45,7 +56,8 @@ func VerifyClaims(jwt jose.JWT, issuer, clientID string) error {
} }
// iss REQUIRED. Issuer Identifier for the Issuer of the response. // iss REQUIRED. Issuer Identifier for the Issuer of the response.
// The iss value is a case sensitive URL using the https scheme that contains scheme, host, and optionally, port number and path components and no query or fragment components. // The iss value is a case sensitive URL using the https scheme that contains scheme,
// host, and optionally, port number and path components and no query or fragment components.
if iss, exists := claims["iss"].(string); exists { if iss, exists := claims["iss"].(string); exists {
if !urlEqual(iss, issuer) { if !urlEqual(iss, issuer) {
return fmt.Errorf("invalid claim value: 'iss'. expected=%s, found=%s.", issuer, iss) return fmt.Errorf("invalid claim value: 'iss'. expected=%s, found=%s.", issuer, iss)
@ -55,19 +67,27 @@ func VerifyClaims(jwt jose.JWT, issuer, clientID string) error {
} }
// iat REQUIRED. Time at which the JWT was issued. // iat REQUIRED. Time at which the JWT was issued.
// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z as measured in UTC until the date/time. // Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z
// as measured in UTC until the date/time.
if _, exists := claims["iat"].(float64); !exists { if _, exists := claims["iat"].(float64); !exists {
return errors.New("missing claim: 'iat'") return errors.New("missing claim: 'iat'")
} }
// aud REQUIRED. Audience(s) that this ID Token is intended for. // aud REQUIRED. Audience(s) that this ID Token is intended for.
// It MUST contain the OAuth 2.0 client_id of the Relying Party as an audience value. It MAY also contain identifiers for other audiences. In the general case, the aud value is an array of case sensitive strings. In the common special case when there is one audience, the aud value MAY be a single case sensitive string. // It MUST contain the OAuth 2.0 client_id of the Relying Party as an audience value.
if aud, exists := claims["aud"].(string); exists { // It MAY also contain identifiers for other audiences. In the general case, the aud
// value is an array of case sensitive strings. In the common special case when there
// is one audience, the aud value MAY be a single case sensitive string.
if aud, ok, err := claims.StringClaim("aud"); err == nil && ok {
if aud != clientID { if aud != clientID {
return errors.New("invalid claim value: 'aud'") return fmt.Errorf("invalid claims, 'aud' claim and 'client_id' do not match, aud=%s, client_id=%s", aud, clientID)
}
} else if aud, ok, err := claims.StringsClaim("aud"); err == nil && ok {
if !containsString(clientID, aud) {
return fmt.Errorf("invalid claims, cannot find 'client_id' in 'aud' claim, aud=%v, client_id=%s", aud, clientID)
} }
} else { } else {
return errors.New("missing claim: 'aud'") return errors.New("invalid claim value: 'aud' is required, and should be either string or string array")
} }
return nil return nil
@ -97,15 +117,16 @@ func VerifyClientClaims(jwt jose.JWT, issuer string) (string, error) {
return "", errors.New("missing required 'sub' claim") return "", errors.New("missing required 'sub' claim")
} }
aud, ok, err := claims.StringClaim("aud") if aud, ok, err := claims.StringClaim("aud"); err == nil && ok {
if err != nil { if aud != sub {
return "", fmt.Errorf("failed to parse 'aud' claim: %v", err) return "", fmt.Errorf("invalid claims, 'aud' claim and 'sub' claim do not match, aud=%s, sub=%s", aud, sub)
} else if !ok { }
return "", errors.New("missing required 'aud' claim") } else if aud, ok, err := claims.StringsClaim("aud"); err == nil && ok {
} if !containsString(sub, aud) {
return "", fmt.Errorf("invalid claims, cannot find 'sud' in 'aud' claim, aud=%v, sub=%s", aud, sub)
if sub != aud { }
return "", fmt.Errorf("invalid claims, 'aud' claim and 'sub' claim do not match, aud=%s, sub=%s", aud, sub) } else {
return "", errors.New("invalid claim value: 'aud' is required, and should be either string or string array")
} }
now := time.Now().UTC() now := time.Now().UTC()

View file

@ -1,297 +0,0 @@
package oidc
import (
"testing"
"time"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
)
func TestVerifyClientClaims(t *testing.T) {
validIss := "https://example.com"
validClientID := "valid-client"
now := time.Now()
tomorrow := now.Add(24 * time.Hour)
header := jose.JOSEHeader{
jose.HeaderKeyAlgorithm: "test-alg",
jose.HeaderKeyID: "1",
}
tests := []struct {
claims jose.Claims
ok bool
}{
// valid token
{
claims: jose.Claims{
"iss": validIss,
"sub": validClientID,
"aud": validClientID,
"iat": float64(now.Unix()),
"exp": float64(tomorrow.Unix()),
},
ok: true,
},
// missing 'iss' claim
{
claims: jose.Claims{
"sub": validClientID,
"aud": validClientID,
"iat": float64(now.Unix()),
"exp": float64(tomorrow.Unix()),
},
ok: false,
},
// invalid 'iss' claim
{
claims: jose.Claims{
"iss": "INVALID",
"sub": validClientID,
"aud": validClientID,
"iat": float64(now.Unix()),
"exp": float64(tomorrow.Unix()),
},
ok: false,
},
// missing 'sub' claim
{
claims: jose.Claims{
"iss": validIss,
"aud": validClientID,
"iat": float64(now.Unix()),
"exp": float64(tomorrow.Unix()),
},
ok: false,
},
// invalid 'sub' claim
{
claims: jose.Claims{
"iss": validIss,
"sub": "INVALID",
"aud": validClientID,
"iat": float64(now.Unix()),
"exp": float64(tomorrow.Unix()),
},
ok: false,
},
// missing 'aud' claim
{
claims: jose.Claims{
"iss": validIss,
"sub": validClientID,
"iat": float64(now.Unix()),
"exp": float64(tomorrow.Unix()),
},
ok: false,
},
// invalid 'aud' claim
{
claims: jose.Claims{
"iss": validIss,
"sub": validClientID,
"aud": "INVALID",
"iat": float64(now.Unix()),
"exp": float64(tomorrow.Unix()),
},
ok: false,
},
// expired
{
claims: jose.Claims{
"iss": validIss,
"sub": validClientID,
"aud": validClientID,
"iat": float64(now.Unix()),
"exp": float64(now.Unix()),
},
ok: false,
},
}
for i, tt := range tests {
jwt, err := jose.NewJWT(header, tt.claims)
if err != nil {
t.Fatalf("case %d: Failed to generate JWT, error=%v", i, err)
}
got, err := VerifyClientClaims(jwt, validIss)
if tt.ok {
if err != nil {
t.Errorf("case %d: unexpected error, err=%v", i, err)
}
if got != validClientID {
t.Errorf("case %d: incorrect client ID, want=%s, got=%s", i, validClientID, got)
}
} else if err == nil {
t.Errorf("case %d: expected error but err is nil", i)
}
}
}
func TestJWTVerifier(t *testing.T) {
iss := "http://example.com"
now := time.Now()
future12 := now.Add(12 * time.Hour)
past36 := now.Add(-36 * time.Hour)
past12 := now.Add(-12 * time.Hour)
priv1, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("failed to generate private key, error=%v", err)
}
pk1 := *key.NewPublicKey(priv1.JWK())
priv2, err := key.GeneratePrivateKey()
if err != nil {
t.Fatalf("failed to generate private key, error=%v", err)
}
pk2 := *key.NewPublicKey(priv2.JWK())
jwtPK1, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past12, future12), priv1.Signer())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
jwtPK1BadClaims, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "YYY", past12, future12), priv1.Signer())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
jwtExpired, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past36, past12), priv1.Signer())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
jwtPK2, err := jose.NewSignedJWT(NewClaims(iss, "XXX", "XXX", past12, future12), priv2.Signer())
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tests := []struct {
verifier JWTVerifier
jwt jose.JWT
wantErr bool
}{
// JWT signed with available key
{
verifier: JWTVerifier{
issuer: "example.com",
clientID: "XXX",
syncFunc: func() error { return nil },
keysFunc: func() []key.PublicKey {
return []key.PublicKey{pk1}
},
},
jwt: *jwtPK1,
wantErr: false,
},
// JWT signed with available key, with bad claims
{
verifier: JWTVerifier{
issuer: "example.com",
clientID: "XXX",
syncFunc: func() error { return nil },
keysFunc: func() []key.PublicKey {
return []key.PublicKey{pk1}
},
},
jwt: *jwtPK1BadClaims,
wantErr: true,
},
// expired JWT signed with available key
{
verifier: JWTVerifier{
issuer: "example.com",
clientID: "XXX",
syncFunc: func() error { return nil },
keysFunc: func() []key.PublicKey {
return []key.PublicKey{pk1}
},
},
jwt: *jwtExpired,
wantErr: true,
},
// JWT signed with unrecognized key, verifiable after sync
{
verifier: JWTVerifier{
issuer: "example.com",
clientID: "XXX",
syncFunc: func() error { return nil },
keysFunc: func() func() []key.PublicKey {
var i int
return func() []key.PublicKey {
defer func() { i++ }()
return [][]key.PublicKey{
[]key.PublicKey{pk1},
[]key.PublicKey{pk2},
}[i]
}
}(),
},
jwt: *jwtPK2,
wantErr: false,
},
// JWT signed with unrecognized key, not verifiable after sync
{
verifier: JWTVerifier{
issuer: "example.com",
clientID: "XXX",
syncFunc: func() error { return nil },
keysFunc: func() []key.PublicKey {
return []key.PublicKey{pk1}
},
},
jwt: *jwtPK2,
wantErr: true,
},
// verifier gets no keys from keysFunc, still not verifiable after sync
{
verifier: JWTVerifier{
issuer: "example.com",
clientID: "XXX",
syncFunc: func() error { return nil },
keysFunc: func() []key.PublicKey {
return []key.PublicKey{}
},
},
jwt: *jwtPK1,
wantErr: true,
},
// verifier gets no keys from keysFunc, verifiable after sync
{
verifier: JWTVerifier{
issuer: "example.com",
clientID: "XXX",
syncFunc: func() error { return nil },
keysFunc: func() func() []key.PublicKey {
var i int
return func() []key.PublicKey {
defer func() { i++ }()
return [][]key.PublicKey{
[]key.PublicKey{},
[]key.PublicKey{pk2},
}[i]
}
}(),
},
jwt: *jwtPK2,
wantErr: false,
},
}
for i, tt := range tests {
err := tt.verifier.Verify(tt.jwt)
if tt.wantErr && (err == nil) {
t.Errorf("case %d: wanted non-nil error", i)
} else if !tt.wantErr && (err != nil) {
t.Errorf("case %d: wanted nil error, got %v", i, err)
}
}
}

View file

@ -0,0 +1,161 @@
package connector
import (
"encoding/json"
"fmt"
"html/template"
"net/http"
"net/url"
"path"
chttp "github.com/coreos/go-oidc/http"
"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
)
const (
BitbucketConnectorType = "bitbucket"
bitbucketAuthURL = "https://bitbucket.org/site/oauth2/authorize"
bitbucketTokenURL = "https://bitbucket.org/site/oauth2/access_token"
bitbucketAPIUserURL = "https://bitbucket.org/api/2.0/user"
bitbucketAPIEmailURL = "https://api.bitbucket.org/2.0/user/emails"
)
func init() {
RegisterConnectorConfigType(BitbucketConnectorType, func() ConnectorConfig { return &BitbucketConnectorConfig{} })
}
type BitbucketConnectorConfig struct {
ID string `json:"id"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
}
func (cfg *BitbucketConnectorConfig) ConnectorID() string {
return cfg.ID
}
func (cfg *BitbucketConnectorConfig) ConnectorType() string {
return BitbucketConnectorType
}
func (cfg *BitbucketConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *template.Template) (Connector, error) {
ns.Path = path.Join(ns.Path, httpPathCallback)
oauth2Conn, err := newBitbucketConnector(cfg.ClientID, cfg.ClientSecret, ns.String())
if err != nil {
return nil, err
}
return &OAuth2Connector{
id: cfg.ID,
loginFunc: lf,
cbURL: ns,
conn: oauth2Conn,
}, nil
}
type bitbucketOAuth2Connector struct {
clientID string
clientSecret string
client *oauth2.Client
}
func newBitbucketConnector(clientID, clientSecret, cbURL string) (oauth2Connector, error) {
config := oauth2.Config{
Credentials: oauth2.ClientCredentials{ID: clientID, Secret: clientSecret},
AuthURL: bitbucketAuthURL,
TokenURL: bitbucketTokenURL,
AuthMethod: oauth2.AuthMethodClientSecretPost,
RedirectURL: cbURL,
}
cli, err := oauth2.NewClient(http.DefaultClient, config)
if err != nil {
return nil, err
}
return &bitbucketOAuth2Connector{
clientID: clientID,
clientSecret: clientSecret,
client: cli,
}, nil
}
func (c *bitbucketOAuth2Connector) Client() *oauth2.Client {
return c.client
}
func (c *bitbucketOAuth2Connector) Identity(cli chttp.Client) (oidc.Identity, error) {
var user struct {
UUID string `json:"uuid"`
Username string `json:"username"`
DisplayName string `json:"display_name"`
}
if err := getAndDecode(cli, bitbucketAPIUserURL, &user); err != nil {
return oidc.Identity{}, fmt.Errorf("getting user info: %v", err)
}
name := user.DisplayName
if name == "" {
name = user.Username
}
var emails struct {
Values []struct {
Email string `json:"email"`
Confirmed bool `json:"is_confirmed"`
Primary bool `json:"is_primary"`
} `json:"values"`
}
if err := getAndDecode(cli, bitbucketAPIEmailURL, &emails); err != nil {
return oidc.Identity{}, fmt.Errorf("getting user email: %v", err)
}
email := ""
for _, val := range emails.Values {
if !val.Confirmed {
continue
}
if email == "" || val.Primary {
email = val.Email
}
if val.Primary {
break
}
}
return oidc.Identity{
ID: user.UUID,
Name: name,
Email: email,
}, nil
}
func getAndDecode(cli chttp.Client, url string, v interface{}) error {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err
}
resp, err := cli.Do(req)
if err != nil {
return fmt.Errorf("get: %v", err)
}
defer resp.Body.Close()
switch {
case resp.StatusCode >= 400 && resp.StatusCode < 500:
return oauth2.NewError(oauth2.ErrorAccessDenied)
case resp.StatusCode == http.StatusOK:
default:
return fmt.Errorf("unexpected status from providor %s", resp.Status)
}
if err := json.NewDecoder(resp.Body).Decode(v); err != nil {
return fmt.Errorf("decode body: %v", err)
}
return nil
}
func (c *bitbucketOAuth2Connector) Healthy() error {
return nil
}
func (c *bitbucketOAuth2Connector) TrustedEmailProvider() bool {
return false
}

View file

@ -0,0 +1,59 @@
package connector
import (
"net/http"
"testing"
"github.com/coreos/go-oidc/oidc"
)
var bitbucketExampleUser1 = `{
"display_name": "tutorials account",
"username": "tutorials",
"uuid": "{c788b2da-b7a2-404c-9e26-d3f077557007}"
}`
var bitbucketExampleUser2 = `{
"username": "tutorials",
"uuid": "{c788b2da-b7a2-404c-9e26-d3f077557007}"
}`
var bitbucketExampleEmail = `{
"values": [
{"email": "tutorials1@bitbucket.org","is_confirmed": false,"is_primary": false},
{"email": "tutorials2@bitbucket.org","is_confirmed": true,"is_primary": false},
{"email": "tutorials3@bitbucket.org","is_confirmed": true,"is_primary": true}
]
}`
func TestBitBucketIdentity(t *testing.T) {
tests := []oauth2IdentityTest{
{
urlResps: map[string]response{
bitbucketAPIUserURL: {http.StatusOK, bitbucketExampleUser1},
bitbucketAPIEmailURL: {http.StatusOK, bitbucketExampleEmail},
},
want: oidc.Identity{
Name: "tutorials account",
ID: "{c788b2da-b7a2-404c-9e26-d3f077557007}",
Email: "tutorials3@bitbucket.org",
},
},
{
urlResps: map[string]response{
bitbucketAPIUserURL: {http.StatusOK, bitbucketExampleUser2},
bitbucketAPIEmailURL: {http.StatusOK, bitbucketExampleEmail},
},
want: oidc.Identity{
Name: "tutorials",
ID: "{c788b2da-b7a2-404c-9e26-d3f077557007}",
Email: "tutorials3@bitbucket.org",
},
},
}
conn, err := newBitbucketConnector("fakeclientid", "fakeclientsecret", "http://example.com/auth/bitbucket/callback")
if err != nil {
t.Fatal(err)
}
runOAuth2IdentityTests(t, conn, tests)
}

View file

@ -0,0 +1,145 @@
package connector
import (
"encoding/json"
"fmt"
"html/template"
"net/http"
"net/url"
"path"
"strconv"
chttp "github.com/coreos/go-oidc/http"
"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
)
const (
GitHubConnectorType = "github"
githubAuthURL = "https://github.com/login/oauth/authorize"
githubTokenURL = "https://github.com/login/oauth/access_token"
githubAPIUserURL = "https://api.github.com/user"
)
func init() {
RegisterConnectorConfigType(GitHubConnectorType, func() ConnectorConfig { return &GitHubConnectorConfig{} })
}
type GitHubConnectorConfig struct {
ID string `json:"id"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
}
func (cfg *GitHubConnectorConfig) ConnectorID() string {
return cfg.ID
}
func (cfg *GitHubConnectorConfig) ConnectorType() string {
return GitHubConnectorType
}
func (cfg *GitHubConnectorConfig) Connector(ns url.URL, lf oidc.LoginFunc, tpls *template.Template) (Connector, error) {
ns.Path = path.Join(ns.Path, httpPathCallback)
oauth2Conn, err := newGitHubConnector(cfg.ClientID, cfg.ClientSecret, ns.String())
if err != nil {
return nil, err
}
return &OAuth2Connector{
id: cfg.ID,
loginFunc: lf,
cbURL: ns,
conn: oauth2Conn,
}, nil
}
type githubOAuth2Connector struct {
clientID string
clientSecret string
client *oauth2.Client
}
func newGitHubConnector(clientID, clientSecret, cbURL string) (oauth2Connector, error) {
config := oauth2.Config{
Credentials: oauth2.ClientCredentials{ID: clientID, Secret: clientSecret},
AuthURL: githubAuthURL,
TokenURL: githubTokenURL,
Scope: []string{"user:email"},
AuthMethod: oauth2.AuthMethodClientSecretPost,
RedirectURL: cbURL,
}
cli, err := oauth2.NewClient(http.DefaultClient, config)
if err != nil {
return nil, err
}
return &githubOAuth2Connector{
clientID: clientID,
clientSecret: clientSecret,
client: cli,
}, nil
}
// standard error form returned by github
type githubError struct {
Message string `json:"message"`
}
func (err githubError) Error() string {
return fmt.Sprintf("github: %s", err.Message)
}
func (c *githubOAuth2Connector) Client() *oauth2.Client {
return c.client
}
func (c *githubOAuth2Connector) Identity(cli chttp.Client) (oidc.Identity, error) {
req, err := http.NewRequest("GET", githubAPIUserURL, nil)
if err != nil {
return oidc.Identity{}, err
}
resp, err := cli.Do(req)
if err != nil {
return oidc.Identity{}, fmt.Errorf("get: %v", err)
}
defer resp.Body.Close()
switch {
case resp.StatusCode >= 400 && resp.StatusCode < 600:
// attempt to decode error from github
var authErr githubError
if err := json.NewDecoder(resp.Body).Decode(&authErr); err != nil {
return oidc.Identity{}, oauth2.NewError(oauth2.ErrorAccessDenied)
}
return oidc.Identity{}, authErr
case resp.StatusCode == http.StatusOK:
default:
return oidc.Identity{}, fmt.Errorf("unexpected status from providor %s", resp.Status)
}
var user struct {
Login string `json:"login"`
ID int64 `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
}
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
return oidc.Identity{}, fmt.Errorf("getting user info: %v", err)
}
name := user.Name
if name == "" {
name = user.Login
}
return oidc.Identity{
ID: strconv.FormatInt(user.ID, 10),
Name: name,
Email: user.Email,
}, nil
}
func (c *githubOAuth2Connector) Healthy() error {
return nil
}
func (c *githubOAuth2Connector) TrustedEmailProvider() bool {
return false
}

View file

@ -0,0 +1,41 @@
package connector
import (
"net/http"
"testing"
"github.com/coreos/go-oidc/oidc"
)
var (
githubExampleUser = `{"login":"octocat","id":1,"name": "monalisa octocat","email": "octocat@github.com"}`
githubExampleError = `{"message":"Bad credentials","documentation_url":"https://developer.github.com/v3"}`
)
func TestGitHubIdentity(t *testing.T) {
tests := []oauth2IdentityTest{
{
urlResps: map[string]response{
githubAPIUserURL: {http.StatusOK, githubExampleUser},
},
want: oidc.Identity{
Name: "monalisa octocat",
ID: "1",
Email: "octocat@github.com",
},
},
{
urlResps: map[string]response{
githubAPIUserURL: {http.StatusUnauthorized, githubExampleError},
},
wantErr: githubError{
Message: "Bad credentials",
},
},
}
conn, err := newGitHubConnector("fakeclientid", "fakeclientsecret", "http://examle.com/auth/github/callback")
if err != nil {
t.Fatal(err)
}
runOAuth2IdentityTests(t, conn, tests)
}

View file

@ -0,0 +1,140 @@
package connector
import (
"net/http"
"net/url"
"strings"
"github.com/coreos/dex/pkg/log"
chttp "github.com/coreos/go-oidc/http"
"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
)
type oauth2Connector interface {
Client() *oauth2.Client
// Identity uses a HTTP client authenticated as the end user to construct
// an OIDC identity for that user.
Identity(cli chttp.Client) (oidc.Identity, error)
// Healthy it should attempt to determine if the connector's credientials
// are valid.
Healthy() error
TrustedEmailProvider() bool
}
type OAuth2Connector struct {
id string
loginFunc oidc.LoginFunc
cbURL url.URL
conn oauth2Connector
}
func (c *OAuth2Connector) ID() string {
return c.id
}
func (c *OAuth2Connector) Healthy() error {
return c.conn.Healthy()
}
func (c *OAuth2Connector) Sync() chan struct{} {
stop := make(chan struct{}, 1)
return stop
}
func (c *OAuth2Connector) TrustedEmailProvider() bool {
return c.conn.TrustedEmailProvider()
}
func (c *OAuth2Connector) LoginURL(sessionKey, prompt string) (string, error) {
return c.conn.Client().AuthCodeURL(sessionKey, oauth2.GrantTypeAuthCode, prompt), nil
}
func (c *OAuth2Connector) Register(mux *http.ServeMux, errorURL url.URL) {
mux.Handle(c.cbURL.Path, c.handleCallbackFunc(c.loginFunc, errorURL))
}
func (c *OAuth2Connector) handleCallbackFunc(lf oidc.LoginFunc, errorURL url.URL) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
e := q.Get("error")
if e != "" {
redirectError(w, errorURL, q)
return
}
code := q.Get("code")
if code == "" {
q.Set("error", oauth2.ErrorInvalidRequest)
q.Set("error_description", "code query param must be set")
redirectError(w, errorURL, q)
return
}
sessionKey := q.Get("state")
token, err := c.conn.Client().RequestToken(oauth2.GrantTypeAuthCode, code)
if err != nil {
log.Errorf("Unable to verify auth code with issuer: %v", err)
q.Set("error", oauth2.ErrorUnsupportedResponseType)
q.Set("error_description", "unable to verify auth code with issuer")
redirectError(w, errorURL, q)
return
}
ident, err := c.conn.Identity(newAuthenticatedClient(token, http.DefaultClient))
if err != nil {
log.Errorf("Unable to retrieve identity: %v", err)
q.Set("error", oauth2.ErrorUnsupportedResponseType)
q.Set("error_description", "unable to retrieve identity from issuer")
redirectError(w, errorURL, q)
return
}
redirectURL, err := lf(ident, sessionKey)
if err != nil {
log.Errorf("Unable to log in %#v: %v", ident, err)
q.Set("error", oauth2.ErrorAccessDenied)
q.Set("error_description", "login failed")
redirectError(w, errorURL, q)
return
}
w.Header().Set("Location", redirectURL)
w.WriteHeader(http.StatusTemporaryRedirect)
return
}
}
// authedClient authenticates all requests as the end user.
type authedClient struct {
token oauth2.TokenResponse
cli chttp.Client
}
func newAuthenticatedClient(token oauth2.TokenResponse, cli chttp.Client) chttp.Client {
return &authedClient{token, cli}
}
func (c *authedClient) Do(req *http.Request) (*http.Response, error) {
req.Header.Set("Authorization", tokenType(c.token)+" "+c.token.AccessToken)
return c.cli.Do(req)
}
// Return the canonical name of the token type if non-empty, else "Bearer".
// Take from golang.org/x/oauth2
func tokenType(token oauth2.TokenResponse) string {
if strings.EqualFold(token.TokenType, "bearer") {
return "Bearer"
}
if strings.EqualFold(token.TokenType, "mac") {
return "MAC"
}
if strings.EqualFold(token.TokenType, "basic") {
return "Basic"
}
if token.TokenType != "" {
return token.TokenType
}
return "Bearer"
}

View file

@ -0,0 +1,63 @@
package connector
import (
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/coreos/go-oidc/oidc"
"github.com/kylelemons/godebug/pretty"
)
type response struct {
statusCode int
body string
}
type oauth2IdentityTest struct {
urlResps map[string]response
want oidc.Identity
wantErr error
}
type fakeClient func(*http.Request) (*http.Response, error)
// implement github.com/coreos/go-oidc/oauth2.Client
func (f fakeClient) Do(r *http.Request) (*http.Response, error) {
return f(r)
}
func runOAuth2IdentityTests(t *testing.T, conn oauth2Connector, tests []oauth2IdentityTest) {
for i, tt := range tests {
f := func(req *http.Request) (*http.Response, error) {
resp, ok := tt.urlResps[req.URL.String()]
if !ok {
return nil, fmt.Errorf("unexpected request URL: %s", req.URL.String())
}
return &http.Response{
StatusCode: resp.statusCode,
Body: ioutil.NopCloser(strings.NewReader(resp.body)),
}, nil
}
got, err := conn.Identity(fakeClient(f))
if tt.wantErr == nil {
if err != nil {
t.Errorf("case %d: failed to get identity=%v", i, err)
continue
}
if diff := pretty.Compare(tt.want, got); diff != "" {
t.Errorf("case %d: Compare(want, got) = %v", i, diff)
}
} else {
if err == nil {
t.Errorf("case %d: want error=%v, got=<nil>", i, tt.wantErr)
continue
}
if diff := pretty.Compare(tt.wantErr, err); diff != "" {
t.Errorf("case %d: Compare(wantErr, gotErr) = %v", i, diff)
}
}
}
}

View file

@ -123,8 +123,10 @@ type templateData struct {
// TODO(sym3tri): store this with the connector config // TODO(sym3tri): store this with the connector config
var connectorDisplayNameMap = map[string]string{ var connectorDisplayNameMap = map[string]string{
"google": "Google", "google": "Google",
"local": "Email", "local": "Email",
"github": "GitHub",
"bitbucket": "Bitbucket",
} }
func execTemplate(w http.ResponseWriter, tpl *template.Template, data interface{}) { func execTemplate(w http.ResponseWriter, tpl *template.Template, data interface{}) {

View file

@ -19,5 +19,17 @@
"issuerURL": "https://accounts.google.com", "issuerURL": "https://accounts.google.com",
"clientID": "${CLIENT_ID}", "clientID": "${CLIENT_ID}",
"clientSecret": "${CLIENT_SECRET}" "clientSecret": "${CLIENT_SECRET}"
},
{
"type": "github",
"id": "github",
"clientID": "${CLIENT_ID}",
"clientSecret": "${CLIENT_SECRET}"
},
{
"type": "bitbucket",
"id": "bitbucket",
"clientID": "${CLIENT_ID}",
"clientSecret": "${CLIENT_SECRET}"
} }
] ]

View file

@ -144,6 +144,14 @@
/* B&W CoreOS SVG logo */ /* B&W CoreOS SVG logo */
background-image: url(); background-image: url();
} }
.btn-icon-github {
background-color: #F5F5F5;
background-image: url();
}
.btn-icon-bitbucket {
background-color: #205081;
background-image: url();
}
.btn-text { .btn-text {
line-height: 36px; line-height: 36px;
padding: 6px 12px; padding: 6px 12px;