Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,28 @@
},
"sqlState" : "56K00"
},
"CONNECT_INVALID_PLAN" : {
"message" : [
"The Spark Connect plan is invalid."
],
"subClass" : {
"CANNOT_PARSE" : {
"message" : [
"Cannot decompress or parse the input plan (<errorMsg>)",
"This may be caused by a corrupted compressed plan.",
"To disable plan compression, set 'spark.connect.session.planCompression.threshold' to -1."
]
},
"PLAN_SIZE_LARGER_THAN_MAX" : {
"message" : [
"The plan size is larger than max (<planSize> vs. <maxPlanSize>)",
"This typically occurs when building very complex queries with many operations, large literals, or deeply nested expressions.",
"Consider splitting the query into smaller parts using temporary views for intermediate results or reducing the number of operations."
]
}
},
"sqlState" : "56K00"
},
"CONNECT_ML" : {
"message" : [
"Generic Spark Connect ML error."
Expand Down
1 change: 1 addition & 0 deletions dev/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ grpcio>=1.76.0
grpcio-status>=1.76.0
googleapis-common-protos>=1.71.0
protobuf==6.33.0
zstandard>=0.25.0

# Spark Connect python proto generation plugin (optional)
mypy-protobuf==3.3.0
Expand Down
1 change: 1 addition & 0 deletions dev/spark-test-image/lint/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ RUN python3.11 -m pip install \
'grpc-stubs==1.24.11' \
'grpcio-status==1.76.0' \
'grpcio==1.76.0' \
'zstandard==0.25.0' \
'ipython' \
'ipython_genutils' \
'jinja2' \
Expand Down
2 changes: 1 addition & 1 deletion dev/spark-test-image/numpy-213/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ RUN apt-get update && apt-get install -y \
# Pin numpy==2.1.3
ARG BASIC_PIP_PKGS="numpy==2.1.3 pyarrow>=22.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 graphviz==0.20.3"
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"

# Install Python 3.11 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
Expand Down
2 changes: 1 addition & 1 deletion dev/spark-test-image/python-310/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ RUN apt-get update && apt-get install -y \

ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 graphviz==0.20.3"
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"

# Install Python 3.10 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
Expand Down
2 changes: 1 addition & 1 deletion dev/spark-test-image/python-311/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ RUN apt-get update && apt-get install -y \

ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 graphviz==0.20.3"
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"

# Install Python 3.11 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
Expand Down
2 changes: 1 addition & 1 deletion dev/spark-test-image/python-312/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ RUN apt-get update && apt-get install -y \

ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 graphviz==0.20.3"
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"

# Install Python 3.12 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12
Expand Down
2 changes: 1 addition & 1 deletion dev/spark-test-image/python-313-nogil/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ RUN apt-get update && apt-get install -y \


ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 graphviz==0.20.3"
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"


# Install Python 3.13 packages
Expand Down
2 changes: 1 addition & 1 deletion dev/spark-test-image/python-313/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ RUN apt-get update && apt-get install -y \

ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 graphviz==0.20.3"
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"

# Install Python 3.13 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13
Expand Down
2 changes: 1 addition & 1 deletion dev/spark-test-image/python-314/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ RUN apt-get update && apt-get install -y \

ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 graphviz==0.20.3"
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"

# Install Python 3.14 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.14
Expand Down
2 changes: 1 addition & 1 deletion dev/spark-test-image/python-minimum/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ RUN apt-get update && apt-get install -y \

ARG BASIC_PIP_PKGS="numpy==1.22.4 pyarrow==15.0.0 pandas==2.2.0 six==1.16.0 scipy scikit-learn coverage unittest-xml-reporting"
# Python deps for Spark Connect
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 googleapis-common-protos==1.71.0 graphviz==0.20 protobuf"
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20 protobuf"

# Install Python 3.10 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
Expand Down
2 changes: 1 addition & 1 deletion dev/spark-test-image/python-ps-minimum/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ RUN apt-get update && apt-get install -y \

ARG BASIC_PIP_PKGS="pyarrow==15.0.0 pandas==2.2.0 six==1.16.0 numpy scipy coverage unittest-xml-reporting"
# Python deps for Spark Connect
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 googleapis-common-protos==1.71.0 graphviz==0.20 protobuf"
ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20 protobuf"

# Install Python 3.10 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
Expand Down
3 changes: 3 additions & 0 deletions python/packaging/classic/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def _supports_symlinks():
_minimum_grpc_version = "1.76.0"
_minimum_googleapis_common_protos_version = "1.71.0"
_minimum_pyyaml_version = "3.11"
_minimum_zstandard_version = "0.25.0"


class InstallCommand(install):
Expand Down Expand Up @@ -366,6 +367,7 @@ def run(self):
"grpcio>=%s" % _minimum_grpc_version,
"grpcio-status>=%s" % _minimum_grpc_version,
"googleapis-common-protos>=%s" % _minimum_googleapis_common_protos_version,
"zstandard>=%s" % _minimum_zstandard_version,
"numpy>=%s" % _minimum_numpy_version,
],
"pipelines": [
Expand All @@ -375,6 +377,7 @@ def run(self):
"grpcio>=%s" % _minimum_grpc_version,
"grpcio-status>=%s" % _minimum_grpc_version,
"googleapis-common-protos>=%s" % _minimum_googleapis_common_protos_version,
"zstandard>=%s" % _minimum_zstandard_version,
"pyyaml>=%s" % _minimum_pyyaml_version,
],
},
Expand Down
2 changes: 2 additions & 0 deletions python/packaging/client/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
_minimum_grpc_version = "1.76.0"
_minimum_googleapis_common_protos_version = "1.71.0"
_minimum_pyyaml_version = "3.11"
_minimum_zstandard_version = "0.25.0"

