Skip to content

Commit 3c23794

Browse files
xi-dbhvanhovell
authored andcommitted
[SPARK-54194][CONNECT][FOLLOWUP] Spark Connect Proto Plan Compression - Scala Client
### What changes were proposed in this pull request? In the previous PR #52894 of Spark Connect Proto Plan Compression, both Server-side and PySpark client changes were implemented. In this PR, the corresponding Scala client changes are implemented, so plan compression are now supported on the Scala client as well. To reproduce the existing issue we are solving here, run this code on Spark Connect Scala client: ``` import scala.util.Random import org.apache.spark.sql.DataFrame import spark.implicits._ def randomLetters(n: Int): String = { Iterator.continually(Random.nextPrintableChar()) .filter(_.isLetter) .take(n) .mkString } val numUniqueSmallRelations = 5 val sizePerSmallRelation = 512 * 1024 val smallDfs: Seq[DataFrame] = (0 until numUniqueSmallRelations).map { _ => Seq(randomLetters(sizePerSmallRelation)).toDF("value") } var resultDf = smallDfs.head for (_ <- 0 until 500) { val idx = Random.nextInt(smallDfs.length) resultDf = resultDf.unionByName(smallDfs(idx)) } resultDf.collect() ``` It fails with RESOURCE_EXHAUSTED error with message `gRPC message exceeds maximum size 134217728: 269207219`, because the server is trying to send an ExecutePlanResponse of ~260MB to the client. With the improvement introduced by the PR, the above code runs successfully and prints the expected result. ### Why are the changes needed? It improves Spark Connect stability when handling large plans. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #53003 from xi-db/plan-compression-scala-client. Authored-by: Xi Lyu <[email protected]> Signed-off-by: Herman van Hovell <[email protected]> (cherry picked from commit 6cb88c1) Signed-off-by: Herman van Hovell <[email protected]>
1 parent 0ca1bad commit 3c23794

File tree

8 files changed

+436
-94
lines changed

8 files changed

+436
-94
lines changed

python/docs/source/getting_started/install.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ Package Supported version Note
230230
`grpcio` >=1.76.0 Required for Spark Connect
231231
`grpcio-status` >=1.76.0 Required for Spark Connect
232232
`googleapis-common-protos` >=1.71.0 Required for Spark Connect
233+
`zstandard` >=0.25.0 Required for Spark Connect
233234
`graphviz` >=0.20 Optional for Spark Connect
234235
========================== ================= ==========================
235236

@@ -313,6 +314,7 @@ Package Supported version Note
313314
`grpcio` >=1.76.0 Required for Spark Connect
314315
`grpcio-status` >=1.76.0 Required for Spark Connect
315316
`googleapis-common-protos` >=1.71.0 Required for Spark Connect
317+
`zstandard` >=0.25.0 Required for Spark Connect
316318
`pyyaml` >=3.11 Required for spark-pipelines command line interface
317319
`graphviz` >=0.20 Optional for Spark Connect
318320
========================== ================= ===================================================

sql/connect/client/jvm/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@
8484
<artifactId>failureaccess</artifactId>
8585
<scope>compile</scope>
8686
</dependency>
87+
<dependency>
88+
<groupId>com.github.luben</groupId>
89+
<artifactId>zstd-jni</artifactId>
90+
<scope>compile</scope>
91+
</dependency>
8792
<!--
8893
When upgrading ammonite, consider upgrading semanticdb-shared too.
8994
-->

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
4141
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
4242
import org.apache.spark.sql.catalyst.parser.ParseException
4343
import org.apache.spark.sql.connect.ConnectConversions._
44-
import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient, SparkResult}
44+
import org.apache.spark.sql.connect.client.{PlanCompressionOptions, RetryPolicy, SparkConnectClient, SparkResult}
4545
import org.apache.spark.sql.connect.test.{ConnectFunSuite, IntegrationTestUtils, QueryTest, RemoteSparkSession, SQLHelper}
4646
import org.apache.spark.sql.connect.test.SparkConnectServerUtils.{createSparkSession, port}
4747
import org.apache.spark.sql.functions._
@@ -2013,6 +2013,22 @@ class ClientE2ETestSuite
20132013
}
20142014
}
20152015
}
2016+
2017+
test("Plan compression works correctly") {
2018+
val originalPlanCompressionOptions = spark.client.getPlanCompressionOptions
2019+
assert(originalPlanCompressionOptions.nonEmpty)
2020+
assert(originalPlanCompressionOptions.get.thresholdBytes > 0)
2021+
assert(originalPlanCompressionOptions.get.algorithm == "ZSTD")
2022+
try {
2023+
spark.client.setPlanCompressionOptions(Some(PlanCompressionOptions(1000, "ZSTD")))
2024+
// Execution should work
2025+
assert(spark.sql(s"select '${"Apache Spark" * 10000}' as value").collect().length == 1)
2026+
// Analysis should work
2027+
assert(spark.sql(s"select '${"Apache Spark" * 10000}' as value").columns.length == 1)
2028+
} finally {
2029+
spark.client.setPlanCompressionOptions(originalPlanCompressionOptions)
2030+
}
2031+
}
20162032
}
20172033

