Skip to content

Commit 2e15e54

Browse files
authored
Merge pull request coreos#162 from gotwarlost/athash
add id token support to verify access token hashes, fixes coreos#126
2 parents 77e7f20 + d836fe7 commit 2e15e54

File tree

4 files changed

+182
-9
lines changed

4 files changed

+182
-9
lines changed

oidc.go

+46
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@ package oidc
33

44
import (
55
"context"
6+
"crypto/sha256"
7+
"crypto/sha512"
8+
"encoding/base64"
69
"encoding/json"
710
"errors"
811
"fmt"
12+
"hash"
913
"io/ioutil"
1014
"mime"
1115
"net/http"
@@ -31,6 +35,11 @@ const (
3135
ScopeOfflineAccess = "offline_access"
3236
)
3337

38+
var (
39+
errNoAtHash = errors.New("id token did not have an access token hash")
40+
errInvalidAtHash = errors.New("access token hash does not match value in ID token")
41+
)
42+
3443
// ClientContext returns a new Context that carries the provided HTTP client.
3544
//
3645
// This method sets the same context key used by the golang.org/x/oauth2 package,
@@ -242,6 +251,14 @@ type IDToken struct {
242251
// and it's the user's responsibility to ensure it contains a valid value.
243252
Nonce string
244253

254+
// at_hash claim, if set in the ID token. Callers can verify an access token
255+
// that corresponds to the ID token using the VerifyAccessToken method.
256+
AccessTokenHash string
257+
258+
// signature algorithm used for ID token, needed to compute a verification hash of an
259+
// access token
260+
sigAlgorithm string
261+
245262
// Raw payload of the id_token.
246263
claims []byte
247264
}
@@ -267,13 +284,42 @@ func (i *IDToken) Claims(v interface{}) error {
267284
return json.Unmarshal(i.claims, v)
268285
}
269286

287+
// VerifyAccessToken verifies that the hash of the access token that corresponds to the iD token
288+
// matches the hash in the id token. It returns an error if the hashes don't match.
289+
// It is the caller's responsibility to ensure that the optional access token hash is present for the ID token
290+
// before calling this method. See https://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
291+
func (i *IDToken) VerifyAccessToken(accessToken string) error {
292+
if i.AccessTokenHash == "" {
293+
return errNoAtHash
294+
}
295+
var h hash.Hash
296+
switch i.sigAlgorithm {
297+
case RS256, ES256, PS256:
298+
h = sha256.New()
299+
case RS384, ES384, PS384:
300+
h = sha512.New384()
301+
case RS512, ES512, PS512:
302+
h = sha512.New()
303+
default:
304+
return fmt.Errorf("oidc: unsupported signing algorithm %q", i.sigAlgorithm)
305+
}
306+
h.Write([]byte(accessToken)) // hash documents that Write will never return an error
307+
sum := h.Sum(nil)[:h.Size()/2]
308+
actual := base64.RawURLEncoding.EncodeToString(sum)
309+
if actual != i.AccessTokenHash {
310+
return errInvalidAtHash
311+
}
312+
return nil
313+
}
314+
270315
type idToken struct {
271316
Issuer string `json:"iss"`
272317
Subject string `json:"sub"`
273318
Audience audience `json:"aud"`
274319
Expiry jsonTime `json:"exp"`
275320
IssuedAt jsonTime `json:"iat"`
276321
Nonce string `json:"nonce"`
322+
AtHash string `json:"at_hash"`
277323
}
278324

279325
type audience []string

oidc_test.go

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package oidc
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
)
7+
8+
const (
9+
// at_hash value and access_token returned by Google.
10+
googleAccessTokenHash = "piwt8oCH-K2D9pXlaS1Y-w"
11+
googleAccessToken = "ya29.CjHSA1l5WUn8xZ6HanHFzzdHdbXm-14rxnC7JHch9eFIsZkQEGoWzaYG4o7k5f6BnPLj"
12+
googleSigningAlg = RS256
13+
// following values computed by own algo for regression testing
14+
computed384TokenHash = "_ILKVQjbEzFKNJjUKC2kz9eReYi0A9Of"
15+
computed512TokenHash = "Spa_APgwBrarSeQbxI-rbragXho6dqFpH5x9PqaPfUI"
16+
)
17+
18+
type accessTokenTest struct {
19+
name string
20+
tok *IDToken
21+
accessToken string
22+
verifier func(err error) error
23+
}
24+
25+
func (a accessTokenTest) run(t *testing.T) {
26+
err := a.tok.VerifyAccessToken(a.accessToken)
27+
result := a.verifier(err)
28+
if result != nil {
29+
t.Error(result)
30+
}
31+
}
32+
33+
func TestAccessTokenVerification(t *testing.T) {
34+
newToken := func(alg, atHash string) *IDToken {
35+
return &IDToken{sigAlgorithm: alg, AccessTokenHash: atHash}
36+
}
37+
assertNil := func(err error) error {
38+
if err != nil {
39+
return fmt.Errorf("want nil error, got %v", err)
40+
}
41+
return nil
42+
}
43+
assertMsg := func(msg string) func(err error) error {
44+
return func(err error) error {
45+
if err == nil {
46+
return fmt.Errorf("expected error, got success")
47+
}
48+
if err.Error() != msg {
49+
return fmt.Errorf("bad error message, %q, (want %q)", err.Error(), msg)
50+
}
51+
return nil
52+
}
53+
}
54+
tests := []accessTokenTest{
55+
{
56+
"goodRS256",
57+
newToken(googleSigningAlg, googleAccessTokenHash),
58+
googleAccessToken,
59+
assertNil,
60+
},
61+
{
62+
"goodES384",
63+
newToken("ES384", computed384TokenHash),
64+
googleAccessToken,
65+
assertNil,
66+
},
67+
{
68+
"goodPS512",
69+
newToken("PS512", computed512TokenHash),
70+
googleAccessToken,
71+
assertNil,
72+
},
73+
{
74+
"badRS256",
75+
newToken("RS256", computed512TokenHash),
76+
googleAccessToken,
77+
assertMsg("access token hash does not match value in ID token"),
78+
},
79+
{
80+
"nohash",
81+
newToken("RS256", ""),
82+
googleAccessToken,
83+
assertMsg("id token did not have an access token hash"),
84+
},
85+
{
86+
"badSignAlgo",
87+
newToken("none", "xxx"),
88+
googleAccessToken,
89+
assertMsg(`oidc: unsupported signing algorithm "none"`),
90+
},
91+
}
92+
for _, test := range tests {
93+
t.Run(test.name, test.run)
94+
}
95+
}

