Skip to content

Commit 17800b5

Browse files
committed
dump
1 parent 0606ffd commit 17800b5

File tree

12 files changed

+450
-245
lines changed

12 files changed

+450
-245
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
from contextlib import contextmanager
18+
from typing import Generator
19+
from pyspark.sql import SparkSession
20+
21+
@contextmanager
22+
def add_pipeline_analysis_context(spark: SparkSession, dataflow_graph_id: str, flow_name: str) -> Generator[None, None, None]:
23+
"""
24+
Context manager that add PipelineAnalysisContext extension to the user context
25+
used for pipeline specific analysis.
26+
"""
27+
extension_id = None
28+
try:
29+
import pyspark.sql.connect.proto as pb2
30+
from google.protobuf import any_pb2
31+
32+
analysis_context = pb2.PipelineAnalysisContext(
33+
dataflow_graph_id=dataflow_graph_id,
34+
flow_name=flow_name
35+
)
36+
37+
extension = any_pb2.Any()
38+
extension.Pack(analysis_context)
39+
40+
extension_id = spark.addThreadlocalUserContextExtension(extension)
41+
yield
42+
finally:
43+
spark.removeUserContextExtension(extension_id)

python/pyspark/pipelines/spark_connect_graph_element_registry.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pyspark.sql.types import StructType
3636
from typing import Any, cast
3737
import pyspark.sql.connect.proto as pb2
38+
from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context
3839

3940

4041
class SparkConnectGraphElementRegistry(GraphElementRegistry):
@@ -43,6 +44,7 @@ class SparkConnectGraphElementRegistry(GraphElementRegistry):
4344
def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None:
4445
# Cast because mypy seems to think `spark`` is a function, not an object. Likely related to
4546
# SPARK-47544.
47+
self._spark = spark
4648
self._client = cast(Any, spark).client
4749
self._dataflow_graph_id = dataflow_graph_id
4850

@@ -110,8 +112,9 @@ def register_output(self, output: Output) -> None:
110112
self._client.execute_command(command)
111113

112114
def register_flow(self, flow: Flow) -> None:
113-
with block_spark_connect_execution_and_analysis():
114-
df = flow.func()
115+
with add_pipeline_analysis_context(spark=self._spark, dataflow_graph_id = self._dataflow_graph_id, flow_name = flow.name):
116+
with block_spark_connect_execution_and_analysis():
117+
df = flow.func()
115118
relation = cast(ConnectDataFrame, df)._plan.plan(self._client)
116119

117120
relation_flow_details = pb2.PipelineCommand.DefineFlow.WriteRelationFlowDetails(

python/pyspark/sql/connect/proto/pipelines_pb2.py

Lines changed: 4 additions & 4 deletions
Large diffs are not rendered by default.

python/pyspark/sql/connect/proto/pipelines_pb2.pyi

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,11 +1499,14 @@ class PipelineAnalysisContext(google.protobuf.message.Message):
14991499

15001500
DATAFLOW_GRAPH_ID_FIELD_NUMBER: builtins.int
15011501
DEFINITION_PATH_FIELD_NUMBER: builtins.int
1502+
FLOW_NAME_FIELD_NUMBER: builtins.int
15021503
EXTENSION_FIELD_NUMBER: builtins.int
15031504
dataflow_graph_id: builtins.str
15041505
"""Unique identifier of the dataflow graph associated with this pipeline."""
15051506
definition_path: builtins.str
15061507
"""The path of the top-level pipeline file determined at runtime during pipeline initialization."""
1508+
flow_name: builtins.str
1509+
"""The name of the Flow involved in this analysis"""
15071510
@property
15081511
def extension(
15091512
self,
@@ -1516,6 +1519,7 @@ class PipelineAnalysisContext(google.protobuf.message.Message):
15161519
*,
15171520
dataflow_graph_id: builtins.str | None = ...,
15181521
definition_path: builtins.str | None = ...,
1522+
flow_name: builtins.str | None = ...,
15191523
extension: collections.abc.Iterable[google.protobuf.any_pb2.Any] | None = ...,
15201524
) -> None: ...
15211525
def HasField(
@@ -1525,10 +1529,14 @@ class PipelineAnalysisContext(google.protobuf.message.Message):
15251529
b"_dataflow_graph_id",
15261530
"_definition_path",
15271531
b"_definition_path",
1532+
"_flow_name",
1533+
b"_flow_name",
15281534
"dataflow_graph_id",
15291535
b"dataflow_graph_id",
15301536
"definition_path",
15311537
b"definition_path",
1538+
"flow_name",
1539+
b"flow_name",
15321540
],
15331541
) -> builtins.bool: ...
15341542
def ClearField(
@@ -1538,12 +1546,16 @@ class PipelineAnalysisContext(google.protobuf.message.Message):
15381546
b"_dataflow_graph_id",
15391547
"_definition_path",
15401548
b"_definition_path",
1549+
"_flow_name",
1550+
b"_flow_name",
15411551
"dataflow_graph_id",
15421552
b"dataflow_graph_id",
15431553
"definition_path",
15441554
b"definition_path",
15451555
"extension",
15461556
b"extension",
1557+
"flow_name",
1558+
b"flow_name",
15471559
],
15481560
) -> None: ...
15491561
@typing.overload
@@ -1554,5 +1566,9 @@ class PipelineAnalysisContext(google.protobuf.message.Message):
15541566
def WhichOneof(
15551567
self, oneof_group: typing_extensions.Literal["_definition_path", b"_definition_path"]
15561568
) -> typing_extensions.Literal["definition_path"] | None: ...
1569+
@typing.overload
1570+
def WhichOneof(
1571+
self, oneof_group: typing_extensions.Literal["_flow_name", b"_flow_name"]
1572+
) -> typing_extensions.Literal["flow_name"] | None: ...
15571573

