Skip to content

Commit a54d2ae

Browse files
sryzadongjoon-hyun
authored andcommitted
[SPARK-54280][SDP] Require pipeline checkpoint storage dir to be absolute path
### What changes were proposed in this pull request? - Raises an error if the pipeline checkpoint storage dir is not an absolute path - Updated the init CLI to create and set a checkpoint storage dir as an absolute path ### Why are the changes needed? Prevent users from accidentally losing checkpoints. ### Does this PR introduce _any_ user-facing change? Yes, but to unreleased functionality. ### How was this patch tested? - New unit tests - Ran the init CLI and then ran pipeline with streaming table ### Was this patch authored or co-authored using generative AI tooling? Closes #52999 from sryza/storage-location-absolute. Authored-by: Sandy Ryza <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]> (cherry picked from commit 4020794) Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 3f506f5 commit a54d2ae

File tree

9 files changed

+147
-6
lines changed

9 files changed

+147
-6
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4997,6 +4997,12 @@
49974997
],
49984998
"sqlState" : "42000"
49994999
},
5000+
"PIPELINE_STORAGE_ROOT_INVALID" : {
5001+
"message" : [
5002+
"Pipeline storage root must be an absolute path with a URI scheme (e.g., file://, s3a://, hdfs://). Got: `<storage_root>`."
5003+
],
5004+
"sqlState" : "42K03"
5005+
},
50005006
"PIPE_OPERATOR_AGGREGATE_EXPRESSION_CONTAINS_NO_AGGREGATE_FUNCTION" : {
50015007
"message" : [
50025008
"Non-grouping expression <expr> is provided as an argument to the |> AGGREGATE pipe operator but does not contain any aggregate function; please update it to include an aggregate function and then retry the query again."

python/pyspark/pipelines/init_cli.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
SPEC = """
2121
name: {{ name }}
22-
storage: storage-root
22+
storage: {{ storage_root }}
2323
libraries:
2424
- glob:
2525
include: transformations/**
@@ -46,10 +46,18 @@ def init(name: str) -> None:
4646
project_dir = Path.cwd() / name
4747
project_dir.mkdir(parents=True, exist_ok=False)
4848

49+
# Create the storage directory
50+
storage_dir = project_dir / "pipeline-storage"
51+
storage_dir.mkdir(parents=True)
52+
53+
# Create absolute file URI for storage path
54+
storage_path = f"file://{storage_dir.resolve()}"
55+
4956
# Write the spec file to the project directory
5057
spec_file = project_dir / "pipeline.yml"
5158
with open(spec_file, "w") as f:
52-
f.write(SPEC.replace("{{ name }}", name))
59+
spec_content = SPEC.replace("{{ name }}", name).replace("{{ storage_root }}", storage_path)
60+
f.write(spec_content)
5361

5462
# Create the transformations directory
5563
transformations_dir = project_dir / "transformations"

python/pyspark/pipelines/tests/test_init_cli.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ def test_init(self):
5151
spec_path = find_pipeline_spec(Path.cwd())
5252
spec = load_pipeline_spec(spec_path)
5353
assert spec.name == project_name
54+
55+
# Verify that the storage path is an absolute URI with file scheme
56+
expected_storage_path = f"file://{Path.cwd() / 'pipeline-storage'}"
57+
self.assertEqual(spec.storage, expected_storage_path)
58+
59+
# Verify that the storage directory was created
60+
self.assertTrue((Path.cwd() / "pipeline-storage").exists())
61+
5462
registry = LocalGraphElementRegistry()
5563
register_definitions(spec_path, registry, spec)
5664
self.assertEqual(len(registry.outputs), 1)

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
437437
val pipelineUpdateContext = new PipelineUpdateContextImpl(
438438
new DataflowGraph(Seq(), Seq(), Seq(), Seq()),
439439
(_: PipelineEvent) => None,
440-
storageRoot = "test_storage_root")
440+
storageRoot = "file:///test_storage_root")
441441
sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext)
442442
assert(
443443
sessionHolder.getPipelineExecution(graphId).nonEmpty,

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
161161
val pipelineUpdateContext = new PipelineUpdateContextImpl(
162162
new DataflowGraph(Seq(), Seq(), Seq(), Seq()),
163163
(_: PipelineEvent) => None,
164-
storageRoot = "test_storage_root")
164+
storageRoot = "file:///test_storage_root")
165165
sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext)
166166
assert(
167167
sessionHolder.getPipelineExecution(graphId).nonEmpty,

sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
package org.apache.spark.sql.pipelines.graph
1919

20+
import org.apache.hadoop.fs.Path
21+
22+
import org.apache.spark.SparkException
2023
import org.apache.spark.sql.classic.SparkSession
2124
import org.apache.spark.sql.pipelines.logging.{FlowProgressEventLogger, PipelineEvent}
2225

@@ -36,6 +39,8 @@ class PipelineUpdateContextImpl(
3639
override val storageRoot: String
3740
) extends PipelineUpdateContext {
3841

42+
PipelineUpdateContextImpl.validateStorageRoot(storageRoot)
43+
3944
override val spark: SparkSession = SparkSession.getActiveSession.getOrElse(
4045
throw new IllegalStateException("SparkSession is not available")
4146
)
@@ -45,3 +50,19 @@ class PipelineUpdateContextImpl(
4550

4651
override val resetCheckpointFlows: FlowFilter = NoFlows
4752
}
53+
54+
object PipelineUpdateContextImpl {
55+
def validateStorageRoot(storageRoot: String): Unit = {
56+
// Use the same validation logic as streaming checkpoint directories
57+
val path = new Path(storageRoot)
58+
59+
val uri = path.toUri
60+
if (!path.isAbsolute || uri.getScheme == null || uri.getScheme.isEmpty) {
61+
throw new SparkException(
62+
errorClass = "PIPELINE_STORAGE_ROOT_INVALID",
63+
messageParameters = Map("storage_root" -> storageRoot),
64+
cause = null
65+
)
66+
}
67+
}
68+
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.pipelines.graph
19+
20+
import org.apache.spark.SparkException
21+
import org.apache.spark.sql.pipelines.utils.{PipelineTest, TestGraphRegistrationContext}
22+
import org.apache.spark.sql.test.SharedSparkSession
23+
24+
class PipelineUpdateContextImplSuite extends PipelineTest with SharedSparkSession {
25+
26+
test("validateStorageRoot should accept valid URIs with schemes") {
27+
val validStorageRoots = Seq(
28+
"file:///tmp/test",
29+
"hdfs://localhost:9000/pipelines",
30+
"s3a://my-bucket/pipelines",
31+
"abfss://[email protected]/pipelines"
32+
)
33+
34+
validStorageRoots.foreach(PipelineUpdateContextImpl.validateStorageRoot)
35+
}
36+
37+
test("validateStorageRoot should reject relative paths") {
38+
val invalidStorageRoots = Seq(
39+
"relative/path",
40+
"./relative/path",
41+
"../relative/path",
42+
"pipelines"
43+
)
44+
45+
invalidStorageRoots.foreach { storageRoot =>
46+
val exception = intercept[SparkException] {
47+
PipelineUpdateContextImpl.validateStorageRoot(storageRoot)
48+
}
49+
assert(exception.getCondition == "PIPELINE_STORAGE_ROOT_INVALID")
50+
assert(exception.getMessageParameters.get("storage_root") == storageRoot)
51+
}
52+
}
53+
54+
test("validateStorageRoot should reject absolute paths without URI scheme") {
55+
val invalidStorageRoots = Seq(
56+
"/tmp/test",
57+
"/absolute/path",
58+
"/pipelines/storage"
59+
)
60+
61+
invalidStorageRoots.foreach { storageRoot =>
62+
val exception = intercept[SparkException] {
63+
PipelineUpdateContextImpl.validateStorageRoot(storageRoot)
64+
}
65+
assert(exception.getCondition == "PIPELINE_STORAGE_ROOT_INVALID")
66+
assert(exception.getMessageParameters.get("storage_root") == storageRoot)
67+
}
68+
}
69+
70+
test("PipelineUpdateContextImpl constructor should validate storage root") {
71+
val session = spark
72+
import session.implicits._
73+
74+
class TestPipeline extends TestGraphRegistrationContext(spark) {
75+
registerPersistedView("test", query = dfFlowFunc(Seq(1).toDF("value")))
76+
}
77+
val graph = new TestPipeline().resolveToDataflowGraph()
78+
79+
val validStorageRoot = "file:///tmp/test"
80+
val context = new PipelineUpdateContextImpl(
81+
unresolvedGraph = graph,
82+
eventCallback = _ => {},
83+
storageRoot = validStorageRoot
84+
)
85+
assert(context.storageRoot == validStorageRoot)
86+
87+
val invalidStorageRoot = "/tmp/test"
88+
val exception = intercept[SparkException] {
89+
new PipelineUpdateContextImpl(
90+
unresolvedGraph = graph,
91+
eventCallback = _ => {},
92+
storageRoot = invalidStorageRoot
93+
)
94+
}
95+
assert(exception.getCondition == "PIPELINE_STORAGE_ROOT_INVALID")
96+
assert(exception.getMessageParameters.get("storage_root") == invalidStorageRoot)
97+
}
98+
}

sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SinkExecutionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class SinkExecutionSuite extends ExecutionTest with SharedSparkSession {
115115
sinkIdentifier: TableIdentifier,
116116
flowIdentifier: TableIdentifier): Unit = {
117117
val expectedCheckpointLocation = new Path(
118-
"file://" + rootDirectory + s"/_checkpoints/${sinkIdentifier.table}/${flowIdentifier.table}/0"
118+
rootDirectory + s"/_checkpoints/${sinkIdentifier.table}/${flowIdentifier.table}/0"
119119
)
120120
val streamingQuery = graphExecution
121121
.flowExecutions(flowIdentifier)

sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/StorageRootMixin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ trait StorageRootMixin extends BeforeAndAfterEach { self: Suite =>
3838
override protected def beforeEach(): Unit = {
3939
super.beforeEach()
4040
storageRoot =
41-
Files.createTempDirectory(getClass.getSimpleName).normalize.toString
41+
s"file://${Files.createTempDirectory(getClass.getSimpleName).normalize.toString}"
4242
}
4343

4444
override protected def afterEach(): Unit = {

0 commit comments

Comments
 (0)