Skip to content

Commit 8c3adfd

Browse files
committed
more
1 parent ccae316 commit 8c3adfd

22 files changed

+121
-73
lines changed

bin/spark-pipelines

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ fi
3030
export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH"
3131
export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.9.9-src.zip:$PYTHONPATH"
3232

33-
${SPARK_HOME}/bin/spark-submit --conf spark.api.mode=connect "${SPARK_HOME}"/python/pyspark/sql/pipelines/cli.py "$@"
33+
$PYSPARK_PYTHON "${SPARK_HOME}"/python/pyspark/pipelines/cli.py "$@"

dev/sparktestsupport/modules.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,17 @@ def __hash__(self):
411411
pipelines = Module(
412412
name="pipelines",
413413
dependencies=[],
414-
source_file_regexes=["sql/pipelines"],
414+
source_file_regexes=["sql/pipelines", "python/pyspark/pipelines"],
415415
sbt_test_goals=[
416416
"pipelines/test",
417417
],
418+
python_test_goals=[
419+
"pyspark.pipelines.tests.test_block_connect_access",
420+
"pyspark.pipelines.tests.test_cli",
421+
"pyspark.pipelines.tests.test_decorators",
422+
"pyspark.pipelines.tests.test_graph_element_registry",
423+
"pyspark.pipelines.tests.test_init_cli",
424+
],
418425
)
419426

