diff --git a/oauth2.go b/oauth2.go index de34feb84..4628f97f4 100644 --- a/oauth2.go +++ b/oauth2.go @@ -242,14 +242,14 @@ func (c *Config) Client(ctx context.Context, t *Token) *http.Client { return NewClient(ctx, c.TokenSource(ctx, t)) } -// TokenSource returns a [TokenSource] that returns t until t expires, -// automatically refreshing it as necessary using the provided context. +// TokenSourceWithOptions returns a [TokenSource] that returns t until t expires, // -// Most users will use [Config.Client] instead. -func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { +// This method provides a way to pass options to the token source. +func (c *Config) TokenSourceWithOptions(ctx context.Context, t *Token, opts ...AuthCodeOption) TokenSource { tkr := &tokenRefresher{ ctx: ctx, conf: c, + opts: opts, } if t != nil { tkr.refreshToken = t.RefreshToken @@ -260,12 +260,21 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { } } +// TokenSource returns a [TokenSource] that returns t until t expires, +// automatically refreshing it as necessary using the provided context. +// +// Most users will use [Config.Client] instead. +func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { + return c.TokenSourceWithOptions(ctx, t) +} + // tokenRefresher is a TokenSource that makes "grant_type=refresh_token" // HTTP requests to renew a token using a RefreshToken. type tokenRefresher struct { ctx context.Context // used to get HTTP requests conf *Config refreshToken string + opts []AuthCodeOption } // WARNING: Token is not safe for concurrent access, as it @@ -277,10 +286,15 @@ func (tf *tokenRefresher) Token() (*Token, error) { return nil, errors.New("oauth2: token expired and refresh token is not set") } - tk, err := retrieveToken(tf.ctx, tf.conf, url.Values{ + v := url.Values{ "grant_type": {"refresh_token"}, "refresh_token": {tf.refreshToken}, - }) + } + for _, opt := range tf.opts { + opt.setValue(v) + } + + tk, err := retrieveToken(tf.ctx, tf.conf, v) if err != nil { return nil, err diff --git a/oauth2_test.go b/oauth2_test.go index 1cc14c644..d32f4ee69 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -551,6 +551,32 @@ func TestRefreshToken_RefreshTokenReplacement(t *testing.T) { } } +func TestRefreshToken_RefreshWithOpts(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("Failed reading request body: %s.", err) + } + if string(body) != "foo=bar&grant_type=refresh_token&refresh_token=OLD_REFRESH_TOKEN" { + t.Errorf("Unexpected exchange payload; got %q", body) + } + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"ACCESS_TOKEN", "scope": "user", "token_type": "bearer", "refresh_token": "NEW_REFRESH_TOKEN"}`)) + return + })) + defer ts.Close() + conf := newConf(ts.URL) + tkr := conf.TokenSourceWithOptions(context.Background(), &Token{RefreshToken: "OLD_REFRESH_TOKEN"}, SetAuthURLParam("foo", "bar")) + tk, err := tkr.Token() + if err != nil { + t.Errorf("got err = %v; want none", err) + return + } + if want := "NEW_REFRESH_TOKEN"; tk.RefreshToken != want { + t.Errorf("RefreshToken = %q; want %q", tk.RefreshToken, want) + } +} + func TestRefreshToken_RefreshTokenPreservation(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json")