Skip to content

Commit 51ce25a

Browse files
authored
feat: Add CleanRooms read module (#2366)
1 parent e8291b4 commit 51ce25a

16 files changed

+654
-60
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,12 @@ building/lambda/arrow
153153
*.swp
154154

155155
# CDK
156+
node_modules
157+
*package.json
156158
*package-lock.json
157159
*.cdk.staging
158160
*cdk.out
161+
*cdk.context.json
159162

160163
# ruff
161164
.ruff_cache/

awswrangler/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
athena,
1212
catalog,
1313
chime,
14+
cleanrooms,
1415
cloudwatch,
1516
data_api,
1617
data_quality,
@@ -43,6 +44,7 @@
4344
"athena",
4445
"catalog",
4546
"chime",
47+
"cleanrooms",
4648
"cloudwatch",
4749
"emr",
4850
"emr_serverless",

awswrangler/_utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from boto3.resources.base import ServiceResource
4646
from botocore.client import BaseClient
4747
from mypy_boto3_athena import AthenaClient
48+
from mypy_boto3_cleanrooms import CleanRoomsServiceClient
4849
from mypy_boto3_dynamodb import DynamoDBClient, DynamoDBServiceResource
4950
from mypy_boto3_ec2 import EC2Client
5051
from mypy_boto3_emr.client import EMRClient
@@ -68,6 +69,7 @@
6869

6970
ServiceName = Literal[
7071
"athena",
72+
"cleanrooms",
7173
"dynamodb",
7274
"ec2",
7375
"emr",
@@ -286,6 +288,16 @@ def client(
286288
...
287289

288290

291+
@overload
292+
def client(
293+
service_name: 'Literal["cleanrooms"]',
294+
session: Optional[boto3.Session] = None,
295+
botocore_config: Optional[Config] = None,
296+
verify: Optional[Union[str, bool]] = None,
297+
) -> "CleanRoomsServiceClient":
298+
...
299+
300+
289301
@overload
290302
def client(
291303
service_name: 'Literal["lakeformation"]',

awswrangler/cleanrooms/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Amazon Clean Rooms Module."""
2+
3+
from awswrangler.cleanrooms._read import read_sql_query
4+
from awswrangler.cleanrooms._utils import wait_query
5+
6+
__all__ = [
7+
"read_sql_query",
8+
"wait_query",
9+
]

awswrangler/cleanrooms/_read.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Amazon Clean Rooms Module hosting read_* functions."""
2+
3+
import logging
4+
from typing import Any, Dict, Iterator, Optional, Union
5+
6+
import boto3
7+
8+
import awswrangler.pandas as pd
9+
from awswrangler import _utils, s3
10+
from awswrangler._sql_formatter import _process_sql_params
11+
from awswrangler.cleanrooms._utils import wait_query
12+
13+
_logger: logging.Logger = logging.getLogger(__name__)
14+
15+
16+
def _delete_after_iterate(
17+
dfs: Iterator[pd.DataFrame], keep_files: bool, kwargs: Dict[str, Any]
18+
) -> Iterator[pd.DataFrame]:
19+
for df in dfs:
20+
yield df
21+
if keep_files is False:
22+
s3.delete_objects(**kwargs)
23+
24+
25+
def read_sql_query(
26+
sql: str,
27+
membership_id: str,
28+
output_bucket: str,
29+
output_prefix: str,
30+
keep_files: bool = True,
31+
params: Optional[Dict[str, Any]] = None,
32+
chunksize: Optional[Union[int, bool]] = None,
33+
use_threads: Union[bool, int] = True,
34+
boto3_session: Optional[boto3.Session] = None,
35+
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
36+
) -> Union[Iterator[pd.DataFrame], pd.DataFrame]:
37+
"""Execute Clean Rooms Protected SQL query and return the results as a Pandas DataFrame.
38+
39+
Parameters
40+
----------
41+
sql : str
42+
SQL query
43+
membership_id : str
44+
Membership ID
45+
output_bucket : str
46+
S3 output bucket name
47+
output_prefix : str
48+
S3 output prefix
49+
keep_files : bool, optional
50+
Whether files in S3 output bucket/prefix are retained. 'True' by default
51+
params : Dict[str, any], optional
52+
Dict of parameters used for constructing the SQL query. Only named parameters are supported.
53+
The dict must be in the form {'name': 'value'} and the SQL query must contain
54+
`:name`. Note that for varchar columns and similar, you must surround the value in single quotes
55+
chunksize : Union[int, bool], optional
56+
If passed, the data is split into an iterable of DataFrames (Memory friendly).
57+
If `True` an iterable of DataFrames is returned without guarantee of chunksize.
58+
If an `INTEGER` is passed, an iterable of DataFrames is returned with maximum rows
59+
equal to the received INTEGER
60+
use_threads : Union[bool, int], optional
61+
True to enable concurrent requests, False to disable multiple threads.
62+
If enabled os.cpu_count() is used as the maximum number of threads.
63+
If integer is provided, specified number is used
64+
boto3_session : boto3.Session, optional
65+
Boto3 Session. If None, the default boto3 session is used
66+
pyarrow_additional_kwargs : Optional[Dict[str, Any]]
67+
Forwarded to `to_pandas` method converting from PyArrow tables to Pandas DataFrame.
68+
Valid values include "split_blocks", "self_destruct", "ignore_metadata".
69+
e.g. pyarrow_additional_kwargs={'split_blocks': True}
70+
71+
Returns
72+
-------
73+
Union[Iterator[pd.DataFrame], pd.DataFrame]
74+
Pandas DataFrame or Generator of Pandas DataFrames if chunksize is provided.
75+
76+
Examples
77+
--------
78+
>>> import awswrangler as wr
79+
>>> df = wr.cleanrooms.read_sql_query(
80+
>>> sql='SELECT DISTINCT...',
81+
>>> membership_id='membership-id',
82+
>>> output_bucket='output-bucket',
83+
>>> output_prefix='output-prefix',
84+
>>> )
85+
"""
86+
client_cleanrooms = _utils.client(service_name="cleanrooms", session=boto3_session)
87+
88+
query_id: str = client_cleanrooms.start_protected_query(
89+
type="SQL",
90+
membershipIdentifier=membership_id,
91+
sqlParameters={"queryString": _process_sql_params(sql, params, engine_type="partiql")},
92+
resultConfiguration={
93+
"outputConfiguration": {
94+
"s3": {
95+
"bucket": output_bucket,
96+
"keyPrefix": output_prefix,
97+
"resultFormat": "PARQUET",
98+
}
99+
}
100+
},
101+
)["protectedQuery"]["id"]
102+
103+
_logger.debug("query_id: %s", query_id)
104+
path: str = wait_query(membership_id=membership_id, query_id=query_id)["protectedQuery"]["result"]["output"]["s3"][
105+
"location"
106+
]
107+
108+
_logger.debug("path: %s", path)
109+
chunked: Union[bool, int] = False if chunksize is None else chunksize
110+
ret = s3.read_parquet(
111+
path=path,
112+
use_threads=use_threads,
113+
chunked=chunked,
114+
boto3_session=boto3_session,
115+
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
116+
)
117+
118+
_logger.debug("type(ret): %s", type(ret))
119+
kwargs: Dict[str, Any] = {
120+
"path": path,
121+
"use_threads": use_threads,
122+
"boto3_session": boto3_session,
123+
}
124+
if chunked is False:
125+
if keep_files is False:
126+
s3.delete_objects(**kwargs)
127+
return ret
128+
return _delete_after_iterate(ret, keep_files, kwargs)

awswrangler/cleanrooms/_utils.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Utilities Module for Amazon Clean Rooms."""
2+
import logging
3+
import time
4+
from typing import TYPE_CHECKING, List, Optional
5+
6+
import boto3
7+
8+
from awswrangler import _utils, exceptions
9+
10+
if TYPE_CHECKING:
11+
from mypy_boto3_cleanrooms.type_defs import GetProtectedQueryOutputTypeDef
12+
13+
_QUERY_FINAL_STATES: List[str] = ["CANCELLED", "FAILED", "SUCCESS", "TIMED_OUT"]
14+
_QUERY_WAIT_POLLING_DELAY: float = 2 # SECONDS
15+
16+
_logger: logging.Logger = logging.getLogger(__name__)
17+
18+
19+
def wait_query(
20+
membership_id: str, query_id: str, boto3_session: Optional[boto3.Session] = None
21+
) -> "GetProtectedQueryOutputTypeDef":
22+
"""Wait for the Clean Rooms protected query to end.
23+
24+
Parameters
25+
----------
26+
membership_id : str
27+
Membership ID
28+
query_id : str
29+
Protected query execution ID
30+
boto3_session : boto3.Session, optional
31+
Boto3 Session. If None, the default boto3 session is used
32+
Returns
33+
-------
34+
Dict[str, Any]
35+
Dictionary with the get_protected_query response.
36+
37+
Raises
38+
------
39+
exceptions.QueryFailed
40+
Raises exception with error message if protected query is cancelled, times out or fails.
41+
42+
Examples
43+
--------
44+
>>> import awswrangler as wr
45+
>>> res = wr.cleanrooms.wait_query(membership_id='membership-id', query_id='query-id')
46+
"""
47+
client_cleanrooms = _utils.client(service_name="cleanrooms", session=boto3_session)
48+
state = "SUBMITTED"
49+
50+
while state not in _QUERY_FINAL_STATES:
51+
time.sleep(_QUERY_WAIT_POLLING_DELAY)
52+
response = client_cleanrooms.get_protected_query(
53+
membershipIdentifier=membership_id, protectedQueryIdentifier=query_id
54+
)
55+
state = response["protectedQuery"].get("status") # type: ignore[assignment]
56+
57+
_logger.debug("state: %s", state)
58+
if state != "SUCCESS":
59+
raise exceptions.QueryFailed(response["protectedQuery"].get("Error"))
60+
return response

docs/source/api.rst

+12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ API Reference
1717
* `Amazon Neptune`_
1818
* `DynamoDB`_
1919
* `Amazon Timestream`_
20+
* `AWS Clean Rooms`_
2021
* `Amazon EMR`_
2122
* `Amazon EMR Serverless`_
2223
* `Amazon CloudWatch Logs`_
@@ -351,6 +352,17 @@ Amazon Timestream
351352
unload_to_files
352353
unload
353354

355+
AWS Clean Rooms
356+
-----------------
357+
358+
.. currentmodule:: awswrangler.cleanrooms
359+
360+
.. autosummary::
361+
:toctree: stubs
362+
363+
read_sql_query
364+
wait_query
365+
354366
Amazon EMR
355367
----------
356368

poetry.lock

+19-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ wheel = "^0.38.1"
8585

8686
# Lint
8787
black = "^23.1.0"
88-
boto3-stubs = {version = "1.26.151", extras = ["athena", "chime", "cloudwatch", "dynamodb", "ec2", "emr", "emr-serverless", "glue", "kms", "lakeformation", "logs", "neptune", "opensearch", "opensearchserverless", "quicksight", "rds", "rds-data", "redshift", "redshift-data", "s3", "secretsmanager", "ssm", "sts", "timestream-query", "timestream-write"]}
88+
boto3-stubs = {version = "^1.26.151", extras = ["athena", "cleanrooms", "chime", "cloudwatch", "dynamodb", "ec2", "emr", "emr-serverless", "glue", "kms", "lakeformation", "logs", "neptune", "opensearch", "opensearchserverless", "quicksight", "rds", "rds-data", "redshift", "redshift-data", "s3", "secretsmanager", "ssm", "sts", "timestream-query", "timestream-write"]}
8989
doc8 = "^1.0"
9090
mypy = "^1.0"
9191
pylint = "^2.17"

test_infra/app.py

+7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from aws_cdk import App, Environment
55
from stacks.base_stack import BaseStack
6+
from stacks.cleanrooms_stack import CleanRoomsStack
67
from stacks.databases_stack import DatabasesStack
78
from stacks.glueray_stack import GlueRayStack
89
from stacks.opensearch_stack import OpenSearchStack
@@ -42,4 +43,10 @@
4243
**env,
4344
)
4445

46+
CleanRoomsStack(
47+
app,
48+
"aws-sdk-pandas-cleanrooms",
49+
**env,
50+
)
51+
4552
app.synth()

0 commit comments

Comments
 (0)