with open("README.md") as f:
long_description = f.read()
Expand Down Expand Up @@ -211,6 +212,7 @@
"grpcio>=%s" % _minimum_grpc_version,
"grpcio-status>=%s" % _minimum_grpc_version,
"googleapis-common-protos>=%s" % _minimum_googleapis_common_protos_version,
"zstandard>=%s" % _minimum_zstandard_version,
"numpy>=%s" % _minimum_numpy_version,
"pyyaml>=%s" % _minimum_pyyaml_version,
],
Expand Down
2 changes: 2 additions & 0 deletions python/packaging/connect/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
_minimum_grpc_version = "1.76.0"
_minimum_googleapis_common_protos_version = "1.71.0"
_minimum_pyyaml_version = "3.11"
_minimum_zstandard_version = "0.25.0"

with open("README.md") as f:
long_description = f.read()
Expand Down Expand Up @@ -121,6 +122,7 @@
"grpcio>=%s" % _minimum_grpc_version,
"grpcio-status>=%s" % _minimum_grpc_version,
"googleapis-common-protos>=%s" % _minimum_googleapis_common_protos_version,
"zstandard>=%s" % _minimum_zstandard_version,
"numpy>=%s" % _minimum_numpy_version,
"pyyaml>=%s" % _minimum_pyyaml_version,
],
Expand Down
131 changes: 129 additions & 2 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import urllib.parse
import uuid
import sys
import time
from typing import (
Iterable,
Iterator,
Expand Down Expand Up @@ -113,6 +114,19 @@
from pyspark.sql.datasource import DataSource


def _import_zstandard_if_available() -> Optional[Any]:
"""
Import zstandard if available, otherwise return None.
This is used to handle the case when zstandard is not installed.
"""
try:
import zstandard

return zstandard
except ImportError:
return None


class ChannelBuilder:
"""
This is a helper class that is used to create a GRPC channel based on the given
Expand Down Expand Up @@ -706,6 +720,10 @@ def __init__(

self._progress_handlers: List[ProgressHandler] = []

self._zstd_module = _import_zstandard_if_available()
self._plan_compression_threshold: Optional[int] = None # Will be fetched lazily
self._plan_compression_algorithm: Optional[str] = None # Will be fetched lazily

# cleanup ml cache if possible
atexit.register(self._cleanup_ml_cache)

Expand Down Expand Up @@ -1156,7 +1174,7 @@ def execute_command(
req = self._execute_plan_request_with_metadata()
if self._user_id:
req.user_context.user_id = self._user_id
req.plan.command.CopyFrom(command)
self._set_command_in_plan(req.plan, command)
data, _, metrics, observed_metrics, properties = self._execute_and_fetch(
req, observations or {}
)
Expand All @@ -1182,7 +1200,7 @@ def execute_command_as_iterator(
req = self._execute_plan_request_with_metadata()
if self._user_id:
req.user_context.user_id = self._user_id
req.plan.command.CopyFrom(command)
self._set_command_in_plan(req.plan, command)
for response in self._execute_and_fetch_as_iterator(req, observations or {}):
if isinstance(response, dict):
yield response
Expand Down Expand Up @@ -1963,6 +1981,17 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
if info.metadata.get("errorClass") == "INVALID_HANDLE.SESSION_CHANGED":
self._closed = True

if info.metadata.get("errorClass") == "CONNECT_INVALID_PLAN.CANNOT_PARSE":
# Disable plan compression if the server fails to interpret the plan.
logger.info(
"Disabling plan compression for the session due to "
"CONNECT_INVALID_PLAN.CANNOT_PARSE error."
)
self._plan_compression_threshold, self._plan_compression_algorithm = (
-1,
"NONE",
)

raise convert_exception(
info,
status.message,
Expand Down Expand Up @@ -2112,6 +2141,104 @@ def _query_model_size(self, model_ref_id: str) -> int:
ml_command_result = properties["ml_command_result"]
return ml_command_result.param.long

def _set_relation_in_plan(self, plan: pb2.Plan, relation: pb2.Relation) -> None:
"""Sets the relation in the plan, attempting compression if configured."""
self._try_compress_and_set_plan(
plan=plan,
message=relation,
op_type=pb2.Plan.CompressedOperation.OpType.OP_TYPE_RELATION,
)

def _set_command_in_plan(self, plan: pb2.Plan, command: pb2.Command) -> None:
"""Sets the command in the plan, attempting compression if configured."""
self._try_compress_and_set_plan(
plan=plan,
message=command,
op_type=pb2.Plan.CompressedOperation.OpType.OP_TYPE_COMMAND,
)

def _try_compress_and_set_plan(
self,
plan: pb2.Plan,
message: google.protobuf.message.Message,
op_type: pb2.Plan.CompressedOperation.OpType.ValueType,
) -> None:
"""
Tries to compress a protobuf message and sets it on the plan.
If compression is not enabled, not effective, or not available,
it falls back to the original message.
"""
(
plan_compression_threshold,
plan_compression_algorithm,
) = self._get_plan_compression_threshold_and_algorithm()
plan_compression_enabled = (
plan_compression_threshold is not None
and plan_compression_threshold >= 0
and plan_compression_algorithm is not None
and plan_compression_algorithm != "NONE"
)
if plan_compression_enabled:
serialized_msg = message.SerializeToString()
original_size = len(serialized_msg)
if (
original_size > plan_compression_threshold
and plan_compression_algorithm == "ZSTD"
and self._zstd_module
):
start_time = time.time()
compressed_operation = pb2.Plan.CompressedOperation(
data=self._zstd_module.compress(serialized_msg),
op_type=op_type,
compression_codec=pb2.CompressionCodec.COMPRESSION_CODEC_ZSTD,
)
duration = time.time() - start_time
compressed_size = len(compressed_operation.data)
logger.debug(
f"Plan compression: original_size={original_size}, "
f"compressed_size={compressed_size}, "
f"saving_ratio={1 - compressed_size / original_size:.2f}, "
f"duration_s={duration:.1f}"
)
if compressed_size < original_size:
plan.compressed_operation.CopyFrom(compressed_operation)
return
else:
logger.debug("Plan compression not effective. Using original plan.")

if op_type == pb2.Plan.CompressedOperation.OpType.OP_TYPE_RELATION:
plan.root.CopyFrom(message) # type: ignore[arg-type]
else:
plan.command.CopyFrom(message) # type: ignore[arg-type]

def _get_plan_compression_threshold_and_algorithm(self) -> Tuple[int, str]:
if self._plan_compression_threshold is None or self._plan_compression_algorithm is None:
try:
(
plan_compression_threshold_str,
self._plan_compression_algorithm,
) = self.get_configs(
"spark.connect.session.planCompression.threshold",
"spark.connect.session.planCompression.defaultAlgorithm",
)
self._plan_compression_threshold = (
int(plan_compression_threshold_str) if plan_compression_threshold_str else -1
)
logger.debug(
f"Plan compression threshold: {self._plan_compression_threshold}, "
f"algorithm: {self._plan_compression_algorithm}"
)
except Exception as e:
self._plan_compression_threshold = -1
self._plan_compression_algorithm = "NONE"
logger.debug(
"Plan compression is disabled because the server does not support it.", e
)
return (
self._plan_compression_threshold,
self._plan_compression_algorithm,
) # type: ignore[return-value]

def clone(self, new_session_id: Optional[str] = None) -> "SparkConnectClient":
"""
Clone this client session on the server side. The server-side session is cloned with
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def to_proto(self, session: "SparkConnectClient", debug: bool = False) -> proto.
if enabled, the proto plan will be printed.
"""
plan = proto.Plan()
plan.root.CopyFrom(self.plan(session))
relation = self.plan(session)
session._set_relation_in_plan(plan, relation)

if debug:
print(plan)
Expand Down
384 changes: 195 additions & 189 deletions python/pyspark/sql/connect/proto/base_pb2.py

Large diffs are not rendered by default.

Loading