Skip to content

Commit e491295

Browse files
feat: Support ORC and CSV in redshift.copy_from_files function (#2849)
1 parent 626a321 commit e491295

File tree

3 files changed

+163
-47
lines changed

3 files changed

+163
-47
lines changed

awswrangler/redshift/_utils.py

+54-24
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,21 @@
66
import json
77
import logging
88
import uuid
9+
from typing import TYPE_CHECKING, Literal
910

1011
import boto3
1112
import botocore
1213
import pandas as pd
1314

1415
from awswrangler import _data_types, _sql_utils, _utils, exceptions, s3
1516

16-
redshift_connector = _utils.import_optional_dependency("redshift_connector")
17+
if TYPE_CHECKING:
18+
try:
19+
import redshift_connector
20+
except ImportError:
21+
pass
22+
else:
23+
redshift_connector = _utils.import_optional_dependency("redshift_connector")
1724

1825
_logger: logging.Logger = logging.getLogger(__name__)
1926

@@ -217,6 +224,7 @@ def _validate_parameters(
217224

218225
def _redshift_types_from_path(
219226
path: str | list[str],
227+
data_format: Literal["parquet", "orc"],
220228
varchar_lengths_default: int,
221229
varchar_lengths: dict[str, int] | None,
222230
parquet_infer_sampling: float,
@@ -229,16 +237,27 @@ def _redshift_types_from_path(
229237
"""Extract Redshift data types from a Pandas DataFrame."""
230238
_varchar_lengths: dict[str, int] = {} if varchar_lengths is None else varchar_lengths
231239
_logger.debug("Scanning parquet schemas in S3 path: %s", path)
232-
athena_types, _ = s3.read_parquet_metadata(
233-
path=path,
234-
sampling=parquet_infer_sampling,
235-
path_suffix=path_suffix,
236-
path_ignore_suffix=path_ignore_suffix,
237-
dataset=False,
238-
use_threads=use_threads,
239-
boto3_session=boto3_session,
240-
s3_additional_kwargs=s3_additional_kwargs,
241-
)
240+
if data_format == "orc":
241+
athena_types, _ = s3.read_orc_metadata(
242+
path=path,
243+
path_suffix=path_suffix,
244+
path_ignore_suffix=path_ignore_suffix,
245+
dataset=False,
246+
use_threads=use_threads,
247+
boto3_session=boto3_session,
248+
s3_additional_kwargs=s3_additional_kwargs,
249+
)
250+
else:
251+
athena_types, _ = s3.read_parquet_metadata(
252+
path=path,
253+
sampling=parquet_infer_sampling,
254+
path_suffix=path_suffix,
255+
path_ignore_suffix=path_ignore_suffix,
256+
dataset=False,
257+
use_threads=use_threads,
258+
boto3_session=boto3_session,
259+
s3_additional_kwargs=s3_additional_kwargs,
260+
)
242261
_logger.debug("Parquet metadata types: %s", athena_types)
243262
redshift_types: dict[str, str] = {}
244263
for col_name, col_type in athena_types.items():
@@ -248,7 +267,7 @@ def _redshift_types_from_path(
248267
return redshift_types
249268

250269

251-
def _create_table( # noqa: PLR0912,PLR0915
270+
def _create_table( # noqa: PLR0912,PLR0913,PLR0915
252271
df: pd.DataFrame | None,
253272
path: str | list[str] | None,
254273
con: "redshift_connector.Connection",
@@ -266,6 +285,8 @@ def _create_table( # noqa: PLR0912,PLR0915
266285
primary_keys: list[str] | None,
267286
varchar_lengths_default: int,
268287
varchar_lengths: dict[str, int] | None,
288+
data_format: Literal["parquet", "orc", "csv"] = "parquet",
289+
redshift_column_types: dict[str, str] | None = None,
269290
parquet_infer_sampling: float = 1.0,
270291
path_suffix: str | None = None,
271292
path_ignore_suffix: str | list[str] | None = None,
@@ -336,19 +357,28 @@ def _create_table( # noqa: PLR0912,PLR0915
336357
path=path,
337358
boto3_session=boto3_session,
338359
)
339-
redshift_types = _redshift_types_from_path(
340-
path=path,
341-
varchar_lengths_default=varchar_lengths_default,
342-
varchar_lengths=varchar_lengths,
343-
parquet_infer_sampling=parquet_infer_sampling,
344-
path_suffix=path_suffix,
345-
path_ignore_suffix=path_ignore_suffix,
346-
use_threads=use_threads,
347-
boto3_session=boto3_session,
348-
s3_additional_kwargs=s3_additional_kwargs,
349-
)
360+
361+
if data_format in ["parquet", "orc"]:
362+
redshift_types = _redshift_types_from_path(
363+
path=path,
364+
data_format=data_format, # type: ignore[arg-type]
365+
varchar_lengths_default=varchar_lengths_default,
366+
varchar_lengths=varchar_lengths,
367+
parquet_infer_sampling=parquet_infer_sampling,
368+
path_suffix=path_suffix,
369+
path_ignore_suffix=path_ignore_suffix,
370+
use_threads=use_threads,
371+
boto3_session=boto3_session,
372+
s3_additional_kwargs=s3_additional_kwargs,
373+
)
374+
else:
375+
if redshift_column_types is None:
376+
raise ValueError(
377+
"redshift_column_types is None. It must be specified for files formats other than Parquet or ORC."
378+
)
379+
redshift_types = redshift_column_types
350380
else:
351-
raise ValueError("df and path are None.You MUST pass at least one.")
381+
raise ValueError("df and path are None. You MUST pass at least one.")
352382
_validate_parameters(
353383
redshift_types=redshift_types,
354384
diststyle=diststyle,

awswrangler/redshift/_write.py

+54-18
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import logging
6-
from typing import Literal
6+
from typing import TYPE_CHECKING, Literal, get_args
77

88
import boto3
99

@@ -15,21 +15,29 @@
1515
from ._connect import _validate_connection
1616
from ._utils import _create_table, _make_s3_auth_string, _upsert
1717

18-
redshift_connector = _utils.import_optional_dependency("redshift_connector")
18+
if TYPE_CHECKING:
19+
try:
20+
import redshift_connector
21+
except ImportError:
22+
pass
23+
else:
24+
redshift_connector = _utils.import_optional_dependency("redshift_connector")
1925

2026
_logger: logging.Logger = logging.getLogger(__name__)
2127

2228
_ToSqlModeLiteral = Literal["append", "overwrite", "upsert"]
2329
_ToSqlOverwriteModeLiteral = Literal["drop", "cascade", "truncate", "delete"]
2430
_ToSqlDistStyleLiteral = Literal["AUTO", "EVEN", "ALL", "KEY"]
2531
_ToSqlSortStyleLiteral = Literal["COMPOUND", "INTERLEAVED"]
32+
_CopyFromFilesDataFormatLiteral = Literal["parquet", "orc", "csv"]
2633

2734

2835
def _copy(
29-
cursor: "redshift_connector.Cursor", # type: ignore[name-defined]
36+
cursor: "redshift_connector.Cursor",
3037
path: str,
3138
table: str,
3239
serialize_to_json: bool,
40+
data_format: _CopyFromFilesDataFormatLiteral = "parquet",
3341
iam_role: str | None = None,
3442
aws_access_key_id: str | None = None,
3543
aws_secret_access_key: str | None = None,
@@ -45,6 +53,11 @@ def _copy(
4553
else:
4654
table_name = f'"{schema}"."{table}"'
4755

56+
if data_format not in ["parquet", "orc"] and serialize_to_json:
57+
raise exceptions.InvalidArgumentCombination(
58+
"You can only use SERIALIZETOJSON with data_format='parquet' or 'orc'."
59+
)
60+
4861
auth_str: str = _make_s3_auth_string(
4962
iam_role=iam_role,
5063
aws_access_key_id=aws_access_key_id,
@@ -54,7 +67,9 @@ def _copy(
5467
)
5568
ser_json_str: str = " SERIALIZETOJSON" if serialize_to_json else ""
5669
column_names_str: str = f"({','.join(column_names)})" if column_names else ""
57-
sql = f"COPY {table_name} {column_names_str}\nFROM '{path}' {auth_str}\nFORMAT AS PARQUET{ser_json_str}"
70+
sql = (
71+
f"COPY {table_name} {column_names_str}\nFROM '{path}' {auth_str}\nFORMAT AS {data_format.upper()}{ser_json_str}"
72+
)
5873

5974
if manifest:
6075
sql += "\nMANIFEST"
@@ -68,7 +83,7 @@ def _copy(
6883
@apply_configs
6984
def to_sql(
7085
df: pd.DataFrame,
71-
con: "redshift_connector.Connection", # type: ignore[name-defined]
86+
con: "redshift_connector.Connection",
7287
table: str,
7388
schema: str,
7489
mode: _ToSqlModeLiteral = "append",
@@ -240,13 +255,15 @@ def to_sql(
240255
@_utils.check_optional_dependency(redshift_connector, "redshift_connector")
241256
def copy_from_files( # noqa: PLR0913
242257
path: str,
243-
con: "redshift_connector.Connection", # type: ignore[name-defined]
258+
con: "redshift_connector.Connection",
244259
table: str,
245260
schema: str,
246261
iam_role: str | None = None,
247262
aws_access_key_id: str | None = None,
248263
aws_secret_access_key: str | None = None,
249264
aws_session_token: str | None = None,
265+
data_format: _CopyFromFilesDataFormatLiteral = "parquet",
266+
redshift_column_types: dict[str, str] | None = None,
250267
parquet_infer_sampling: float = 1.0,
251268
mode: _ToSqlModeLiteral = "append",
252269
overwrite_method: _ToSqlOverwriteModeLiteral = "drop",
@@ -270,16 +287,19 @@ def copy_from_files( # noqa: PLR0913
270287
precombine_key: str | None = None,
271288
column_names: list[str] | None = None,
272289
) -> None:
273-
"""Load Parquet files from S3 to a Table on Amazon Redshift (Through COPY command).
290+
"""Load files from S3 to a Table on Amazon Redshift (Through COPY command).
274291
275292
https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html
276293
277294
Note
278295
----
279296
If the table does not exist yet,
280297
it will be automatically created for you
281-
using the Parquet metadata to
298+
using the Parquet/ORC/CSV metadata to
282299
infer the columns data types.
300+
If the data is in the CSV format,
301+
the Redshift column types need to be
302+
specified manually using ``redshift_column_types``.
283303
284304
Note
285305
----
@@ -305,6 +325,15 @@ def copy_from_files( # noqa: PLR0913
305325
The secret key for your AWS account.
306326
aws_session_token : str, optional
307327
The session key for your AWS account. This is only needed when you are using temporary credentials.
328+
data_format: str, optional
329+
Data format to be loaded.
330+
Supported values are Parquet, ORC, and CSV.
331+
Default is Parquet.
332+
redshift_column_types: dict, optional
333+
Dictionary with keys as column names and values as Redshift column types.
334+
Only used when ``data_format`` is CSV.
335+
336+
e.g. ```{'col1': 'BIGINT', 'col2': 'VARCHAR(256)'}```
308337
parquet_infer_sampling : float
309338
Random sample ratio of files that will have the metadata inspected.
310339
Must be `0.0 < sampling <= 1.0`.
@@ -382,25 +411,30 @@ def copy_from_files( # noqa: PLR0913
382411
Examples
383412
--------
384413
>>> import awswrangler as wr
385-
>>> con = wr.redshift.connect("MY_GLUE_CONNECTION")
386-
>>> wr.redshift.copy_from_files(
387-
... path="s3://bucket/my_parquet_files/",
388-
... con=con,
389-
... table="my_table",
390-
... schema="public",
391-
... iam_role="arn:aws:iam::XXX:role/XXX"
392-
... )
393-
>>> con.close()
414+
>>> with wr.redshift.connect("MY_GLUE_CONNECTION") as con:
415+
... wr.redshift.copy_from_files(
416+
... path="s3://bucket/my_parquet_files/",
417+
... con=con,
418+
... table="my_table",
419+
... schema="public",
420+
... iam_role="arn:aws:iam::XXX:role/XXX"
421+
... )
394422
395423
"""
396424
_logger.debug("Copying objects from S3 path: %s", path)
425+
426+
data_format = data_format.lower() # type: ignore[assignment]
427+
if data_format not in get_args(_CopyFromFilesDataFormatLiteral):
428+
raise exceptions.InvalidArgumentValue(f"The specified data_format {data_format} is not supported.")
429+
397430
autocommit_temp: bool = con.autocommit
398431
con.autocommit = False
399432
try:
400433
with con.cursor() as cursor:
401434
created_table, created_schema = _create_table(
402435
df=None,
403436
path=path,
437+
data_format=data_format,
404438
parquet_infer_sampling=parquet_infer_sampling,
405439
path_suffix=path_suffix,
406440
path_ignore_suffix=path_ignore_suffix,
@@ -410,6 +444,7 @@ def copy_from_files( # noqa: PLR0913
410444
schema=schema,
411445
mode=mode,
412446
overwrite_method=overwrite_method,
447+
redshift_column_types=redshift_column_types,
413448
diststyle=diststyle,
414449
sortstyle=sortstyle,
415450
distkey=distkey,
@@ -431,6 +466,7 @@ def copy_from_files( # noqa: PLR0913
431466
table=created_table,
432467
schema=created_schema,
433468
iam_role=iam_role,
469+
data_format=data_format,
434470
aws_access_key_id=aws_access_key_id,
435471
aws_secret_access_key=aws_secret_access_key,
436472
aws_session_token=aws_session_token,
@@ -467,7 +503,7 @@ def copy_from_files( # noqa: PLR0913
467503
def copy( # noqa: PLR0913
468504
df: pd.DataFrame,
469505
path: str,
470-
con: "redshift_connector.Connection", # type: ignore[name-defined]
506+
con: "redshift_connector.Connection",
471507
table: str,
472508
schema: str,
473509
iam_role: str | None = None,

0 commit comments

Comments
 (0)