@@ -55,7 +55,7 @@ case class FilePartition(index: Int, files: Array[PartitionedFile])
5555
5656object FilePartition extends SessionStateHelper with Logging {
5757
58- private def getFilePartitions (
58+ private def getFilePartitionsBySize (
5959 partitionedFiles : Seq [PartitionedFile ],
6060 maxSplitBytes : Long ,
6161 openCostInBytes : Long ): Seq [FilePartition ] = {
@@ -75,7 +75,7 @@ object FilePartition extends SessionStateHelper with Logging {
7575 }
7676
7777 // Assign files to partitions using "Next Fit Decreasing"
78- partitionedFiles.foreach { file =>
78+ partitionedFiles.sortBy(_.length)(implicitly[ Ordering [ Long ]].reverse). foreach { file =>
7979 if (currentSize + file.length > maxSplitBytes) {
8080 closePartition()
8181 }
@@ -87,28 +87,114 @@ object FilePartition extends SessionStateHelper with Logging {
8787 partitions.toSeq
8888 }
8989
90+ private def getFilePartitionsByFileNum (
91+ partitionedFiles : Seq [PartitionedFile ],
92+ outputPartitions : Int ,
93+ smallFileThreshold : Double ): Seq [FilePartition ] = {
94+ // Flatten and sort descending by file size.
95+ val filesSorted : Seq [(PartitionedFile , Long )] =
96+ partitionedFiles
97+ .map(f => (f, f.length))
98+ .sortBy(_._2)(Ordering .Long .reverse)
99+
100+ val partitions = Array .fill(outputPartitions)(mutable.ArrayBuffer .empty[PartitionedFile ])
101+
102+ def addToBucket (
103+ heap : mutable.PriorityQueue [(Long , Int , Int )],
104+ file : PartitionedFile ,
105+ sz : Long ): Unit = {
106+ val (load, numFiles, idx) = heap.dequeue()
107+ partitions(idx) += file
108+ heap.enqueue((load + sz, numFiles + 1 , idx))
109+ }
110+
111+ // First by load, then by numFiles.
112+ val heapByFileSize =
113+ mutable.PriorityQueue .empty[(Long , Int , Int )](
114+ Ordering
115+ .by[(Long , Int , Int ), (Long , Int )] {
116+ case (load, numFiles, _) =>
117+ (load, numFiles)
118+ }
119+ .reverse
120+ )
121+
122+ if (smallFileThreshold > 0 ) {
123+ val smallFileTotalSize = filesSorted.map(_._2).sum * smallFileThreshold
124+ // First by numFiles, then by load.
125+ val heapByFileNum =
126+ mutable.PriorityQueue .empty[(Long , Int , Int )](
127+ Ordering
128+ .by[(Long , Int , Int ), (Int , Long )] {
129+ case (load, numFiles, _) =>
130+ (numFiles, load)
131+ }
132+ .reverse
133+ )
134+
135+ (0 until outputPartitions).foreach(i => heapByFileNum.enqueue((0L , 0 , i)))
136+
137+ var numSmallFiles = 0
138+ var smallFileSize = 0L
139+ // Enqueue small files to the least number of files and the least load.
140+ filesSorted.reverse.takeWhile(f => f._2 + smallFileSize <= smallFileTotalSize).foreach {
141+ case (file, sz) =>
142+ addToBucket(heapByFileNum, file, sz)
143+ numSmallFiles += 1
144+ smallFileSize += sz
145+ }
146+
147+ // Move buckets from heapByFileNum to heapByFileSize.
148+ while (heapByFileNum.nonEmpty) {
149+ heapByFileSize.enqueue(heapByFileNum.dequeue())
150+ }
151+
152+ // Finally, enqueue remaining files.
153+ filesSorted.take(filesSorted.size - numSmallFiles).foreach {
154+ case (file, sz) =>
155+ addToBucket(heapByFileSize, file, sz)
156+ }
157+ } else {
158+ (0 until outputPartitions).foreach(i => heapByFileSize.enqueue((0L , 0 , i)))
159+
160+ filesSorted.foreach {
161+ case (file, sz) =>
162+ addToBucket(heapByFileSize, file, sz)
163+ }
164+ }
165+
166+ partitions.zipWithIndex.map { case (p, idx) => FilePartition (idx, p.toArray) }
167+ }
168+
90169 def getFilePartitions (
91170 sparkSession : SparkSession ,
92171 partitionedFiles : Seq [PartitionedFile ],
93172 maxSplitBytes : Long ): Seq [FilePartition ] = {
94173 val conf = getSqlConf(sparkSession)
95174 val openCostBytes = conf.filesOpenCostInBytes
96175 val maxPartNum = conf.filesMaxPartitionNum
97- val partitions = getFilePartitions(partitionedFiles, maxSplitBytes, openCostBytes)
98- if (maxPartNum.exists(partitions.size > _)) {
99- val totalSizeInBytes =
100- partitionedFiles.map(_.length + openCostBytes).map(BigDecimal (_)).sum[BigDecimal ]
101- val desiredSplitBytes =
102- (totalSizeInBytes / BigDecimal (maxPartNum.get)).setScale(0 , RoundingMode .UP ).longValue
103- val desiredPartitions = getFilePartitions(partitionedFiles, desiredSplitBytes, openCostBytes)
104- logWarning(log " The number of partitions is ${MDC (NUM_PARTITIONS , partitions.size)}, " +
105- log " which exceeds the maximum number configured: " +
106- log " ${MDC (MAX_NUM_PARTITIONS , maxPartNum.get)}. Spark rescales it to " +
107- log " ${MDC (DESIRED_NUM_PARTITIONS , desiredPartitions.size)} by ignoring the " +
108- log " configuration of ${MDC (CONFIG , SQLConf .FILES_MAX_PARTITION_BYTES .key)}. " )
109- desiredPartitions
110- } else {
111- partitions
176+ val partitions = getFilePartitionsBySize(partitionedFiles, maxSplitBytes, openCostBytes)
177+ conf.filesPartitionStrategy match {
178+ case " file_based" =>
179+ getFilePartitionsByFileNum(partitionedFiles, Math .min(partitions.size,
180+ maxPartNum.getOrElse(Int .MaxValue )), conf.smallFileThreshold)
181+ case " size_based" =>
182+ if (maxPartNum.exists(partitions.size > _)) {
183+ val totalSizeInBytes =
184+ partitionedFiles.map(_.length + openCostBytes).map(BigDecimal (_)).sum[BigDecimal ]
185+ val desiredSplitBytes =
186+ (totalSizeInBytes / BigDecimal (maxPartNum.get)).setScale(0 , RoundingMode .UP ).longValue
187+ val desiredPartitions = getFilePartitionsBySize(
188+ partitionedFiles, desiredSplitBytes, openCostBytes)
189+ logWarning(log " The number of partitions is ${MDC (NUM_PARTITIONS , partitions.size)}, " +
190+ log " which exceeds the maximum number configured: " +
191+ log " ${MDC (MAX_NUM_PARTITIONS , maxPartNum.get)}. Spark rescales it to " +
192+ log " ${MDC (DESIRED_NUM_PARTITIONS , desiredPartitions.size)} by ignoring the " +
193+ log " configuration of ${MDC (CONFIG , SQLConf .FILES_MAX_PARTITION_BYTES .key)}. " )
194+ desiredPartitions
195+ } else {
196+ partitions
197+ }
112198 }
113199 }
114200
0 commit comments