3
3
from __future__ import annotations
4
4
5
5
import logging
6
- from typing import Literal
6
+ from typing import TYPE_CHECKING , Literal , get_args
7
7
8
8
import boto3
9
9
15
15
from ._connect import _validate_connection
16
16
from ._utils import _create_table , _make_s3_auth_string , _upsert
17
17
18
- redshift_connector = _utils .import_optional_dependency ("redshift_connector" )
18
+ if TYPE_CHECKING :
19
+ try :
20
+ import redshift_connector
21
+ except ImportError :
22
+ pass
23
+ else :
24
+ redshift_connector = _utils .import_optional_dependency ("redshift_connector" )
19
25
20
26
_logger : logging .Logger = logging .getLogger (__name__ )
21
27
22
28
_ToSqlModeLiteral = Literal ["append" , "overwrite" , "upsert" ]
23
29
_ToSqlOverwriteModeLiteral = Literal ["drop" , "cascade" , "truncate" , "delete" ]
24
30
_ToSqlDistStyleLiteral = Literal ["AUTO" , "EVEN" , "ALL" , "KEY" ]
25
31
_ToSqlSortStyleLiteral = Literal ["COMPOUND" , "INTERLEAVED" ]
32
+ _CopyFromFilesDataFormatLiteral = Literal ["parquet" , "orc" , "csv" ]
26
33
27
34
28
35
def _copy (
29
- cursor : "redshift_connector.Cursor" , # type: ignore[name-defined]
36
+ cursor : "redshift_connector.Cursor" ,
30
37
path : str ,
31
38
table : str ,
32
39
serialize_to_json : bool ,
40
+ data_format : _CopyFromFilesDataFormatLiteral = "parquet" ,
33
41
iam_role : str | None = None ,
34
42
aws_access_key_id : str | None = None ,
35
43
aws_secret_access_key : str | None = None ,
@@ -45,6 +53,11 @@ def _copy(
45
53
else :
46
54
table_name = f'"{ schema } "."{ table } "'
47
55
56
+ if data_format not in ["parquet" , "orc" ] and serialize_to_json :
57
+ raise exceptions .InvalidArgumentCombination (
58
+ "You can only use SERIALIZETOJSON with data_format='parquet' or 'orc'."
59
+ )
60
+
48
61
auth_str : str = _make_s3_auth_string (
49
62
iam_role = iam_role ,
50
63
aws_access_key_id = aws_access_key_id ,
@@ -54,7 +67,9 @@ def _copy(
54
67
)
55
68
ser_json_str : str = " SERIALIZETOJSON" if serialize_to_json else ""
56
69
column_names_str : str = f"({ ',' .join (column_names )} )" if column_names else ""
57
- sql = f"COPY { table_name } { column_names_str } \n FROM '{ path } ' { auth_str } \n FORMAT AS PARQUET{ ser_json_str } "
70
+ sql = (
71
+ f"COPY { table_name } { column_names_str } \n FROM '{ path } ' { auth_str } \n FORMAT AS { data_format .upper ()} { ser_json_str } "
72
+ )
58
73
59
74
if manifest :
60
75
sql += "\n MANIFEST"
@@ -68,7 +83,7 @@ def _copy(
68
83
@apply_configs
69
84
def to_sql (
70
85
df : pd .DataFrame ,
71
- con : "redshift_connector.Connection" , # type: ignore[name-defined]
86
+ con : "redshift_connector.Connection" ,
72
87
table : str ,
73
88
schema : str ,
74
89
mode : _ToSqlModeLiteral = "append" ,
@@ -240,13 +255,15 @@ def to_sql(
240
255
@_utils .check_optional_dependency (redshift_connector , "redshift_connector" )
241
256
def copy_from_files ( # noqa: PLR0913
242
257
path : str ,
243
- con : "redshift_connector.Connection" , # type: ignore[name-defined]
258
+ con : "redshift_connector.Connection" ,
244
259
table : str ,
245
260
schema : str ,
246
261
iam_role : str | None = None ,
247
262
aws_access_key_id : str | None = None ,
248
263
aws_secret_access_key : str | None = None ,
249
264
aws_session_token : str | None = None ,
265
+ data_format : _CopyFromFilesDataFormatLiteral = "parquet" ,
266
+ redshift_column_types : dict [str , str ] | None = None ,
250
267
parquet_infer_sampling : float = 1.0 ,
251
268
mode : _ToSqlModeLiteral = "append" ,
252
269
overwrite_method : _ToSqlOverwriteModeLiteral = "drop" ,
@@ -270,16 +287,19 @@ def copy_from_files( # noqa: PLR0913
270
287
precombine_key : str | None = None ,
271
288
column_names : list [str ] | None = None ,
272
289
) -> None :
273
- """Load Parquet files from S3 to a Table on Amazon Redshift (Through COPY command).
290
+ """Load files from S3 to a Table on Amazon Redshift (Through COPY command).
274
291
275
292
https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html
276
293
277
294
Note
278
295
----
279
296
If the table does not exist yet,
280
297
it will be automatically created for you
281
- using the Parquet metadata to
298
+ using the Parquet/ORC/CSV metadata to
282
299
infer the columns data types.
300
+ If the data is in the CSV format,
301
+ the Redshift column types need to be
302
+ specified manually using ``redshift_column_types``.
283
303
284
304
Note
285
305
----
@@ -305,6 +325,15 @@ def copy_from_files( # noqa: PLR0913
305
325
The secret key for your AWS account.
306
326
aws_session_token : str, optional
307
327
The session key for your AWS account. This is only needed when you are using temporary credentials.
328
+ data_format: str, optional
329
+ Data format to be loaded.
330
+ Supported values are Parquet, ORC, and CSV.
331
+ Default is Parquet.
332
+ redshift_column_types: dict, optional
333
+ Dictionary with keys as column names and values as Redshift column types.
334
+ Only used when ``data_format`` is CSV.
335
+
336
+ e.g. ```{'col1': 'BIGINT', 'col2': 'VARCHAR(256)'}```
308
337
parquet_infer_sampling : float
309
338
Random sample ratio of files that will have the metadata inspected.
310
339
Must be `0.0 < sampling <= 1.0`.
@@ -382,25 +411,30 @@ def copy_from_files( # noqa: PLR0913
382
411
Examples
383
412
--------
384
413
>>> import awswrangler as wr
385
- >>> con = wr.redshift.connect("MY_GLUE_CONNECTION")
386
- >>> wr.redshift.copy_from_files(
387
- ... path="s3://bucket/my_parquet_files/",
388
- ... con=con,
389
- ... table="my_table",
390
- ... schema="public",
391
- ... iam_role="arn:aws:iam::XXX:role/XXX"
392
- ... )
393
- >>> con.close()
414
+ >>> with wr.redshift.connect("MY_GLUE_CONNECTION") as con:
415
+ ... wr.redshift.copy_from_files(
416
+ ... path="s3://bucket/my_parquet_files/",
417
+ ... con=con,
418
+ ... table="my_table",
419
+ ... schema="public",
420
+ ... iam_role="arn:aws:iam::XXX:role/XXX"
421
+ ... )
394
422
395
423
"""
396
424
_logger .debug ("Copying objects from S3 path: %s" , path )
425
+
426
+ data_format = data_format .lower () # type: ignore[assignment]
427
+ if data_format not in get_args (_CopyFromFilesDataFormatLiteral ):
428
+ raise exceptions .InvalidArgumentValue (f"The specified data_format { data_format } is not supported." )
429
+
397
430
autocommit_temp : bool = con .autocommit
398
431
con .autocommit = False
399
432
try :
400
433
with con .cursor () as cursor :
401
434
created_table , created_schema = _create_table (
402
435
df = None ,
403
436
path = path ,
437
+ data_format = data_format ,
404
438
parquet_infer_sampling = parquet_infer_sampling ,
405
439
path_suffix = path_suffix ,
406
440
path_ignore_suffix = path_ignore_suffix ,
@@ -410,6 +444,7 @@ def copy_from_files( # noqa: PLR0913
410
444
schema = schema ,
411
445
mode = mode ,
412
446
overwrite_method = overwrite_method ,
447
+ redshift_column_types = redshift_column_types ,
413
448
diststyle = diststyle ,
414
449
sortstyle = sortstyle ,
415
450
distkey = distkey ,
@@ -431,6 +466,7 @@ def copy_from_files( # noqa: PLR0913
431
466
table = created_table ,
432
467
schema = created_schema ,
433
468
iam_role = iam_role ,
469
+ data_format = data_format ,
434
470
aws_access_key_id = aws_access_key_id ,
435
471
aws_secret_access_key = aws_secret_access_key ,
436
472
aws_session_token = aws_session_token ,
@@ -467,7 +503,7 @@ def copy_from_files( # noqa: PLR0913
467
503
def copy ( # noqa: PLR0913
468
504
df : pd .DataFrame ,
469
505
path : str ,
470
- con : "redshift_connector.Connection" , # type: ignore[name-defined]
506
+ con : "redshift_connector.Connection" ,
471
507
table : str ,
472
508
schema : str ,
473
509
iam_role : str | None = None ,
0 commit comments