Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/docs/source/getting_started/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ Package Supported version Note
`grpcio` >=1.76.0 Required for Spark Connect
`grpcio-status` >=1.76.0 Required for Spark Connect
`googleapis-common-protos` >=1.71.0 Required for Spark Connect
`zstandard` >=0.25.0 Required for Spark Connect
`graphviz` >=0.20 Optional for Spark Connect
========================== ================= ==========================

Expand Down Expand Up @@ -313,6 +314,7 @@ Package Supported version Note
`grpcio` >=1.76.0 Required for Spark Connect
`grpcio-status` >=1.76.0 Required for Spark Connect
`googleapis-common-protos` >=1.71.0 Required for Spark Connect
`zstandard` >=0.25.0 Required for Spark Connect
`pyyaml` >=3.11 Required for spark-pipelines command line interface
`graphviz` >=0.20 Optional for Spark Connect
========================== ================= ===================================================
5 changes: 5 additions & 0 deletions sql/connect/client/jvm/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@
<artifactId>failureaccess</artifactId>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.github.luben</groupId>
<artifactId>zstd-jni</artifactId>
<scope>compile</scope>
</dependency>
<!--
When upgrading ammonite, consider upgrading semanticdb-shared too.
-->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.{PlanCompressionOptions, RetryPolicy, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.test.{ConnectFunSuite, IntegrationTestUtils, QueryTest, RemoteSparkSession, SQLHelper}
import org.apache.spark.sql.connect.test.SparkConnectServerUtils.{createSparkSession, port}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -2005,6 +2005,22 @@ class ClientE2ETestSuite
}
}
}

test("Plan compression works correctly") {
val originalPlanCompressionOptions = spark.client.getPlanCompressionOptions
assert(originalPlanCompressionOptions.nonEmpty)
assert(originalPlanCompressionOptions.get.thresholdBytes > 0)
assert(originalPlanCompressionOptions.get.algorithm == "ZSTD")
try {
spark.client.setPlanCompressionOptions(Some(PlanCompressionOptions(1000, "ZSTD")))
// Execution should work
assert(spark.sql(s"select '${"Apache Spark" * 10000}' as value").collect().length == 1)
// Analysis should work
assert(spark.sql(s"select '${"Apache Spark" * 10000}' as value").columns.length == 1)
} finally {
spark.client.setPlanCompressionOptions(originalPlanCompressionOptions)
}
}
}

private[sql] case class ClassData(a: String, b: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
.builder()
.connectionString(s"sc://localhost:${server.getPort}")
.build()
// Disable plan compression to make sure there is only one RPC request in client.analyze,
// so the interceptor can capture the initial header.
client.setPlanCompressionOptions(None)

val session = SparkSession.builder().client(client).create()
val df = session.range(10)
Expand Down Expand Up @@ -521,6 +524,9 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
.connectionString(s"sc://localhost:${server.getPort}")
.enableReattachableExecute()
.build()
// Disable plan compression to make sure there is only one RPC request in client.analyze,
// so the interceptor can capture the initial header.
client.setPlanCompressionOptions(None)

val plan = buildPlan("select * from range(10000000)")
val dummyUUID = "10a4c38e-7e87-40ee-9d6f-60ff0751e63b"
Expand All @@ -533,16 +539,105 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
assert(resp.getOperationId == dummyUUID)
}
}

test("Plan compression works correctly for execution") {
startDummyServer(0)
client = SparkConnectClient
.builder()
.connectionString(s"sc://localhost:${server.getPort}")
.enableReattachableExecute()
.build()
// Set plan compression options for testing
client.setPlanCompressionOptions(Some(PlanCompressionOptions(1000, "ZSTD")))

// Small plan should not be compressed
val plan = buildPlan("select * from range(10)")
val iter = client.execute(plan)
val reattachableIter =
ExecutePlanResponseReattachableIterator.fromIterator(iter)
while (reattachableIter.hasNext) {
reattachableIter.next()
}
assert(service.getAndClearLatestInputPlan().hasRoot)

// Large plan should be compressed
val plan2 = buildPlan(s"select ${"Apache Spark" * 10000} as value")
val iter2 = client.execute(plan2)
val reattachableIter2 =
ExecutePlanResponseReattachableIterator.fromIterator(iter2)
while (reattachableIter2.hasNext) {
reattachableIter2.next()
}
assert(service.getAndClearLatestInputPlan().hasCompressedOperation)
}

