Skip to content

Add HTTP status check before parsing token responses #562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions requests_oauthlib/oauth2_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError
from oauthlib.oauth2 import LegacyApplicationClient
from oauthlib.oauth2 import TokenExpiredError, is_secure_transport
from oauthlib.oauth2.rfc6749.errors import CustomOAuth2Error
import requests

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -199,6 +200,17 @@ def authorization_url(self, url, state=None, **kwargs):
state,
)

def validate_token_response(self, r):
message = ""
try:
r.raise_for_status()
except requests.HTTPError as e:
message = str(e)
if r.text:
message += f"\nBody: {r.text}"
if message:
raise CustomOAuth2Error('Response error', message, uri=r.request.url, status_code=r.status_code)

def fetch_token(
self,
token_url,
Expand Down Expand Up @@ -403,6 +415,7 @@ def fetch_token(
log.debug("Invoking hook %s.", hook)
r = hook(r)

self.validate_token_response(r)
self._client.parse_request_body_response(r.text, scope=self.scope)
self.token = self._client.token
log.debug("Obtained token %s.", self.token)
Expand Down Expand Up @@ -493,6 +506,7 @@ def refresh_token(
log.debug("Invoking hook %s.", hook)
r = hook(r)

self.validate_token_response(r)
self.token = self._client.parse_request_body_response(r.text, scope=self.scope)
if "refresh_token" not in self.token:
log.debug("No new refresh token given. Re-using old.")
Expand Down
58 changes: 58 additions & 0 deletions tests/test_oauth2_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from oauthlib.common import urlencode
from oauthlib.oauth2 import TokenExpiredError, OAuth2Error
from oauthlib.oauth2 import MismatchingStateError
from oauthlib.oauth2.rfc6749.errors import CustomOAuth2Error
from oauthlib.oauth2 import WebApplicationClient, MobileApplicationClient
from oauthlib.oauth2 import LegacyApplicationClient, BackendApplicationClient
from requests_oauthlib import OAuth2Session, TokenUpdated
Expand Down Expand Up @@ -524,6 +525,63 @@ def fake_send(r, **kwargs):
sess.fetch_token(url)
self.assertTrue(sess.authorized)

def test_fetch_token_http_error_handling(self):
"""Test that HTTP errors are properly raised instead of parsing error responses as tokens."""
url = "https://example.com/token"

# Test 400 error response (like the original issue)
error_400 = {"messages": [{"logLevel": "Error", "text": "User not Found"}], "status": "Forbidden"}

def fake_400_error(r, **kwargs):
resp = mock.MagicMock()
resp.text = json.dumps(error_400)
resp.status_code = 400
resp.request = mock.MagicMock()
resp.request.url = url
resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
"400 Client Error", response=resp
)
return resp

for client in self.clients:
sess = OAuth2Session(client=client)
sess.send = fake_400_error

if isinstance(client, LegacyApplicationClient):
self.assertRaises(
CustomOAuth2Error,
sess.fetch_token,
url,
username="username1",
password="password1",
)
else:
self.assertRaises(CustomOAuth2Error, sess.fetch_token, url)

def test_refresh_token_http_error_handling(self):
"""Test that HTTP errors are properly raised instead of parsing error responses as tokens."""
url = "https://example.com/refresh"

# Test 400 error response
error_400 = {"error": "invalid_grant",
"error_description": "Refresh token is invalid"}

def fake_400_error(r, **kwargs):
resp = mock.MagicMock()
resp.text = json.dumps(error_400)
resp.status_code = 400
resp.request = mock.MagicMock()
resp.request.url = url
resp.raise_for_status.side_effect = requests.exceptions.HTTPError(
"400 Client Error", response=resp
)
return resp

for client in self.clients:
sess = OAuth2Session(client=client, token=self.token)
sess.send = fake_400_error
self.assertRaises(CustomOAuth2Error, sess.refresh_token, url)


class OAuth2SessionNetrcTest(OAuth2SessionTest):
"""Ensure that there is no magic auth handling.
Expand Down