@@ -60,6 +60,7 @@ def __init__(
60
60
user : str ,
61
61
password : str ,
62
62
table_name : Union [Callable [[SinkItem ], str ], str ],
63
+ schema_name : str = "public" ,
63
64
schema_auto_update : bool = True ,
64
65
connection_timeout_seconds : int = 30 ,
65
66
statement_timeout_seconds : int = 30 ,
@@ -77,6 +78,9 @@ def __init__(
77
78
:param password: Database user password.
78
79
:param table_name: PostgreSQL table name as either a string or a callable which
79
80
receives a SinkItem and returns a string.
81
+ :param schema_name: The schema name. Schemas are a way of organizing tables and
82
+ not related to the table data, referenced as `<schema_name>.<table_name>`.
83
+ PostrgeSQL uses "public" by default under the hood.
80
84
:param schema_auto_update: Automatically update the schema when new columns are detected.
81
85
:param connection_timeout_seconds: Timeout for connection.
82
86
:param statement_timeout_seconds: Timeout for DDL operations such as table
@@ -95,6 +99,7 @@ def __init__(
95
99
)
96
100
self ._table_name = _table_name_setter (table_name )
97
101
self ._tables = set ()
102
+ self ._schema_name = schema_name
98
103
self ._schema_auto_update = schema_auto_update
99
104
options = kwargs .pop ("options" , "" )
100
105
if "statement_timeout" not in options :
@@ -113,6 +118,7 @@ def __init__(
113
118
114
119
def setup (self ):
115
120
self ._client = psycopg2 .connect (** self ._client_settings )
121
+ self ._create_schema ()
116
122
117
123
def write (self , batch : SinkBatch ):
118
124
tables = {}
@@ -138,14 +144,22 @@ def write(self, batch: SinkBatch):
138
144
with self ._client :
139
145
for name , values in tables .items ():
140
146
if self ._schema_auto_update :
141
- self ._init_table (name )
147
+ self ._create_table (name )
142
148
self ._add_new_columns (name , values ["cols_types" ])
143
149
self ._insert_rows (name , values ["rows" ])
144
150
except psycopg2 .Error as e :
145
151
self ._client .rollback ()
146
152
raise PostgreSQLSinkException (f"Failed to write batch: { str (e )} " ) from e
147
153
table_counts = {table : len (values ["rows" ]) for table , values in tables .items ()}
148
- logger .info (f"Successfully wrote records to tables; row counts: { table_counts } " )
154
+ schema_log = (
155
+ " "
156
+ if self ._schema_name == "public"
157
+ else f" for schema '{ self ._schema_name } ' "
158
+ )
159
+ logger .info (
160
+ f"Successfully wrote records{ schema_log } to tables; "
161
+ f"table row counts: { table_counts } "
162
+ )
149
163
150
164
def add (
151
165
self ,
@@ -172,7 +186,15 @@ def add(
172
186
offset = offset ,
173
187
)
174
188
175
- def _init_table (self , table_name : str ):
189
+ def _create_schema (self ):
190
+ query = sql .SQL ("CREATE SCHEMA IF NOT EXISTS {}" ).format (
191
+ sql .Identifier (self ._schema_name )
192
+ )
193
+
194
+ with self ._client .cursor () as cursor :
195
+ cursor .execute (query )
196
+
197
+ def _create_table (self , table_name : str ):
176
198
if table_name in self ._tables :
177
199
return
178
200
query = sql .SQL (
@@ -183,7 +205,7 @@ def _init_table(self, table_name: str):
183
205
)
184
206
"""
185
207
).format (
186
- table = sql .Identifier (table_name ),
208
+ table = sql .Identifier (self . _schema_name , table_name ),
187
209
timestamp_col = sql .Identifier (_TIMESTAMP_COLUMN_NAME ),
188
210
key_col = sql .Identifier (_KEY_COLUMN_NAME ),
189
211
)
@@ -205,7 +227,7 @@ def _add_new_columns(self, table_name: str, columns: dict[str, type]) -> None:
205
227
ADD COLUMN IF NOT EXISTS {column} {col_type}
206
228
"""
207
229
).format (
208
- table = sql .Identifier (table_name ),
230
+ table = sql .Identifier (self . _schema_name , table_name ),
209
231
column = sql .Identifier (col_name ),
210
232
col_type = sql .SQL (postgres_col_type ),
211
233
)
@@ -223,7 +245,7 @@ def _insert_rows(self, table_name: str, rows: list[dict]) -> None:
223
245
values = [[row .get (col , None ) for col in columns ] for row in rows ]
224
246
225
247
query = sql .SQL ("INSERT INTO {table} ({columns}) VALUES %s" ).format (
226
- table = sql .Identifier (table_name ),
248
+ table = sql .Identifier (self . _schema_name , table_name ),
227
249
columns = sql .SQL (", " ).join (map (sql .Identifier , columns )),
228
250
)
229
251
0 commit comments