test("Plan compression works correctly for analysis") {
startDummyServer(0)
client = SparkConnectClient
.builder()
.connectionString(s"sc://localhost:${server.getPort}")
.enableReattachableExecute()
.build()
// Set plan compression options for testing
client.setPlanCompressionOptions(Some(PlanCompressionOptions(1000, "ZSTD")))

// Small plan should not be compressed
val plan = buildPlan("select * from range(10)")
client.analyze(proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA, Some(plan))
assert(service.getAndClearLatestInputPlan().hasRoot)

// Large plan should be compressed
val plan2 = buildPlan(s"select ${"Apache Spark" * 10000} as value")
client.analyze(proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA, Some(plan2))
assert(service.getAndClearLatestInputPlan().hasCompressedOperation)
}

test("Plan compression will be disabled if the configs are not defined on the server") {
startDummyServer(0)
client = SparkConnectClient
.builder()
.connectionString(s"sc://localhost:${server.getPort}")
.enableReattachableExecute()
.build()

service.setErrorToThrowOnConfig(
"spark.connect.session.planCompression.defaultAlgorithm",
new StatusRuntimeException(Status.INTERNAL.withDescription("SQL_CONF_NOT_FOUND")))

// Execute a few queries to make sure the client fetches the configs only once.
(1 to 3).foreach { _ =>
val plan = buildPlan(s"select ${"Apache Spark" * 10000} as value")
val iter = client.execute(plan)
val reattachableIter =
ExecutePlanResponseReattachableIterator.fromIterator(iter)
while (reattachableIter.hasNext) {
reattachableIter.next()
}
assert(service.getAndClearLatestInputPlan().hasRoot)
}
// The plan compression options should be empty.
assert(client.getPlanCompressionOptions.isEmpty)
// The client should try to fetch the config only once.
assert(service.getAndClearLatestConfigRequests().size == 1)
}
}

class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase {

private var inputPlan: proto.Plan = _
private val inputArtifactRequests: mutable.ListBuffer[AddArtifactsRequest] =
mutable.ListBuffer.empty
private val inputConfigRequests = mutable.ListBuffer.empty[proto.ConfigRequest]
private val sparkConfigs = mutable.Map.empty[String, String]

var errorToThrowOnExecute: Option[Throwable] = None

private var errorToThrowOnConfig: Map[String, Throwable] = Map.empty

private[sql] def setErrorToThrowOnConfig(key: String, error: Throwable): Unit = synchronized {
errorToThrowOnConfig = errorToThrowOnConfig + (key -> error)
}

private[sql] def getAndClearLatestInputPlan(): proto.Plan = synchronized {
val plan = inputPlan
inputPlan = null
Expand All @@ -556,6 +651,13 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer
requests
}

private[sql] def getAndClearLatestConfigRequests(): Seq[proto.ConfigRequest] =
synchronized {
val requests = inputConfigRequests.clone().toSeq
inputConfigRequests.clear()
requests
}

override def executePlan(
request: ExecutePlanRequest,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
Expand Down Expand Up @@ -666,6 +768,38 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer
responseObserver.onCompleted()
}

override def config(
request: proto.ConfigRequest,
responseObserver: StreamObserver[proto.ConfigResponse]): Unit = {
inputConfigRequests.synchronized {
inputConfigRequests.append(request)
}
require(
request.getOperation.hasGetOption,
"Only GetOption is supported. Other operations " +
"can be implemented by following the same procedure below.")

val responseBuilder = proto.ConfigResponse.newBuilder().setSessionId(request.getSessionId)
request.getOperation.getGetOption.getKeysList.asScala.iterator.foreach { key =>
if (errorToThrowOnConfig.contains(key)) {
val error = errorToThrowOnConfig(key)
responseObserver.onError(error)
return
}

val kvBuilder = proto.KeyValue.newBuilder()
synchronized {
sparkConfigs.get(key).foreach { value =>
kvBuilder.setKey(key)
kvBuilder.setValue(value)
}
}
responseBuilder.addPairs(kvBuilder.build())
}
responseObserver.onNext(responseBuilder.build())
responseObserver.onCompleted()
}

override def interrupt(
request: proto.InterruptRequest,
responseObserver: StreamObserver[proto.InterruptResponse]): Unit = {
Expand Down
4 changes: 4 additions & 0 deletions sql/connect/common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@
<artifactId>netty-transport-native-unix-common</artifactId>
<version>${netty.version}</version>
</dependency>
<dependency>
<groupId>com.github.luben</groupId>
<artifactId>zstd-jni</artifactId>
</dependency>
<!--
This spark-tags test-dep is needed even though it isn't used in this module,
otherwise testing-cmds that excludethem will yield errors.
Expand Down
Loading