diff --git a/requests_oauthlib/oauth2_session.py b/requests_oauthlib/oauth2_session.py index 93cc4d7..8094bfa 100644 --- a/requests_oauthlib/oauth2_session.py +++ b/requests_oauthlib/oauth2_session.py @@ -4,6 +4,7 @@ from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError from oauthlib.oauth2 import LegacyApplicationClient from oauthlib.oauth2 import TokenExpiredError, is_secure_transport +from oauthlib.oauth2.rfc6749.errors import CustomOAuth2Error import requests log = logging.getLogger(__name__) @@ -199,6 +200,17 @@ def authorization_url(self, url, state=None, **kwargs): state, ) + def validate_token_response(self, r): + message = "" + try: + r.raise_for_status() + except requests.HTTPError as e: + message = str(e) + if r.text: + message += f"\nBody: {r.text}" + if message: + raise CustomOAuth2Error('Response error', message, uri=r.request.url, status_code=r.status_code) + def fetch_token( self, token_url, @@ -403,6 +415,7 @@ def fetch_token( log.debug("Invoking hook %s.", hook) r = hook(r) + self.validate_token_response(r) self._client.parse_request_body_response(r.text, scope=self.scope) self.token = self._client.token log.debug("Obtained token %s.", self.token) @@ -493,6 +506,7 @@ def refresh_token( log.debug("Invoking hook %s.", hook) r = hook(r) + self.validate_token_response(r) self.token = self._client.parse_request_body_response(r.text, scope=self.scope) if "refresh_token" not in self.token: log.debug("No new refresh token given. Re-using old.") diff --git a/tests/test_oauth2_session.py b/tests/test_oauth2_session.py index 7e3e63c..50834fc 100644 --- a/tests/test_oauth2_session.py +++ b/tests/test_oauth2_session.py @@ -12,6 +12,7 @@ from oauthlib.common import urlencode from oauthlib.oauth2 import TokenExpiredError, OAuth2Error from oauthlib.oauth2 import MismatchingStateError +from oauthlib.oauth2.rfc6749.errors import CustomOAuth2Error from oauthlib.oauth2 import WebApplicationClient, MobileApplicationClient from oauthlib.oauth2 import LegacyApplicationClient, BackendApplicationClient from requests_oauthlib import OAuth2Session, TokenUpdated @@ -524,6 +525,63 @@ def fake_send(r, **kwargs): sess.fetch_token(url) self.assertTrue(sess.authorized) + def test_fetch_token_http_error_handling(self): + """Test that HTTP errors are properly raised instead of parsing error responses as tokens.""" + url = "https://example.com/token" + + # Test 400 error response (like the original issue) + error_400 = {"messages": [{"logLevel": "Error", "text": "User not Found"}], "status": "Forbidden"} + + def fake_400_error(r, **kwargs): + resp = mock.MagicMock() + resp.text = json.dumps(error_400) + resp.status_code = 400 + resp.request = mock.MagicMock() + resp.request.url = url + resp.raise_for_status.side_effect = requests.exceptions.HTTPError( + "400 Client Error", response=resp + ) + return resp + + for client in self.clients: + sess = OAuth2Session(client=client) + sess.send = fake_400_error + + if isinstance(client, LegacyApplicationClient): + self.assertRaises( + CustomOAuth2Error, + sess.fetch_token, + url, + username="username1", + password="password1", + ) + else: + self.assertRaises(CustomOAuth2Error, sess.fetch_token, url) + + def test_refresh_token_http_error_handling(self): + """Test that HTTP errors are properly raised instead of parsing error responses as tokens.""" + url = "https://example.com/refresh" + + # Test 400 error response + error_400 = {"error": "invalid_grant", + "error_description": "Refresh token is invalid"} + + def fake_400_error(r, **kwargs): + resp = mock.MagicMock() + resp.text = json.dumps(error_400) + resp.status_code = 400 + resp.request = mock.MagicMock() + resp.request.url = url + resp.raise_for_status.side_effect = requests.exceptions.HTTPError( + "400 Client Error", response=resp + ) + return resp + + for client in self.clients: + sess = OAuth2Session(client=client, token=self.token) + sess.send = fake_400_error + self.assertRaises(CustomOAuth2Error, sess.refresh_token, url) + class OAuth2SessionNetrcTest(OAuth2SessionTest): """Ensure that there is no magic auth handling.