diff --git a/Makefile b/Makefile index fc3c5ee8..c95887b4 100644 --- a/Makefile +++ b/Makefile @@ -37,7 +37,10 @@ remove_pytest_asyncio_from_sync: sed -i 's/@pytest.mark.asyncio//g' tests/_sync/test_client.py sed -i 's/_async/_sync/g' tests/_sync/test_client.py sed -i 's/Async/Sync/g' tests/_sync/test_client.py + sed -i 's/Async/Sync/g' postgrest/_sync/request_builder.py sed -i 's/_client\.SyncClient/_client\.Client/g' tests/_sync/test_client.py + sed -i 's/SyncHTTPTransport/HTTPTransport/g' tests/_sync/test_client.py + sed -i 's/SyncHTTPTransport/HTTPTransport/g' tests/_sync/client.py sleep: sleep 2 diff --git a/postgrest/__init__.py b/postgrest/__init__.py index f060e684..17c5f1f0 100644 --- a/postgrest/__init__.py +++ b/postgrest/__init__.py @@ -27,4 +27,41 @@ from .deprecated_client import Client, PostgrestClient from .deprecated_get_request_builder import GetRequestBuilder from .exceptions import APIError +from .types import ( + CountMethod, + Filters, + RequestMethod, + ReturnMethod, +) from .version import __version__ + +__all__ = [ + "AsyncPostgrestClient", + "AsyncFilterRequestBuilder", + "AsyncQueryRequestBuilder", + "AsyncRequestBuilder", + "AsyncRPCFilterRequestBuilder", + "AsyncSelectRequestBuilder", + "AsyncSingleRequestBuilder", + "AsyncMaybeSingleRequestBuilder", + "SyncPostgrestClient", + "SyncFilterRequestBuilder", + "SyncMaybeSingleRequestBuilder", + "SyncQueryRequestBuilder", + "SyncRequestBuilder", + "SyncRPCFilterRequestBuilder", + "SyncSelectRequestBuilder", + "SyncSingleRequestBuilder", + "APIResponse", + "DEFAULT_POSTGREST_CLIENT_HEADERS", + "Client", + "PostgrestClient", + "GetRequestBuilder", + "APIError", + "CountMethod", + "Filters", + "RequestMethod", + "ReturnMethod", + "Timeout", + "__version__", +] diff --git a/postgrest/_async/client.py b/postgrest/_async/client.py index b2994d32..f3e1a73a 100644 --- a/postgrest/_async/client.py +++ b/postgrest/_async/client.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Any, Dict, Optional, Union, cast +from warnings import warn from deprecation import deprecated from httpx import Headers, QueryParams, Timeout @@ -27,18 +28,50 @@ def __init__( *, schema: str = "public", headers: Dict[str, str] = DEFAULT_POSTGREST_CLIENT_HEADERS, - timeout: Union[int, float, Timeout] = DEFAULT_POSTGREST_CLIENT_TIMEOUT, - verify: bool = True, + timeout: Union[int, float, Timeout, None] = None, + verify: Optional[bool] = None, proxy: Optional[str] = None, + http_client: Optional[AsyncClient] = None, ) -> None: + if timeout is not None: + warn( + "The 'timeout' parameter is deprecated. Please configure it in the http client instead.", + DeprecationWarning, + stacklevel=2, + ) + if verify is not None: + warn( + "The 'verify' parameter is deprecated. Please configure it in the http client instead.", + DeprecationWarning, + stacklevel=2, + ) + if proxy is not None: + warn( + "The 'proxy' parameter is deprecated. Please configure it in the http client instead.", + DeprecationWarning, + stacklevel=2, + ) + + self.verify = bool(verify) if verify is not None else True + self.timeout = ( + timeout + if isinstance(timeout, Timeout) + else ( + int(abs(timeout)) + if timeout is not None + else DEFAULT_POSTGREST_CLIENT_TIMEOUT + ) + ) + BasePostgrestClient.__init__( self, base_url, schema=schema, headers=headers, - timeout=timeout, - verify=verify, + timeout=self.timeout, + verify=self.verify, proxy=proxy, + http_client=http_client, ) self.session = cast(AsyncClient, self.session) @@ -50,6 +83,15 @@ def create_session( verify: bool = True, proxy: Optional[str] = None, ) -> AsyncClient: + http_client = None + if isinstance(self.http_client, AsyncClient): + http_client = self.http_client + + if http_client is not None: + http_client.base_url = base_url + http_client.headers.update({**headers}) + return http_client + return AsyncClient( base_url=base_url, headers=headers, diff --git a/postgrest/_sync/client.py b/postgrest/_sync/client.py index 1a27cfb2..6a77a204 100644 --- a/postgrest/_sync/client.py +++ b/postgrest/_sync/client.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import Any, Dict, Optional, Union, cast +from warnings import warn from deprecation import deprecated from httpx import Headers, QueryParams, Timeout @@ -27,18 +28,50 @@ def __init__( *, schema: str = "public", headers: Dict[str, str] = DEFAULT_POSTGREST_CLIENT_HEADERS, - timeout: Union[int, float, Timeout] = DEFAULT_POSTGREST_CLIENT_TIMEOUT, - verify: bool = True, + timeout: Union[int, float, Timeout, None] = None, + verify: Optional[bool] = None, proxy: Optional[str] = None, + http_client: Optional[SyncClient] = None, ) -> None: + if timeout is not None: + warn( + "The 'timeout' parameter is deprecated. Please configure it in the http client instead.", + DeprecationWarning, + stacklevel=2, + ) + if verify is not None: + warn( + "The 'verify' parameter is deprecated. Please configure it in the http client instead.", + DeprecationWarning, + stacklevel=2, + ) + if proxy is not None: + warn( + "The 'proxy' parameter is deprecated. Please configure it in the http client instead.", + DeprecationWarning, + stacklevel=2, + ) + + self.verify = bool(verify) if verify is not None else True + self.timeout = ( + timeout + if isinstance(timeout, Timeout) + else ( + int(abs(timeout)) + if timeout is not None + else DEFAULT_POSTGREST_CLIENT_TIMEOUT + ) + ) + BasePostgrestClient.__init__( self, base_url, schema=schema, headers=headers, - timeout=timeout, - verify=verify, + timeout=self.timeout, + verify=self.verify, proxy=proxy, + http_client=http_client, ) self.session = cast(SyncClient, self.session) @@ -50,6 +83,15 @@ def create_session( verify: bool = True, proxy: Optional[str] = None, ) -> SyncClient: + http_client = None + if isinstance(self.http_client, SyncClient): + http_client = self.http_client + + if http_client is not None: + http_client.base_url = base_url + http_client.headers.update({**headers}) + return http_client + return SyncClient( base_url=base_url, headers=headers, diff --git a/postgrest/_sync/request_builder.py b/postgrest/_sync/request_builder.py index 4f8af88e..572a7df0 100644 --- a/postgrest/_sync/request_builder.py +++ b/postgrest/_sync/request_builder.py @@ -287,7 +287,7 @@ def select( *columns: The names of the columns to fetch. count: The method to use to get the count of rows returned. Returns: - :class:`AsyncSelectRequestBuilder` + :class:`SyncSelectRequestBuilder` """ method, params, headers, json = pre_select(*columns, count=count, head=head) return SyncSelectRequestBuilder[_ReturnT]( @@ -314,7 +314,7 @@ def insert( Otherwise, use the default value for the column. Only applies for bulk inserts. Returns: - :class:`AsyncQueryRequestBuilder` + :class:`SyncQueryRequestBuilder` """ method, params, headers, json = pre_insert( json, @@ -350,7 +350,7 @@ def upsert( not when merging with existing rows under `ignoreDuplicates: false`. This also only applies when doing bulk upserts. Returns: - :class:`AsyncQueryRequestBuilder` + :class:`SyncQueryRequestBuilder` """ method, params, headers, json = pre_upsert( json, @@ -378,7 +378,7 @@ def update( count: The method to use to get the count of rows returned. returning: Either 'minimal' or 'representation' Returns: - :class:`AsyncFilterRequestBuilder` + :class:`SyncFilterRequestBuilder` """ method, params, headers, json = pre_update( json, @@ -401,7 +401,7 @@ def delete( count: The method to use to get the count of rows returned. returning: Either 'minimal' or 'representation' Returns: - :class:`AsyncFilterRequestBuilder` + :class:`SyncFilterRequestBuilder` """ method, params, headers, json = pre_delete( count=count, diff --git a/postgrest/base_client.py b/postgrest/base_client.py index 2c9756ab..840e5a0e 100644 --- a/postgrest/base_client.py +++ b/postgrest/base_client.py @@ -20,6 +20,7 @@ def __init__( timeout: Union[int, float, Timeout], verify: bool = True, proxy: Optional[str] = None, + http_client: Union[SyncClient, AsyncClient, None] = None, ) -> None: if not is_http_url(base_url): ValueError("base_url must be a valid HTTP URL string") @@ -33,8 +34,13 @@ def __init__( self.timeout = timeout self.verify = verify self.proxy = proxy + self.http_client = http_client self.session = self.create_session( - self.base_url, self.headers, self.timeout, self.verify, self.proxy + self.base_url, + self.headers, + self.timeout, + self.verify, + self.proxy, ) @abstractmethod diff --git a/postgrest/exceptions.py b/postgrest/exceptions.py index 203153ea..d4ef668d 100644 --- a/postgrest/exceptions.py +++ b/postgrest/exceptions.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Any, Dict, Optional from pydantic import BaseModel @@ -34,7 +34,7 @@ class APIError(Exception): details: Optional[str] """The error details.""" - def __init__(self, error: Dict[str, str]) -> None: + def __init__(self, error: Dict[str, Any]) -> None: self._raw_error = error self.message = error.get("message") self.code = error.get("code") diff --git a/pyproject.toml b/pyproject.toml index 594811cb..c88b7572 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ furo = ">=2023.9.10,<2025.0.0" [tool.pytest.ini_options] asyncio_mode = "auto" +filterwarnings = [ + "ignore::DeprecationWarning", # ignore deprecation warnings globally +] [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/_async/client.py b/tests/_async/client.py index cb97e6d0..a65f2baa 100644 --- a/tests/_async/client.py +++ b/tests/_async/client.py @@ -1,4 +1,7 @@ +from httpx import AsyncHTTPTransport, Limits + from postgrest import AsyncPostgrestClient +from postgrest.utils import AsyncClient REST_URL = "http://127.0.0.1:3000" @@ -7,3 +10,20 @@ def rest_client(): return AsyncPostgrestClient( base_url=REST_URL, ) + + +def rest_client_httpx(): + transport = AsyncHTTPTransport( + retries=4, + limits=Limits( + max_connections=1, + max_keepalive_connections=1, + keepalive_expiry=None, + ), + ) + headers = {"x-user-agent": "my-app/0.0.1"} + http_client = AsyncClient(transport=transport, headers=headers) + return AsyncPostgrestClient( + base_url=REST_URL, + http_client=http_client, + ) diff --git a/tests/_async/test_client.py b/tests/_async/test_client.py index fa32fcaa..35b6204f 100644 --- a/tests/_async/test_client.py +++ b/tests/_async/test_client.py @@ -1,10 +1,19 @@ from unittest.mock import patch import pytest -from httpx import BasicAuth, Headers, Request, Response +from httpx import ( + AsyncHTTPTransport, + BasicAuth, + Headers, + Limits, + Request, + Response, + Timeout, +) from postgrest import AsyncPostgrestClient from postgrest.exceptions import APIError +from postgrest.utils import AsyncClient @pytest.fixture @@ -46,6 +55,32 @@ async def test_custom_headers(self): assert session.headers.items() >= headers.items() +class TestHttpxClientConstructor: + @pytest.mark.asyncio + async def test_custom_httpx_client(self): + transport = AsyncHTTPTransport( + retries=10, + limits=Limits( + max_connections=1, + max_keepalive_connections=1, + keepalive_expiry=None, + ), + ) + headers = {"x-user-agent": "my-app/0.0.1"} + http_client = AsyncClient(transport=transport, headers=headers) + async with AsyncPostgrestClient( + "https://example.com", http_client=http_client, timeout=20.0 + ) as client: + session = client.session + + assert session.base_url == "https://example.com" + assert session.timeout == Timeout( + timeout=5.0 + ) # Should be the default 5 since we use custom httpx client + assert session.headers.get("x-user-agent") == "my-app/0.0.1" + assert isinstance(session, AsyncClient) + + class TestAuth: def test_auth_token(self, postgrest_client: AsyncPostgrestClient): postgrest_client.auth("s3cr3t") diff --git a/tests/_async/test_filter_request_builder_integration.py b/tests/_async/test_filter_request_builder_integration.py index 26d5260b..b4ff44fc 100644 --- a/tests/_async/test_filter_request_builder_integration.py +++ b/tests/_async/test_filter_request_builder_integration.py @@ -1,11 +1,30 @@ -from .client import rest_client +from postgrest import CountMethod + +from .client import rest_client, rest_client_httpx + + +async def test_multivalued_param_httpx(): + res = ( + await rest_client_httpx() + .from_("countries") + .select("country_name, iso", count=CountMethod.exact) + .lte("numcode", 8) + .gte("numcode", 4) + .execute() + ) + + assert res.count == 2 + assert res.data == [ + {"country_name": "AFGHANISTAN", "iso": "AF"}, + {"country_name": "ALBANIA", "iso": "AL"}, + ] async def test_multivalued_param(): res = ( await rest_client() .from_("countries") - .select("country_name, iso", count="exact") + .select("country_name, iso", count=CountMethod.exact) .lte("numcode", 8) .gte("numcode", 4) .execute() @@ -506,7 +525,12 @@ async def test_rpc_get_with_args(): async def test_rpc_get_with_count(): res = ( await rest_client() - .rpc("search_countries_by_name", {"search_name": "Al"}, get=True, count="exact") + .rpc( + "search_countries_by_name", + {"search_name": "Al"}, + get=True, + count=CountMethod.exact, + ) .select("nicename") .execute() ) @@ -517,7 +541,12 @@ async def test_rpc_get_with_count(): async def test_rpc_head_count(): res = ( await rest_client() - .rpc("search_countries_by_name", {"search_name": "Al"}, head=True, count="exact") + .rpc( + "search_countries_by_name", + {"search_name": "Al"}, + head=True, + count=CountMethod.exact, + ) .execute() ) diff --git a/tests/_sync/client.py b/tests/_sync/client.py index 7b3f3e09..659582c4 100644 --- a/tests/_sync/client.py +++ b/tests/_sync/client.py @@ -1,4 +1,7 @@ +from httpx import HTTPTransport, Limits + from postgrest import SyncPostgrestClient +from postgrest.utils import SyncClient REST_URL = "http://127.0.0.1:3000" @@ -7,3 +10,20 @@ def rest_client(): return SyncPostgrestClient( base_url=REST_URL, ) + + +def rest_client_httpx(): + transport = HTTPTransport( + retries=4, + limits=Limits( + max_connections=1, + max_keepalive_connections=1, + keepalive_expiry=None, + ), + ) + headers = {"x-user-agent": "my-app/0.0.1"} + http_client = SyncClient(transport=transport, headers=headers) + return SyncPostgrestClient( + base_url=REST_URL, + http_client=http_client, + ) diff --git a/tests/_sync/test_client.py b/tests/_sync/test_client.py index 0e70a63e..57b5cce4 100644 --- a/tests/_sync/test_client.py +++ b/tests/_sync/test_client.py @@ -1,10 +1,19 @@ from unittest.mock import patch import pytest -from httpx import BasicAuth, Headers, Request, Response +from httpx import ( + BasicAuth, + Headers, + HTTPTransport, + Limits, + Request, + Response, + Timeout, +) from postgrest import SyncPostgrestClient from postgrest.exceptions import APIError +from postgrest.utils import SyncClient @pytest.fixture @@ -45,6 +54,32 @@ def test_custom_headers(self): assert session.headers.items() >= headers.items() +class TestHttpxClientConstructor: + + def test_custom_httpx_client(self): + transport = HTTPTransport( + retries=10, + limits=Limits( + max_connections=1, + max_keepalive_connections=1, + keepalive_expiry=None, + ), + ) + headers = {"x-user-agent": "my-app/0.0.1"} + http_client = SyncClient(transport=transport, headers=headers) + with SyncPostgrestClient( + "https://example.com", http_client=http_client, timeout=20.0 + ) as client: + session = client.session + + assert session.base_url == "https://example.com" + assert session.timeout == Timeout( + timeout=5.0 + ) # Should be the default 5 since we use custom httpx client + assert session.headers.get("x-user-agent") == "my-app/0.0.1" + assert isinstance(session, SyncClient) + + class TestAuth: def test_auth_token(self, postgrest_client: SyncPostgrestClient): postgrest_client.auth("s3cr3t") diff --git a/tests/_sync/test_filter_request_builder_integration.py b/tests/_sync/test_filter_request_builder_integration.py index d52fc2ca..896bd9ee 100644 --- a/tests/_sync/test_filter_request_builder_integration.py +++ b/tests/_sync/test_filter_request_builder_integration.py @@ -1,11 +1,30 @@ -from .client import rest_client +from postgrest import CountMethod + +from .client import rest_client, rest_client_httpx + + +def test_multivalued_param_httpx(): + res = ( + rest_client_httpx() + .from_("countries") + .select("country_name, iso", count=CountMethod.exact) + .lte("numcode", 8) + .gte("numcode", 4) + .execute() + ) + + assert res.count == 2 + assert res.data == [ + {"country_name": "AFGHANISTAN", "iso": "AF"}, + {"country_name": "ALBANIA", "iso": "AL"}, + ] def test_multivalued_param(): res = ( rest_client() .from_("countries") - .select("country_name, iso", count="exact") + .select("country_name, iso", count=CountMethod.exact) .lte("numcode", 8) .gte("numcode", 4) .execute() @@ -499,7 +518,12 @@ def test_rpc_get_with_args(): def test_rpc_get_with_count(): res = ( rest_client() - .rpc("search_countries_by_name", {"search_name": "Al"}, get=True, count="exact") + .rpc( + "search_countries_by_name", + {"search_name": "Al"}, + get=True, + count=CountMethod.exact, + ) .select("nicename") .execute() ) @@ -510,7 +534,12 @@ def test_rpc_get_with_count(): def test_rpc_head_count(): res = ( rest_client() - .rpc("search_countries_by_name", {"search_name": "Al"}, head=True, count="exact") + .rpc( + "search_countries_by_name", + {"search_name": "Al"}, + head=True, + count=CountMethod.exact, + ) .execute() )