Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions dev_env/postgres/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@
conn.bulk_insert("employees", [{"name": "rob", "department": "hr"}])
base_query = "SELECT * FROM employees"

logger.info("creating table")
conn.query("CREATE TABLE IF NOT EXISTS vacation ( name VARCHAR(128), days INT )")
logger.info("inserting into vacation")
conn.query("INSERT INTO vacation (name, days) VALUES ('rob', 4)")
logger.info("selecting from vacation")
print(conn.query("SELECT * FROM vacation"))
logger.info("dropping vacation")
conn.query("DROP TABLE vacation")

logger.info("Querying with pagination:")
for i, row in enumerate(conn.query(base_query)):
print(row)
Expand Down
13 changes: 10 additions & 3 deletions src/pyapiary/dbms_connectors/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ def query(self, query: str, params=None):
https://www.psycopg.org/psycopg3/docs/api/pool.html#module-psycopg_pool
"""
with self.connection_pool.connection() as conn:
with conn.transaction():
with conn.cursor() as cur:
# claude recommended a transaction wrapper here
return conn.execute(query, params).fetchall()
cur.execute(query, params)
if cur.description:
return cur.fetchall()
else:
return None

def bulk_insert(self, table: str, data: List[Dict[str, Any]]):
if not data:
Expand Down Expand Up @@ -82,7 +86,10 @@ async def async_query(self, query: str, params=None):
"""
async with self.connection_pool.connection() as conn:
cur = await conn.execute(query, params)
return await cur.fetchall()
if cur.description:
return await cur.fetchall()
else:
return None

async def async_bulk_insert(self, table_name: str, data: List[Dict[str, Any]]):
if not data:
Expand Down
116 changes: 73 additions & 43 deletions src/pyapiary/tests/test_postgres/test_unit_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,26 @@

import pytest

# Mock psycopg_pool before importing the module so tests run even when the
# driver is not installed in the environment.
sys.modules.setdefault("psycopg_pool", MagicMock())

from pyapiary.dbms_connectors.postgres import PostgresConnector, AsyncPostgresConnector


# ──────────────────────────────────────────────
# Helpers
# ──────────────────────────────────────────────

def make_sync_conn(cursor):
mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)
return mock_conn

def wire_pool(pg, conn):
pg.connection_pool.connection.return_value.__enter__ = MagicMock(return_value=conn)
pg.connection_pool.connection.return_value.__exit__ = MagicMock(return_value=False)


# ──────────────────────────────────────────────
# Sync PostgresConnector
# ──────────────────────────────────────────────
Expand Down Expand Up @@ -76,28 +89,42 @@ def test_close_when_pool_is_none(self, pg):


class TestPostgresConnectorQuery:
def test_query_returns_results(self, pg):
mock_conn = MagicMock()
mock_conn.execute.return_value.fetchall.return_value = [("row1",), ("row2",)]
pg.connection_pool.connection.return_value.__enter__ = MagicMock(return_value=mock_conn)
pg.connection_pool.connection.return_value.__exit__ = MagicMock(return_value=False)
mock_conn.transaction.return_value.__enter__ = MagicMock()
mock_conn.transaction.return_value.__exit__ = MagicMock(return_value=False)
def test_select_returns_rows(self, pg):
cur = MagicMock()
cur.description = [("col1",)]
cur.fetchall.return_value = [("row1",), ("row2",)]
wire_pool(pg, make_sync_conn(cur))

result = pg.query("SELECT 1")
assert result == [("row1",), ("row2",)]
mock_conn.execute.assert_called_once_with("SELECT 1", None)
cur.execute.assert_called_once_with("SELECT 1", None)

def test_query_passes_params(self, pg):
mock_conn = MagicMock()
mock_conn.execute.return_value.fetchall.return_value = []
pg.connection_pool.connection.return_value.__enter__ = MagicMock(return_value=mock_conn)
pg.connection_pool.connection.return_value.__exit__ = MagicMock(return_value=False)
mock_conn.transaction.return_value.__enter__ = MagicMock()
mock_conn.transaction.return_value.__exit__ = MagicMock(return_value=False)
def test_select_passes_params(self, pg):
cur = MagicMock()
cur.description = [("col1",)]
cur.fetchall.return_value = []
wire_pool(pg, make_sync_conn(cur))

pg.query("SELECT * FROM t WHERE id = %s", (42,))
mock_conn.execute.assert_called_once_with("SELECT * FROM t WHERE id = %s", (42,))
cur.execute.assert_called_once_with("SELECT * FROM t WHERE id = %s", (42,))

