Skip to content

Commit 9b60ead

Browse files
Fix return types in async client, as done w/ sync.
1 parent d3aa982 commit 9b60ead

File tree

3 files changed

+78
-44
lines changed

3 files changed

+78
-44
lines changed

db_wrapper/client/async_client.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44
from typing import (
5+
cast,
56
Any,
67
TypeVar,
78
Union,
@@ -10,8 +11,8 @@
1011
List,
1112
Dict)
1213

13-
import aiopg # type: ignore
14-
from psycopg2.extras import register_uuid
14+
import aiopg
15+
from psycopg2.extras import register_uuid, RealDictRow
1516
from psycopg2 import sql
1617

1718
from db_wrapper.connection import ConnectionParameters, connect
@@ -20,10 +21,6 @@
2021
register_uuid()
2122

2223

23-
# Generic doesn't need a more descriptive name
24-
# pylint: disable=invalid-name
25-
T = TypeVar('T')
26-
2724
Query = Union[str, sql.Composed]
2825

2926

@@ -57,6 +54,11 @@ async def _execute_query(
5754
query: Query,
5855
params: Optional[Dict[Hashable, Any]] = None,
5956
) -> None:
57+
# aiopg type is incorrect & thinks execute only takes str
58+
# when in the query is passed through to psycopg2's
59+
# cursor.execute which does accept sql.Composed objects.
60+
query = cast(str, query)
61+
6062
if params:
6163
await cursor.execute(query, params)
6264
else:
@@ -88,7 +90,7 @@ async def execute_and_return(
8890
self,
8991
query: Query,
9092
params: Optional[Dict[Hashable, Any]] = None,
91-
) -> List[T]:
93+
) -> List[RealDictRow]:
9294
"""Execute the given SQL query & return the result.
9395
9496
Arguments:
@@ -102,5 +104,5 @@ async def execute_and_return(
102104
async with self._connection.cursor() as cursor:
103105
await self._execute_query(cursor, query, params)
104106

105-
result: List[T] = await cursor.fetchall()
107+
result: List[RealDictRow] = await cursor.fetchall()
106108
return result

db_wrapper/model/async_model.py

Lines changed: 64 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
"""Asynchronous Model objects."""
22

3-
from typing import Any, Dict, List
3+
from typing import Any, Dict, List, Type
44
from uuid import UUID
55

6+
from psycopg2.extras import RealDictRow
7+
68
from db_wrapper.client import AsyncClient
79
from .base import (
810
ensure_exactly_one,
11+
sql,
912
T,
1013
CreateABC,
1114
ReadABC,
1215
UpdateABC,
1316
DeleteABC,
1417
ModelABC,
15-
sql,
1618
)
1719

1820

@@ -23,16 +25,22 @@ class AsyncCreate(CreateABC[T]):
2325

2426
_client: AsyncClient
2527

26-
def __init__(self, client: AsyncClient, table: sql.Composable) -> None:
27-
super().__init__(table)
28+
def __init__(
29+
self,
30+
client: AsyncClient,
31+
table: sql.Composable,
32+
return_constructor: Type[T]
33+
) -> None:
34+
super().__init__(table, return_constructor)
2835
self._client = client
2936

3037
async def one(self, item: T) -> T:
3138
"""Create one new record with a given item."""
32-
result: List[T] = await self._client.execute_and_return(
33-
self._query_one(item))
39+
query_result: List[RealDictRow] = \
40+
await self._client.execute_and_return(self._query_one(item))
41+
result: T = self._return_constructor(**query_result[0])
3442

35-
return result[0]
43+
return result
3644

3745

3846
class AsyncRead(ReadABC[T]):
@@ -42,19 +50,27 @@ class AsyncRead(ReadABC[T]):
4250

4351
_client: AsyncClient
4452

45-
def __init__(self, client: AsyncClient, table: sql.Composable) -> None:
46-
super().__init__(table)
53+
def __init__(
54+
self,
55+
client: AsyncClient,
56+
table: sql.Composable,
57+
return_constructor: Type[T]
58+
) -> None:
59+
super().__init__(table, return_constructor)
4760
self._client = client
4861

4962
async def one_by_id(self, id_value: UUID) -> T:
5063
"""Read a row by it's id."""
51-
result: List[T] = await self._client.execute_and_return(
52-
self._query_one_by_id(id_value))
64+
query_result: List[RealDictRow] = \
65+
await self._client.execute_and_return(
66+
self._query_one_by_id(id_value))
5367

5468
# Should only return one item from DB
55-
ensure_exactly_one(result)
69+
ensure_exactly_one(query_result)
70+
71+
result: T = self._return_constructor(**query_result[0])
5672

57-
return result[0]
73+
return result
5874

5975

6076
class AsyncUpdate(UpdateABC[T]):
@@ -64,8 +80,13 @@ class AsyncUpdate(UpdateABC[T]):
6480