15581574
global___PipelineAnalysisContext = PipelineAnalysisContext

sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,8 @@ message PipelineAnalysisContext {
299299
optional string dataflow_graph_id = 1;
300300
// The path of the top-level pipeline file determined at runtime during pipeline initialization.
301301
optional string definition_path = 2;
302+
// The name of the Flow involved in this analysis
303+
optional string flow_name = 3;
302304

303305
// Reserved field for protocol extensions.
304306
repeated google.protobuf.Any extension = 999;

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,25 @@
1717

1818
package org.apache.spark.sql.connect.pipelines
1919

20-
import scala.jdk.CollectionConverters._
21-
import scala.util.Using
22-
2320
import io.grpc.stub.StreamObserver
24-
2521
import org.apache.spark.connect.proto
26-
import org.apache.spark.connect.proto.{ExecutePlanResponse, PipelineCommandResult, Relation, ResolvedIdentifier}
22+
import org.apache.spark.connect.proto._
2723
import org.apache.spark.internal.Logging
2824
import org.apache.spark.sql.AnalysisException
2925
import org.apache.spark.sql.catalyst.TableIdentifier
3026
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
27+
import org.apache.spark.sql.classic.DataFrame
3128
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
3229
import org.apache.spark.sql.connect.service.SessionHolder
3330
import org.apache.spark.sql.pipelines.Language.Python
3431
import org.apache.spark.sql.pipelines.common.RunState.{CANCELED, FAILED}
35-
import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis, GraphIdentifierManager, GraphRegistrationContext, IdentifierHelper, NoTables, PipelineUpdateContextImpl, QueryContext, QueryOrigin, QueryOriginType, Sink, SinkImpl, SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView, UnresolvedFlow}
32+
import org.apache.spark.sql.pipelines.graph._
3633
import org.apache.spark.sql.pipelines.logging.{PipelineEvent, RunProgress}
3734
import org.apache.spark.sql.types.StructType
3835