20182034
private[sql] case class ClassData(a: String, b: Int)

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
187187
.builder()
188188
.connectionString(s"sc://localhost:${server.getPort}")
189189
.build()
190+
// Disable plan compression to make sure there is only one RPC request in client.analyze,
191+
// so the interceptor can capture the initial header.
192+
client.setPlanCompressionOptions(None)
190193

191194
val session = SparkSession.builder().client(client).create()
192195
val df = session.range(10)
@@ -521,6 +524,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
521524
.connectionString(s"sc://localhost:${server.getPort}")
522525
.enableReattachableExecute()
523526
.build()
527+
// Disable plan compression to make sure there is only one RPC request in client.analyze,
528+
// so the interceptor can capture the initial header.
529+
client.setPlanCompressionOptions(None)
524530

525531
val plan = buildPlan("select * from range(10000000)")
526532
val dummyUUID = "10a4c38e-7e87-40ee-9d6f-60ff0751e63b"
@@ -533,16 +539,105 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
533539
assert(resp.getOperationId == dummyUUID)
534540
}
535541
}
542+
543+
test("Plan compression works correctly for execution") {
544+
startDummyServer(0)
545+
client = SparkConnectClient
546+
.builder()
547+
.connectionString(s"sc://localhost:${server.getPort}")
548+
.enableReattachableExecute()
549+
.build()
550+
// Set plan compression options for testing
551+
client.setPlanCompressionOptions(Some(PlanCompressionOptions(1000, "ZSTD")))
552+
553+
// Small plan should not be compressed
554+
val plan = buildPlan("select * from range(10)")
555+
val iter = client.execute(plan)
556+
val reattachableIter =
557+
ExecutePlanResponseReattachableIterator.fromIterator(iter)
558+
while (reattachableIter.hasNext) {
559+
reattachableIter.next()
560+
}
561+
assert(service.getAndClearLatestInputPlan().hasRoot)
562+
563+
// Large plan should be compressed
564+
val plan2 = buildPlan(s"select ${"Apache Spark" * 10000} as value")
565+
val iter2 = client.execute(plan2)
566+
val reattachableIter2 =
567+
ExecutePlanResponseReattachableIterator.fromIterator(iter2)
568+
while (reattachableIter2.hasNext) {
569+
reattachableIter2.next()
570+
}
571+
assert(service.getAndClearLatestInputPlan().hasCompressedOperation)
572+
}
573+
574+
test("Plan compression works correctly for analysis") {
575+
startDummyServer(0)
576+
client = SparkConnectClient
577+
.builder()
578+
.connectionString(s"sc://localhost:${server.getPort}")
579+
.enableReattachableExecute()
580+
.build()
581+
// Set plan compression options for testing
582+
client.setPlanCompressionOptions(Some(PlanCompressionOptions(1000, "ZSTD")))
583+
584+
// Small plan should not be compressed
585+
val plan = buildPlan("select * from range(10)")
586+
client.analyze(proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA, Some(plan))
587+
assert(service.getAndClearLatestInputPlan().hasRoot)
588+
589+
// Large plan should be compressed
590+
val plan2 = buildPlan(s"select ${"Apache Spark" * 10000} as value")
591+
client.analyze(proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA, Some(plan2))
592+
assert(service.getAndClearLatestInputPlan().hasCompressedOperation)
593+
}
594+
595+
test("Plan compression will be disabled if the configs are not defined on the server") {
596+
startDummyServer(0)
597+
client = SparkConnectClient
598+
.builder()
599+
.connectionString(s"sc://localhost:${server.getPort}")
600+
.enableReattachableExecute()
601+
.build()
602+
603+
service.setErrorToThrowOnConfig(
604+
"spark.connect.session.planCompression.defaultAlgorithm",
605+
new StatusRuntimeException(Status.INTERNAL.withDescription("SQL_CONF_NOT_FOUND")))
606+
607+
// Execute a few queries to make sure the client fetches the configs only once.
608+
(1 to 3).foreach { _ =>
609+
val plan = buildPlan(s"select ${"Apache Spark" * 10000} as value")
610+
val iter = client.execute(plan)
611+
val reattachableIter =
612+
ExecutePlanResponseReattachableIterator.fromIterator(iter)
613+
while (reattachableIter.hasNext) {
614+
reattachableIter.next()
615+
}
616+
assert(service.getAndClearLatestInputPlan().hasRoot)
617+
}
618+
// The plan compression options should be empty.
619+
assert(client.getPlanCompressionOptions.isEmpty)
620+
// The client should try to fetch the config only once.
621+
assert(service.getAndClearLatestConfigRequests().size == 1)
622+
}
536623
}
537624