verify.go

+9-7
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,14 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok
170170
}
171171

172172
t := &IDToken{
173-
Issuer: token.Issuer,
174-
Subject: token.Subject,
175-
Audience: []string(token.Audience),
176-
Expiry: time.Time(token.Expiry),
177-
IssuedAt: time.Time(token.IssuedAt),
178-
Nonce: token.Nonce,
179-
claims: payload,
173+
Issuer: token.Issuer,
174+
Subject: token.Subject,
175+
Audience: []string(token.Audience),
176+
Expiry: time.Time(token.Expiry),
177+
IssuedAt: time.Time(token.IssuedAt),
178+
Nonce: token.Nonce,
179+
AccessTokenHash: token.AtHash,
180+
claims: payload,
180181
}
181182

182183
// Check issuer.
@@ -228,6 +229,7 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok
228229
if len(v.config.SupportedSigningAlgs) != 0 && !contains(v.config.SupportedSigningAlgs, sig.Header.Algorithm) {
229230
return nil, fmt.Errorf("oidc: id token signed with unsupported algorithm, expected %q got %q", v.config.SupportedSigningAlgs, sig.Header.Algorithm)
230231
}
232+
t.sigAlgorithm = sig.Header.Algorithm
231233

232234
gotPayload, err := v.keySet.VerifySignature(ctx, rawIDToken)
233235
if err != nil {

verify_test.go

+32-2
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,30 @@ func TestVerifySigningAlg(t *testing.T) {
192192
}
193193
}
194194

195+
func TestAccessTokenHash(t *testing.T) {
196+
atHash := "piwt8oCH-K2D9pXlaS1Y-w"
197+
vt := verificationTest{
198+
name: "preserves token hash and sig algo",
199+
idToken: `{"iss":"https://foo","aud":"client1", "at_hash": "` + atHash + `"}`,
200+
config: Config{
201+
ClientID: "client1",
202+
SkipExpiryCheck: true,
203+
},
204+
signKey: newRSAKey(t),
205+
}
206+
t.Run("at_hash", func(t *testing.T) {
207+
tok := vt.runGetToken(t)
208+
if tok != nil {
209+
if tok.AccessTokenHash != atHash {
210+
t.Errorf("access token hash not preserved correctly, want %q got %q", atHash, tok.AccessTokenHash)
211+
}
212+
if tok.sigAlgorithm != RS256 {
213+
t.Errorf("invalid signature algo, want %q got %q", RS256, tok.sigAlgorithm)
214+
}
215+
}
216+
})
217+
}
218+
195219
type verificationTest struct {
196220
// Name of the subtest.
197221
name string
@@ -212,7 +236,7 @@ type verificationTest struct {
212236
wantErr bool
213237
}
214238

215-
func (v verificationTest) run(t *testing.T) {
239+
func (v verificationTest) runGetToken(t *testing.T) *IDToken {
216240
token := v.signKey.sign(t, []byte(v.idToken))
217241

218242
ctx, cancel := context.WithCancel(context.Background())
@@ -230,7 +254,8 @@ func (v verificationTest) run(t *testing.T) {
230254
}
231255
verifier := newVerifier(ks, &v.config, issuer)
232256

233-
if _, err := verifier.Verify(ctx, token); err != nil {
257+
idToken, err := verifier.Verify(ctx, token)
258+
if err != nil {
234259
if !v.wantErr {
235260
t.Errorf("%s: verify %v", v.name, err)
236261
}
@@ -239,4 +264,9 @@ func (v verificationTest) run(t *testing.T) {
239264
t.Errorf("%s: expected error", v.name)
240265
}
241266
}
267+
return idToken
268+
}
269+
270+
func (v verificationTest) run(t *testing.T) {
271+
v.runGetToken(t)
242272
}

0 commit comments

Comments
 (0)