Skip to content

Commit 3183c5e

Browse files
committed
add schema option
1 parent 3336e7c commit 3183c5e

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

docs/connectors/sinks/postgresql-sink.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,22 @@ PostgreSQLSink provides at-least-once guarantees, meaning that the same records
7575

7676
PostgreSQLSink accepts the following configuration parameters:
7777

78+
## Required
79+
7880
- `host`: The address of the PostgreSQL server.
7981
- `port`: The port of the PostgreSQL server.
8082
- `dbname`: The name of the PostgreSQL database.
8183
- `user`: The database user name.
8284
- `password`: The database user password.
83-
- `table_name`: The name of the PostgreSQL table where data will be written.
85+
- `table_name`: PostgreSQL table name as either a string or a callable which receives
86+
a `SinkItem` (from quixstreams.sinks.base.item) and returns a string.
87+
88+
89+
### Optional
90+
91+
- `schema_name`: The schema name. Schemas are a way of organizing tables and
92+
not related to the table data, referenced as `<schema_name>.<table_name>`.
93+
PostrgeSQL uses "public" by default under the hood.
8494
- `schema_auto_update`: If True, the sink will automatically update the schema by adding new columns when new fields are detected. Default: True.
8595

8696

quixstreams/sinks/community/postgresql.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
user: str,
6161
password: str,
6262
table_name: Union[Callable[[SinkItem], str], str],
63+
schema_name: str = "public",
6364
schema_auto_update: bool = True,
6465
connection_timeout_seconds: int = 30,
6566
statement_timeout_seconds: int = 30,
@@ -77,6 +78,9 @@ def __init__(
7778
:param password: Database user password.
7879
:param table_name: PostgreSQL table name as either a string or a callable which
7980
receives a SinkItem and returns a string.
81+
:param schema_name: The schema name. Schemas are a way of organizing tables and
82+
not related to the table data, referenced as `<schema_name>.<table_name>`.
83+
PostrgeSQL uses "public" by default under the hood.
8084
:param schema_auto_update: Automatically update the schema when new columns are detected.
8185
:param connection_timeout_seconds: Timeout for connection.
8286
:param statement_timeout_seconds: Timeout for DDL operations such as table
@@ -95,6 +99,7 @@ def __init__(
9599
)
96100
self._table_name = _table_name_setter(table_name)
97101
self._tables = set()
102+
self._schema_name = schema_name
98103
self._schema_auto_update = schema_auto_update
99104
options = kwargs.pop("options", "")
100105
if "statement_timeout" not in options:
@@ -113,6 +118,7 @@ def __init__(
113118

114119
def setup(self):
115120
self._client = psycopg2.connect(**self._client_settings)
121+
self._create_schema()
116122

117123
def write(self, batch: SinkBatch):
118124
tables = {}
@@ -138,14 +144,22 @@ def write(self, batch: SinkBatch):
138144
with self._client:
139145
for name, values in tables.items():
140146
if self._schema_auto_update:
141-
self._init_table(name)
147+
self._create_table(name)
142148
self._add_new_columns(name, values["cols_types"])
143149
self._insert_rows(name, values["rows"])
144150
except psycopg2.Error as e:
145151
self._client.rollback()
146152
raise PostgreSQLSinkException(f"Failed to write batch: {str(e)}") from e
147153
table_counts = {table: len(values["rows"]) for table, values in tables.items()}
148-
logger.info(f"Successfully wrote records to tables; row counts: {table_counts}")
154+
schema_log = (
155+
" "
156+
if self._schema_name == "public"
157+
else f" for schema '{self._schema_name}' "
158+
)
159+
logger.info(
160+
f"Successfully wrote records{schema_log}to tables; "
161+
f"table row counts: {table_counts}"
162+
)
149163

150164
def add(
151165
self,
@@ -172,7 +186,15 @@ def add(
172186
offset=offset,
173187
)
174188

175-
def _init_table(self, table_name: str):
189+
def _create_schema(self):
190+
query = sql.SQL("CREATE SCHEMA IF NOT EXISTS {}").format(
191+
sql.Identifier(self._schema_name)
192+
)
193+
194+
with self._client.cursor() as cursor:
195+
cursor.execute(query)
196+
197+
def _create_table(self, table_name: str):
176198
if table_name in self._tables:
177199
return
178200
query = sql.SQL(
@@ -183,7 +205,7 @@ def _init_table(self, table_name: str):
183205
)
184206
"""
185207
).format(
186-
table=sql.Identifier(table_name),
208+
table=sql.Identifier(self._schema_name, table_name),
187209
timestamp_col=sql.Identifier(_TIMESTAMP_COLUMN_NAME),
188210
key_col=sql.Identifier(_KEY_COLUMN_NAME),
189211
)
@@ -205,7 +227,7 @@ def _add_new_columns(self, table_name: str, columns: dict[str, type]) -> None:
205227
ADD COLUMN IF NOT EXISTS {column} {col_type}
206228
"""
207229
).format(
208-
table=sql.Identifier(table_name),
230+
table=sql.Identifier(self._schema_name, table_name),
209231
column=sql.Identifier(col_name),
210232
col_type=sql.SQL(postgres_col_type),
211233
)
@@ -223,7 +245,7 @@ def _insert_rows(self, table_name: str, rows: list[dict]) -> None:
223245
values = [[row.get(col, None) for col in columns] for row in rows]
224246

225247
query = sql.SQL("INSERT INTO {table} ({columns}) VALUES %s").format(
226-
table=sql.Identifier(table_name),
248+
table=sql.Identifier(self._schema_name, table_name),
227249
columns=sql.SQL(", ").join(map(sql.Identifier, columns)),
228250
)
229251

0 commit comments

Comments
 (0)