Skip to content

Commit e515b85

Browse files
committed
done
1 parent ecaec3d commit e515b85

File tree

13 files changed

+564
-29
lines changed

13 files changed

+564
-29
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6261,6 +6261,12 @@
62616261
},
62626262
"sqlState" : "0A000"
62636263
},
6264+
"UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND": {
6265+
"message" : [
6266+
"'<command>' is not supported in spark.sql(\"...\") API in Spark Declarative Pipeline."
6267+
],
6268+
"sqlState" : "0A000"
6269+
},
62646270
"UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" : {
62656271
"message" : [
62666272
"The char/varchar type can't be used in the table schema.",
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
from contextlib import contextmanager
18+
from typing import Generator, Optional
19+
from pyspark.sql import SparkSession
20+
21+
from typing import Any, cast
22+
23+
24+
@contextmanager
25+
def add_pipeline_analysis_context(
26+
spark: SparkSession, dataflow_graph_id: str, flow_name_opt: Optional[str]
27+
) -> Generator[None, None, None]:
28+
"""
29+
Context manager that add PipelineAnalysisContext extension to the user context
30+
used for pipeline specific analysis.
31+
"""
32+
_extension_id = None
33+
_client = cast(Any, spark).client
34+
try:
35+
import pyspark.sql.connect.proto as pb2
36+
from google.protobuf import any_pb2
37+
38+
_analysis_context = pb2.PipelineAnalysisContext(dataflow_graph_id=dataflow_graph_id)
39+
if flow_name_opt is not None:
40+
_analysis_context.flow_name = flow_name_opt
41+
42+
_extension = any_pb2.Any()
43+
_extension.Pack(_analysis_context)
44+
45+
_extension_id = _client.add_threadlocal_user_context_extension(_extension)
46+
yield
47+
finally:
48+
_client.remove_user_context_extension(_extension_id)

python/pyspark/pipelines/block_connect_access.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,22 @@
2424
BLOCKED_RPC_NAMES = ["AnalyzePlan", "ExecutePlan"]
2525

2626

27+
def _is_sql_command_request(request: object) -> bool:
28+
"""Check if the request is spark.sql() command (ExecutePlanRequest with a sql_command)."""
29+
try:
30+
if not hasattr(request, "plan"):
31+
return False
32+
33+
plan = request.plan
34+
35+
if not plan.HasField("command"):
36+
return False
37+
38+
return plan.command.HasField("sql_command")
39+
except Exception:
40+
return False
41+
42+
2743
@contextmanager
2844
def block_spark_connect_execution_and_analysis() -> Generator[None, None, None]:
2945
"""
@@ -41,7 +57,17 @@ def blocked_getattr(self: SparkConnectServiceStub, name: str) -> Callable:
4157
if name not in BLOCKED_RPC_NAMES:
4258
return original_getattr(self, name)
4359

44-
def blocked_method(*args: object, **kwargs: object) -> NoReturn:
60+
# Get the original method first
61+
original_method = original_getattr(self, name)
62+
63+
def blocked_method(*args: object, **kwargs: object):
64+
# allowlist spark.sql() command (ExecutePlan with sql_command)
65+
if name == "ExecutePlan" and len(args) > 0:
66+
request = args[0]
67+
if _is_sql_command_request(request):
68+
return original_method(*args, **kwargs)
69+
70+
# Block all other ExecutePlan and AnalyzePlan calls
4571
raise PySparkException(
4672
errorClass="ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION",
4773
messageParameters={},

python/pyspark/pipelines/cli.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
handle_pipeline_events,
5050
)
5151

52+
from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context
53+
5254
PIPELINE_SPEC_FILE_NAMES = ["pipeline.yaml", "pipeline.yml"]
5355

5456

@@ -216,7 +218,11 @@ def validate_str_dict(d: Mapping[str, str], field_name: str) -> Mapping[str, str
216218

217219

218220
def register_definitions(
219-
spec_path: Path, registry: GraphElementRegistry, spec: PipelineSpec
221+
spec_path: Path,
222+
registry: GraphElementRegistry,
223+
spec: PipelineSpec,
224+
spark: SparkSession,
225+
dataflow_graph_id: str,
220226
) -> None:
221227
"""Register the graph element definitions in the pipeline spec with the given registry.
222228
- Looks for Python files matching the glob patterns in the spec and imports them.
@@ -245,8 +251,11 @@ def register_definitions(
245251
assert (
246252
module_spec.loader is not None
247253
), f"Module spec has no loader for {file}"
248-
with block_session_mutations():
249-
module_spec.loader.exec_module(module)
254+
with add_pipeline_analysis_context(
255+
spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name_opt=None
256+
):
257+
with block_session_mutations():
258+
module_spec.loader.exec_module(module)
250259
elif file.suffix == ".sql":
251260
log_with_curr_timestamp(f"Registering SQL file {file}...")
252261
with file.open("r") as f:
@@ -324,7 +333,7 @@ def run(
324333

325334
log_with_curr_timestamp("Registering graph elements...")
326335
registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id)
327-
register_definitions(spec_path, registry, spec)
336+
register_definitions(spec_path, registry, spec, spark, dataflow_graph_id)
328337

329338
log_with_curr_timestamp("Starting run...")
330339
result_iter = start_run(

python/pyspark/pipelines/spark_connect_graph_element_registry.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pyspark.sql.types import StructType
3636
from typing import Any, cast
3737
import pyspark.sql.connect.proto as pb2
38+
from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context
3839

3940

4041
class SparkConnectGraphElementRegistry(GraphElementRegistry):
@@ -43,6 +44,7 @@ class SparkConnectGraphElementRegistry(GraphElementRegistry):
4344
def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None:
4445
# Cast because mypy seems to think `spark`` is a function, not an object. Likely related to
4546
# SPARK-47544.
47+
self._spark = spark
4648
self._client = cast(Any, spark).client
4749
self._dataflow_graph_id = dataflow_graph_id
4850

@@ -110,8 +112,11 @@ def register_output(self, output: Output) -> None:
110112
self._client.execute_command(command)
111113

112114
def register_flow(self, flow: Flow) -> None:
113-
with block_spark_connect_execution_and_analysis():
114-
df = flow.func()
115+
with add_pipeline_analysis_context(
116+
spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, flow_name_opt=flow.name
117+
):
118+
with block_spark_connect_execution_and_analysis():
119+
df = flow.func()
115120
relation = cast(ConnectDataFrame, df)._plan.plan(self._client)
116121

117122
relation_flow_details = pb2.PipelineCommand.DefineFlow.WriteRelationFlowDetails(

python/pyspark/sql/connect/client/core.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,9 @@ def __init__(
727727
# cleanup ml cache if possible
728728
atexit.register(self._cleanup_ml_cache)
729729

730+
self.global_user_context_extensions = []
731+
self.global_user_context_extensions_lock = threading.Lock()
732+
730733
@property
731734
def _stub(self) -> grpc_lib.SparkConnectServiceStub:
732735
if self.is_closed:
@@ -1277,6 +1280,24 @@ def token(self) -> Optional[str]:
12771280
"""
12781281
return self._builder.token
12791282

1283+
def _update_request_with_user_context_extensions(
1284+
self,
1285+
req: Union[
1286+
pb2.AnalyzePlanRequest,
1287+
pb2.ConfigRequest,
1288+
pb2.ExecutePlanRequest,
1289+
pb2.FetchErrorDetailsRequest,
1290+
pb2.InterruptRequest,
1291+
],
1292+
) -> None:
1293+
with self.global_user_context_extensions_lock:
1294+
for _, extension in self.global_user_context_extensions:
1295+
req.user_context.extensions.append(extension)
1296+
if not hasattr(self.thread_local, "user_context_extensions"):
1297+
return
1298+
for _, extension in self.thread_local.user_context_extensions:
1299+
req.user_context.extensions.append(extension)
1300+
12801301
def _execute_plan_request_with_metadata(
12811302
self, operation_id: Optional[str] = None
12821303
) -> pb2.ExecutePlanRequest:
@@ -1307,6 +1328,7 @@ def _execute_plan_request_with_metadata(
13071328
messageParameters={"arg_name": "operation_id", "origin": str(ve)},
13081329
)
13091330
req.operation_id = operation_id
1331+
self._update_request_with_user_context_extensions(req)
13101332
return req
13111333

13121334
def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
@@ -1317,6 +1339,7 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
13171339
req.client_type = self._builder.userAgent
13181340
if self._user_id:
13191341
req.user_context.user_id = self._user_id
1342+
self._update_request_with_user_context_extensions(req)
13201343
return req
13211344

13221345
def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
@@ -1731,6 +1754,7 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest:
17311754
req.client_type = self._builder.userAgent
17321755
if self._user_id:
17331756
req.user_context.user_id = self._user_id
1757+
self._update_request_with_user_context_extensions(req)
17341758
return req
17351759

17361760
def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
@@ -1807,6 +1831,7 @@ def _interrupt_request(
18071831
)
18081832
if self._user_id:
18091833
req.user_context.user_id = self._user_id
1834+
self._update_request_with_user_context_extensions(req)
18101835
return req
18111836

18121837
def interrupt_all(self) -> Optional[List[str]]:
@@ -1905,6 +1930,38 @@ def _throw_if_invalid_tag(self, tag: str) -> None:
19051930
messageParameters={"arg_name": "Spark Connect tag", "arg_value": tag},
19061931
)
19071932

1933+
def add_threadlocal_user_context_extension(self, extension: any_pb2.Any) -> str:
1934+
if not hasattr(self.thread_local, "user_context_extensions"):
1935+
self.thread_local.user_context_extensions = list()
1936+
extension_id = "threadlocal_" + str(uuid.uuid4())
1937+
self.thread_local.user_context_extensions.append((extension_id, extension))
1938+
return extension_id
1939+
1940+
def add_global_user_context_extension(self, extension: any_pb2.Any) -> str:
1941+
extension_id = "global_" + str(uuid.uuid4())
1942+
with self.global_user_context_extensions_lock:
1943+
self.global_user_context_extensions.append((extension_id, extension))
1944+
return extension_id
1945+
1946+
def remove_user_context_extension(self, extension_id: str) -> None:
1947+
if extension_id.find("threadlocal_") == 0:
1948+
if not hasattr(self.thread_local, "user_context_extensions"):
1949+
return
1950+
self.thread_local.user_context_extensions = list(
1951+
filter(lambda ex: ex[0] != extension_id, self.thread_local.user_context_extensions)
1952+
)
1953+
elif extension_id.find("global_") == 0:
1954+
with self.global_user_context_extensions_lock:
1955+
self.global_user_context_extensions = list(
1956+
filter(lambda ex: ex[0] != extension_id, self.global_user_context_extensions)
1957+
)
1958+
1959+
def clear_user_context_extensions(self) -> None:
1960+
if hasattr(self.thread_local, "user_context_extensions"):
1961+
self.thread_local.user_context_extensions = list()
1962+
with self.global_user_context_extensions_lock:
1963+
self.global_user_context_extensions = list()
1964+
19081965
def _handle_error(self, error: Exception) -> NoReturn:
19091966
"""
19101967
Handle errors that occur during RPC calls.
@@ -1945,7 +2002,7 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet
19452002
req.client_observed_server_side_session_id = self._server_session_id
19462003
if self._user_id:
19472004
req.user_context.user_id = self._user_id
1948-
2005+
self._update_request_with_user_context_extensions(req)
19492006
try:
19502007
return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata())
19512008
except grpc.RpcError:

0 commit comments

Comments
 (0)