diff --git a/Makefile b/Makefile index 652e0e6..fc3c5ee 100644 --- a/Makefile +++ b/Makefile @@ -37,6 +37,7 @@ remove_pytest_asyncio_from_sync: sed -i 's/@pytest.mark.asyncio//g' tests/_sync/test_client.py sed -i 's/_async/_sync/g' tests/_sync/test_client.py sed -i 's/Async/Sync/g' tests/_sync/test_client.py + sed -i 's/_client\.SyncClient/_client\.Client/g' tests/_sync/test_client.py sleep: sleep 2 diff --git a/postgrest/_async/request_builder.py b/postgrest/_async/request_builder.py index 2892fa3..fa5b856 100644 --- a/postgrest/_async/request_builder.py +++ b/postgrest/_async/request_builder.py @@ -1,6 +1,5 @@ from __future__ import annotations -from json import JSONDecodeError from typing import Any, Generic, Optional, TypeVar, Union from httpx import Headers, QueryParams @@ -19,7 +18,7 @@ pre_update, pre_upsert, ) -from ..exceptions import APIError, generate_default_error_message +from ..exceptions import APIError, APIErrorFromJSON, generate_default_error_message from ..types import ReturnMethod from ..utils import AsyncClient, get_origin_and_cast @@ -75,10 +74,9 @@ async def execute(self) -> APIResponse[_ReturnT]: return body return APIResponse[_ReturnT].from_http_request_response(r) else: - raise APIError(r.json()) + json_obj = APIErrorFromJSON.model_validate_json(r.content) + raise APIError(dict(json_obj)) except ValidationError as e: - raise APIError(r.json()) from e - except JSONDecodeError: raise APIError(generate_default_error_message(r)) @@ -124,10 +122,9 @@ async def execute(self) -> SingleAPIResponse[_ReturnT]: ): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok) return SingleAPIResponse[_ReturnT].from_http_request_response(r) else: - raise APIError(r.json()) + json_obj = APIErrorFromJSON.model_validate_json(r.content) + raise APIError(dict(json_obj)) except ValidationError as e: - raise APIError(r.json()) from e - except JSONDecodeError: raise APIError(generate_default_error_message(r)) diff --git a/postgrest/_sync/request_builder.py b/postgrest/_sync/request_builder.py index 9683b1a..4f8af88 100644 --- a/postgrest/_sync/request_builder.py +++ b/postgrest/_sync/request_builder.py @@ -1,6 +1,5 @@ from __future__ import annotations -from json import JSONDecodeError from typing import Any, Generic, Optional, TypeVar, Union from httpx import Headers, QueryParams @@ -19,7 +18,7 @@ pre_update, pre_upsert, ) -from ..exceptions import APIError, generate_default_error_message +from ..exceptions import APIError, APIErrorFromJSON, generate_default_error_message from ..types import ReturnMethod from ..utils import SyncClient, get_origin_and_cast @@ -75,10 +74,9 @@ def execute(self) -> APIResponse[_ReturnT]: return body return APIResponse[_ReturnT].from_http_request_response(r) else: - raise APIError(r.json()) + json_obj = APIErrorFromJSON.model_validate_json(r.content) + raise APIError(dict(json_obj)) except ValidationError as e: - raise APIError(r.json()) from e - except JSONDecodeError: raise APIError(generate_default_error_message(r)) @@ -124,10 +122,9 @@ def execute(self) -> SingleAPIResponse[_ReturnT]: ): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok) return SingleAPIResponse[_ReturnT].from_http_request_response(r) else: - raise APIError(r.json()) + json_obj = APIErrorFromJSON.model_validate_json(r.content) + raise APIError(dict(json_obj)) except ValidationError as e: - raise APIError(r.json()) from e - except JSONDecodeError: raise APIError(generate_default_error_message(r)) @@ -290,7 +287,7 @@ def select( *columns: The names of the columns to fetch. count: The method to use to get the count of rows returned. Returns: - :class:`SyncSelectRequestBuilder` + :class:`AsyncSelectRequestBuilder` """ method, params, headers, json = pre_select(*columns, count=count, head=head) return SyncSelectRequestBuilder[_ReturnT]( @@ -317,7 +314,7 @@ def insert( Otherwise, use the default value for the column. Only applies for bulk inserts. Returns: - :class:`SyncQueryRequestBuilder` + :class:`AsyncQueryRequestBuilder` """ method, params, headers, json = pre_insert( json, @@ -353,7 +350,7 @@ def upsert( not when merging with existing rows under `ignoreDuplicates: false`. This also only applies when doing bulk upserts. Returns: - :class:`SyncQueryRequestBuilder` + :class:`AsyncQueryRequestBuilder` """ method, params, headers, json = pre_upsert( json, @@ -381,7 +378,7 @@ def update( count: The method to use to get the count of rows returned. returning: Either 'minimal' or 'representation' Returns: - :class:`SyncFilterRequestBuilder` + :class:`AsyncFilterRequestBuilder` """ method, params, headers, json = pre_update( json, @@ -404,7 +401,7 @@ def delete( count: The method to use to get the count of rows returned. returning: Either 'minimal' or 'representation' Returns: - :class:`SyncFilterRequestBuilder` + :class:`AsyncFilterRequestBuilder` """ method, params, headers, json = pre_delete( count=count, diff --git a/postgrest/exceptions.py b/postgrest/exceptions.py index 303c570..203153e 100644 --- a/postgrest/exceptions.py +++ b/postgrest/exceptions.py @@ -1,5 +1,23 @@ from typing import Dict, Optional +from pydantic import BaseModel + + +class APIErrorFromJSON(BaseModel): + """ + A pydantic object to validate an error info object + from a json string. + """ + + message: Optional[str] + """The error message.""" + code: Optional[str] + """The error code.""" + hint: Optional[str] + """The error hint.""" + details: Optional[str] + """The error details.""" + class APIError(Exception): """ diff --git a/tests/_async/test_client.py b/tests/_async/test_client.py index fb7881c..fa32fca 100644 --- a/tests/_async/test_client.py +++ b/tests/_async/test_client.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from httpx import BasicAuth, Headers +from httpx import BasicAuth, Headers, Request, Response from postgrest import AsyncPostgrestClient from postgrest.exceptions import APIError @@ -127,3 +127,28 @@ async def test_response_maybe_single(postgrest_client: AsyncPostgrestClient): exc_response = exc_info.value.json() assert isinstance(exc_response.get("message"), str) assert "code" in exc_response and int(exc_response["code"]) == 204 + + +# https://github.com/supabase/postgrest-py/issues/595 +@pytest.mark.asyncio +async def test_response_client_invalid_response_but_valid_json( + postgrest_client: AsyncPostgrestClient, +): + with patch( + "httpx._client.AsyncClient.request", + return_value=Response( + status_code=502, + text='"gateway error: Error: Network connection lost."', # quotes makes this text a valid non-dict JSON object + request=Request(method="GET", url="http://example.com"), + ), + ): + client = postgrest_client.from_("test").select("a", "b").eq("c", "d").single() + assert "Accept" in client.headers + assert client.headers.get("Accept") == "application/vnd.pgrst.object+json" + with pytest.raises(APIError) as exc_info: + await client.execute() + assert isinstance(exc_info, pytest.ExceptionInfo) + exc_response = exc_info.value.json() + assert isinstance(exc_response.get("message"), str) + assert exc_response.get("message") == "JSON could not be generated" + assert "code" in exc_response and int(exc_response["code"]) == 502 diff --git a/tests/_sync/test_client.py b/tests/_sync/test_client.py index 930fb7b..0e70a63 100644 --- a/tests/_sync/test_client.py +++ b/tests/_sync/test_client.py @@ -1,7 +1,7 @@ from unittest.mock import patch import pytest -from httpx import BasicAuth, Headers +from httpx import BasicAuth, Headers, Request, Response from postgrest import SyncPostgrestClient from postgrest.exceptions import APIError @@ -123,3 +123,29 @@ def test_response_maybe_single(postgrest_client: SyncPostgrestClient): exc_response = exc_info.value.json() assert isinstance(exc_response.get("message"), str) assert "code" in exc_response and int(exc_response["code"]) == 204 + + +# https://github.com/supabase/postgrest-py/issues/595 + + +def test_response_client_invalid_response_but_valid_json( + postgrest_client: SyncPostgrestClient, +): + with patch( + "httpx._client.Client.request", + return_value=Response( + status_code=502, + text='"gateway error: Error: Network connection lost."', # quotes makes this text a valid non-dict JSON object + request=Request(method="GET", url="http://example.com"), + ), + ): + client = postgrest_client.from_("test").select("a", "b").eq("c", "d").single() + assert "Accept" in client.headers + assert client.headers.get("Accept") == "application/vnd.pgrst.object+json" + with pytest.raises(APIError) as exc_info: + client.execute() + assert isinstance(exc_info, pytest.ExceptionInfo) + exc_response = exc_info.value.json() + assert isinstance(exc_response.get("message"), str) + assert exc_response.get("message") == "JSON could not be generated" + assert "code" in exc_response and int(exc_response["code"]) == 502