36+
import scala.jdk.CollectionConverters._
37+
import scala.util.Using
38+
3939
/** Handler for SparkConnect PipelineCommands */
4040
private[connect] object PipelinesHandler extends Logging {
4141

@@ -47,8 +47,6 @@ private[connect] object PipelinesHandler extends Logging {
4747
* Command to be handled
4848
* @param responseObserver
4949
* The response observer where the response will be sent
50-
* @param sparkSession
51-
* The spark session
5250
* @param transformRelationFunc
5351
* Function used to convert a relation to a LogicalPlan. This is used when determining the
5452
* LogicalPlan that a flow returns.
@@ -108,7 +106,6 @@ private[connect] object PipelinesHandler extends Logging {
108106
identifierBuilder.addNamespace(ns)
109107
}
110108
identifierBuilder.setTableName(resolvedFlow.identifier)
111-
val identifier = identifierBuilder.build()
112109
PipelineCommandResult
113110
.newBuilder()
114111
.setDefineFlowResult(
@@ -129,6 +126,24 @@ private[connect] object PipelinesHandler extends Logging {
129126
}
130127
}
131128

129+
def executeSQL(
130+
sessionHolder: SessionHolder,
131+
plan: LogicalPlan,
132+
pipelineAnalysisContext: PipelineAnalysisContext
133+
): DataFrame = {
134+
val graphRegistrationContext = {
135+
sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(
136+
pipelineAnalysisContext.getDataflowGraphId)
137+
}
138+
val pipelineSqlProcessor = new PipelineSqlProcessor(graphRegistrationContext)
139+
val context = ExternalQueryAnalysisContext(
140+
queryContext = QueryContext(
141+
currentCatalog = Option(graphRegistrationContext.defaultCatalog),
142+
currentDatabase = Option(graphRegistrationContext.defaultDatabase)),
143+
spark = sessionHolder.session)
144+
pipelineSqlProcessor.processSparkSqlQuery(queryPlan = plan, context = context)
145+
}
146+
132147
private def createDataflowGraph(
133148
cmd: proto.PipelineCommand.CreateDataflowGraph,
134149
sessionHolder: SessionHolder): String = {
@@ -161,7 +176,7 @@ private[connect] object PipelinesHandler extends Logging {
161176

162177
val graphElementRegistry =
163178
sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
164-
val sqlGraphElementRegistrationContext = new SqlGraphRegistrationContext(graphElementRegistry)
179+
val sqlGraphElementRegistrationContext = new PipelineSqlProcessor(graphElementRegistry)
165180
sqlGraphElementRegistrationContext.processSqlFile(
166181
cmd.getSqlText,
167182
cmd.getSqlFilePath,
@@ -293,8 +308,7 @@ private[connect] object PipelinesHandler extends Logging {
293308
val rawDestinationIdentifier = GraphIdentifierManager
294309
.parseTableIdentifier(name = flow.getTargetDatasetName, spark = sessionHolder.session)
295310
val flowWritesToView =
296-
graphElementRegistry
297-
.getViews()
311+
graphElementRegistry.getViews
298312
.filter(_.isInstanceOf[TemporaryView])
299313
.exists(_.identifier == rawDestinationIdentifier)
300314
val flowWritesToSink =
@@ -304,7 +318,7 @@ private[connect] object PipelinesHandler extends Logging {
304318
// If the flow is created implicitly as part of defining a view or that it writes to a sink,
305319
// then we do not qualify the flow identifier and the flow destination. This is because
306320
// views and sinks are not permitted to have multipart
307-
val isImplicitFlowForTempView = (isImplicitFlow && flowWritesToView)
321+
val isImplicitFlowForTempView = isImplicitFlow && flowWritesToView
308322
val Seq(flowIdentifier, destinationIdentifier) =
309323
Seq(rawFlowIdentifier, rawDestinationIdentifier).map { rawIdentifier =>
310324
if (isImplicitFlowForTempView || flowWritesToSink) {
@@ -330,8 +344,8 @@ private[connect] object PipelinesHandler extends Logging {
330344
once = false,
331345
queryContext = QueryContext(Option(defaultCatalog), Option(defaultDatabase)),
332346
origin = QueryOrigin(
333-
filePath = Option.when(flow.getSourceCodeLocation.hasFileName)(
334-
flow.getSourceCodeLocation.getFileName),
347+
filePath = Option
348+
.when(flow.getSourceCodeLocation.hasFileName)(flow.getSourceCodeLocation.getFileName),
335349
line = Option.when(flow.getSourceCodeLocation.hasLineNumber)(
336350
flow.getSourceCodeLocation.getLineNumber),
337351
objectType = Option(QueryOriginType.Flow.toString),

0 commit comments

Comments
 (0)