Skip to content

fix(API): validate JSON input for APIError.__init__() #597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ remove_pytest_asyncio_from_sync:
sed -i 's/@pytest.mark.asyncio//g' tests/_sync/test_client.py
sed -i 's/_async/_sync/g' tests/_sync/test_client.py
sed -i 's/Async/Sync/g' tests/_sync/test_client.py
sed -i 's/_client\.SyncClient/_client\.Client/g' tests/_sync/test_client.py

sleep:
sleep 2
13 changes: 5 additions & 8 deletions postgrest/_async/request_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from json import JSONDecodeError
from typing import Any, Generic, Optional, TypeVar, Union

from httpx import Headers, QueryParams
Expand All @@ -19,7 +18,7 @@
pre_update,
pre_upsert,
)
from ..exceptions import APIError, generate_default_error_message
from ..exceptions import APIError, APIErrorFromJSON, generate_default_error_message
from ..types import ReturnMethod
from ..utils import AsyncClient, get_origin_and_cast

Expand Down Expand Up @@ -75,10 +74,9 @@ async def execute(self) -> APIResponse[_ReturnT]:
return body
return APIResponse[_ReturnT].from_http_request_response(r)
else:
raise APIError(r.json())
json_obj = APIErrorFromJSON.model_validate_json(r.content)
raise APIError(dict(json_obj))
except ValidationError as e:
raise APIError(r.json()) from e
except JSONDecodeError:
raise APIError(generate_default_error_message(r))


Expand Down Expand Up @@ -124,10 +122,9 @@ async def execute(self) -> SingleAPIResponse[_ReturnT]:
): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok)
return SingleAPIResponse[_ReturnT].from_http_request_response(r)
else:
raise APIError(r.json())
json_obj = APIErrorFromJSON.model_validate_json(r.content)
raise APIError(dict(json_obj))
except ValidationError as e:
raise APIError(r.json()) from e
except JSONDecodeError:
raise APIError(generate_default_error_message(r))


Expand Down
23 changes: 10 additions & 13 deletions postgrest/_sync/request_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from json import JSONDecodeError
from typing import Any, Generic, Optional, TypeVar, Union

from httpx import Headers, QueryParams
Expand All @@ -19,7 +18,7 @@
pre_update,
pre_upsert,
)
from ..exceptions import APIError, generate_default_error_message
from ..exceptions import APIError, APIErrorFromJSON, generate_default_error_message
from ..types import ReturnMethod
from ..utils import SyncClient, get_origin_and_cast

Expand Down Expand Up @@ -75,10 +74,9 @@ def execute(self) -> APIResponse[_ReturnT]:
return body
return APIResponse[_ReturnT].from_http_request_response(r)
else:
raise APIError(r.json())
json_obj = APIErrorFromJSON.model_validate_json(r.content)
raise APIError(dict(json_obj))
except ValidationError as e:
raise APIError(r.json()) from e
except JSONDecodeError:
raise APIError(generate_default_error_message(r))


Expand Down Expand Up @@ -124,10 +122,9 @@ def execute(self) -> SingleAPIResponse[_ReturnT]:
): # Response.ok from JS (https://developer.mozilla.org/en-US/docs/Web/API/Response/ok)
return SingleAPIResponse[_ReturnT].from_http_request_response(r)
else:
raise APIError(r.json())
json_obj = APIErrorFromJSON.model_validate_json(r.content)
raise APIError(dict(json_obj))
except ValidationError as e:
raise APIError(r.json()) from e
except JSONDecodeError:
raise APIError(generate_default_error_message(r))


