Skip to content

Commit 9675264

Browse files
committed
add extension
1 parent d9f7ce1 commit 9675264

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

python/pyspark/pipelines/block_connect_access.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,29 @@ def blocked_method(*args: object, **kwargs: object) -> NoReturn:
5353
extension_id = None
5454
try:
5555
# Apply our custom __getattribute__ method
56+
extension_id = add_user_context_extension(spark, definition_path="pipeline.path")
5657
setattr(SparkConnectServiceStub, "__getattribute__", blocked_getattr)
5758
yield
5859
finally:
5960
# Restore the original __getattribute__ method
61+
if extension_id is not None:
62+
spark.removeUserContextExtension(extension_id)
6063
setattr(SparkConnectServiceStub, "__getattribute__", original_getattr)
6164

65+
def add_user_context_extension(spark: SparkSession, **context_fields: Any) -> str:
66+
"""
67+
Adds a user context extension to Spark's thread-local user context and
68+
returns the extension ID.
69+
"""
70+
if not context_fields:
71+
raise ValueError("At least one field must be provided for PipelineAnalysisContext.")
72+
73+
from pyspark.sql.connect.proto import pipelines_pb2
74+
75+
analysis_context = pipelines_pb2.PipelineAnalysisContext(**context_fields)
76+
77+
from google.protobuf import any_pb2
78+
79+
extension = any_pb2.Any()
80+
81+
return spark.addThreadlocalUserContextExtension(extension)

0 commit comments

Comments
 (0)