Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -6961,6 +6961,12 @@
],
"sqlState" : "0A000"
},
"UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND" : {
"message" : [
"'<command>' is not supported in spark.sql(\"...\") API in Spark Declarative Pipeline."
],
"sqlState" : "0A000"
},
"UNSUPPORTED_SAVE_MODE" : {
"message" : [
"The save mode <saveMode> is not supported for:"
Expand Down
48 changes: 48 additions & 0 deletions python/pyspark/pipelines/add_pipeline_analysis_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from contextlib import contextmanager
from typing import Generator, Optional
from pyspark.sql import SparkSession

from typing import Any, cast


@contextmanager
def add_pipeline_analysis_context(
spark: SparkSession, dataflow_graph_id: str, flow_name: Optional[str]
) -> Generator[None, None, None]:
"""
Context manager that add PipelineAnalysisContext extension to the user context
used for pipeline specific analysis.
"""
extension_id = None
# Cast because mypy seems to think `spark` is a function, not an object.
# Likely related to SPARK-47544.
client = cast(Any, spark).client
try:
import pyspark.sql.connect.proto as pb2
from google.protobuf import any_pb2

analysis_context = pb2.PipelineAnalysisContext(
dataflow_graph_id=dataflow_graph_id, flow_name=flow_name
)
extension = any_pb2.Any()
extension.Pack(analysis_context)
extension_id = client.add_threadlocal_user_context_extension(extension)
yield
finally:
client.remove_user_context_extension(extension_id)
46 changes: 37 additions & 9 deletions python/pyspark/pipelines/block_connect_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
#
from contextlib import contextmanager
from typing import Callable, Generator, NoReturn
from typing import Any, Callable, Generator

from pyspark.errors import PySparkException
from pyspark.sql.connect.proto.base_pb2_grpc import SparkConnectServiceStub
Expand All @@ -24,6 +24,27 @@
BLOCKED_RPC_NAMES = ["AnalyzePlan", "ExecutePlan"]


def _is_sql_command_request(rpc_name: str, args: tuple) -> bool:
"""
Check if the RPC call is a spark.sql() command (ExecutePlan with sql_command).

:param rpc_name: Name of the RPC being called
:param args: Arguments passed to the RPC
:return: True if this is an ExecutePlan request with a sql_command
"""
if rpc_name != "ExecutePlan" or len(args) == 0:
return False

request = args[0]
if not hasattr(request, "plan"):
return False
plan = request.plan
if not plan.HasField("command"):
return False
command = plan.command
return command.HasField("sql_command")


@contextmanager
def block_spark_connect_execution_and_analysis() -> Generator[None, None, None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have the context on the server side, it might make more sense to block these operations there – then we don't need to replicate this weird monkeypatching logic across all the clients when we add support for other languages. Doesn't need to be part of this PR though.

"""
Expand All @@ -38,16 +59,23 @@ def block_spark_connect_execution_and_analysis() -> Generator[None, None, None]:

# Define a new __getattribute__ method that blocks RPC calls
def blocked_getattr(self: SparkConnectServiceStub, name: str) -> Callable:
if name not in BLOCKED_RPC_NAMES:
return original_getattr(self, name)
original_method = original_getattr(self, name)

def blocked_method(*args: object, **kwargs: object) -> NoReturn:
raise PySparkException(
errorClass="ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION",
messageParameters={},
)
def intercepted_method(*args: object, **kwargs: object) -> Any:
# Allow all RPCs that are not AnalyzePlan or ExecutePlan
if name not in BLOCKED_RPC_NAMES:
return original_method(*args, **kwargs)
# Allow spark.sql() commands (ExecutePlan with sql_command)
elif _is_sql_command_request(name, args):
return original_method(*args, **kwargs)
# Block all other AnalyzePlan and ExecutePlan calls
else:
raise PySparkException(
errorClass="ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION",
messageParameters={},
)

return blocked_method
return intercepted_method

try:
# Apply our custom __getattribute__ method
Expand Down
17 changes: 13 additions & 4 deletions python/pyspark/pipelines/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
handle_pipeline_events,
)

from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context

PIPELINE_SPEC_FILE_NAMES = ["pipeline.yaml", "pipeline.yml"]


Expand Down Expand Up @@ -216,7 +218,11 @@ def validate_str_dict(d: Mapping[str, str], field_name: str) -> Mapping[str, str


def register_definitions(
spec_path: Path, registry: GraphElementRegistry, spec: PipelineSpec
spec_path: Path,
registry: GraphElementRegistry,
spec: PipelineSpec,
spark: SparkSession,
dataflow_graph_id: str,
) -> None:
"""Register the graph element definitions in the pipeline spec with the given registry.
- Looks for Python files matching the glob patterns in the spec and imports them.
Expand Down Expand Up @@ -245,8 +251,11 @@ def register_definitions(
assert (
module_spec.loader is not None
), f"Module spec has no loader for {file}"
with block_session_mutations():
module_spec.loader.exec_module(module)
with add_pipeline_analysis_context(
spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name=None
):
with block_session_mutations():
module_spec.loader.exec_module(module)
elif file.suffix == ".sql":
log_with_curr_timestamp(f"Registering SQL file {file}...")
with file.open("r") as f:
Expand Down Expand Up @@ -324,7 +333,7 @@ def run(

log_with_curr_timestamp("Registering graph elements...")
registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id)
register_definitions(spec_path, registry, spec)
register_definitions(spec_path, registry, spec, spark, dataflow_graph_id)

log_with_curr_timestamp("Starting run...")
result_iter = start_run(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pyspark.sql.types import StructType
from typing import Any, cast
import pyspark.sql.connect.proto as pb2
from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context


class SparkConnectGraphElementRegistry(GraphElementRegistry):
Expand All @@ -43,6 +44,7 @@ class SparkConnectGraphElementRegistry(GraphElementRegistry):
def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None:
# Cast because mypy seems to think `spark`` is a function, not an object. Likely related to
# SPARK-47544.
self._spark = spark
self._client = cast(Any, spark).client
self._dataflow_graph_id = dataflow_graph_id

Expand Down Expand Up @@ -110,8 +112,11 @@ def register_output(self, output: Output) -> None:
self._client.execute_command(command)

def register_flow(self, flow: Flow) -> None:
with block_spark_connect_execution_and_analysis():
df = flow.func()
with add_pipeline_analysis_context(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, this will result in two PipelineAnalysisContexts added to the same request. And the server code knows to expect that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Exactly, for spark.sql() outside query function, it would only have one extension associated with it, but two for spark.sql() inside query function

spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, flow_name=flow.name
):
with block_spark_connect_execution_and_analysis():
df = flow.func()
relation = cast(ConnectDataFrame, df)._plan.plan(self._client)

relation_flow_details = pb2.PipelineCommand.DefineFlow.WriteRelationFlowDetails(
Expand Down
100 changes: 100 additions & 0 deletions python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

from pyspark.testing.connectutils import (
ReusedConnectTestCase,
should_test_connect,
connect_requirement_message,
)

if should_test_connect:
from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class AddPipelineAnalysisContextTests(ReusedConnectTestCase):
def test_add_pipeline_analysis_context_with_flow_name(self):
with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id", "test_flow_name"):
import pyspark.sql.connect.proto as pb2

thread_local_extensions = self.spark.client.thread_local.user_context_extensions
self.assertEqual(len(thread_local_extensions), 1)
# Extension is stored as (id, extension), unpack the extension
_extension_id, extension = thread_local_extensions[0]
context = pb2.PipelineAnalysisContext()
extension.Unpack(context)
self.assertEqual(context.dataflow_graph_id, "test_dataflow_graph_id")
self.assertEqual(context.flow_name, "test_flow_name")
thread_local_extensions_after = self.spark.client.thread_local.user_context_extensions
self.assertEqual(len(thread_local_extensions_after), 0)

def test_add_pipeline_analysis_context_without_flow_name(self):
with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id", None):
import pyspark.sql.connect.proto as pb2

thread_local_extensions = self.spark.client.thread_local.user_context_extensions
self.assertEqual(len(thread_local_extensions), 1)
# Extension is stored as (id, extension), unpack the extension
_extension_id, extension = thread_local_extensions[0]
context = pb2.PipelineAnalysisContext()
extension.Unpack(context)
self.assertEqual(context.dataflow_graph_id, "test_dataflow_graph_id")
# Empty string means no flow name
self.assertEqual(context.flow_name, "")
thread_local_extensions_after = self.spark.client.thread_local.user_context_extensions
self.assertEqual(len(thread_local_extensions_after), 0)

def test_nested_add_pipeline_analysis_context(self):
import pyspark.sql.connect.proto as pb2

with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id_1", flow_name=None):
with add_pipeline_analysis_context(
self.spark, "test_dataflow_graph_id_2", flow_name="test_flow_name"
):
thread_local_extensions = self.spark.client.thread_local.user_context_extensions
self.assertEqual(len(thread_local_extensions), 2)
# Extension is stored as (id, extension), unpack the extensions
_, extension_1 = thread_local_extensions[0]
context_1 = pb2.PipelineAnalysisContext()
extension_1.Unpack(context_1)
self.assertEqual(context_1.dataflow_graph_id, "test_dataflow_graph_id_1")
self.assertEqual(context_1.flow_name, "")
_, extension_2 = thread_local_extensions[1]
context_2 = pb2.PipelineAnalysisContext()
extension_2.Unpack(context_2)
self.assertEqual(context_2.dataflow_graph_id, "test_dataflow_graph_id_2")
self.assertEqual(context_2.flow_name, "test_flow_name")
thread_local_extensions_after_1 = self.spark.client.thread_local.user_context_extensions
self.assertEqual(len(thread_local_extensions_after_1), 1)
_, extension_3 = thread_local_extensions_after_1[0]
context_3 = pb2.PipelineAnalysisContext()
extension_3.Unpack(context_3)
self.assertEqual(context_3.dataflow_graph_id, "test_dataflow_graph_id_1")
self.assertEqual(context_3.flow_name, "")
thread_local_extensions_after_2 = self.spark.client.thread_local.user_context_extensions
self.assertEqual(len(thread_local_extensions_after_2), 0)


if __name__ == "__main__":
try:
import xmlrunner # type: ignore

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
15 changes: 11 additions & 4 deletions python/pyspark/pipelines/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from pyspark.errors import PySparkException
from pyspark.testing.connectutils import (
ReusedConnectTestCase,
should_test_connect,
connect_requirement_message,
)
Expand All @@ -45,7 +46,7 @@
not should_test_connect or not have_yaml,
connect_requirement_message or yaml_requirement_message,
)
class CLIUtilityTests(unittest.TestCase):
class CLIUtilityTests(ReusedConnectTestCase):
def test_load_pipeline_spec(self):
with tempfile.NamedTemporaryFile(mode="w") as tmpfile:
tmpfile.write(
Expand Down Expand Up @@ -294,7 +295,9 @@ def mv2():
)

registry = LocalGraphElementRegistry()
register_definitions(outer_dir / "pipeline.yaml", registry, spec)
register_definitions(
outer_dir / "pipeline.yaml", registry, spec, self.spark, "test_graph_id"
)
self.assertEqual(len(registry.outputs), 1)
self.assertEqual(registry.outputs[0].name, "mv1")

Expand All @@ -315,7 +318,9 @@ def test_register_definitions_file_raises_error(self):

registry = LocalGraphElementRegistry()
with self.assertRaises(RuntimeError) as context:
register_definitions(outer_dir / "pipeline.yml", registry, spec)
register_definitions(
outer_dir / "pipeline.yml", registry, spec, self.spark, "test_graph_id"
)
self.assertIn("This is a test exception", str(context.exception))

def test_register_definitions_unsupported_file_extension_matches_glob(self):
Expand All @@ -334,7 +339,7 @@ def test_register_definitions_unsupported_file_extension_matches_glob(self):

registry = LocalGraphElementRegistry()
with self.assertRaises(PySparkException) as context:
register_definitions(outer_dir, registry, spec)
register_definitions(outer_dir, registry, spec, self.spark, "test_graph_id")
self.assertEqual(
context.exception.getCondition(), "PIPELINE_UNSUPPORTED_DEFINITIONS_FILE_EXTENSION"
)
Expand Down Expand Up @@ -382,6 +387,8 @@ def test_python_import_current_directory(self):
configuration={},
libraries=[LibrariesGlob(include="defs.py")],
),
self.spark,
"test_graph_id",
)

def test_full_refresh_all_conflicts_with_full_refresh(self):
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/pipelines/tests/test_init_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_init(self):
self.assertTrue((Path.cwd() / "pipeline-storage").exists())

registry = LocalGraphElementRegistry()
register_definitions(spec_path, registry, spec)
register_definitions(spec_path, registry, spec, self.spark, "test_graph_id")
self.assertEqual(len(registry.outputs), 1)
self.assertEqual(registry.outputs[0].name, "example_python_materialized_view")
self.assertEqual(len(registry.flows), 1)
Expand Down
Loading