Skip to content

Commit 576a5b8

Browse files
fix: pass params as query params for get/head requests (#593)
Co-authored-by: Andrew Smith <[email protected]>
1 parent 2fa8a82 commit 576a5b8

File tree

5 files changed

+71
-1
lines changed

5 files changed

+71
-1
lines changed

infra/init.sql

+6
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ as $function$
7777
select * from countries;
7878
$function$;
7979

80+
create or replace function public.search_countries_by_name(search_name text)
81+
returns setof countries
82+
language sql
83+
as $function$
84+
select * from countries where nicename ilike '%' || search_name || '%';
85+
$function$;
8086

8187
create table
8288
orchestral_sections (id int8 primary key, name text);

postgrest/_async/client.py

+9
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ def rpc(
133133

134134
headers = Headers({"Prefer": f"count={count}"}) if count else Headers()
135135

136+
if method in ("HEAD", "GET"):
137+
return AsyncRPCFilterRequestBuilder[Any](
138+
self.session,
139+
f"/rpc/{func}",
140+
method,
141+
headers,
142+
QueryParams(params),
143+
json={},
144+
)
136145
# the params here are params to be sent to the RPC and not the queryparams!
137146
return AsyncRPCFilterRequestBuilder[Any](
138147
self.session, f"/rpc/{func}", method, headers, QueryParams(), json=params

postgrest/_sync/client.py

+9
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ def rpc(
133133

134134
headers = Headers({"Prefer": f"count={count}"}) if count else Headers()
135135

136+
if method in ("HEAD", "GET"):
137+
return SyncRPCFilterRequestBuilder[Any](
138+
self.session,
139+
f"/rpc/{func}",
140+
method,
141+
headers,
142+
QueryParams(params),
143+
json={},
144+
)
136145
# the params here are params to be sent to the RPC and not the queryparams!
137146
return SyncRPCFilterRequestBuilder[Any](
138147
self.session, f"/rpc/{func}", method, headers, QueryParams(), json=params

postgrest/base_request_builder.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,11 @@ def select(
660660
"""
661661
method, params, headers, json = pre_select(*columns, count=None)
662662
self.params = self.params.add("select", params.get("select"))
663-
self.headers["Prefer"] = "return=representation"
663+
if self.headers.get("Prefer"):
664+
self.headers["Prefer"] += ",return=representation"
665+
else:
666+
self.headers["Prefer"] = "return=representation"
667+
664668
return self
665669

666670
def single(self) -> Self:

tests/_sync/test_filter_request_builder_integration.py

+42
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,48 @@ def test_rpc_with_range():
476476
]
477477

478478

479+
def test_rpc_post_with_args():
480+
res = (
481+
rest_client()
482+
.rpc("search_countries_by_name", {"search_name": "Alban"})
483+
.select("nicename, iso")
484+
.execute()
485+
)
486+
assert res.data == [{"nicename": "Albania", "iso": "AL"}]
487+
488+
489+
def test_rpc_get_with_args():
490+
res = (
491+
rest_client()
492+
.rpc("search_countries_by_name", {"search_name": "Alger"}, get=True)
493+
.select("nicename, iso")
494+
.execute()
495+
)
496+
assert res.data == [{"nicename": "Algeria", "iso": "DZ"}]
497+
498+
499+
def test_rpc_get_with_count():
500+
res = (
501+
rest_client()
502+
.rpc("search_countries_by_name", {"search_name": "Al"}, get=True, count="exact")
503+
.select("nicename")
504+
.execute()
505+
)
506+
assert res.count == 2
507+
assert res.data == [{"nicename": "Albania"}, {"nicename": "Algeria"}]
508+
509+
510+
def test_rpc_head_count():
511+
res = (
512+
rest_client()
513+
.rpc("search_countries_by_name", {"search_name": "Al"}, head=True, count="exact")
514+
.execute()
515+
)
516+
517+
assert res.count == 2
518+
assert res.data == []
519+
520+
479521
def test_order():
480522
res = (
481523
rest_client()

0 commit comments

Comments
 (0)