Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CustomTokenSource for custom token validation #396

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 55 additions & 45 deletions oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,7 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
if t != nil {
tkr.refreshToken = t.RefreshToken
}
return &reuseTokenSource{
t: t,
new: tkr,
}
return ReuseTokenSource(t, tkr)
}

// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
Expand Down Expand Up @@ -281,24 +278,64 @@ func (tf *tokenRefresher) Token() (*Token, error) {
return tk, err
}

// reuseTokenSource is a TokenSource that holds a single token in memory
// and validates its expiry before each call to retrieve it with
// Token. If it's expired, it will be auto-refreshed using the
// new TokenSource.
type reuseTokenSource struct {
new TokenSource // called when t is expired.
// StaticTokenSource returns a TokenSource that always returns the same token.
// Because the provided token t is never refreshed, StaticTokenSource is only
// useful for tokens that never expire.
func StaticTokenSource(t *Token) TokenSource {
return staticTokenSource{t}
}

// staticTokenSource is a TokenSource that always returns the same Token.
type staticTokenSource struct {
t *Token
}

func (s staticTokenSource) Token() (*Token, error) {
return s.t, nil
}

// ValidFunc should return false when the passed token is invalid and true when
// the token is valid. ValidFunc should NOT modify the token passed to it.
type ValidFunc func(t *Token) bool

// CustomTokenSource returns a TokenSource which repeatedly returns the
// same token as long as ValidFunc returns true, starting with t.
// When ValidFunc returns false (the cached token is invalid), a new token
// is obtained from src.
func CustomTokenSource(t *Token, src TokenSource, validFunc ValidFunc) TokenSource {
// Don't wrap a customTokenSource in itself. That would work,
// but cause an unnecessary number of mutex operations.
// Just build the equivalent one.
if rt, ok := src.(*customTokenSource); ok {
if t == nil {
// Just use it directly.
return rt
}
src = rt.new
}
return &customTokenSource{
t: t,
new: src,
validFunc: validFunc,
}
}

type customTokenSource struct {
validFunc ValidFunc // used for determining whether the token should be refreshed
new TokenSource // called when validFunc returns invalid

mu sync.Mutex // guards t
t *Token
}

// Token returns the current token if it's still valid, else will
// refresh the current token (using r.Context for HTTP client
// information) and return the new one.
func (s *reuseTokenSource) Token() (*Token, error) {
// Token returns a TokenSource that will return the current token so long as
// ValidFunc returns that the token is valid, otherwise it will refresh the
// current token (using r.Context for HTTP client information) and return the
// new one.
func (s *customTokenSource) Token() (*Token, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.t.Valid() {
if s.validFunc(s.t) {
return s.t, nil
}
t, err := s.new.Token()
Expand All @@ -309,22 +346,6 @@ func (s *reuseTokenSource) Token() (*Token, error) {
return t, nil
}

// StaticTokenSource returns a TokenSource that always returns the same token.
// Because the provided token t is never refreshed, StaticTokenSource is only
// useful for tokens that never expire.
func StaticTokenSource(t *Token) TokenSource {
return staticTokenSource{t}
}

// staticTokenSource is a TokenSource that always returns the same Token.
type staticTokenSource struct {
t *Token
}

func (s staticTokenSource) Token() (*Token, error) {
return s.t, nil
}

// HTTPClient is the context key to use with golang.org/x/net/context's
// WithValue function to associate an *http.Client value with a context.
var HTTPClient internal.ContextKey
Expand Down Expand Up @@ -364,18 +385,7 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client {
// means it's always safe to wrap ReuseTokenSource around any other
// TokenSource without adverse effects.
func ReuseTokenSource(t *Token, src TokenSource) TokenSource {
// Don't wrap a reuseTokenSource in itself. That would work,
// but cause an unnecessary number of mutex operations.
// Just build the equivalent one.
if rt, ok := src.(*reuseTokenSource); ok {
if t == nil {
// Just use it directly.
return rt
}
src = rt.new
}
return &reuseTokenSource{
t: t,
new: src,
}
return CustomTokenSource(t, src, func(t *Token) bool {
return t.Valid()
})
}
56 changes: 56 additions & 0 deletions oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,59 @@ func TestConfigClientWithToken(t *testing.T) {
t.Error(err)
}
}

type mockTokenSource struct {
nextToken *Token
}

func (s mockTokenSource) Token() (*Token, error) {
return s.nextToken, nil
}

func TestCustomTokenSource(t *testing.T) {
foobarToken := &Token{AccessToken: "foobar"}
barbazToken := &Token{AccessToken: "barbaz"}

testCases := []struct {
name string
t *Token
src TokenSource
validToken bool
expectedToken *Token
}{
{
name: "invalid token",
t: foobarToken,
src: mockTokenSource{nextToken: barbazToken},
validToken: false,
expectedToken: barbazToken,
},
{
name: "valid token",
t: foobarToken,
src: mockTokenSource{nextToken: barbazToken},
validToken: true,
expectedToken: foobarToken,
},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
validFunc := func(t *Token) bool { return tt.validToken }
ts := CustomTokenSource(tt.t, tt.src, validFunc)

// the same expected token should always be returned no matter how many iterations
// we go through since the validfunc returns a constant value
for i := 0; i < 3; i++ {
newToken, err := ts.Token()
if err != nil {
t.Errorf("did not expect an error but got: %v", err)
}

if tt.expectedToken != newToken {
t.Errorf("expected token %v, but got %v", tt.expectedToken, newToken)
}
}
})
}
}