def test_non_select_returns_none(self, pg):
cur = MagicMock()
cur.description = None # INSERT/DDL has no description
wire_pool(pg, make_sync_conn(cur))

result = pg.query("INSERT INTO t VALUES (1)")
assert result is None
cur.fetchall.assert_not_called()

def test_select_empty_result_returns_empty_list(self, pg):
cur = MagicMock()
cur.description = [("col1",)]
cur.fetchall.return_value = []
wire_pool(pg, make_sync_conn(cur))

result = pg.query("SELECT 1 WHERE false")
assert result == []


class TestPostgresConnectorBulkInsert:
Expand All @@ -111,12 +138,7 @@ def test_bulk_insert_calls_copy(self, pg):
mock_cursor.copy.return_value.__enter__ = MagicMock(return_value=mock_copy)
mock_cursor.copy.return_value.__exit__ = MagicMock(return_value=False)

mock_conn = MagicMock()
mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = MagicMock(return_value=False)

pg.connection_pool.connection.return_value.__enter__ = MagicMock(return_value=mock_conn)
pg.connection_pool.connection.return_value.__exit__ = MagicMock(return_value=False)
wire_pool(pg, make_sync_conn(mock_cursor))

data = [{"name": "alice", "age": 30}, {"name": "bob", "age": 25}]
pg.bulk_insert("users", data)
Expand Down Expand Up @@ -196,36 +218,44 @@ async def test_aexit_closes_pool(self, async_pg):


class TestAsyncPostgresConnectorQuery:
@pytest.mark.asyncio
async def test_async_query_returns_results(self, async_pg):
mock_cursor = AsyncMock()
mock_cursor.fetchall.return_value = [("row1",)]

def _wire(self, async_pg, cur):
mock_conn = AsyncMock()
mock_conn.execute.return_value = mock_cursor

mock_conn.execute = AsyncMock(return_value=cur)
async_cm = AsyncMock()
async_cm.__aenter__.return_value = mock_conn
async_pg.connection_pool.connection.return_value = async_cm
return mock_conn

@pytest.mark.asyncio
async def test_select_returns_rows(self, async_pg):
cur = AsyncMock()
cur.description = [("col1",)]
cur.fetchall = AsyncMock(return_value=[("row1",)])
conn = self._wire(async_pg, cur)

result = await async_pg.async_query("SELECT 1")
assert result == [("row1",)]
mock_conn.execute.assert_awaited_once_with("SELECT 1", None)
conn.execute.assert_awaited_once_with("SELECT 1", None)

@pytest.mark.asyncio
async def test_async_query_passes_params(self, async_pg):
mock_cursor = AsyncMock()
mock_cursor.fetchall.return_value = []
async def test_select_passes_params(self, async_pg):
cur = AsyncMock()
cur.description = [("col1",)]
cur.fetchall = AsyncMock(return_value=[])
conn = self._wire(async_pg, cur)

mock_conn = AsyncMock()
mock_conn.execute.return_value = mock_cursor
await async_pg.async_query("SELECT * FROM t WHERE id = %s", (1,))
conn.execute.assert_awaited_once_with("SELECT * FROM t WHERE id = %s", (1,))

async_cm = AsyncMock()
async_cm.__aenter__.return_value = mock_conn
async_pg.connection_pool.connection.return_value = async_cm
@pytest.mark.asyncio
async def test_non_select_returns_none(self, async_pg):
cur = AsyncMock()
cur.description = None
conn = self._wire(async_pg, cur)

await async_pg.async_query("SELECT * FROM t WHERE id = %s", (1,))
mock_conn.execute.assert_awaited_once_with("SELECT * FROM t WHERE id = %s", (1,))
result = await async_pg.async_query("INSERT INTO t VALUES (1)")
assert result is None
cur.fetchall.assert_not_called()


class TestAsyncPostgresConnectorBulkInsert:
Expand Down Expand Up @@ -256,4 +286,4 @@ async def test_async_bulk_insert_calls_copy(self, async_pg):
mock_cursor.copy.assert_called_once_with("COPY users (name, age) FROM STDIN")
assert mock_copy.write_row.await_count == 2
mock_copy.write_row.assert_any_await(("alice", 30))
mock_copy.write_row.assert_any_await(("bob", 25))
mock_copy.write_row.assert_any_await(("bob", 25))
3 changes: 3 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading