Skip to content

Commit 1de64a4

Browse files
harshmotw-dbcloud-fan
authored andcommitted
[SPARK-53032] Fix parquet format of shredded timestamp values within arrays
### What changes were proposed in this pull request? This PR is an extension of the [previous PR](#51609) which did not account for the fact that timestamps within Variant arrays could be shredded as well. This PR makes sure that these timestamps are stored in compliance to the shredding spec. ### Why are the changes needed? Variants representing arrays of timestamps could be shredded and the written format must reflect the parquet spec. ### Does this PR introduce _any_ user-facing change? This PR must go in the same version as the [previous PR](#51609). The physical format of shredded timestamps within parquet files will be different. ### How was this patch tested? Incorporated `array<timestamp>` within previous unit test. ### Was this patch authored or co-authored using generative AI tooling? no Closes #51734 from harshmotw-db/harsh-motwani_data/shredding_timestamp_fix. Authored-by: Harsh Motwani <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent a7e9d93 commit 1de64a4

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,9 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging {
308308
writeFields(row.getStruct(ordinal, t.length), t, fieldWriters)
309309
}
310310

311-
case t: ArrayType => makeArrayWriter(t)
311+
case t: ArrayType => makeArrayWriter(t, inShredded)
312312

313-
case t: MapType => makeMapWriter(t)
313+
case t: MapType => makeMapWriter(t, inShredded)
314314

315315
case t: UserDefinedType[_] => makeWriter(t.sqlType, inShredded)
316316

@@ -391,9 +391,9 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging {
391391
}
392392
}
393393

394-
def makeArrayWriter(arrayType: ArrayType): ValueWriter = {
394+
def makeArrayWriter(arrayType: ArrayType, inShredded: Boolean): ValueWriter = {
395395
// The shredded schema should not have an array inside
396-
val elementWriter = makeWriter(arrayType.elementType, inShredded = false)
396+
val elementWriter = makeWriter(arrayType.elementType, inShredded)
397397

398398
def threeLevelArrayWriter(repeatedGroupName: String, elementFieldName: String): ValueWriter =
399399
(row: SpecializedGetters, ordinal: Int) => {
@@ -472,9 +472,12 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging {
472472
}
473473
}
474474

475-
private def makeMapWriter(mapType: MapType): ValueWriter = {
476-
val keyWriter = makeWriter(mapType.keyType, inShredded = false)
477-
val valueWriter = makeWriter(mapType.valueType, inShredded = false)
475+
private def makeMapWriter(mapType: MapType, inShredded: Boolean): ValueWriter = {
476+
// TODO: If maps are ever supported in the shredded schema, we should add a test in
477+
// `ParquetVariantShreddingSuite` to make sure that timestamps within maps are shredded
478+
// correctly as INT64.
479+
val keyWriter = makeWriter(mapType.keyType, inShredded)
480+
val valueWriter = makeWriter(mapType.valueType, inShredded)
478481
val repeatedGroupName = if (writeLegacyParquetFormat) {
479482
// Legacy mode:
480483
//

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,21 @@ class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with Share
4747
ParquetOutputTimestampType.values.foreach { timestampParquetType =>
4848
withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> timestampParquetType.toString) {
4949
withTempDir { dir =>
50-
val schema = "t timestamp, st struct<t timestamp>"
50+
val schema = "t timestamp, st struct<t timestamp>, at array<timestamp>"
5151
val fullSchema = "v struct<metadata binary, value binary, typed_value struct<" +
5252
"t struct<value binary, typed_value timestamp>," +
5353
"st struct<" +
54-
"value binary, typed_value struct<t struct<value binary, typed_value timestamp>>>>>, " +
54+
"value binary, typed_value struct<t struct<value binary, typed_value timestamp>>>," +
55+
"at struct<" +
56+
"value binary, typed_value array<struct<value binary, typed_value timestamp>>>" +
57+
">>, " +
5558
"t1 timestamp, st1 struct<t1 timestamp>"
5659
val df = spark.sql(
5760
"""
5861
| select
5962
| to_variant_object(
60-
| named_struct('t', 1::timestamp, 'st', named_struct('t', 2::timestamp))
63+
| named_struct('t', 1::timestamp, 'st', named_struct('t', 2::timestamp),
64+
| 'at', array(5::timestamp))
6165
| ) v, 3::timestamp t1, named_struct('t1', 4::timestamp) st1
6266
| from range(1)
6367
|""".stripMargin)
@@ -82,6 +86,9 @@ class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with Share
8286
checkAnswer(
8387
shreddedDf.selectExpr("st1.t1::long"),
8488
Seq(Row(4)))
89+
checkAnswer(
90+
shreddedDf.selectExpr("v.typed_value.at.typed_value[0].typed_value::long"),
91+
Seq(Row(5)))
8592
val file = dir.listFiles().find(_.getName.endsWith(".parquet")).get
8693
val parquetFilePath = file.getAbsolutePath
8794
val inputFile = HadoopInputFile.fromPath(new Path(parquetFilePath), new Configuration())
@@ -106,6 +113,16 @@ class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with Share
106113
assert(typedValue2.getLogicalTypeAnnotation == LogicalTypeAnnotation.timestampType(
107114
true, LogicalTypeAnnotation.TimeUnit.MICROS))
108115

116+
// v.typed_value.at.typed_value[0].typed_value
117+
val atGroup = typedValueGroup.getType("at").asGroupType()
118+
val atTypedValueGroup = atGroup.getType("typed_value").asGroupType()
119+
val atLGroup = atTypedValueGroup.getType("list").asGroupType()
120+
val atLEGroup = atLGroup.getType("element").asGroupType()
121+
val typedValue3 = atLEGroup.getType("typed_value").asPrimitiveType()
122+
assert(typedValue3.getPrimitiveTypeName == PrimitiveTypeName.INT64)
123+
assert(typedValue3.getLogicalTypeAnnotation == LogicalTypeAnnotation.timestampType(
124+
true, LogicalTypeAnnotation.TimeUnit.MICROS))
125+
109126
def verifyNonVariantTimestampType(t: PrimitiveType): Unit = {
110127
timestampParquetType match {
111128
case ParquetOutputTimestampType.INT96 =>

0 commit comments

Comments
 (0)