diff --git a/oauth2.go b/oauth2.go index 291df5c83..bf56b9cde 100644 --- a/oauth2.go +++ b/oauth2.go @@ -244,10 +244,7 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { if t != nil { tkr.refreshToken = t.RefreshToken } - return &reuseTokenSource{ - t: t, - new: tkr, - } + return ReuseTokenSource(t, tkr) } // tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token" @@ -281,24 +278,64 @@ func (tf *tokenRefresher) Token() (*Token, error) { return tk, err } -// reuseTokenSource is a TokenSource that holds a single token in memory -// and validates its expiry before each call to retrieve it with -// Token. If it's expired, it will be auto-refreshed using the -// new TokenSource. -type reuseTokenSource struct { - new TokenSource // called when t is expired. +// StaticTokenSource returns a TokenSource that always returns the same token. +// Because the provided token t is never refreshed, StaticTokenSource is only +// useful for tokens that never expire. +func StaticTokenSource(t *Token) TokenSource { + return staticTokenSource{t} +} + +// staticTokenSource is a TokenSource that always returns the same Token. +type staticTokenSource struct { + t *Token +} + +func (s staticTokenSource) Token() (*Token, error) { + return s.t, nil +} + +// ValidFunc should return false when the passed token is invalid and true when +// the token is valid. ValidFunc should NOT modify the token passed to it. +type ValidFunc func(t *Token) bool + +// CustomTokenSource returns a TokenSource which repeatedly returns the +// same token as long as ValidFunc returns true, starting with t. +// When ValidFunc returns false (the cached token is invalid), a new token +// is obtained from src. +func CustomTokenSource(t *Token, src TokenSource, validFunc ValidFunc) TokenSource { + // Don't wrap a customTokenSource in itself. That would work, + // but cause an unnecessary number of mutex operations. + // Just build the equivalent one. + if rt, ok := src.(*customTokenSource); ok { + if t == nil { + // Just use it directly. + return rt + } + src = rt.new + } + return &customTokenSource{ + t: t, + new: src, + validFunc: validFunc, + } +} + +type customTokenSource struct { + validFunc ValidFunc // used for determining whether the token should be refreshed + new TokenSource // called when validFunc returns invalid mu sync.Mutex // guards t t *Token } -// Token returns the current token if it's still valid, else will -// refresh the current token (using r.Context for HTTP client -// information) and return the new one. -func (s *reuseTokenSource) Token() (*Token, error) { +// Token returns a TokenSource that will return the current token so long as +// ValidFunc returns that the token is valid, otherwise it will refresh the +// current token (using r.Context for HTTP client information) and return the +// new one. +func (s *customTokenSource) Token() (*Token, error) { s.mu.Lock() defer s.mu.Unlock() - if s.t.Valid() { + if s.validFunc(s.t) { return s.t, nil } t, err := s.new.Token() @@ -309,22 +346,6 @@ func (s *reuseTokenSource) Token() (*Token, error) { return t, nil } -// StaticTokenSource returns a TokenSource that always returns the same token. -// Because the provided token t is never refreshed, StaticTokenSource is only -// useful for tokens that never expire. -func StaticTokenSource(t *Token) TokenSource { - return staticTokenSource{t} -} - -// staticTokenSource is a TokenSource that always returns the same Token. -type staticTokenSource struct { - t *Token -} - -func (s staticTokenSource) Token() (*Token, error) { - return s.t, nil -} - // HTTPClient is the context key to use with golang.org/x/net/context's // WithValue function to associate an *http.Client value with a context. var HTTPClient internal.ContextKey @@ -364,18 +385,7 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client { // means it's always safe to wrap ReuseTokenSource around any other // TokenSource without adverse effects. func ReuseTokenSource(t *Token, src TokenSource) TokenSource { - // Don't wrap a reuseTokenSource in itself. That would work, - // but cause an unnecessary number of mutex operations. - // Just build the equivalent one. - if rt, ok := src.(*reuseTokenSource); ok { - if t == nil { - // Just use it directly. - return rt - } - src = rt.new - } - return &reuseTokenSource{ - t: t, - new: src, - } + return CustomTokenSource(t, src, func(t *Token) bool { + return t.Valid() + }) } diff --git a/oauth2_test.go b/oauth2_test.go index b7975e166..9603aadc3 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -565,3 +565,59 @@ func TestConfigClientWithToken(t *testing.T) { t.Error(err) } } + +type mockTokenSource struct { + nextToken *Token +} + +func (s mockTokenSource) Token() (*Token, error) { + return s.nextToken, nil +} + +func TestCustomTokenSource(t *testing.T) { + foobarToken := &Token{AccessToken: "foobar"} + barbazToken := &Token{AccessToken: "barbaz"} + + testCases := []struct { + name string + t *Token + src TokenSource + validToken bool + expectedToken *Token + }{ + { + name: "invalid token", + t: foobarToken, + src: mockTokenSource{nextToken: barbazToken}, + validToken: false, + expectedToken: barbazToken, + }, + { + name: "valid token", + t: foobarToken, + src: mockTokenSource{nextToken: barbazToken}, + validToken: true, + expectedToken: foobarToken, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + validFunc := func(t *Token) bool { return tt.validToken } + ts := CustomTokenSource(tt.t, tt.src, validFunc) + + // the same expected token should always be returned no matter how many iterations + // we go through since the validfunc returns a constant value + for i := 0; i < 3; i++ { + newToken, err := ts.Token() + if err != nil { + t.Errorf("did not expect an error but got: %v", err) + } + + if tt.expectedToken != newToken { + t.Errorf("expected token %v, but got %v", tt.expectedToken, newToken) + } + } + }) + } +}