6581
_client: AsyncClient
6682

67-
def __init__(self, client: AsyncClient, table: sql.Composable) -> None:
68-
super().__init__(table)
83+
def __init__(
84+
self,
85+
client: AsyncClient,
86+
table: sql.Composable,
87+
return_constructor: Type[T]
88+
) -> None:
89+
super().__init__(table, return_constructor)
6990
self._client = client
7091

7192
async def one_by_id(self, id_value: str, changes: Dict[str, Any]) -> T:
@@ -79,12 +100,14 @@ async def one_by_id(self, id_value: str, changes: Dict[str, Any]) -> T:
79100
Returns:
80101
full value of row updated
81102
"""
82-
result: List[T] = await self._client.execute_and_return(
83-
self._query_one_by_id(id_value, changes))
103+
query_result: List[RealDictRow] = \
104+
await self._client.execute_and_return(
105+
self._query_one_by_id(id_value, changes))
84106

85-
ensure_exactly_one(result)
107+
ensure_exactly_one(query_result)
108+
result: T = self._return_constructor(**query_result[0])
86109

87-
return result[0]
110+
return result
88111

89112

90113
class AsyncDelete(DeleteABC[T]):
@@ -94,19 +117,26 @@ class AsyncDelete(DeleteABC[T]):
94117

95118
_client: AsyncClient
96119

97-
def __init__(self, client: AsyncClient, table: sql.Composable) -> None:
98-
super().__init__(table)
120+
def __init__(
121+
self,
122+
client: AsyncClient,
123+
table: sql.Composable,
124+
return_constructor: Type[T]
125+
) -> None:
126+
super().__init__(table, return_constructor)
99127
self._client = client
100128

101129
async def one_by_id(self, id_value: str) -> T:
102130
"""Delete one record with matching ID."""
103-
result: List[T] = await self._client.execute_and_return(
104-
self._query_one_by_id(id_value))
131+
query_result: List[RealDictRow] = \
132+
await self._client.execute_and_return(
133+
self._query_one_by_id(id_value))
105134

106135
# Should only return one item from DB
107-
ensure_exactly_one(result)
136+
ensure_exactly_one(query_result)
137+
result = self._return_constructor(**query_result[0])
108138

109-
return result[0]
139+
return result
110140

111141

112142
class AsyncModel(ModelABC[T]):
@@ -122,19 +152,22 @@ class AsyncModel(ModelABC[T]):
122152
_update: AsyncUpdate[T]
123153
_delete: AsyncDelete[T]
124154

125-
# PENDS python 3.9 support in pylint
126-
# pylint: disable=unsubscriptable-object
127155
def __init__(
128156
self,
129157
client: AsyncClient,
130158
table: str,
159+
return_constructor: Type[T],
131160
) -> None:
132161
super().__init__(client, table)
133162

134-
self._create = AsyncCreate[T](self.client, self.table)
135-
self._read = AsyncRead[T](self.client, self.table)
136-
self._update = AsyncUpdate[T](self.client, self.table)
137-
self._delete = AsyncDelete[T](self.client, self.table)
163+
self._create = AsyncCreate[T](
164+
self.client, self.table, return_constructor)
165+
self._read = AsyncRead[T](
166+
self.client, self.table, return_constructor)
167+
self._update = AsyncUpdate[T](
168+
self.client, self.table, return_constructor)
169+
self._delete = AsyncDelete[T](
170+
self.client, self.table, return_constructor)
138171

139172
@property
140173
def create(self) -> AsyncCreate[T]:

db_wrapper/model/sync_model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77

88
from db_wrapper.client import SyncClient
99
from .base import (
10+
ensure_exactly_one,
11+
sql,
1012
T,
1113
CreateABC,
1214
DeleteABC,
1315
ReadABC,
1416
UpdateABC,
1517
ModelABC,
16-
ensure_exactly_one,
17-
sql,
1818
)
1919

2020

@@ -150,8 +150,6 @@ class SyncModel(ModelABC[T]):
150150
_update: SyncUpdate[T]
151151
_delete: SyncDelete[T]
152152

153-
# PENDS python 3.9 support in pylint
154-
# pylint: disable=unsubscriptable-object
155153
def __init__(
156154
self,
157155
client: SyncClient,
@@ -162,7 +160,8 @@ def __init__(
162160

163161
self._create = SyncCreate[T](
164162
self.client, self.table, return_constructor)
165-
self._read = SyncRead[T](self.client, self.table, return_constructor)
163+
self._read = SyncRead[T](
164+
self.client, self.table, return_constructor)
166165
self._update = SyncUpdate[T](
167166
self.client, self.table, return_constructor)
168167
self._delete = SyncDelete[T](

0 commit comments

Comments
 (0)