|
22 | 22 | import requests |
23 | 23 |
|
24 | 24 | from requests.auth import _basic_auth_str |
| 25 | +from requests.exceptions import HTTPError |
25 | 26 |
|
26 | 27 |
|
27 | 28 | fake_time = time.time() |
28 | 29 | CODE = "asdf345xdf" |
29 | 30 |
|
30 | 31 |
|
31 | | -def fake_token(token): |
| 32 | +def fake_token(token, status_code: int = 200): |
32 | 33 | def fake_send(r, **kwargs): |
33 | 34 | resp = mock.MagicMock() |
34 | | - resp.status_code = 200 |
| 35 | + resp.status_code = status_code |
35 | 36 | resp.text = json.dumps(token) |
36 | 37 | return resp |
37 | 38 |
|
@@ -133,11 +134,11 @@ def test_refresh_token_request(self): |
133 | 134 | self.expired_token["expires_in"] = "-1" |
134 | 135 | del self.expired_token["expires_at"] |
135 | 136 |
|
136 | | - def fake_refresh(r, **kwargs): |
| 137 | + def fake_refresh(r, status_code: int = 200, **kwargs): |
137 | 138 | if "/refresh" in r.url: |
138 | 139 | self.assertNotIn("Authorization", r.headers) |
139 | 140 | resp = mock.MagicMock() |
140 | | - resp.status_code = 200 |
| 141 | + resp.status_code = status_code |
141 | 142 | resp.text = json.dumps(self.token) |
142 | 143 | return resp |
143 | 144 |
|
@@ -170,6 +171,19 @@ def token_updater(token): |
170 | 171 | sess.send = fake_refresh |
171 | 172 | sess.get("https://i.b") |
172 | 173 |
|
| 174 | + # test 5xx error handler |
| 175 | + for client in self.clients: |
| 176 | + sess = OAuth2Session( |
| 177 | + client=client, |
| 178 | + token=self.expired_token, |
| 179 | + auto_refresh_url="https://i.b/refresh", |
| 180 | + token_updater=token_updater, |
| 181 | + ) |
| 182 | + sess.send = lambda r, **kwargs: fake_refresh( |
| 183 | + r=r, status_code=503, kwargs=kwargs, |
| 184 | + ) |
| 185 | + self.assertRaises(HTTPError, sess.get, "https://i.b") |
| 186 | + |
173 | 187 | def fake_refresh_with_auth(r, **kwargs): |
174 | 188 | if "/refresh" in r.url: |
175 | 189 | self.assertIn("Authorization", r.headers) |
@@ -256,6 +270,23 @@ def test_fetch_token(self): |
256 | 270 | else: |
257 | 271 | self.assertRaises(OAuth2Error, sess.fetch_token, url) |
258 | 272 |
|
| 273 | + # test 5xx error responses |
| 274 | + error = {"error": "server error!"} |
| 275 | + for client in self.clients: |
| 276 | + sess = OAuth2Session(client=client, token=self.token) |
| 277 | + sess.send = fake_token(error, status_code=500) |
| 278 | + if isinstance(client, LegacyApplicationClient): |
| 279 | + # this client requires a username+password |
| 280 | + self.assertRaises( |
| 281 | + HTTPError, |
| 282 | + sess.fetch_token, |
| 283 | + url, |
| 284 | + username="username1", |
| 285 | + password="password1", |
| 286 | + ) |
| 287 | + else: |
| 288 | + self.assertRaises(HTTPError, sess.fetch_token, url) |
| 289 | + |
259 | 290 | # there are different scenarios in which the `client_id` can be specified |
260 | 291 | # reference `oauthlib.tests.oauth2.rfc6749.clients.test_web_application.WebApplicationClientTest.test_prepare_request_body` |
261 | 292 | # this only needs to test WebApplicationClient |
|
0 commit comments