Expand Down Expand Up @@ -290,7 +287,7 @@ def select(
*columns: The names of the columns to fetch.
count: The method to use to get the count of rows returned.
Returns:
:class:`SyncSelectRequestBuilder`
:class:`AsyncSelectRequestBuilder`
"""
method, params, headers, json = pre_select(*columns, count=count, head=head)
return SyncSelectRequestBuilder[_ReturnT](
Expand All @@ -317,7 +314,7 @@ def insert(
Otherwise, use the default value for the column.
Only applies for bulk inserts.
Returns:
:class:`SyncQueryRequestBuilder`
:class:`AsyncQueryRequestBuilder`
"""
method, params, headers, json = pre_insert(
json,
Expand Down Expand Up @@ -353,7 +350,7 @@ def upsert(
not when merging with existing rows under `ignoreDuplicates: false`.
This also only applies when doing bulk upserts.
Returns:
:class:`SyncQueryRequestBuilder`
:class:`AsyncQueryRequestBuilder`
"""
method, params, headers, json = pre_upsert(
json,
Expand Down Expand Up @@ -381,7 +378,7 @@ def update(
count: The method to use to get the count of rows returned.
returning: Either 'minimal' or 'representation'
Returns:
:class:`SyncFilterRequestBuilder`
:class:`AsyncFilterRequestBuilder`
"""
method, params, headers, json = pre_update(
json,
Expand All @@ -404,7 +401,7 @@ def delete(
count: The method to use to get the count of rows returned.
returning: Either 'minimal' or 'representation'
Returns:
:class:`SyncFilterRequestBuilder`
:class:`AsyncFilterRequestBuilder`
"""
method, params, headers, json = pre_delete(
count=count,
Expand Down
18 changes: 18 additions & 0 deletions postgrest/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
from typing import Dict, Optional

from pydantic import BaseModel


class APIErrorFromJSON(BaseModel):
"""
A pydantic object to validate an error info object
from a json string.
"""

message: Optional[str]
"""The error message."""
code: Optional[str]
"""The error code."""
hint: Optional[str]
"""The error hint."""
details: Optional[str]
"""The error details."""


class APIError(Exception):
"""
Expand Down
27 changes: 26 additions & 1 deletion tests/_async/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import patch

import pytest
from httpx import BasicAuth, Headers
from httpx import BasicAuth, Headers, Request, Response

from postgrest import AsyncPostgrestClient
from postgrest.exceptions import APIError
Expand Down Expand Up @@ -127,3 +127,28 @@ async def test_response_maybe_single(postgrest_client: AsyncPostgrestClient):
exc_response = exc_info.value.json()
assert isinstance(exc_response.get("message"), str)
assert "code" in exc_response and int(exc_response["code"]) == 204


# https://github.com/supabase/postgrest-py/issues/595
@pytest.mark.asyncio
async def test_response_client_invalid_response_but_valid_json(
postgrest_client: AsyncPostgrestClient,
):
with patch(
"httpx._client.AsyncClient.request",
return_value=Response(
status_code=502,
text='"gateway error: Error: Network connection lost."', # quotes makes this text a valid non-dict JSON object
request=Request(method="GET", url="http://example.com"),
),
):
client = postgrest_client.from_("test").select("a", "b").eq("c", "d").single()
assert "Accept" in client.headers
assert client.headers.get("Accept") == "application/vnd.pgrst.object+json"
with pytest.raises(APIError) as exc_info:
await client.execute()
assert isinstance(exc_info, pytest.ExceptionInfo)
exc_response = exc_info.value.json()
assert isinstance(exc_response.get("message"), str)
assert exc_response.get("message") == "JSON could not be generated"
assert "code" in exc_response and int(exc_response["code"]) == 502
28 changes: 27 additions & 1 deletion tests/_sync/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import patch

import pytest
from httpx import BasicAuth, Headers
from httpx import BasicAuth, Headers, Request, Response

from postgrest import SyncPostgrestClient
from postgrest.exceptions import APIError
Expand Down Expand Up @@ -123,3 +123,29 @@ def test_response_maybe_single(postgrest_client: SyncPostgrestClient):
exc_response = exc_info.value.json()
assert isinstance(exc_response.get("message"), str)
assert "code" in exc_response and int(exc_response["code"]) == 204


# https://github.com/supabase/postgrest-py/issues/595


def test_response_client_invalid_response_but_valid_json(
postgrest_client: SyncPostgrestClient,
):
with patch(
"httpx._client.Client.request",
return_value=Response(
status_code=502,
text='"gateway error: Error: Network connection lost."', # quotes makes this text a valid non-dict JSON object
request=Request(method="GET", url="http://example.com"),
),
):
client = postgrest_client.from_("test").select("a", "b").eq("c", "d").single()
assert "Accept" in client.headers
assert client.headers.get("Accept") == "application/vnd.pgrst.object+json"
with pytest.raises(APIError) as exc_info:
client.execute()
assert isinstance(exc_info, pytest.ExceptionInfo)
exc_response = exc_info.value.json()
assert isinstance(exc_response.get("message"), str)
assert exc_response.get("message") == "JSON could not be generated"
assert "code" in exc_response and int(exc_response["code"]) == 502