Skip to content

Commit 4cadadc

Browse files
committed
add partition by file num strategy
1 parent aad6c51 commit 4cadadc

File tree

3 files changed

+144
-17
lines changed

3 files changed

+144
-17
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2430,6 +2430,28 @@ object SQLConf {
24302430
.checkValue(v => v > 0, "The maximum number of partitions must be a positive integer.")
24312431
.createOptional
24322432

2433+
val FILES_PARTITION_STRATEGY = buildConf("spark.sql.files.partitionStrategy")
2434+
.doc("The strategy to coalesce small files into larger partitions when reading files. " +
2435+
"Options are `size_based` (coalesce based on size of files), and `file_based` "
2436+
+ "(coalesce based on number of files). The number of output partitions depends on " +
2437+
"`spark.sql.files.maxPartitionBytes` and `spark.sql.files.maxPartitionNum`. " +
2438+
"This configuration is effective only when using file-based sources such as " +
2439+
"Parquet, JSON and ORC.")
2440+
.version("3.5.0")
2441+
.stringConf
2442+
.checkValues(Set("size_based", "file_based"))
2443+
.createWithDefault("size_based")
2444+
2445+
val SMALL_FILE_THRESHOLD =
2446+
buildConf("spark.sql.files.smallFileThreshold")
2447+
.doc(
2448+
"Defines the total size threshold for small files in a table scan. If the cumulative size " +
2449+
"of small files falls below this threshold, they are distributed across multiple " +
2450+
"partitions to avoid concentrating them in a single partition. This configuration is " +
2451+
"used when `spark.sql.files.coalesceStrategy` is set to `file_based`.")
2452+
.doubleConf
2453+
.createWithDefault(0.5)
2454+
24332455
val IGNORE_CORRUPT_FILES = buildConf("spark.sql.files.ignoreCorruptFiles")
24342456
.doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " +
24352457
"encountering corrupted files and the contents that have been read will still be returned. " +
@@ -6949,6 +6971,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
69496971

69506972
def filesMaxPartitionNum: Option[Int] = getConf(FILES_MAX_PARTITION_NUM)
69516973

6974+
def filesPartitionStrategy: String = getConf(FILES_PARTITION_STRATEGY)
6975+
6976+
def smallFileThreshold: Double = getConf(SMALL_FILE_THRESHOLD)
6977+
69526978
def ignoreCorruptFiles: Boolean = getConf(IGNORE_CORRUPT_FILES)
69536979

69546980
def ignoreMissingFiles: Boolean = getConf(IGNORE_MISSING_FILES)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala

Lines changed: 103 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ case class FilePartition(index: Int, files: Array[PartitionedFile])
5555

5656
object FilePartition extends SessionStateHelper with Logging {
5757

58-
private def getFilePartitions(
58+
private def getFilePartitionsBySize(
5959
partitionedFiles: Seq[PartitionedFile],
6060
maxSplitBytes: Long,
6161
openCostInBytes: Long): Seq[FilePartition] = {
@@ -75,7 +75,7 @@ object FilePartition extends SessionStateHelper with Logging {
7575
}
7676

7777
// Assign files to partitions using "Next Fit Decreasing"
78-
partitionedFiles.foreach { file =>
78+
partitionedFiles.sortBy(_.length)(implicitly[Ordering[Long]].reverse).foreach { file =>
7979
if (currentSize + file.length > maxSplitBytes) {
8080
closePartition()
8181
}
@@ -87,28 +87,114 @@ object FilePartition extends SessionStateHelper with Logging {
8787
partitions.toSeq
8888
}
8989

90+
private def getFilePartitionsByFileNum(
91+
partitionedFiles: Seq[PartitionedFile],
92+
outputPartitions: Int,
93+
smallFileThreshold: Double): Seq[FilePartition] = {
94+
// Flatten and sort descending by file size.
95+
val filesSorted: Seq[(PartitionedFile, Long)] =
96+
partitionedFiles
97+
.map(f => (f, f.length))
98+
.sortBy(_._2)(Ordering.Long.reverse)
99+
100+
val partitions = Array.fill(outputPartitions)(mutable.ArrayBuffer.empty[PartitionedFile])
101+
102+
def addToBucket(
103+
heap: mutable.PriorityQueue[(Long, Int, Int)],
104+
file: PartitionedFile,
105+
sz: Long): Unit = {
106+
val (load, numFiles, idx) = heap.dequeue()
107+
partitions(idx) += file
108+
heap.enqueue((load + sz, numFiles + 1, idx))
109+
}
110+
111+
// First by load, then by numFiles.
112+
val heapByFileSize =
113+
mutable.PriorityQueue.empty[(Long, Int, Int)](
114+
Ordering
115+
.by[(Long, Int, Int), (Long, Int)] {
116+
case (load, numFiles, _) =>
117+
(load, numFiles)
118+
}
119+
.reverse
120+
)
121+
122+
if (smallFileThreshold > 0) {
123+
val smallFileTotalSize = filesSorted.map(_._2).sum * smallFileThreshold
124+
// First by numFiles, then by load.
125+
val heapByFileNum =
126+
mutable.PriorityQueue.empty[(Long, Int, Int)](
127+
Ordering
128+
.by[(Long, Int, Int), (Int, Long)] {
129+
case (load, numFiles, _) =>
130+
(numFiles, load)
131+
}
132+
.reverse
133+
)
134+
135+
(0 until outputPartitions).foreach(i => heapByFileNum.enqueue((0L, 0, i)))
136+
137+
var numSmallFiles = 0
138+
var smallFileSize = 0L
139+
// Enqueue small files to the least number of files and the least load.
140+
filesSorted.reverse.takeWhile(f => f._2 + smallFileSize <= smallFileTotalSize).foreach {
141+
case (file, sz) =>
142+
addToBucket(heapByFileNum, file, sz)
143+
numSmallFiles += 1
144+
smallFileSize += sz
145+
}
146+
147+
// Move buckets from heapByFileNum to heapByFileSize.
148+
while (heapByFileNum.nonEmpty) {
149+
heapByFileSize.enqueue(heapByFileNum.dequeue())
150+
}
151+
152+
// Finally, enqueue remaining files.
153+
filesSorted.take(filesSorted.size - numSmallFiles).foreach {
154+
case (file, sz) =>
155+
addToBucket(heapByFileSize, file, sz)
156+
}
157+
} else {
158+
(0 until outputPartitions).foreach(i => heapByFileSize.enqueue((0L, 0, i)))
159+
160+
filesSorted.foreach {
161+
case (file, sz) =>
162+
addToBucket(heapByFileSize, file, sz)
163+
}
164+
}
165+
166+
partitions.zipWithIndex.map { case (p, idx) => FilePartition(idx, p.toArray) }
167+
}
168+
90169
def getFilePartitions(
91170
sparkSession: SparkSession,
92171
partitionedFiles: Seq[PartitionedFile],
93172
maxSplitBytes: Long): Seq[FilePartition] = {
94173
val conf = getSqlConf(sparkSession)
95174
val openCostBytes = conf.filesOpenCostInBytes
96175
val maxPartNum = conf.filesMaxPartitionNum
97-
val partitions = getFilePartitions(partitionedFiles, maxSplitBytes, openCostBytes)
98-
if (maxPartNum.exists(partitions.size > _)) {
99-
val totalSizeInBytes =
100-
partitionedFiles.map(_.length + openCostBytes).map(BigDecimal(_)).sum[BigDecimal]
101-
val desiredSplitBytes =
102-
(totalSizeInBytes / BigDecimal(maxPartNum.get)).setScale(0, RoundingMode.UP).longValue
103-
val desiredPartitions = getFilePartitions(partitionedFiles, desiredSplitBytes, openCostBytes)
104-
logWarning(log"The number of partitions is ${MDC(NUM_PARTITIONS, partitions.size)}, " +
105-
log"which exceeds the maximum number configured: " +
106-
log"${MDC(MAX_NUM_PARTITIONS, maxPartNum.get)}. Spark rescales it to " +
107-
log"${MDC(DESIRED_NUM_PARTITIONS, desiredPartitions.size)} by ignoring the " +
108-
log"configuration of ${MDC(CONFIG, SQLConf.FILES_MAX_PARTITION_BYTES.key)}.")
109-
desiredPartitions
110-
} else {
111-
partitions
176+
val partitions = getFilePartitionsBySize(partitionedFiles, maxSplitBytes, openCostBytes)
177+
conf.filesPartitionStrategy match {
178+
case "file_based" =>
179+
getFilePartitionsByFileNum(partitionedFiles, Math.min(partitions.size,
180+
maxPartNum.getOrElse(Int.MaxValue)), conf.smallFileThreshold)
181+
case "size_based" =>
182+
if (maxPartNum.exists(partitions.size > _)) {
183+
val totalSizeInBytes =
184+
partitionedFiles.map(_.length + openCostBytes).map(BigDecimal(_)).sum[BigDecimal]
185+
val desiredSplitBytes =
186+
(totalSizeInBytes / BigDecimal(maxPartNum.get)).setScale(0, RoundingMode.UP).longValue
187+
val desiredPartitions = getFilePartitionsBySize(
188+
partitionedFiles, desiredSplitBytes, openCostBytes)
189+
logWarning(log"The number of partitions is ${MDC(NUM_PARTITIONS, partitions.size)}, " +
190+
log"which exceeds the maximum number configured: " +
191+
log"${MDC(MAX_NUM_PARTITIONS, maxPartNum.get)}. Spark rescales it to " +
192+
log"${MDC(DESIRED_NUM_PARTITIONS, desiredPartitions.size)} by ignoring the " +
193+
log"configuration of ${MDC(CONFIG, SQLConf.FILES_MAX_PARTITION_BYTES.key)}.")
194+
desiredPartitions
195+
} else {
196+
partitions
197+
}
112198
}
113199
}
114200

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,21 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession {
627627
}
628628
}
629629

630+
test(s"Test ${SQLConf.FILES_PARTITION_STRATEGY.key} works as expected") {
631+
val files = {
632+
Range(0, 20000 - 10).map(p => PartitionedFile(InternalRow.empty, sp(s"$p"), 0, 50000))
633+
} ++ Range(0, 10).map(p => PartitionedFile(InternalRow.empty, sp(s"small_$p"), 0, 5000))
634+
635+
withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "50000",
636+
SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "0",
637+
SQLConf.FILES_PARTITION_STRATEGY.key -> "file_based"
638+
) {
639+
val partitions = FilePartition.getFilePartitions(
640+
spark, files, conf.filesMaxPartitionBytes)
641+
assert(!partitions.exists(_.files.length >= 10))
642+
}
643+
}
644+
630645
// Helpers for checking the arguments passed to the FileFormat.
631646

632647
protected val checkPartitionSchema =

0 commit comments

Comments
 (0)