538625
class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase {
539626

540627
private var inputPlan: proto.Plan = _
541628
private val inputArtifactRequests: mutable.ListBuffer[AddArtifactsRequest] =
542629
mutable.ListBuffer.empty
630+
private val inputConfigRequests = mutable.ListBuffer.empty[proto.ConfigRequest]
631+
private val sparkConfigs = mutable.Map.empty[String, String]
543632

544633
var errorToThrowOnExecute: Option[Throwable] = None
545634

635+
private var errorToThrowOnConfig: Map[String, Throwable] = Map.empty
636+
637+
private[sql] def setErrorToThrowOnConfig(key: String, error: Throwable): Unit = synchronized {
638+
errorToThrowOnConfig = errorToThrowOnConfig + (key -> error)
639+
}
640+
546641
private[sql] def getAndClearLatestInputPlan(): proto.Plan = synchronized {
547642
val plan = inputPlan
548643
inputPlan = null
@@ -556,6 +651,13 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer
556651
requests
557652
}
558653

654+
private[sql] def getAndClearLatestConfigRequests(): Seq[proto.ConfigRequest] =
655+
synchronized {
656+
val requests = inputConfigRequests.clone().toSeq
657+
inputConfigRequests.clear()
658+
requests
659+
}
660+
559661
override def executePlan(
560662
request: ExecutePlanRequest,
561663
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
@@ -666,6 +768,38 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer
666768
responseObserver.onCompleted()
667769
}
668770

771+
override def config(
772+
request: proto.ConfigRequest,
773+
responseObserver: StreamObserver[proto.ConfigResponse]): Unit = {
774+
inputConfigRequests.synchronized {
775+
inputConfigRequests.append(request)
776+
}
777+
require(
778+
request.getOperation.hasGetOption,
779+
"Only GetOption is supported. Other operations " +
780+
"can be implemented by following the same procedure below.")
781+
782+
val responseBuilder = proto.ConfigResponse.newBuilder().setSessionId(request.getSessionId)
783+
request.getOperation.getGetOption.getKeysList.asScala.iterator.foreach { key =>
784+
if (errorToThrowOnConfig.contains(key)) {
785+
val error = errorToThrowOnConfig(key)
786+
responseObserver.onError(error)
787+
return
788+
}
789+
790+
val kvBuilder = proto.KeyValue.newBuilder()
791+
synchronized {
792+
sparkConfigs.get(key).foreach { value =>
793+
kvBuilder.setKey(key)
794+
kvBuilder.setValue(value)
795+
}
796+
}
797+
responseBuilder.addPairs(kvBuilder.build())
798+
}
799+
responseObserver.onNext(responseBuilder.build())
800+
responseObserver.onCompleted()
801+
}
802+
669803
override def interrupt(
670804
request: proto.InterruptRequest,
671805
responseObserver: StreamObserver[proto.InterruptResponse]): Unit = {

sql/connect/common/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@
8787
<artifactId>netty-transport-native-unix-common</artifactId>
8888
<version>${netty.version}</version>
8989
</dependency>
90+
<dependency>
91+
<groupId>com.github.luben</groupId>
92+
<artifactId>zstd-jni</artifactId>
93+
</dependency>
9094
<!--
9195
This spark-tags test-dep is needed even though it isn't used in this module,
9296
otherwise testing-cmds that excludethem will yield errors.

0 commit comments

Comments
 (0)