Skip to content

Commit

Permalink
oauth2: close request body if errors occur before base RoundTripper i…
Browse files Browse the repository at this point in the history
…s invoked

Fixes golang/oauth#269

Change-Id: I25eb3273a0868a999a2e98961ae5e4040e44ad7a
Reviewed-on: https://go-review.googlesource.com/114956
Reviewed-by: Brad Fitzpatrick <[email protected]>
  • Loading branch information
Tim Cooper authored and bradfitz committed May 29, 2018
1 parent bee4e0a commit 30b72df
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 0 deletions.
13 changes: 13 additions & 0 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ type Transport struct {
// access token. If no token exists or token is expired,
// tries to refresh/fetch a new token.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
reqBodyClosed := false
if req.Body != nil {
defer func() {
if !reqBodyClosed {
req.Body.Close()
}
}()
}

if t.Source == nil {
return nil, errors.New("oauth2: Transport's Source is nil")
}
Expand All @@ -46,6 +55,10 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
token.SetAuthHeader(req2)
t.setModReq(req, req2)
res, err := t.base().RoundTrip(req2)

// req.Body is assumed to have been closed by the base RoundTripper.
reqBodyClosed = true

if err != nil {
t.setModReq(req, nil)
return nil, err
Expand Down
60 changes: 60 additions & 0 deletions transport_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package oauth2

import (
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
Expand All @@ -27,6 +29,64 @@ func TestTransportNilTokenSource(t *testing.T) {
}
}

type readCloseCounter struct {
CloseCount int
ReadErr error
}

func (r *readCloseCounter) Read(b []byte) (int, error) {
return 0, r.ReadErr
}

func (r *readCloseCounter) Close() error {
r.CloseCount++
return nil
}

func TestTransportCloseRequestBody(t *testing.T) {
tr := &Transport{}
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
defer server.Close()
client := &http.Client{Transport: tr}
body := &readCloseCounter{
ReadErr: errors.New("readCloseCounter.Read not implemented"),
}
resp, err := client.Post(server.URL, "application/json", body)
if err == nil {
t.Errorf("got no errors, want an error with nil token source")
}
if resp != nil {
t.Errorf("Response = %v; want nil", resp)
}
if expected := 1; body.CloseCount != expected {
t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
}
}

func TestTransportCloseRequestBodySuccess(t *testing.T) {
tr := &Transport{
Source: StaticTokenSource(&Token{
AccessToken: "abc",
}),
}
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
defer server.Close()
client := &http.Client{Transport: tr}
body := &readCloseCounter{
ReadErr: io.EOF,
}
resp, err := client.Post(server.URL, "application/json", body)
if err != nil {
t.Errorf("got error %v; expected none", err)
}
if resp == nil {
t.Errorf("Response is nil; expected non-nil")
}
if expected := 1; body.CloseCount != expected {
t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
}
}

func TestTransportTokenSource(t *testing.T) {
ts := &tokenSource{
token: &Token{
Expand Down

0 comments on commit 30b72df

Please sign in to comment.