diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 63148f581f04..326c0bf843e8 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -6961,6 +6961,12 @@ ], "sqlState" : "0A000" }, + "UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND" : { + "message" : [ + "'' is not supported in spark.sql(\"...\") API in Spark Declarative Pipeline." + ], + "sqlState" : "0A000" + }, "UNSUPPORTED_SAVE_MODE" : { "message" : [ "The save mode is not supported for:" diff --git a/python/pyspark/pipelines/add_pipeline_analysis_context.py b/python/pyspark/pipelines/add_pipeline_analysis_context.py new file mode 100644 index 000000000000..6ecabdf43b07 --- /dev/null +++ b/python/pyspark/pipelines/add_pipeline_analysis_context.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from contextlib import contextmanager +from typing import Generator, Optional +from pyspark.sql import SparkSession + +from typing import Any, cast + + +@contextmanager +def add_pipeline_analysis_context( + spark: SparkSession, dataflow_graph_id: str, flow_name: Optional[str] +) -> Generator[None, None, None]: + """ + Context manager that add PipelineAnalysisContext extension to the user context + used for pipeline specific analysis. + """ + extension_id = None + # Cast because mypy seems to think `spark` is a function, not an object. + # Likely related to SPARK-47544. + client = cast(Any, spark).client + try: + import pyspark.sql.connect.proto as pb2 + from google.protobuf import any_pb2 + + analysis_context = pb2.PipelineAnalysisContext( + dataflow_graph_id=dataflow_graph_id, flow_name=flow_name + ) + extension = any_pb2.Any() + extension.Pack(analysis_context) + extension_id = client.add_threadlocal_user_context_extension(extension) + yield + finally: + client.remove_user_context_extension(extension_id) diff --git a/python/pyspark/pipelines/block_connect_access.py b/python/pyspark/pipelines/block_connect_access.py index c5dacbbc2c5c..696d0e39b005 100644 --- a/python/pyspark/pipelines/block_connect_access.py +++ b/python/pyspark/pipelines/block_connect_access.py @@ -15,7 +15,7 @@ # limitations under the License. # from contextlib import contextmanager -from typing import Callable, Generator, NoReturn +from typing import Any, Callable, Generator from pyspark.errors import PySparkException from pyspark.sql.connect.proto.base_pb2_grpc import SparkConnectServiceStub @@ -24,6 +24,27 @@ BLOCKED_RPC_NAMES = ["AnalyzePlan", "ExecutePlan"] +def _is_sql_command_request(rpc_name: str, args: tuple) -> bool: + """ + Check if the RPC call is a spark.sql() command (ExecutePlan with sql_command). + + :param rpc_name: Name of the RPC being called + :param args: Arguments passed to the RPC + :return: True if this is an ExecutePlan request with a sql_command + """ + if rpc_name != "ExecutePlan" or len(args) == 0: + return False + + request = args[0] + if not hasattr(request, "plan"): + return False + plan = request.plan + if not plan.HasField("command"): + return False + command = plan.command + return command.HasField("sql_command") + + @contextmanager def block_spark_connect_execution_and_analysis() -> Generator[None, None, None]: """ @@ -38,16 +59,23 @@ def block_spark_connect_execution_and_analysis() -> Generator[None, None, None]: # Define a new __getattribute__ method that blocks RPC calls def blocked_getattr(self: SparkConnectServiceStub, name: str) -> Callable: - if name not in BLOCKED_RPC_NAMES: - return original_getattr(self, name) + original_method = original_getattr(self, name) - def blocked_method(*args: object, **kwargs: object) -> NoReturn: - raise PySparkException( - errorClass="ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION", - messageParameters={}, - ) + def intercepted_method(*args: object, **kwargs: object) -> Any: + # Allow all RPCs that are not AnalyzePlan or ExecutePlan + if name not in BLOCKED_RPC_NAMES: + return original_method(*args, **kwargs) + # Allow spark.sql() commands (ExecutePlan with sql_command) + elif _is_sql_command_request(name, args): + return original_method(*args, **kwargs) + # Block all other AnalyzePlan and ExecutePlan calls + else: + raise PySparkException( + errorClass="ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION", + messageParameters={}, + ) - return blocked_method + return intercepted_method try: # Apply our custom __getattribute__ method diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py index ca198f1c3aff..3ba0bb58fe94 100644 --- a/python/pyspark/pipelines/cli.py +++ b/python/pyspark/pipelines/cli.py @@ -49,6 +49,8 @@ handle_pipeline_events, ) +from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context + PIPELINE_SPEC_FILE_NAMES = ["pipeline.yaml", "pipeline.yml"] @@ -216,7 +218,11 @@ def validate_str_dict(d: Mapping[str, str], field_name: str) -> Mapping[str, str def register_definitions( - spec_path: Path, registry: GraphElementRegistry, spec: PipelineSpec + spec_path: Path, + registry: GraphElementRegistry, + spec: PipelineSpec, + spark: SparkSession, + dataflow_graph_id: str, ) -> None: """Register the graph element definitions in the pipeline spec with the given registry. - Looks for Python files matching the glob patterns in the spec and imports them. @@ -245,8 +251,11 @@ def register_definitions( assert ( module_spec.loader is not None ), f"Module spec has no loader for {file}" - with block_session_mutations(): - module_spec.loader.exec_module(module) + with add_pipeline_analysis_context( + spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name=None + ): + with block_session_mutations(): + module_spec.loader.exec_module(module) elif file.suffix == ".sql": log_with_curr_timestamp(f"Registering SQL file {file}...") with file.open("r") as f: @@ -324,7 +333,7 @@ def run( log_with_curr_timestamp("Registering graph elements...") registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id) - register_definitions(spec_path, registry, spec) + register_definitions(spec_path, registry, spec, spark, dataflow_graph_id) log_with_curr_timestamp("Starting run...") result_iter = start_run( diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py b/python/pyspark/pipelines/spark_connect_graph_element_registry.py index e8a8561c3e74..b8d297fced3f 100644 --- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py +++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py @@ -35,6 +35,7 @@ from pyspark.sql.types import StructType from typing import Any, cast import pyspark.sql.connect.proto as pb2 +from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context class SparkConnectGraphElementRegistry(GraphElementRegistry): @@ -43,6 +44,7 @@ class SparkConnectGraphElementRegistry(GraphElementRegistry): def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None: # Cast because mypy seems to think `spark`` is a function, not an object. Likely related to # SPARK-47544. + self._spark = spark self._client = cast(Any, spark).client self._dataflow_graph_id = dataflow_graph_id @@ -110,8 +112,11 @@ def register_output(self, output: Output) -> None: self._client.execute_command(command) def register_flow(self, flow: Flow) -> None: - with block_spark_connect_execution_and_analysis(): - df = flow.func() + with add_pipeline_analysis_context( + spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, flow_name=flow.name + ): + with block_spark_connect_execution_and_analysis(): + df = flow.func() relation = cast(ConnectDataFrame, df)._plan.plan(self._client) relation_flow_details = pb2.PipelineCommand.DefineFlow.WriteRelationFlowDetails( diff --git a/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py b/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py new file mode 100644 index 000000000000..57c5da22d460 --- /dev/null +++ b/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py @@ -0,0 +1,100 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.testing.connectutils import ( + ReusedConnectTestCase, + should_test_connect, + connect_requirement_message, +) + +if should_test_connect: + from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class AddPipelineAnalysisContextTests(ReusedConnectTestCase): + def test_add_pipeline_analysis_context_with_flow_name(self): + with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id", "test_flow_name"): + import pyspark.sql.connect.proto as pb2 + + thread_local_extensions = self.spark.client.thread_local.user_context_extensions + self.assertEqual(len(thread_local_extensions), 1) + # Extension is stored as (id, extension), unpack the extension + _extension_id, extension = thread_local_extensions[0] + context = pb2.PipelineAnalysisContext() + extension.Unpack(context) + self.assertEqual(context.dataflow_graph_id, "test_dataflow_graph_id") + self.assertEqual(context.flow_name, "test_flow_name") + thread_local_extensions_after = self.spark.client.thread_local.user_context_extensions + self.assertEqual(len(thread_local_extensions_after), 0) + + def test_add_pipeline_analysis_context_without_flow_name(self): + with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id", None): + import pyspark.sql.connect.proto as pb2 + + thread_local_extensions = self.spark.client.thread_local.user_context_extensions + self.assertEqual(len(thread_local_extensions), 1) + # Extension is stored as (id, extension), unpack the extension + _extension_id, extension = thread_local_extensions[0] + context = pb2.PipelineAnalysisContext() + extension.Unpack(context) + self.assertEqual(context.dataflow_graph_id, "test_dataflow_graph_id") + # Empty string means no flow name + self.assertEqual(context.flow_name, "") + thread_local_extensions_after = self.spark.client.thread_local.user_context_extensions + self.assertEqual(len(thread_local_extensions_after), 0) + + def test_nested_add_pipeline_analysis_context(self): + import pyspark.sql.connect.proto as pb2 + + with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id_1", flow_name=None): + with add_pipeline_analysis_context( + self.spark, "test_dataflow_graph_id_2", flow_name="test_flow_name" + ): + thread_local_extensions = self.spark.client.thread_local.user_context_extensions + self.assertEqual(len(thread_local_extensions), 2) + # Extension is stored as (id, extension), unpack the extensions + _, extension_1 = thread_local_extensions[0] + context_1 = pb2.PipelineAnalysisContext() + extension_1.Unpack(context_1) + self.assertEqual(context_1.dataflow_graph_id, "test_dataflow_graph_id_1") + self.assertEqual(context_1.flow_name, "") + _, extension_2 = thread_local_extensions[1] + context_2 = pb2.PipelineAnalysisContext() + extension_2.Unpack(context_2) + self.assertEqual(context_2.dataflow_graph_id, "test_dataflow_graph_id_2") + self.assertEqual(context_2.flow_name, "test_flow_name") + thread_local_extensions_after_1 = self.spark.client.thread_local.user_context_extensions + self.assertEqual(len(thread_local_extensions_after_1), 1) + _, extension_3 = thread_local_extensions_after_1[0] + context_3 = pb2.PipelineAnalysisContext() + extension_3.Unpack(context_3) + self.assertEqual(context_3.dataflow_graph_id, "test_dataflow_graph_id_1") + self.assertEqual(context_3.flow_name, "") + thread_local_extensions_after_2 = self.spark.client.thread_local.user_context_extensions + self.assertEqual(len(thread_local_extensions_after_2), 0) + + +if __name__ == "__main__": + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/pipelines/tests/test_cli.py b/python/pyspark/pipelines/tests/test_cli.py index e8445e63d439..ff3022fa2966 100644 --- a/python/pyspark/pipelines/tests/test_cli.py +++ b/python/pyspark/pipelines/tests/test_cli.py @@ -22,6 +22,7 @@ from pyspark.errors import PySparkException from pyspark.testing.connectutils import ( + ReusedConnectTestCase, should_test_connect, connect_requirement_message, ) @@ -45,7 +46,7 @@ not should_test_connect or not have_yaml, connect_requirement_message or yaml_requirement_message, ) -class CLIUtilityTests(unittest.TestCase): +class CLIUtilityTests(ReusedConnectTestCase): def test_load_pipeline_spec(self): with tempfile.NamedTemporaryFile(mode="w") as tmpfile: tmpfile.write( @@ -294,7 +295,9 @@ def mv2(): ) registry = LocalGraphElementRegistry() - register_definitions(outer_dir / "pipeline.yaml", registry, spec) + register_definitions( + outer_dir / "pipeline.yaml", registry, spec, self.spark, "test_graph_id" + ) self.assertEqual(len(registry.outputs), 1) self.assertEqual(registry.outputs[0].name, "mv1") @@ -315,7 +318,9 @@ def test_register_definitions_file_raises_error(self): registry = LocalGraphElementRegistry() with self.assertRaises(RuntimeError) as context: - register_definitions(outer_dir / "pipeline.yml", registry, spec) + register_definitions( + outer_dir / "pipeline.yml", registry, spec, self.spark, "test_graph_id" + ) self.assertIn("This is a test exception", str(context.exception)) def test_register_definitions_unsupported_file_extension_matches_glob(self): @@ -334,7 +339,7 @@ def test_register_definitions_unsupported_file_extension_matches_glob(self): registry = LocalGraphElementRegistry() with self.assertRaises(PySparkException) as context: - register_definitions(outer_dir, registry, spec) + register_definitions(outer_dir, registry, spec, self.spark, "test_graph_id") self.assertEqual( context.exception.getCondition(), "PIPELINE_UNSUPPORTED_DEFINITIONS_FILE_EXTENSION" ) @@ -382,6 +387,8 @@ def test_python_import_current_directory(self): configuration={}, libraries=[LibrariesGlob(include="defs.py")], ), + self.spark, + "test_graph_id", ) def test_full_refresh_all_conflicts_with_full_refresh(self): diff --git a/python/pyspark/pipelines/tests/test_init_cli.py b/python/pyspark/pipelines/tests/test_init_cli.py index 43c553eddc38..e51bab6a4a69 100644 --- a/python/pyspark/pipelines/tests/test_init_cli.py +++ b/python/pyspark/pipelines/tests/test_init_cli.py @@ -60,7 +60,7 @@ def test_init(self): self.assertTrue((Path.cwd() / "pipeline-storage").exists()) registry = LocalGraphElementRegistry() - register_definitions(spec_path, registry, spec) + register_definitions(spec_path, registry, spec, self.spark, "test_graph_id") self.assertEqual(len(registry.outputs), 1) self.assertEqual(registry.outputs[0].name, "example_python_materialized_view") self.assertEqual(len(registry.flows), 1) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 2a2ac0e6b539..48e07642e157 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -727,6 +727,9 @@ def __init__( # cleanup ml cache if possible atexit.register(self._cleanup_ml_cache) + self.global_user_context_extensions: List[Tuple[str, any_pb2.Any]] = [] + self.global_user_context_extensions_lock = threading.Lock() + @property def _stub(self) -> grpc_lib.SparkConnectServiceStub: if self.is_closed: @@ -1277,6 +1280,24 @@ def token(self) -> Optional[str]: """ return self._builder.token + def _update_request_with_user_context_extensions( + self, + req: Union[ + pb2.AnalyzePlanRequest, + pb2.ConfigRequest, + pb2.ExecutePlanRequest, + pb2.FetchErrorDetailsRequest, + pb2.InterruptRequest, + ], + ) -> None: + with self.global_user_context_extensions_lock: + for _, extension in self.global_user_context_extensions: + req.user_context.extensions.append(extension) + if not hasattr(self.thread_local, "user_context_extensions"): + return + for _, extension in self.thread_local.user_context_extensions: + req.user_context.extensions.append(extension) + def _execute_plan_request_with_metadata( self, operation_id: Optional[str] = None ) -> pb2.ExecutePlanRequest: @@ -1307,6 +1328,7 @@ def _execute_plan_request_with_metadata( messageParameters={"arg_name": "operation_id", "origin": str(ve)}, ) req.operation_id = operation_id + self._update_request_with_user_context_extensions(req) return req def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: @@ -1317,6 +1339,7 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult: @@ -1731,6 +1754,7 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest: req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]: @@ -1807,6 +1831,7 @@ def _interrupt_request( ) if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def interrupt_all(self) -> Optional[List[str]]: @@ -1905,6 +1930,38 @@ def _throw_if_invalid_tag(self, tag: str) -> None: messageParameters={"arg_name": "Spark Connect tag", "arg_value": tag}, ) + def add_threadlocal_user_context_extension(self, extension: any_pb2.Any) -> str: + if not hasattr(self.thread_local, "user_context_extensions"): + self.thread_local.user_context_extensions = list() + extension_id = "threadlocal_" + str(uuid.uuid4()) + self.thread_local.user_context_extensions.append((extension_id, extension)) + return extension_id + + def add_global_user_context_extension(self, extension: any_pb2.Any) -> str: + extension_id = "global_" + str(uuid.uuid4()) + with self.global_user_context_extensions_lock: + self.global_user_context_extensions.append((extension_id, extension)) + return extension_id + + def remove_user_context_extension(self, extension_id: str) -> None: + if extension_id.find("threadlocal_") == 0: + if not hasattr(self.thread_local, "user_context_extensions"): + return + self.thread_local.user_context_extensions = list( + filter(lambda ex: ex[0] != extension_id, self.thread_local.user_context_extensions) + ) + elif extension_id.find("global_") == 0: + with self.global_user_context_extensions_lock: + self.global_user_context_extensions = list( + filter(lambda ex: ex[0] != extension_id, self.global_user_context_extensions) + ) + + def clear_user_context_extensions(self) -> None: + if hasattr(self.thread_local, "user_context_extensions"): + self.thread_local.user_context_extensions = list() + with self.global_user_context_extensions_lock: + self.global_user_context_extensions = list() + def _handle_error(self, error: Exception) -> NoReturn: """ Handle errors that occur during RPC calls. @@ -1945,7 +2002,7 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet req.client_observed_server_side_session_id = self._server_session_id if self._user_id: req.user_context.user_id = self._user_id - + self._update_request_with_user_context_extensions(req) try: return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata()) except grpc.RpcError: diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py b/python/pyspark/sql/connect/proto/pipelines_pb2.py index 0eb77c84b5b5..7a30def861d2 100644 --- a/python/pyspark/sql/connect/proto/pipelines_pb2.py +++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py @@ -42,7 +42,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xed"\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01 \x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12R\n\rdefine_output\x18\x02 \x01(\x0b\x32+.spark.connect.PipelineCommand.DefineOutputH\x00R\x0c\x64\x65\x66ineOutput\x12L\n\x0b\x64\x65\x66ine_flow\x18\x03 \x01(\x0b\x32).spark.connect.PipelineCommand.DefineFlowH\x00R\ndefineFlow\x12\x62\n\x13\x64rop_dataflow_graph\x18\x04 \x01(\x0b\x32\x30.spark.connect.PipelineCommand.DropDataflowGraphH\x00R\x11\x64ropDataflowGraph\x12\x46\n\tstart_run\x18\x05 \x01(\x0b\x32\'.spark.connect.PipelineCommand.StartRunH\x00R\x08startRun\x12r\n\x19\x64\x65\x66ine_sql_graph_elements\x18\x06 \x01(\x0b\x32\x35.spark.connect.PipelineCommand.DefineSqlGraphElementsH\x00R\x16\x64\x65\x66ineSqlGraphElements\x12\xa1\x01\n*get_query_function_execution_signal_stream\x18\x07 \x01(\x0b\x32\x44.spark.connect.PipelineCommand.GetQueryFunctionExecutionSignalStreamH\x00R%getQueryFunctionExecutionSignalStream\x12\x88\x01\n!define_flow_query_function_result\x18\x08 \x01(\x0b\x32<.spark.connect.PipelineCommand.DefineFlowQueryFunctionResultH\x00R\x1d\x64\x65\x66ineFlowQueryFunctionResult\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\xb4\x02\n\x13\x43reateDataflowGraph\x12,\n\x0f\x64\x65\x66\x61ult_catalog\x18\x01 \x01(\tH\x00R\x0e\x64\x65\x66\x61ultCatalog\x88\x01\x01\x12.\n\x10\x64\x65\x66\x61ult_database\x18\x02 \x01(\tH\x01R\x0f\x64\x65\x66\x61ultDatabase\x88\x01\x01\x12Z\n\x08sql_conf\x18\x05 \x03(\x0b\x32?.spark.connect.PipelineCommand.CreateDataflowGraph.SqlConfEntryR\x07sqlConf\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x12\n\x10_default_catalogB\x13\n\x11_default_database\x1aZ\n\x11\x44ropDataflowGraph\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\x92\n\n\x0c\x44\x65\x66ineOutput\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12$\n\x0boutput_name\x18\x02 \x01(\tH\x02R\noutputName\x88\x01\x01\x12?\n\x0boutput_type\x18\x03 \x01(\x0e\x32\x19.spark.connect.OutputTypeH\x03R\noutputType\x88\x01\x01\x12\x1d\n\x07\x63omment\x18\x04 \x01(\tH\x04R\x07\x63omment\x88\x01\x01\x12X\n\x14source_code_location\x18\x05 \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x05R\x12sourceCodeLocation\x88\x01\x01\x12_\n\rtable_details\x18\x06 \x01(\x0b\x32\x38.spark.connect.PipelineCommand.DefineOutput.TableDetailsH\x00R\x0ctableDetails\x12\\\n\x0csink_details\x18\x07 \x01(\x0b\x32\x37.spark.connect.PipelineCommand.DefineOutput.SinkDetailsH\x00R\x0bsinkDetails\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\xc0\x03\n\x0cTableDetails\x12x\n\x10table_properties\x18\x01 \x03(\x0b\x32M.spark.connect.PipelineCommand.DefineOutput.TableDetails.TablePropertiesEntryR\x0ftableProperties\x12%\n\x0epartition_cols\x18\x02 \x03(\tR\rpartitionCols\x12\x1b\n\x06\x66ormat\x18\x03 \x01(\tH\x01R\x06\x66ormat\x88\x01\x01\x12\x43\n\x10schema_data_type\x18\x04 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x0eschemaDataType\x12%\n\rschema_string\x18\x05 \x01(\tH\x00R\x0cschemaString\x12-\n\x12\x63lustering_columns\x18\x06 \x03(\tR\x11\x63lusteringColumns\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x08\n\x06schemaB\t\n\x07_format\x1a\xd1\x01\n\x0bSinkDetails\x12^\n\x07options\x18\x01 \x03(\x0b\x32\x44.spark.connect.PipelineCommand.DefineOutput.SinkDetails.OptionsEntryR\x07options\x12\x1b\n\x06\x66ormat\x18\x02 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07\x64\x65tailsB\x14\n\x12_dataflow_graph_idB\x0e\n\x0c_output_nameB\x0e\n\x0c_output_typeB\n\n\x08_commentB\x17\n\x15_source_code_location\x1a\xff\x06\n\nDefineFlow\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tflow_name\x18\x02 \x01(\tH\x02R\x08\x66lowName\x88\x01\x01\x12\x33\n\x13target_dataset_name\x18\x03 \x01(\tH\x03R\x11targetDatasetName\x88\x01\x01\x12Q\n\x08sql_conf\x18\x04 \x03(\x0b\x32\x36.spark.connect.PipelineCommand.DefineFlow.SqlConfEntryR\x07sqlConf\x12 \n\tclient_id\x18\x05 \x01(\tH\x04R\x08\x63lientId\x88\x01\x01\x12X\n\x14source_code_location\x18\x06 \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x05R\x12sourceCodeLocation\x88\x01\x01\x12x\n\x15relation_flow_details\x18\x07 \x01(\x0b\x32\x42.spark.connect.PipelineCommand.DefineFlow.WriteRelationFlowDetailsH\x00R\x13relationFlowDetails\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x17\n\x04once\x18\x08 \x01(\x08H\x06R\x04once\x88\x01\x01\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x61\n\x18WriteRelationFlowDetails\x12\x38\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x08relation\x88\x01\x01\x42\x0b\n\t_relation\x1a:\n\x08Response\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x42\x0c\n\n_flow_nameB\t\n\x07\x64\x65tailsB\x14\n\x12_dataflow_graph_idB\x0c\n\n_flow_nameB\x16\n\x14_target_dataset_nameB\x0c\n\n_client_idB\x17\n\x15_source_code_locationB\x07\n\x05_once\x1a\xc2\x02\n\x08StartRun\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x34\n\x16\x66ull_refresh_selection\x18\x02 \x03(\tR\x14\x66ullRefreshSelection\x12-\n\x10\x66ull_refresh_all\x18\x03 \x01(\x08H\x01R\x0e\x66ullRefreshAll\x88\x01\x01\x12+\n\x11refresh_selection\x18\x04 \x03(\tR\x10refreshSelection\x12\x15\n\x03\x64ry\x18\x05 \x01(\x08H\x02R\x03\x64ry\x88\x01\x01\x12\x1d\n\x07storage\x18\x06 \x01(\tH\x03R\x07storage\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x13\n\x11_full_refresh_allB\x06\n\x04_dryB\n\n\x08_storage\x1a\xc7\x01\n\x16\x44\x65\x66ineSqlGraphElements\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\'\n\rsql_file_path\x18\x02 \x01(\tH\x01R\x0bsqlFilePath\x88\x01\x01\x12\x1e\n\x08sql_text\x18\x03 \x01(\tH\x02R\x07sqlText\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x10\n\x0e_sql_file_pathB\x0b\n\t_sql_text\x1a\x9e\x01\n%GetQueryFunctionExecutionSignalStream\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tclient_id\x18\x02 \x01(\tH\x01R\x08\x63lientId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x0c\n\n_client_id\x1a\xdd\x01\n\x1d\x44\x65\x66ineFlowQueryFunctionResult\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x12/\n\x11\x64\x61taflow_graph_id\x18\x02 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x38\n\x08relation\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationH\x02R\x08relation\x88\x01\x01\x42\x0c\n\n_flow_nameB\x14\n\x12_dataflow_graph_idB\x0b\n\t_relationB\x0e\n\x0c\x63ommand_type"\xf0\x05\n\x15PipelineCommandResult\x12\x81\x01\n\x1c\x63reate_dataflow_graph_result\x18\x01 \x01(\x0b\x32>.spark.connect.PipelineCommandResult.CreateDataflowGraphResultH\x00R\x19\x63reateDataflowGraphResult\x12k\n\x14\x64\x65\x66ine_output_result\x18\x02 \x01(\x0b\x32\x37.spark.connect.PipelineCommandResult.DefineOutputResultH\x00R\x12\x64\x65\x66ineOutputResult\x12\x65\n\x12\x64\x65\x66ine_flow_result\x18\x03 \x01(\x0b\x32\x35.spark.connect.PipelineCommandResult.DefineFlowResultH\x00R\x10\x64\x65\x66ineFlowResult\x1a\x62\n\x19\x43reateDataflowGraphResult\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\x85\x01\n\x12\x44\x65\x66ineOutputResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifier\x1a\x83\x01\n\x10\x44\x65\x66ineFlowResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifierB\r\n\x0bresult_type"I\n\x13PipelineEventResult\x12\x32\n\x05\x65vent\x18\x01 \x01(\x0b\x32\x1c.spark.connect.PipelineEventR\x05\x65vent"t\n\rPipelineEvent\x12\x38\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n\x07message\x18\x02 \x01(\tH\x00R\x07message\x88\x01\x01\x42\n\n\x08_message"\xf1\x01\n\x12SourceCodeLocation\x12 \n\tfile_name\x18\x01 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12$\n\x0bline_number\x18\x02 \x01(\x05H\x01R\nlineNumber\x88\x01\x01\x12,\n\x0f\x64\x65\x66inition_path\x18\x03 \x01(\tH\x02R\x0e\x64\x65\x66initionPath\x88\x01\x01\x12\x33\n\textension\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\textensionB\x0c\n\n_file_nameB\x0e\n\x0c_line_numberB\x12\n\x10_definition_path"E\n$PipelineQueryFunctionExecutionSignal\x12\x1d\n\nflow_names\x18\x01 \x03(\tR\tflowNames"\xd7\x01\n\x17PipelineAnalysisContext\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12,\n\x0f\x64\x65\x66inition_path\x18\x02 \x01(\tH\x01R\x0e\x64\x65\x66initionPath\x88\x01\x01\x12\x33\n\textension\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\textensionB\x14\n\x12_dataflow_graph_idB\x12\n\x10_definition_path*i\n\nOutputType\x12\x1b\n\x17OUTPUT_TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11MATERIALIZED_VIEW\x10\x01\x12\t\n\x05TABLE\x10\x02\x12\x12\n\x0eTEMPORARY_VIEW\x10\x03\x12\x08\n\x04SINK\x10\x04\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xed"\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01 \x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12R\n\rdefine_output\x18\x02 \x01(\x0b\x32+.spark.connect.PipelineCommand.DefineOutputH\x00R\x0c\x64\x65\x66ineOutput\x12L\n\x0b\x64\x65\x66ine_flow\x18\x03 \x01(\x0b\x32).spark.connect.PipelineCommand.DefineFlowH\x00R\ndefineFlow\x12\x62\n\x13\x64rop_dataflow_graph\x18\x04 \x01(\x0b\x32\x30.spark.connect.PipelineCommand.DropDataflowGraphH\x00R\x11\x64ropDataflowGraph\x12\x46\n\tstart_run\x18\x05 \x01(\x0b\x32\'.spark.connect.PipelineCommand.StartRunH\x00R\x08startRun\x12r\n\x19\x64\x65\x66ine_sql_graph_elements\x18\x06 \x01(\x0b\x32\x35.spark.connect.PipelineCommand.DefineSqlGraphElementsH\x00R\x16\x64\x65\x66ineSqlGraphElements\x12\xa1\x01\n*get_query_function_execution_signal_stream\x18\x07 \x01(\x0b\x32\x44.spark.connect.PipelineCommand.GetQueryFunctionExecutionSignalStreamH\x00R%getQueryFunctionExecutionSignalStream\x12\x88\x01\n!define_flow_query_function_result\x18\x08 \x01(\x0b\x32<.spark.connect.PipelineCommand.DefineFlowQueryFunctionResultH\x00R\x1d\x64\x65\x66ineFlowQueryFunctionResult\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\xb4\x02\n\x13\x43reateDataflowGraph\x12,\n\x0f\x64\x65\x66\x61ult_catalog\x18\x01 \x01(\tH\x00R\x0e\x64\x65\x66\x61ultCatalog\x88\x01\x01\x12.\n\x10\x64\x65\x66\x61ult_database\x18\x02 \x01(\tH\x01R\x0f\x64\x65\x66\x61ultDatabase\x88\x01\x01\x12Z\n\x08sql_conf\x18\x05 \x03(\x0b\x32?.spark.connect.PipelineCommand.CreateDataflowGraph.SqlConfEntryR\x07sqlConf\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x12\n\x10_default_catalogB\x13\n\x11_default_database\x1aZ\n\x11\x44ropDataflowGraph\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\x92\n\n\x0c\x44\x65\x66ineOutput\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12$\n\x0boutput_name\x18\x02 \x01(\tH\x02R\noutputName\x88\x01\x01\x12?\n\x0boutput_type\x18\x03 \x01(\x0e\x32\x19.spark.connect.OutputTypeH\x03R\noutputType\x88\x01\x01\x12\x1d\n\x07\x63omment\x18\x04 \x01(\tH\x04R\x07\x63omment\x88\x01\x01\x12X\n\x14source_code_location\x18\x05 \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x05R\x12sourceCodeLocation\x88\x01\x01\x12_\n\rtable_details\x18\x06 \x01(\x0b\x32\x38.spark.connect.PipelineCommand.DefineOutput.TableDetailsH\x00R\x0ctableDetails\x12\\\n\x0csink_details\x18\x07 \x01(\x0b\x32\x37.spark.connect.PipelineCommand.DefineOutput.SinkDetailsH\x00R\x0bsinkDetails\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\xc0\x03\n\x0cTableDetails\x12x\n\x10table_properties\x18\x01 \x03(\x0b\x32M.spark.connect.PipelineCommand.DefineOutput.TableDetails.TablePropertiesEntryR\x0ftableProperties\x12%\n\x0epartition_cols\x18\x02 \x03(\tR\rpartitionCols\x12\x1b\n\x06\x66ormat\x18\x03 \x01(\tH\x01R\x06\x66ormat\x88\x01\x01\x12\x43\n\x10schema_data_type\x18\x04 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x0eschemaDataType\x12%\n\rschema_string\x18\x05 \x01(\tH\x00R\x0cschemaString\x12-\n\x12\x63lustering_columns\x18\x06 \x03(\tR\x11\x63lusteringColumns\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x08\n\x06schemaB\t\n\x07_format\x1a\xd1\x01\n\x0bSinkDetails\x12^\n\x07options\x18\x01 \x03(\x0b\x32\x44.spark.connect.PipelineCommand.DefineOutput.SinkDetails.OptionsEntryR\x07options\x12\x1b\n\x06\x66ormat\x18\x02 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07\x64\x65tailsB\x14\n\x12_dataflow_graph_idB\x0e\n\x0c_output_nameB\x0e\n\x0c_output_typeB\n\n\x08_commentB\x17\n\x15_source_code_location\x1a\xff\x06\n\nDefineFlow\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tflow_name\x18\x02 \x01(\tH\x02R\x08\x66lowName\x88\x01\x01\x12\x33\n\x13target_dataset_name\x18\x03 \x01(\tH\x03R\x11targetDatasetName\x88\x01\x01\x12Q\n\x08sql_conf\x18\x04 \x03(\x0b\x32\x36.spark.connect.PipelineCommand.DefineFlow.SqlConfEntryR\x07sqlConf\x12 \n\tclient_id\x18\x05 \x01(\tH\x04R\x08\x63lientId\x88\x01\x01\x12X\n\x14source_code_location\x18\x06 \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x05R\x12sourceCodeLocation\x88\x01\x01\x12x\n\x15relation_flow_details\x18\x07 \x01(\x0b\x32\x42.spark.connect.PipelineCommand.DefineFlow.WriteRelationFlowDetailsH\x00R\x13relationFlowDetails\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x17\n\x04once\x18\x08 \x01(\x08H\x06R\x04once\x88\x01\x01\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x61\n\x18WriteRelationFlowDetails\x12\x38\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x08relation\x88\x01\x01\x42\x0b\n\t_relation\x1a:\n\x08Response\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x42\x0c\n\n_flow_nameB\t\n\x07\x64\x65tailsB\x14\n\x12_dataflow_graph_idB\x0c\n\n_flow_nameB\x16\n\x14_target_dataset_nameB\x0c\n\n_client_idB\x17\n\x15_source_code_locationB\x07\n\x05_once\x1a\xc2\x02\n\x08StartRun\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x34\n\x16\x66ull_refresh_selection\x18\x02 \x03(\tR\x14\x66ullRefreshSelection\x12-\n\x10\x66ull_refresh_all\x18\x03 \x01(\x08H\x01R\x0e\x66ullRefreshAll\x88\x01\x01\x12+\n\x11refresh_selection\x18\x04 \x03(\tR\x10refreshSelection\x12\x15\n\x03\x64ry\x18\x05 \x01(\x08H\x02R\x03\x64ry\x88\x01\x01\x12\x1d\n\x07storage\x18\x06 \x01(\tH\x03R\x07storage\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x13\n\x11_full_refresh_allB\x06\n\x04_dryB\n\n\x08_storage\x1a\xc7\x01\n\x16\x44\x65\x66ineSqlGraphElements\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\'\n\rsql_file_path\x18\x02 \x01(\tH\x01R\x0bsqlFilePath\x88\x01\x01\x12\x1e\n\x08sql_text\x18\x03 \x01(\tH\x02R\x07sqlText\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x10\n\x0e_sql_file_pathB\x0b\n\t_sql_text\x1a\x9e\x01\n%GetQueryFunctionExecutionSignalStream\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tclient_id\x18\x02 \x01(\tH\x01R\x08\x63lientId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x0c\n\n_client_id\x1a\xdd\x01\n\x1d\x44\x65\x66ineFlowQueryFunctionResult\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x12/\n\x11\x64\x61taflow_graph_id\x18\x02 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x38\n\x08relation\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationH\x02R\x08relation\x88\x01\x01\x42\x0c\n\n_flow_nameB\x14\n\x12_dataflow_graph_idB\x0b\n\t_relationB\x0e\n\x0c\x63ommand_type"\xf0\x05\n\x15PipelineCommandResult\x12\x81\x01\n\x1c\x63reate_dataflow_graph_result\x18\x01 \x01(\x0b\x32>.spark.connect.PipelineCommandResult.CreateDataflowGraphResultH\x00R\x19\x63reateDataflowGraphResult\x12k\n\x14\x64\x65\x66ine_output_result\x18\x02 \x01(\x0b\x32\x37.spark.connect.PipelineCommandResult.DefineOutputResultH\x00R\x12\x64\x65\x66ineOutputResult\x12\x65\n\x12\x64\x65\x66ine_flow_result\x18\x03 \x01(\x0b\x32\x35.spark.connect.PipelineCommandResult.DefineFlowResultH\x00R\x10\x64\x65\x66ineFlowResult\x1a\x62\n\x19\x43reateDataflowGraphResult\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\x85\x01\n\x12\x44\x65\x66ineOutputResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifier\x1a\x83\x01\n\x10\x44\x65\x66ineFlowResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifierB\r\n\x0bresult_type"I\n\x13PipelineEventResult\x12\x32\n\x05\x65vent\x18\x01 \x01(\x0b\x32\x1c.spark.connect.PipelineEventR\x05\x65vent"t\n\rPipelineEvent\x12\x38\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n\x07message\x18\x02 \x01(\tH\x00R\x07message\x88\x01\x01\x42\n\n\x08_message"\xf1\x01\n\x12SourceCodeLocation\x12 \n\tfile_name\x18\x01 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12$\n\x0bline_number\x18\x02 \x01(\x05H\x01R\nlineNumber\x88\x01\x01\x12,\n\x0f\x64\x65\x66inition_path\x18\x03 \x01(\tH\x02R\x0e\x64\x65\x66initionPath\x88\x01\x01\x12\x33\n\textension\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\textensionB\x0c\n\n_file_nameB\x0e\n\x0c_line_numberB\x12\n\x10_definition_path"E\n$PipelineQueryFunctionExecutionSignal\x12\x1d\n\nflow_names\x18\x01 \x03(\tR\tflowNames"\x87\x02\n\x17PipelineAnalysisContext\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12,\n\x0f\x64\x65\x66inition_path\x18\x02 \x01(\tH\x01R\x0e\x64\x65\x66initionPath\x88\x01\x01\x12 \n\tflow_name\x18\x03 \x01(\tH\x02R\x08\x66lowName\x88\x01\x01\x12\x33\n\textension\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\textensionB\x14\n\x12_dataflow_graph_idB\x12\n\x10_definition_pathB\x0c\n\n_flow_name*i\n\nOutputType\x12\x1b\n\x17OUTPUT_TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11MATERIALIZED_VIEW\x10\x01\x12\t\n\x05TABLE\x10\x02\x12\x12\n\x0eTEMPORARY_VIEW\x10\x03\x12\x08\n\x04SINK\x10\x04\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -69,8 +69,8 @@ ]._serialized_options = b"8\001" _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options = b"8\001" - _globals["_OUTPUTTYPE"]._serialized_start = 6139 - _globals["_OUTPUTTYPE"]._serialized_end = 6244 + _globals["_OUTPUTTYPE"]._serialized_start = 6187 + _globals["_OUTPUTTYPE"]._serialized_end = 6292 _globals["_PIPELINECOMMAND"]._serialized_start = 195 _globals["_PIPELINECOMMAND"]._serialized_end = 4656 _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 1129 @@ -126,5 +126,5 @@ _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 5850 _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 5919 _globals["_PIPELINEANALYSISCONTEXT"]._serialized_start = 5922 - _globals["_PIPELINEANALYSISCONTEXT"]._serialized_end = 6137 + _globals["_PIPELINEANALYSISCONTEXT"]._serialized_end = 6185 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi index e0768a1f6bae..39a1e29ae7dd 100644 --- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi +++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi @@ -1499,11 +1499,14 @@ class PipelineAnalysisContext(google.protobuf.message.Message): DATAFLOW_GRAPH_ID_FIELD_NUMBER: builtins.int DEFINITION_PATH_FIELD_NUMBER: builtins.int + FLOW_NAME_FIELD_NUMBER: builtins.int EXTENSION_FIELD_NUMBER: builtins.int dataflow_graph_id: builtins.str """Unique identifier of the dataflow graph associated with this pipeline.""" definition_path: builtins.str """The path of the top-level pipeline file determined at runtime during pipeline initialization.""" + flow_name: builtins.str + """The name of the Flow involved in this analysis""" @property def extension( self, @@ -1516,6 +1519,7 @@ class PipelineAnalysisContext(google.protobuf.message.Message): *, dataflow_graph_id: builtins.str | None = ..., definition_path: builtins.str | None = ..., + flow_name: builtins.str | None = ..., extension: collections.abc.Iterable[google.protobuf.any_pb2.Any] | None = ..., ) -> None: ... def HasField( @@ -1525,10 +1529,14 @@ class PipelineAnalysisContext(google.protobuf.message.Message): b"_dataflow_graph_id", "_definition_path", b"_definition_path", + "_flow_name", + b"_flow_name", "dataflow_graph_id", b"dataflow_graph_id", "definition_path", b"definition_path", + "flow_name", + b"flow_name", ], ) -> builtins.bool: ... def ClearField( @@ -1538,12 +1546,16 @@ class PipelineAnalysisContext(google.protobuf.message.Message): b"_dataflow_graph_id", "_definition_path", b"_definition_path", + "_flow_name", + b"_flow_name", "dataflow_graph_id", b"dataflow_graph_id", "definition_path", b"definition_path", "extension", b"extension", + "flow_name", + b"flow_name", ], ) -> None: ... @typing.overload @@ -1554,5 +1566,9 @@ class PipelineAnalysisContext(google.protobuf.message.Message): def WhichOneof( self, oneof_group: typing_extensions.Literal["_definition_path", b"_definition_path"] ) -> typing_extensions.Literal["definition_path"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_flow_name", b"_flow_name"] + ) -> typing_extensions.Literal["flow_name"] | None: ... global___PipelineAnalysisContext = PipelineAnalysisContext diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto index a92e24fda915..0874c2d10ec5 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto @@ -299,6 +299,8 @@ message PipelineAnalysisContext { optional string dataflow_graph_id = 1; // The path of the top-level pipeline file determined at runtime during pipeline initialization. optional string definition_path = 2; + // The name of the Flow involved in this analysis + optional string flow_name = 3; // Reserved field for protocol extensions. repeated google.protobuf.Any extension = 999; diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 1a3b0d2231c6..4c60e0f70ff4 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connect.pipelines +import scala.collection.Seq import scala.jdk.CollectionConverters._ import scala.util.Using @@ -27,9 +28,10 @@ import org.apache.spark.connect.proto.{ExecutePlanResponse, PipelineCommandResul import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Command, CreateNamespace, CreateTable, CreateTableAsSelect, CreateView, DescribeRelation, DropView, InsertIntoStatement, LogicalPlan, RenameTable, ShowColumns, ShowCreateTable, ShowFunctions, ShowTableProperties, ShowTables, ShowViews} import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.sql.execution.command.{ShowCatalogsCommand, ShowNamespacesCommand} import org.apache.spark.sql.pipelines.Language.Python import org.apache.spark.sql.pipelines.common.RunState.{CANCELED, FAILED} import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis, GraphIdentifierManager, GraphRegistrationContext, IdentifierHelper, NoTables, PipelineUpdateContextImpl, QueryContext, QueryOrigin, QueryOriginType, Sink, SinkImpl, SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView, UnresolvedFlow} @@ -129,6 +131,50 @@ private[connect] object PipelinesHandler extends Logging { } } + /** + * Block SQL commands that have side effects or modify data. + * + * Pipeline definitions should be declarative and side-effect free. This prevents users from + * inadvertently modifying catalogs, creating tables, or performing other stateful operations + * outside the pipeline API boundary during pipeline registration or analysis. + * + * This is a best-effort approach: we block known problematic commands while allowing a curated + * set of read-only operations (e.g., SHOW, DESCRIBE). + */ + def blockUnsupportedSqlCommand(queryPlan: LogicalPlan): Unit = { + val allowlistedCommands = Set( + classOf[DescribeRelation], + classOf[ShowTables], + classOf[ShowTableProperties], + classOf[ShowNamespacesCommand], + classOf[ShowColumns], + classOf[ShowFunctions], + classOf[ShowViews], + classOf[ShowCatalogsCommand], + classOf[ShowCreateTable]) + val isSqlCommandExplicitlyAllowlisted = allowlistedCommands.exists(_.isInstance(queryPlan)) + val isUnsupportedSqlPlan = if (isSqlCommandExplicitlyAllowlisted) { + false + } else { + // Disable all [[Command]] except the ones that are explicitly allowlisted + // in "allowlistedCommands". + queryPlan.isInstanceOf[Command] || + // Following commands are not subclasses of [[Command]] but have side effects. + queryPlan.isInstanceOf[CreateTableAsSelect] || + queryPlan.isInstanceOf[CreateTable] || + queryPlan.isInstanceOf[CreateView] || + queryPlan.isInstanceOf[InsertIntoStatement] || + queryPlan.isInstanceOf[RenameTable] || + queryPlan.isInstanceOf[CreateNamespace] || + queryPlan.isInstanceOf[DropView] + } + if (isUnsupportedSqlPlan) { + throw new AnalysisException( + "UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND", + Map("command" -> queryPlan.getClass.getSimpleName)) + } + } + private def createDataflowGraph( cmd: proto.PipelineCommand.CreateDataflowGraph, sessionHolder: SessionHolder): String = { @@ -148,6 +194,9 @@ private[connect] object PipelinesHandler extends Logging { val defaultSqlConf = cmd.getSqlConfMap.asScala.toMap + sessionHolder.session.catalog.setCurrentCatalog(defaultCatalog) + sessionHolder.session.catalog.setCurrentDatabase(defaultDatabase) + sessionHolder.dataflowGraphRegistry.createDataflowGraph( defaultCatalog = defaultCatalog, defaultDatabase = defaultDatabase, diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 8bc33c41b3a3..644784fa3db6 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -21,11 +21,12 @@ import java.util.{HashMap, Properties, UUID} import scala.collection.mutable import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag import scala.util.Try import scala.util.control.NonFatal import com.google.common.collect.Lists -import com.google.protobuf.{Any => ProtoAny, ByteString} +import com.google.protobuf.{Any => ProtoAny, ByteString, Message} import io.grpc.{Context, Status, StatusRuntimeException} import io.grpc.stub.StreamObserver @@ -33,7 +34,7 @@ import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException, import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult} +import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, PipelineAnalysisContext, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult} import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult import org.apache.spark.connect.proto.Parse.ParseFormat import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance @@ -2941,10 +2942,28 @@ class SparkConnectPlanner( .build()) } + private def getExtensionList[T <: Message: ClassTag]( + extensions: mutable.Buffer[ProtoAny]): Seq[T] = { + val cls = implicitly[ClassTag[T]].runtimeClass + .asInstanceOf[Class[_ <: Message]] + extensions.collect { + case any if any.is(cls) => any.unpack(cls).asInstanceOf[T] + }.toSeq + } + private def handleSqlCommand( command: SqlCommand, responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { val tracker = executeHolder.eventsManager.createQueryPlanningTracker() + val userContextExtensions = executeHolder.request.getUserContext.getExtensionsList.asScala + val pipelineAnalysisContextList = { + getExtensionList[PipelineAnalysisContext](userContextExtensions) + } + val hasPipelineAnalysisContext = pipelineAnalysisContextList.nonEmpty + val insidePipelineFlowFunction = pipelineAnalysisContextList.exists(_.hasFlowName) + // To avoid explicit handling of the result on the client, we build the expected input + // of the relation on the server. The client has to simply forward the result. + val result = SqlCommandResult.newBuilder() val relation = if (command.hasInput) { command.getInput @@ -2964,6 +2983,18 @@ class SparkConnectPlanner( .build() } + // Block unsupported SQL commands if the request comes from Spark Declarative Pipelines. + if (hasPipelineAnalysisContext) { + PipelinesHandler.blockUnsupportedSqlCommand(queryPlan = transformRelation(relation)) + } + + // If the spark.sql() is called inside a pipeline flow function, we don't need to execute + // the SQL command and defer the actual analysis and execution to the flow function. + if (insidePipelineFlowFunction) { + result.setRelation(relation) + return + } + val df = relation.getRelTypeCase match { case proto.Relation.RelTypeCase.SQL => executeSQL(relation.getSql, tracker) @@ -2982,9 +3013,6 @@ class SparkConnectPlanner( case _ => Seq.empty } - // To avoid explicit handling of the result on the client, we build the expected input - // of the relation on the server. The client has to simply forward the result. - val result = SqlCommandResult.newBuilder() // Only filled when isCommand val metrics = ExecutePlanResponse.Metrics.newBuilder() if (isCommand || isSqlScript) { diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index 1a72d112aa2e..1850241f0702 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -50,7 +50,7 @@ class PythonPipelineSuite def buildGraph(pythonText: String): DataflowGraph = { assume(PythonTestDepsChecker.isConnectDepsAvailable) - val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n") + val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n") // create a unique identifier to allow identifying the session and dataflow graph val customSessionIdentifier = UUID.randomUUID().toString val pythonCode = @@ -64,6 +64,9 @@ class PythonPipelineSuite |from pyspark.pipelines.graph_element_registry import ( | graph_element_registration_context, |) + |from pyspark.pipelines.add_pipeline_analysis_context import ( + | add_pipeline_analysis_context + |) | |spark = SparkSession.builder \\ | .remote("sc://localhost:$serverPort") \\ @@ -79,7 +82,10 @@ class PythonPipelineSuite |) | |registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id) - |with graph_element_registration_context(registry): + |with add_pipeline_analysis_context( + | spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name=None + |): + | with graph_element_registration_context(registry): |$indentedPythonText |""".stripMargin @@ -143,7 +149,7 @@ class PythonPipelineSuite QueryOrigin( language = Option(Python()), filePath = Option(""), - line = Option(28), + line = Option(34), objectName = Option("spark_catalog.default.table1"), objectType = Option(QueryOriginType.Flow.toString))), errorChecker = ex => @@ -195,7 +201,7 @@ class PythonPipelineSuite QueryOrigin( language = Option(Python()), filePath = Option(""), - line = Option(34), + line = Option(40), objectName = Option("spark_catalog.default.mv2"), objectType = Option(QueryOriginType.Flow.toString))), expectedEventLevel = EventLevel.INFO) @@ -209,7 +215,7 @@ class PythonPipelineSuite QueryOrigin( language = Option(Python()), filePath = Option(""), - line = Option(38), + line = Option(44), objectName = Option("spark_catalog.default.mv"), objectType = Option(QueryOriginType.Flow.toString))), expectedEventLevel = EventLevel.INFO) @@ -227,7 +233,7 @@ class PythonPipelineSuite QueryOrigin( language = Option(Python()), filePath = Option(""), - line = Option(28), + line = Option(34), objectName = Option("spark_catalog.default.table1"), objectType = Option(QueryOriginType.Flow.toString))), expectedEventLevel = EventLevel.INFO) @@ -241,7 +247,7 @@ class PythonPipelineSuite QueryOrigin( language = Option(Python()), filePath = Option(""), - line = Option(43), + line = Option(49), objectName = Option("spark_catalog.default.standalone_flow1"), objectType = Option(QueryOriginType.Flow.toString))), expectedEventLevel = EventLevel.INFO) @@ -334,21 +340,35 @@ class PythonPipelineSuite |@dp.table |def b(): | return spark.readStream.table("src") + | + |@dp.materialized_view + |def c(): + | return spark.sql("SELECT * FROM src") + | + |@dp.table + |def d(): + | return spark.sql("SELECT * FROM STREAM src") |""".stripMargin).resolve().validate() assert( graph.table.keySet == Set( graphIdentifier("src"), graphIdentifier("a"), - graphIdentifier("b"))) - Seq("a", "b").foreach { flowName => + graphIdentifier("b"), + graphIdentifier("c"), + graphIdentifier("d"))) + Seq("a", "b", "c").foreach { flowName => // dependency is properly tracked assert(graph.resolvedFlow(graphIdentifier(flowName)).inputs == Set(graphIdentifier("src"))) } val (streamingFlows, batchFlows) = graph.resolvedFlows.partition(_.df.isStreaming) - assert(batchFlows.map(_.identifier) == Seq(graphIdentifier("src"), graphIdentifier("a"))) - assert(streamingFlows.map(_.identifier) == Seq(graphIdentifier("b"))) + assert( + batchFlows.map(_.identifier) == Seq( + graphIdentifier("src"), + graphIdentifier("a"), + graphIdentifier("c"))) + assert(streamingFlows.map(_.identifier) == Seq(graphIdentifier("b"), graphIdentifier("d"))) } test("referencing external datasets") { @@ -365,18 +385,33 @@ class PythonPipelineSuite |@dp.table |def c(): | return spark.readStream.table("spark_catalog.default.src") + | + |@dp.materialized_view + |def d(): + | return spark.sql("SELECT * FROM spark_catalog.default.src") + | + |@dp.table + |def e(): + | return spark.sql("SELECT * FROM STREAM spark_catalog.default.src") |""".stripMargin).resolve().validate() assert( graph.tables.map(_.identifier).toSet == Set( graphIdentifier("a"), graphIdentifier("b"), - graphIdentifier("c"))) + graphIdentifier("c"), + graphIdentifier("d"), + graphIdentifier("e"))) // dependency is not tracked assert(graph.resolvedFlows.forall(_.inputs.isEmpty)) val (streamingFlows, batchFlows) = graph.resolvedFlows.partition(_.df.isStreaming) - assert(batchFlows.map(_.identifier).toSet == Set(graphIdentifier("a"), graphIdentifier("b"))) - assert(streamingFlows.map(_.identifier) == Seq(graphIdentifier("c"))) + assert( + batchFlows.map(_.identifier).toSet == Set( + graphIdentifier("a"), + graphIdentifier("b"), + graphIdentifier("d"))) + assert( + streamingFlows.map(_.identifier).toSet == Set(graphIdentifier("c"), graphIdentifier("e"))) } test("referencing internal datasets failed") { @@ -392,9 +427,17 @@ class PythonPipelineSuite |@dp.table |def c(): | return spark.readStream.table("src") + | + |@dp.materialized_view + |def d(): + | return spark.sql("SELECT * FROM src") + | + |@dp.table + |def e(): + | return spark.sql("SELECT * FROM STREAM src") |""".stripMargin).resolve() - assert(graph.resolutionFailedFlows.size == 3) + assert(graph.resolutionFailedFlows.size == 5) graph.resolutionFailedFlows.foreach { flow => assert(flow.failure.head.getMessage.contains("[TABLE_OR_VIEW_NOT_FOUND]")) assert(flow.failure.head.getMessage.contains("`src`")) @@ -414,12 +457,94 @@ class PythonPipelineSuite |@dp.materialized_view |def c(): | return spark.readStream.table("spark_catalog.default.src") + | + |@dp.materialized_view + |def d(): + | return spark.sql("SELECT * FROM spark_catalog.default.src") + | + |@dp.table + |def e(): + | return spark.sql("SELECT * FROM STREAM spark_catalog.default.src") |""".stripMargin).resolve() + assert(graph.resolutionFailedFlows.size == 5) graph.resolutionFailedFlows.foreach { flow => - assert(flow.failure.head.getMessage.contains("[TABLE_OR_VIEW_NOT_FOUND] The table or view")) + assert(flow.failure.head.getMessage.contains("[TABLE_OR_VIEW_NOT_FOUND]")) + assert(flow.failure.head.getMessage.contains("`spark_catalog`.`default`.`src`")) } } + test("reading external datasets outside query function works") { + sql("CREATE TABLE spark_catalog.default.src AS SELECT * FROM RANGE(5)") + val graph = buildGraph(s""" + |spark_sql_df = spark.sql("SELECT * FROM spark_catalog.default.src") + |read_table_df = spark.read.table("spark_catalog.default.src") + | + |@dp.materialized_view + |def mv_from_spark_sql_df(): + | return spark_sql_df + | + |@dp.materialized_view + |def mv_from_read_table_df(): + | return read_table_df + |""".stripMargin).resolve().validate() + + assert( + graph.resolvedFlows.map(_.identifier).toSet == Set( + graphIdentifier("mv_from_spark_sql_df"), + graphIdentifier("mv_from_read_table_df"))) + assert(graph.resolvedFlows.forall(_.inputs.isEmpty)) + assert(graph.resolvedFlows.forall(!_.df.isStreaming)) + } + + test( + "reading internal datasets outside query function that don't trigger " + + "eager analysis or execution") { + val graph = buildGraph(""" + |@dp.materialized_view + |def src(): + | return spark.range(5) + | + |read_table_df = spark.read.table("src") + | + |@dp.materialized_view + |def mv_from_read_table_df(): + | return read_table_df + | + |""".stripMargin).resolve().validate() + assert( + graph.resolvedFlows.map(_.identifier).toSet == Set( + graphIdentifier("mv_from_read_table_df"), + graphIdentifier("src"))) + assert(graph.resolvedFlows.forall(!_.df.isStreaming)) + assert( + graph + .resolvedFlow(graphIdentifier("mv_from_read_table_df")) + .inputs + .contains(graphIdentifier("src"))) + } + + gridTest( + "reading internal datasets outside query function that trigger " + + "eager analysis or execution will fail")( + Seq("""spark.sql("SELECT * FROM src")""", """spark.read.table("src").collect()""")) { + command => + val ex = intercept[RuntimeException] { + buildGraph(s""" + |@dp.materialized_view + |def src(): + | return spark.range(5) + | + |spark_sql_df = $command + | + |@dp.materialized_view + |def mv_from_spark_sql_df(): + | return spark_sql_df + |""".stripMargin) + } + assert(ex.getMessage.contains("TABLE_OR_VIEW_NOT_FOUND")) + assert(ex.getMessage.contains("`src`")) + } + test("create dataset with the same name will fail") { assume(PythonTestDepsChecker.isConnectDepsAvailable) val ex = intercept[AnalysisException] { @@ -902,4 +1027,82 @@ class PythonPipelineSuite s"Table should have no transforms, but got: ${stTransforms.mkString(", ")}") } } + + // List of unsupported SQL commands that should result in a failure. + private val unsupportedSqlCommandList: Seq[String] = Seq( + "SET CATALOG some_catalog", + "USE SCHEMA some_schema", + "SET `test_conf` = `true`", + "CREATE TABLE some_table (id INT)", + "CREATE VIEW some_view AS SELECT * FROM some_table", + "INSERT INTO some_table VALUES (1)", + "ALTER TABLE some_table RENAME TO some_new_table", + "CREATE NAMESPACE some_namespace", + "DROP VIEW some_view", + "CREATE MATERIALIZED VIEW some_view AS SELECT * FROM some_table", + "CREATE STREAMING TABLE some_table AS SELECT * FROM some_table") + + gridTest("Unsupported SQL command outside query function should result in a failure")( + unsupportedSqlCommandList) { unsupportedSqlCommand => + val ex = intercept[RuntimeException] { + buildGraph(s""" + |spark.sql("$unsupportedSqlCommand") + | + |@dp.materialized_view() + |def mv(): + | return spark.range(5) + |""".stripMargin) + } + assert(ex.getMessage.contains("UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND")) + } + + gridTest("Unsupported SQL command inside query function should result in a failure")( + unsupportedSqlCommandList) { unsupportedSqlCommand => + val ex = intercept[RuntimeException] { + buildGraph(s""" + |@dp.materialized_view() + |def mv(): + | spark.sql("$unsupportedSqlCommand") + | return spark.range(5) + |""".stripMargin) + } + assert(ex.getMessage.contains("UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND")) + } + + // List of supported SQL commands that should work. + val supportedSqlCommandList: Seq[String] = Seq( + "DESCRIBE TABLE spark_catalog.default.src", + "SHOW TABLES", + "SHOW TBLPROPERTIES spark_catalog.default.src", + "SHOW NAMESPACES", + "SHOW COLUMNS FROM spark_catalog.default.src", + "SHOW FUNCTIONS", + "SHOW VIEWS", + "SHOW CATALOGS", + "SHOW CREATE TABLE spark_catalog.default.src", + "SELECT * FROM RANGE(5)", + "SELECT * FROM spark_catalog.default.src") + + gridTest("Supported SQL command outside query function should work")(supportedSqlCommandList) { + supportedSqlCommand => + sql("CREATE TABLE spark_catalog.default.src AS SELECT * FROM RANGE(5)") + buildGraph(s""" + |spark.sql("$supportedSqlCommand") + | + |@dp.materialized_view() + |def mv(): + | return spark.range(5) + |""".stripMargin) + } + + gridTest("Supported SQL command inside query function should work")(supportedSqlCommandList) { + supportedSqlCommand => + sql("CREATE TABLE spark_catalog.default.src AS SELECT * FROM RANGE(5)") + buildGraph(s""" + |@dp.materialized_view() + |def mv(): + | spark.sql("$supportedSqlCommand") + | return spark.range(5) + |""".stripMargin) + } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala index ab60462e8735..c9551646385c 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala @@ -51,6 +51,32 @@ class SparkDeclarativePipelinesServerSuite } } + test( + "create dataflow graph set session catalog and database to pipeline " + + "default catalog and database") { + withRawBlockingStub { implicit stub => + // Use default spark_catalog and create a test database + sql("CREATE DATABASE IF NOT EXISTS test_db") + try { + val graphId = sendPlan( + buildCreateDataflowGraphPlan( + proto.PipelineCommand.CreateDataflowGraph + .newBuilder() + .setDefaultCatalog("spark_catalog") + .setDefaultDatabase("test_db") + .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId + val definition = + getDefaultSessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId) + assert(definition.defaultCatalog == "spark_catalog") + assert(definition.defaultDatabase == "test_db") + assert(getDefaultSessionHolder.session.catalog.currentCatalog() == "spark_catalog") + assert(getDefaultSessionHolder.session.catalog.currentDatabase == "test_db") + } finally { + sql("DROP DATABASE IF EXISTS test_db") + } + } + } + test("Define a flow for a graph that does not exist") { val ex = intercept[Exception] { withRawBlockingStub { implicit stub => @@ -515,8 +541,7 @@ class SparkDeclarativePipelinesServerSuite name: String, datasetType: OutputType, datasetName: String, - defaultCatalog: String = "", - defaultDatabase: String = "", + defaultDatabase: String, expectedResolvedDatasetName: String, expectedResolvedCatalog: String, expectedResolvedNamespace: Seq[String]) @@ -526,6 +551,7 @@ class SparkDeclarativePipelinesServerSuite name = "TEMPORARY_VIEW", datasetType = OutputType.TEMPORARY_VIEW, datasetName = "tv", + defaultDatabase = "default", expectedResolvedDatasetName = "tv", expectedResolvedCatalog = "", expectedResolvedNamespace = Seq.empty), @@ -533,6 +559,7 @@ class SparkDeclarativePipelinesServerSuite name = "TABLE", datasetType = OutputType.TABLE, datasetName = "`tb`", + defaultDatabase = "default", expectedResolvedDatasetName = "tb", expectedResolvedCatalog = "spark_catalog", expectedResolvedNamespace = Seq("default")), @@ -540,6 +567,7 @@ class SparkDeclarativePipelinesServerSuite name = "MV", datasetType = OutputType.MATERIALIZED_VIEW, datasetName = "mv", + defaultDatabase = "default", expectedResolvedDatasetName = "mv", expectedResolvedCatalog = "spark_catalog", expectedResolvedNamespace = Seq("default"))).map(tc => tc.name -> tc).toMap @@ -549,7 +577,6 @@ class SparkDeclarativePipelinesServerSuite name = "TEMPORARY_VIEW", datasetType = OutputType.TEMPORARY_VIEW, datasetName = "tv", - defaultCatalog = "custom_catalog", defaultDatabase = "custom_db", expectedResolvedDatasetName = "tv", expectedResolvedCatalog = "", @@ -558,19 +585,17 @@ class SparkDeclarativePipelinesServerSuite name = "TABLE", datasetType = OutputType.TABLE, datasetName = "`tb`", - defaultCatalog = "`my_catalog`", defaultDatabase = "`my_db`", expectedResolvedDatasetName = "tb", - expectedResolvedCatalog = "`my_catalog`", + expectedResolvedCatalog = "spark_catalog", expectedResolvedNamespace = Seq("`my_db`")), DefineOutputTestCase( name = "MV", datasetType = OutputType.MATERIALIZED_VIEW, datasetName = "mv", - defaultCatalog = "another_catalog", defaultDatabase = "another_db", expectedResolvedDatasetName = "mv", - expectedResolvedCatalog = "another_catalog", + expectedResolvedCatalog = "spark_catalog", expectedResolvedNamespace = Seq("another_db"))) .map(tc => tc.name -> tc) .toMap @@ -604,40 +629,45 @@ class SparkDeclarativePipelinesServerSuite } } - namedGridTest("DefineOutput returns resolved data name for custom catalog/schema")( + namedGridTest("DefineOutput returns resolved data name for custom schema")( defineDatasetCustomTests) { testCase => withRawBlockingStub { implicit stub => - // Build and send the CreateDataflowGraph command with custom catalog/db - val graphId = sendPlan( - buildCreateDataflowGraphPlan( - proto.PipelineCommand.CreateDataflowGraph - .newBuilder() - .setDefaultCatalog(testCase.defaultCatalog) - .setDefaultDatabase(testCase.defaultDatabase) - .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId + sql(s"CREATE DATABASE IF NOT EXISTS spark_catalog.${testCase.defaultDatabase}") + try { + // Build and send the CreateDataflowGraph command with custom catalog/db + val graphId = sendPlan( + buildCreateDataflowGraphPlan( + proto.PipelineCommand.CreateDataflowGraph + .newBuilder() + .setDefaultCatalog("spark_catalog") + .setDefaultDatabase(testCase.defaultDatabase) + .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId - assert(graphId.nonEmpty) + assert(graphId.nonEmpty) - // Build DefineOutput with the created graphId and dataset info - val defineDataset = DefineOutput - .newBuilder() - .setDataflowGraphId(graphId) - .setOutputName(testCase.datasetName) - .setOutputType(testCase.datasetType) - val pipelineCmd = PipelineCommand - .newBuilder() - .setDefineOutput(defineDataset) - .build() - - val res = sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult - assert(res !== PipelineCommandResult.getDefaultInstance) - assert(res.hasDefineOutputResult) - val graphResult = res.getDefineOutputResult - val identifier = graphResult.getResolvedIdentifier + // Build DefineOutput with the created graphId and dataset info + val defineDataset = DefineOutput + .newBuilder() + .setDataflowGraphId(graphId) + .setOutputName(testCase.datasetName) + .setOutputType(testCase.datasetType) + val pipelineCmd = PipelineCommand + .newBuilder() + .setDefineOutput(defineDataset) + .build() - assert(identifier.getCatalogName == testCase.expectedResolvedCatalog) - assert(identifier.getNamespaceList.asScala == testCase.expectedResolvedNamespace) - assert(identifier.getTableName == testCase.expectedResolvedDatasetName) + val res = sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult + assert(res !== PipelineCommandResult.getDefaultInstance) + assert(res.hasDefineOutputResult) + val graphResult = res.getDefineOutputResult + val identifier = graphResult.getResolvedIdentifier + + assert(identifier.getCatalogName == testCase.expectedResolvedCatalog) + assert(identifier.getNamespaceList.asScala == testCase.expectedResolvedNamespace) + assert(identifier.getTableName == testCase.expectedResolvedDatasetName) + } finally { + sql(s"DROP DATABASE IF EXISTS spark_catalog.${testCase.defaultDatabase}") + } } } @@ -645,7 +675,6 @@ class SparkDeclarativePipelinesServerSuite name: String, datasetType: OutputType, flowName: String, - defaultCatalog: String, defaultDatabase: String, expectedResolvedFlowName: String, expectedResolvedCatalog: String, @@ -656,7 +685,6 @@ class SparkDeclarativePipelinesServerSuite name = "MV", datasetType = OutputType.MATERIALIZED_VIEW, flowName = "`mv`", - defaultCatalog = "`spark_catalog`", defaultDatabase = "`default`", expectedResolvedFlowName = "mv", expectedResolvedCatalog = "spark_catalog", @@ -665,7 +693,6 @@ class SparkDeclarativePipelinesServerSuite name = "TV", datasetType = OutputType.TEMPORARY_VIEW, flowName = "tv", - defaultCatalog = "spark_catalog", defaultDatabase = "default", expectedResolvedFlowName = "tv", expectedResolvedCatalog = "", @@ -676,16 +703,14 @@ class SparkDeclarativePipelinesServerSuite name = "MV custom", datasetType = OutputType.MATERIALIZED_VIEW, flowName = "mv", - defaultCatalog = "custom_catalog", defaultDatabase = "custom_db", expectedResolvedFlowName = "mv", - expectedResolvedCatalog = "custom_catalog", + expectedResolvedCatalog = "spark_catalog", expectedResolvedNamespace = Seq("custom_db")), DefineFlowTestCase( name = "TV custom", datasetType = OutputType.TEMPORARY_VIEW, flowName = "tv", - defaultCatalog = "custom_catalog", defaultDatabase = "custom_db", expectedResolvedFlowName = "tv", expectedResolvedCatalog = "", @@ -756,68 +781,73 @@ class SparkDeclarativePipelinesServerSuite namedGridTest("DefineFlow returns resolved data name for custom catalog/schema")( defineFlowCustomTests) { testCase => withRawBlockingStub { implicit stub => - val graphId = sendPlan( - buildCreateDataflowGraphPlan( - proto.PipelineCommand.CreateDataflowGraph + sql(s"CREATE DATABASE IF NOT EXISTS spark_catalog.${testCase.defaultDatabase}") + try { + val graphId = sendPlan( + buildCreateDataflowGraphPlan( + proto.PipelineCommand.CreateDataflowGraph + .newBuilder() + .setDefaultCatalog("spark_catalog") + .setDefaultDatabase(testCase.defaultDatabase) + .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId + assert(graphId.nonEmpty) + + // If the dataset type is TEMPORARY_VIEW, define the dataset explicitly first + if (testCase.datasetType == OutputType.TEMPORARY_VIEW) { + val defineDataset = DefineOutput .newBuilder() - .setDefaultCatalog(testCase.defaultCatalog) - .setDefaultDatabase(testCase.defaultDatabase) - .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId - assert(graphId.nonEmpty) + .setDataflowGraphId(graphId) + .setOutputName(testCase.flowName) + .setOutputType(OutputType.TEMPORARY_VIEW) - // If the dataset type is TEMPORARY_VIEW, define the dataset explicitly first - if (testCase.datasetType == OutputType.TEMPORARY_VIEW) { - val defineDataset = DefineOutput + val defineDatasetCmd = PipelineCommand + .newBuilder() + .setDefineOutput(defineDataset) + .build() + + val datasetRes = + sendPlan(buildPlanFromPipelineCommand(defineDatasetCmd)).getPipelineCommandResult + assert(datasetRes.hasDefineOutputResult) + } + + val defineFlow = DefineFlow .newBuilder() .setDataflowGraphId(graphId) - .setOutputName(testCase.flowName) - .setOutputType(OutputType.TEMPORARY_VIEW) - - val defineDatasetCmd = PipelineCommand + .setFlowName(testCase.flowName) + .setTargetDatasetName(testCase.flowName) + .setRelationFlowDetails( + DefineFlow.WriteRelationFlowDetails + .newBuilder() + .setRelation( + Relation + .newBuilder() + .setUnresolvedTableValuedFunction( + UnresolvedTableValuedFunction + .newBuilder() + .setFunctionName("range") + .addArguments(Expression + .newBuilder() + .setLiteral(Expression.Literal.newBuilder().setInteger(5).build()) + .build()) + .build()) + .build()) + .build()) + .build() + val pipelineCmd = PipelineCommand .newBuilder() - .setDefineOutput(defineDataset) + .setDefineFlow(defineFlow) .build() - - val datasetRes = - sendPlan(buildPlanFromPipelineCommand(defineDatasetCmd)).getPipelineCommandResult - assert(datasetRes.hasDefineOutputResult) + val res = sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult + assert(res.hasDefineFlowResult) + val graphResult = res.getDefineFlowResult + val identifier = graphResult.getResolvedIdentifier + + assert(identifier.getCatalogName == testCase.expectedResolvedCatalog) + assert(identifier.getNamespaceList.asScala == testCase.expectedResolvedNamespace) + assert(identifier.getTableName == testCase.expectedResolvedFlowName) + } finally { + sql(s"DROP DATABASE IF EXISTS spark_catalog.${testCase.defaultDatabase}") } - - val defineFlow = DefineFlow - .newBuilder() - .setDataflowGraphId(graphId) - .setFlowName(testCase.flowName) - .setTargetDatasetName(testCase.flowName) - .setRelationFlowDetails( - DefineFlow.WriteRelationFlowDetails - .newBuilder() - .setRelation( - Relation - .newBuilder() - .setUnresolvedTableValuedFunction( - UnresolvedTableValuedFunction - .newBuilder() - .setFunctionName("range") - .addArguments(Expression - .newBuilder() - .setLiteral(Expression.Literal.newBuilder().setInteger(5).build()) - .build()) - .build()) - .build()) - .build()) - .build() - val pipelineCmd = PipelineCommand - .newBuilder() - .setDefineFlow(defineFlow) - .build() - val res = sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult - assert(res.hasDefineFlowResult) - val graphResult = res.getDefineFlowResult - val identifier = graphResult.getResolvedIdentifier - - assert(identifier.getCatalogName == testCase.expectedResolvedCatalog) - assert(identifier.getNamespaceList.asScala == testCase.expectedResolvedNamespace) - assert(identifier.getTableName == testCase.expectedResolvedFlowName) } } }