420427
connect = Module(
@@ -556,10 +563,6 @@ def __hash__(self):
556563
"pyspark.sql.tests.pandas.test_pandas_udf_window",
557564
"pyspark.sql.tests.pandas.test_pandas_sqlmetrics",
558565
"pyspark.sql.tests.pandas.test_converter",
559-
"pyspark.sql.tests.pipelines.test_block_connect_access",
560-
"pyspark.sql.tests.pipelines.test_cli",
561-
"pyspark.sql.tests.pipelines.test_decorators",
562-
"pyspark.sql.tests.pipelines.test_graph_element_registry",
563566
"pyspark.sql.tests.test_python_datasource",
564567
"pyspark.sql.tests.test_python_streaming_datasource",
565568
"pyspark.sql.tests.test_readwriter",

python/mypy.ini

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ ignore_errors = True
117117
[mypy-pyspark.pandas.tests.*]
118118
ignore_errors = True
119119

120+
[mypy-pyspark.pipelines.tests.*]
121+
ignore_errors = True
122+
120123
[mypy-pyspark.tests.*]
121124
ignore_errors = True
122125

@@ -191,6 +194,3 @@ ignore_missing_imports = True
191194
; Ignore errors for proto generated code
192195
[mypy-pyspark.sql.connect.proto.*, pyspark.sql.connect.proto, pyspark.sql.streaming.proto]
193196
ignore_errors = True
194-
195-
[mypy-pyspark.sql.pipelines.proto.*]
196-
ignore_errors = True

python/pyspark/sql/pipelines/__init__.py renamed to python/pyspark/pipelines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
from pyspark.sql.pipelines.api import (
17+
from pyspark.pipelines.api import (
1818
append_flow,
1919
create_streaming_table,
2020
materialized_view,

python/pyspark/sql/pipelines/api.py renamed to python/pyspark/pipelines/api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from typing import Callable, Dict, List, Optional, Union, overload
1818

1919
from pyspark.errors import PySparkTypeError
20-
from pyspark.sql.pipelines.graph_element_registry import get_active_graph_element_registry
21-
from pyspark.sql.pipelines.type_error_utils import validate_optional_list_of_str_arg
22-
from pyspark.sql.pipelines.flow import Flow, QueryFunction
23-
from pyspark.sql.pipelines.source_code_location import (
20+
from pyspark.pipelines.graph_element_registry import get_active_graph_element_registry
21+
from pyspark.pipelines.type_error_utils import validate_optional_list_of_str_arg
22+
from pyspark.pipelines.flow import Flow, QueryFunction
23+
from pyspark.pipelines.source_code_location import (
2424
get_caller_source_code_location,
2525
)
26-
from pyspark.sql.pipelines.dataset import (
26+
from pyspark.pipelines.dataset import (
2727
MaterializedView,
2828
StreamingTable,
2929
TemporaryView,

python/pyspark/sql/pipelines/cli.py renamed to python/pyspark/pipelines/cli.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@
3232

3333
from pyspark.errors import PySparkException, PySparkTypeError
3434
from pyspark.sql import SparkSession
35-
from pyspark.sql.pipelines.graph_element_registry import (
35+
from pyspark.pipelines.graph_element_registry import (
3636
graph_element_registration_context,
3737
GraphElementRegistry,
3838
)
39-
from pyspark.sql.pipelines.init_cli import init
40-
from pyspark.sql.pipelines.spark_connect_graph_element_registry import (
39+
from pyspark.pipelines.init_cli import init
40+
from pyspark.pipelines.spark_connect_graph_element_registry import (
4141
SparkConnectGraphElementRegistry,
4242
)
43-
from pyspark.sql.pipelines.spark_connect_pipeline import (
43+
from pyspark.pipelines.spark_connect_pipeline import (
4444
create_dataflow_graph,
4545
start_run,
4646
handle_pipeline_events,

python/pyspark/sql/pipelines/dataset.py renamed to python/pyspark/pipelines/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from dataclasses import dataclass
1818
from typing import Mapping, Optional, Sequence, Union
1919

20-
from pyspark.sql.pipelines.source_code_location import SourceCodeLocation
20+
from pyspark.pipelines.source_code_location import SourceCodeLocation
2121
from pyspark.sql.types import StructType
2222

2323

python/pyspark/sql/pipelines/flow.py renamed to python/pyspark/pipelines/flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Callable, Dict
1919

2020
from pyspark.sql import DataFrame
21-
from pyspark.sql.pipelines.source_code_location import SourceCodeLocation
21+
from pyspark.pipelines.source_code_location import SourceCodeLocation
2222

2323
QueryFunction = Callable[[], DataFrame]
2424

python/pyspark/sql/pipelines/graph_element_registry.py renamed to python/pyspark/pipelines/graph_element_registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from abc import ABC, abstractmethod
1919
from pathlib import Path
2020

21-
from pyspark.sql.pipelines.dataset import Dataset
22-
from pyspark.sql.pipelines.flow import Flow
21+
from pyspark.pipelines.dataset import Dataset
22+
from pyspark.pipelines.flow import Flow
2323
from contextlib import contextmanager
2424
from contextvars import ContextVar
2525
from typing import Generator, Optional

python/pyspark/sql/pipelines/init_cli.py renamed to python/pyspark/pipelines/init_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
"""
2727

2828
PYTHON_EXAMPLE = """
29+
from pyspark import pipelines as sdp
2930
from pyspark.sql import DataFrame, SparkSession
30-
from pyspark.sql import pipelines as sdp
3131
3232
spark = SparkSession.active()
3333

python/pyspark/sql/pipelines/spark_connect_graph_element_registry.py renamed to python/pyspark/pipelines/spark_connect_graph_element_registry.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,27 @@
1919
from pyspark.errors import PySparkTypeError
2020
from pyspark.sql import SparkSession
2121
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
22-
from pyspark.sql.pipelines.block_connect_access import block_spark_connect_execution_and_analysis
23-
from pyspark.sql.pipelines.dataset import (
22+
from pyspark.pipelines.block_connect_access import block_spark_connect_execution_and_analysis
23+
from pyspark.pipelines.dataset import (
2424
Dataset,
2525
MaterializedView,
2626
Table,
2727
StreamingTable,
2828
TemporaryView,
2929
)
30-
from pyspark.sql.pipelines.flow import Flow
31-
from pyspark.sql.pipelines.graph_element_registry import GraphElementRegistry
32-
from typing import cast
30+
from pyspark.pipelines.flow import Flow
31+
from pyspark.pipelines.graph_element_registry import GraphElementRegistry
32+
from typing import Any, cast
3333
import pyspark.sql.connect.proto as pb2
3434

3535

3636
class SparkConnectGraphElementRegistry(GraphElementRegistry):
3737
"""Registers datasets and flows in a dataflow graph held in a Spark Connect server."""
3838

3939
def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None:
40-
self._spark = spark
40+
# Cast because mypy seems to think `spark`` is a function, not an object. Likely related to
41+
# SPARK-47544.
42+
self._client = cast(Any, spark).client
4143
self._dataflow_graph_id = dataflow_graph_id
4244

4345
def register_dataset(self, dataset: Dataset) -> None:
@@ -80,12 +82,12 @@ def register_dataset(self, dataset: Dataset) -> None:
8082
)
8183
command = pb2.Command()
8284
command.pipeline_command.define_dataset.CopyFrom(inner_command)
83-
self._spark.client.execute_command(command)
85+
self._client.execute_command(command)
8486

8587
def register_flow(self, flow: Flow) -> None:
8688
with block_spark_connect_execution_and_analysis():
8789
df = flow.func()
88-
relation = cast(ConnectDataFrame, df)._plan.plan(self._spark.client)
90+
relation = cast(ConnectDataFrame, df)._plan.plan(self._client)
8991

9092
inner_command = pb2.PipelineCommand.DefineFlow(
9193
dataflow_graph_id=self._dataflow_graph_id,
@@ -97,7 +99,7 @@ def register_flow(self, flow: Flow) -> None:
9799
)
98100
command = pb2.Command()
99101
command.pipeline_command.define_flow.CopyFrom(inner_command)
100-
self._spark.client.execute_command(command)
102+
self._client.execute_command(command)
101103

102104
def register_sql(self, sql_text: str, file_path: Path) -> None:
103105
inner_command = pb2.DefineSqlGraphElements(
@@ -107,4 +109,4 @@ def register_sql(self, sql_text: str, file_path: Path) -> None:
107109
)
108110
command = pb2.Command()
109111
command.pipeline_command.define_sql_graph_elements.CopyFrom(inner_command)
110-
self._spark.client.execute_command(command)
112+
self._client.execute_command(command)

python/pyspark/sql/pipelines/spark_connect_pipeline.py renamed to python/pyspark/pipelines/spark_connect_pipeline.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
from typing import Any, Dict, Mapping, Iterator, Optional
18-
19-
from pyspark.sql import SparkSession
17+
from typing import Any, Dict, Mapping, Iterator, Optional, cast
2018

2119
import pyspark.sql.connect.proto as pb2
20+
from pyspark.sql import SparkSession
2221
from pyspark.errors.exceptions.base import PySparkValueError
2322

2423

@@ -39,7 +38,9 @@ def create_dataflow_graph(
3938
)
4039
command = pb2.Command()
4140
command.pipeline_command.create_dataflow_graph.CopyFrom(inner_command)
42-
(_, properties, _) = spark.client.execute_command(command)
41+
# Cast because mypy seems to think `spark`` is a function, not an object. Likely related to
42+
# SPARK-47544.
43+
(_, properties, _) = cast(Any, spark).client.execute_command(command)
4344
return properties["pipeline_command_result"].create_dataflow_graph_result.dataflow_graph_id
4445

4546

@@ -69,4 +70,6 @@ def start_run(spark: SparkSession, dataflow_graph_id: str) -> Iterator[Dict[str,
6970
inner_command = pb2.PipelineCommand.StartRun(dataflow_graph_id=dataflow_graph_id)
7071
command = pb2.Command()
7172
command.pipeline_command.start_run.CopyFrom(inner_command)
72-
return spark.client.execute_command_as_iterator(command)
73+
# Cast because mypy seems to think `spark`` is a function, not an object. Likely related to
74+
# SPARK-47544.
75+
return cast(Any, spark).client.execute_command_as_iterator(command)

python/pyspark/sql/tests/pipelines/local_graph_element_registry.py renamed to python/pyspark/pipelines/tests/local_graph_element_registry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from pathlib import Path
2020
from typing import List, Sequence
2121

22-
from pyspark.sql.pipelines.dataset import Dataset
23-
from pyspark.sql.pipelines.flow import Flow
24-
from pyspark.sql.pipelines.graph_element_registry import GraphElementRegistry
22+
from pyspark.pipelines.dataset import Dataset
23+
from pyspark.pipelines.flow import Flow
24+
from pyspark.pipelines.graph_element_registry import GraphElementRegistry
2525

2626

2727
@dataclass(frozen=True)

python/pyspark/sql/tests/pipelines/test_block_connect_access.py renamed to python/pyspark/pipelines/tests/test_block_connect_access.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from pyspark.errors import PySparkException
2020
from pyspark.testing.connectutils import ReusedConnectTestCase
21-
from pyspark.sql.pipelines.block_connect_access import block_spark_connect_execution_and_analysis
21+
from pyspark.pipelines.block_connect_access import block_spark_connect_execution_and_analysis
2222
from pyspark.testing.connectutils import (
2323
ReusedConnectTestCase,
2424
should_test_connect,

python/pyspark/sql/tests/pipelines/test_cli.py renamed to python/pyspark/pipelines/tests/test_cli.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,16 @@
2121
from pathlib import Path
2222

2323
from pyspark.errors import PySparkException
24-
from pyspark.sql.pipelines.cli import (
24+
from pyspark.pipelines.cli import (
2525
change_dir,
2626
find_pipeline_spec,
27-
init,
2827
load_pipeline_spec,
2928
register_definitions,
3029
unpack_pipeline_spec,
3130
DefinitionsGlob,
3231
PipelineSpec,
3332
)
34-
from pyspark.sql.tests.pipelines.local_graph_element_registry import LocalGraphElementRegistry
33+
from pyspark.pipelines.tests.local_graph_element_registry import LocalGraphElementRegistry
3534

3635

3736
class CLIUtilityTests(unittest.TestCase):
@@ -210,7 +209,7 @@ def test_register_definitions(self):
210209
f.write(
211210
textwrap.dedent(
212211
"""
213-
from pyspark.sql import pipelines as sdp
212+
from pyspark import pipelines as sdp
214213
@sdp.materialized_view
215214
def mv1():
216215
raise NotImplementedError()
@@ -222,7 +221,7 @@ def mv1():
222221
f.write(
223222
textwrap.dedent(
224223
"""
225-
from pyspark.sql import pipelines as sdp
224+
from pyspark import pipelines as sdp
226225
def mv2():
227226
raise NotImplementedError()
228227
"""
@@ -314,27 +313,6 @@ def test_python_import_current_directory(self):
314313
),
315314
)
316315

317-
def test_init(self):
318-
with tempfile.TemporaryDirectory() as temp_dir:
319-
project_name = "test_project"
320-
with change_dir(Path(temp_dir)):
321-
init(project_name)
322-
with change_dir(Path(temp_dir) / project_name):
323-
spec_path = find_pipeline_spec(Path.cwd())
324-
spec = load_pipeline_spec(spec_path)
325-
registry = LocalGraphElementRegistry()
326-
register_definitions(spec_path, registry, spec)
327-
self.assertEqual(len(registry.datasets), 1)
328-
self.assertEqual(registry.datasets[0].name, "example_python_materialized_view")
329-
self.assertEqual(len(registry.flows), 1)
330-
self.assertEqual(registry.flows[0].name, "example_python_materialized_view")
331-
self.assertEqual(registry.flows[0].target, "example_python_materialized_view")
332-
self.assertEqual(len(registry.sql_files), 1)
333-
self.assertEqual(
334-
registry.sql_files[0].file_path,
335-
Path("transformations") / "example_sql_materialized_view.sql",
336-
)
337-
338316

339317
if __name__ == "__main__":
340318
try:

python/pyspark/sql/tests/pipelines/test_decorators.py renamed to python/pyspark/pipelines/tests/test_decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import unittest
1919

2020
from pyspark.errors import PySparkTypeError
21-
from pyspark.sql import pipelines as sdp
21+
from pyspark import pipelines as sdp
2222

2323

2424
class DecoratorsTest(unittest.TestCase):

python/pyspark/sql/tests/pipelines/test_graph_element_registry.py renamed to python/pyspark/pipelines/tests/test_graph_element_registry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
import unittest
1919

2020
from pyspark.errors import PySparkException
21-
from pyspark.sql.pipelines.graph_element_registry import graph_element_registration_context
22-
from pyspark.sql import pipelines as sdp
23-
from pyspark.sql.tests.pipelines.local_graph_element_registry import LocalGraphElementRegistry
21+
from pyspark.pipelines.graph_element_registry import graph_element_registration_context
22+
from pyspark import pipelines as sdp
23+
from pyspark.pipelines.tests.local_graph_element_registry import LocalGraphElementRegistry
2424

2525

2626
class GraphElementRegistryTest(unittest.TestCase):

0 commit comments

Comments
 (0)