diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java index 0e2c4816a3c1a..acb585cb88031 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java @@ -99,6 +99,11 @@ public BigDecimal getDecimal() { return VariantUtil.getDecimal(value, pos); } + // Get the decimal value, including trailing zeros + public BigDecimal getDecimalWithOriginalScale() { + return VariantUtil.getDecimalWithOriginalScale(value, pos); + } + // Get a float value from the variant. public float getFloat() { return VariantUtil.getFloat(value, pos); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 17b8dd493cf80..8548bce694b18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5343,6 +5343,31 @@ object SQLConf { .stringConf .createWithDefault("") + val VARIANT_SHREDDING_MAX_SCHEMA_WIDTH = + buildConf("spark.sql.variant.shredding.maxSchemaWidth") + .internal() + .doc("Maximum number of shredded fields to create when inferring a schema for Variant") + .version("4.2.0") + .intConf + .createWithDefault(300) + + val VARIANT_SHREDDING_MAX_SCHEMA_DEPTH = + buildConf("spark.sql.variant.shredding.maxSchemaDepth") + .internal() + .doc("Maximum depth in Variant value to traverse when inferring a schema. " + + "Any array/object below this depth will be shredded as a single binary.") + .version("4.2.0") + .intConf + .createWithDefault(50) + + val VARIANT_INFER_SHREDDING_SCHEMA = + buildConf("spark.sql.variant.inferShreddingSchema") + .internal() + .doc("Infer shredding schema when writing Variant columns in Parquet tables.") + .version("4.2.0") + .booleanConf + .createWithDefault(false) + val LEGACY_CSV_ENABLE_DATE_TIME_PARSING_FALLBACK = buildConf("spark.sql.legacy.csv.enableDateTimeParsingFallback") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala new file mode 100644 index 0000000000000..a249e191fc177 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/InferVariantShreddingSchema.scala @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.SparkRuntimeException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.types.variant._ +import org.apache.spark.types.variant.VariantUtil.Type +import org.apache.spark.unsafe.types._ + +/** + * + * Infer a schema when there are Variant values in the shredding schema. + * Only VariantType values at the top level or nested in struct fields are replaced. + * VariantType nested in arrays or maps are not modified. + * @param schema The original schema containing VariantType. + */ +class InferVariantShreddingSchema(val schema: StructType) { + + /** + * Create a list of paths to Variant values in the schema. + * Variant fields nested in arrays or maps are not included. + * For example, if the schema is + * struct> + * the function will return [[0], [1, 2] + */ + private def getPathsToVariant(s: StructType): Seq[Seq[Int]] = { + s.fields.zipWithIndex + .map { + case (field, idx) => + field.dataType match { + case VariantType => + Seq(Seq(idx)) + case inner: StructType => + // Prepend this index to each downstream path. + getPathsToVariant(inner).map { path => + idx +: path + } + case _ => Seq() + } + } + .toSeq + .flatten + } + + private def getValueAtPath(s: StructType, row: InternalRow, p: Seq[Int]): Option[VariantVal] = { + if (row.isNullAt(p.head)) { + None + } else if (p.length == 1) { + // We've reached the Variant value. + Some(row.getVariant(p.head)) + } else { + // The field must be a struct. + val childStruct = s.fields(p.head).dataType.asInstanceOf[StructType] + getValueAtPath( + childStruct, + row.getStruct(p.head, childStruct.length), + p.tail + ) + } + } + + private val pathsToVariant = getPathsToVariant(schema) + + private val maxShreddedFieldsPerFile = + SQLConf.get.getConf(SQLConf.VARIANT_SHREDDING_MAX_SCHEMA_WIDTH) + + private val maxShreddingDepth = + SQLConf.get.getConf(SQLConf.VARIANT_SHREDDING_MAX_SCHEMA_DEPTH) + + private val COUNT_METADATA_KEY = "COUNT" + + /** + * Return an appropriate schema for shredding a Variant value. + * It is similar to the SchemaOfVariant expression, but the rules are somewhat different, because + * we want the types to be consistent with what will be allowed during shredding. E.g. + * SchemaOfVariant will consider the common type across Integer and Double to be double, but we + * consider it to be VariantType, since shredding will not allow those types to be written to + * the same typed_value. + * We also maintain metadata on struct fields to track how frequently they occur. Rare fields + * are dropped in the final schema. + */ + private def schemaOf(v: Variant, maxDepth: Int): DataType = v.getType match { + case Type.OBJECT => + if (maxDepth <= 0) return VariantType + val size = v.objectSize() + val fields = new Array[StructField](size) + for (i <- 0 until size) { + val field = v.getFieldAtIndex(i) + fields(i) = StructField(field.key, schemaOf(field.value, maxDepth - 1), + metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY, 1).build()) + } + // According to the variant spec, object fields must be sorted alphabetically. So we don't + // have to sort, but just need to validate they are sorted. + for (i <- 1 until size) { + if (fields(i - 1).name >= fields(i).name) { + throw new SparkRuntimeException( + errorClass = "MALFORMED_VARIANT", + messageParameters = Map.empty + ) + } + } + StructType(fields) + case Type.ARRAY => + if (maxDepth <= 0) return VariantType + var elementType: DataType = NullType + for (i <- 0 until v.arraySize()) { + elementType = mergeSchema(elementType, schemaOf(v.getElementAtIndex(i), maxDepth - 1)) + } + ArrayType(elementType) + case Type.NULL => NullType + case Type.BOOLEAN => BooleanType + case Type.LONG => + // Compute the smallest decimal that can contain this value. + // This will allow us to merge with decimal later without introducing excessive precision. + // If we only end up encountering integer values, we'll convert back to LongType when we + // finalize. + val d = BigDecimal(v.getLong()) + val precision = d.precision + if (precision <= Decimal.MAX_LONG_DIGITS) { + DecimalType(precision, 0) + } else { + // Value is too large for Decimal(18, 0), so record its type as long. + LongType + } + case Type.STRING => StringType + case Type.DOUBLE => DoubleType + case Type.DECIMAL => + // Don't strip trailing zeros to determine scale. Even if we allow scale relaxation during + // shredding, it's useful to take trailing zeros as a hint that the extra digits may be used + // in later values, and use the larger scale. + val d = Decimal(v.getDecimalWithOriginalScale()) + DecimalType(d.precision, d.scale) + case Type.DATE => DateType + case Type.TIMESTAMP => TimestampType + case Type.TIMESTAMP_NTZ => TimestampNTZType + case Type.FLOAT => FloatType + case Type.BINARY => BinaryType + // Spark doesn't support UUID, so shred it as an untyped value. + case Type.UUID => VariantType + } + + private def getFieldCount(field: StructField): Long = { + field.metadata.getLong(COUNT_METADATA_KEY) + } + + // Merge two decimals with possibly different scales. + private def mergeDecimal(d1: DecimalType, d2: DecimalType): DataType = { + val scale = Math.max(d1.scale, d2.scale) + val range = Math.max(d1.precision - d1.scale, d2.precision - d2.scale) + if (range + scale > DecimalType.MAX_PRECISION) { + // DecimalType can't support precision > 38 + VariantType + } else { + DecimalType(range + scale, scale) + } + } + + private def mergeDecimalWithLong(d: DecimalType): DataType = { + if (d.scale == 0 && d.precision <= 18) { + // It's an integer-like Decimal. Rather than widen to a precision of 19, we can + // use LongType + LongType + } else { + // Long can always fit in a Decimal(19, 0) + mergeDecimal(d, DecimalType(19, 0)) + } + } + + private def mergeSchema(dt1: DataType, dt2: DataType): DataType = { + (dt1, dt2) match { + // Allow VariantNull to appear in any typed schema + case (NullType, t) => t + case (t, NullType) => t + case (d1: DecimalType, d2: DecimalType) => + mergeDecimal(d1, d2) + case (d: DecimalType, LongType) => + mergeDecimalWithLong(d) + case (LongType, d: DecimalType) => + mergeDecimalWithLong(d) + case (StructType(fields1), StructType(fields2)) => + // Rely on fields being sorted by name, and merge fields with the same name recursively. + val newFields = new java.util.ArrayList[StructField]() + + var f1Idx = 0 + var f2Idx = 0 + // We end up dropping all but 300 fields in the final schema, but add a cap on how many + // we'll try to track to avoid memory/time blow-ups in the intermediate state. + val maxStructSize = 1000 + + while (f1Idx < fields1.length && f2Idx < fields2.length && newFields.size < maxStructSize) { + val f1Name = fields1(f1Idx).name + val f2Name = fields2(f2Idx).name + val comp = f1Name.compareTo(f2Name) + if (comp == 0) { + val dataType = mergeSchema(fields1(f1Idx).dataType, fields2(f2Idx).dataType) + val c1 = getFieldCount(fields1(f1Idx)) + val c2 = getFieldCount(fields2(f2Idx)) + newFields.add( + StructField( + f1Name, + dataType, + metadata = new MetadataBuilder().putLong(COUNT_METADATA_KEY, c1 + c2).build() + ) + ) + f1Idx += 1 + f2Idx += 1 + } else if (comp < 0) { // f1Name < f2Name + newFields.add(fields1(f1Idx)) + f1Idx += 1 + } else { // f1Name > f2Name + newFields.add(fields2(f2Idx)) + f2Idx += 1 + } + } + while (f1Idx < fields1.length && newFields.size < maxStructSize) { + newFields.add(fields1(f1Idx)) + f1Idx += 1 + } + while (f2Idx < fields2.length && newFields.size < maxStructSize) { + newFields.add(fields2(f2Idx)) + f2Idx += 1 + } + StructType(newFields.toArray(Array.empty[StructField])) + case (ArrayType(e1, _), ArrayType(e2, _)) => + ArrayType(mergeSchema(e1, e2)) + // For any other scalar types, the types must be identical, or we give up and use Variant. + case (_, _) if dt1 == dt2 => dt1 + case _ => VariantType + } + } + + /** + * Update each VariantType with its inferred schema. + */ + private def updateSchema( + schema: StructType, + inferredSchemas: Map[Seq[Int], StructType], + path: Seq[Int] = Seq()): StructType = { + val newFields = schema.fields.zipWithIndex.map { + case (field, idx) => + field.dataType match { + case VariantType => + // Right now, we infer a schema for every VariantType that isn't nested in an array or + // map, so we should always find a replacement. + val fullPath = path :+ idx + assert(inferredSchemas.contains(fullPath)) + field.copy(dataType = inferredSchemas(fullPath)) + case inner: StructType => + val newType = updateSchema(inner, inferredSchemas, path :+ idx) + field.copy(dataType = newType) + case dt => field + } + } + StructType(newFields) + } + + // Container for a mutable integer, to track the total number of shredded fields we can add across + // the file. It should be initialized to the maximum allowed across the file schema. + // `finalizeSimpleSchema` decrements it, and stops adding new fields once it hits 0. + private case class MaxFields(var remaining: Int) + + /** + * Given the schema of a Variant type, finalize the schema. Specifically: + * 1) Widen integer types to LongType, since it adds flexibility for shredding, and + * shouldn't have much storage size impact after encoding. + * 2) Replace empty structs with VariantType, since empty structs are invalid in Parquet. + * 3) Limit the total number of shredded fields in the schema + */ + private def finalizeSimpleSchema( + dt: DataType, + minCardinality: Int, + maxFields: MaxFields): DataType = { + // Every field uses a value column. + maxFields.remaining -= 1 + if (maxFields.remaining <= 0) { + // No space left for a typed_value. Use VariantType, which only consumes a value column. + return VariantType + } + + dt match { + case StructType(fields) => + val newFields = new java.util.ArrayList[StructField]() + // Drop fields with less than the required cardinality. + fields + .filter(getFieldCount(_) >= minCardinality) + .foreach { field => + if (maxFields.remaining > 0) { + newFields.add( + field.copy( + dataType = finalizeSimpleSchema(field.dataType, minCardinality, maxFields) + ) + ) + } + } + // If we weren't able to retain any fields, just use VariantType + if (newFields.size() > 0) StructType(newFields) else VariantType + case ArrayType(elementType, _) => + ArrayType(finalizeSimpleSchema(elementType, minCardinality, maxFields)) + case ByteType | ShortType | IntegerType | LongType => + maxFields.remaining -= 1 + // We widen all integer types to long. There isn't much benefit to shredding as a + // narrower integer type. + LongType + case d: DecimalType if d.precision <= 18 && d.scale == 0 => + // This was probably an integer type originally, and we converted to Decimal(N, 0) to + // allow it to merge with decimal. Since it still has 0 scale, we can convert back to + // LongType in the final schema. + maxFields.remaining -= 1 + LongType + case d: DecimalType => + // Store as 8-byte if precision is small enough, otherwise use 16-byte decimal. + maxFields.remaining -= 1 + if (d.precision <= Decimal.MAX_LONG_DIGITS) { + DecimalType(Decimal.MAX_LONG_DIGITS, d.scale) + } else { + DecimalType(DecimalType.MAX_PRECISION, d.scale) + } + case VariantType | NullType => + // VariantType and NullType don't have a corresponding typed_value. They just write + // to the value column. + VariantType + case t => + // All other scalar types use typed_value. + maxFields.remaining -= 1 + t + } + } + + def inferSchema(rows: Seq[InternalRow]): StructType = { + // For each path to a Variant value, iterate over all rows and update the inferred schema. + // Add the result to a map, which we'll use to update the full schema. + // maxShreddedFieldsPerFile is a global max for all fields, so initialize it here. + val maxFields = MaxFields(maxShreddedFieldsPerFile) + val inferredSchemas = pathsToVariant.map { path => + var numNonNullValues = 0 + val simpleSchema = rows.foldLeft(NullType: DataType) { + case (partialSchema, row) => + getValueAtPath(schema, row, path).map { variantVal => + numNonNullValues += 1 + val v = new Variant(variantVal.getValue, variantVal.getMetadata) + val schemaOfRow = schemaOf(v, maxShreddingDepth) + mergeSchema(partialSchema, schemaOfRow) + // If getValueAtPath returned None, the value is null in this row; just ignore. + } + .getOrElse(partialSchema) + // If we didn't find any non-null rows, use an unshredded schema. + } + + // Don't infer a schema for fields that appear in less than 10% of rows. + // Ensure that minCardinality is at least 1 if we have any rows. + val minCardinality = (numNonNullValues + 9) / 10 + + val finalizedSchema = finalizeSimpleSchema(simpleSchema, minCardinality, maxFields) + val shreddingSchema = SparkShreddingUtils.variantShreddingSchema(finalizedSchema) + val schemaWithMetadata = SparkShreddingUtils.addWriteShreddingMetadata(shreddingSchema) + (path, schemaWithMetadata) + }.toMap + + // Insert each inferred schema into the full schema. + updateSchema(schema, inferredSchemas) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index dd5669bda07c9..36658156cb8be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -77,6 +77,14 @@ class ParquetOptions( def datetimeRebaseModeInRead: String = parameters .get(DATETIME_REBASE_MODE) .getOrElse(sqlConf.getConf(SQLConf.PARQUET_REBASE_MODE_IN_READ).toString) + + val inferShreddingForVariant: Boolean = { + // VARIANT_WRITE_SHREDDING_ENABLED is a global kill switch. + sqlConf.getConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED) && + parameters.get(SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key).map(_.toBoolean) + .getOrElse(sqlConf.getConf(SQLConf.VARIANT_INFER_SHREDDING_SCHEMA)) + } + /** * The rebasing mode for INT96 timestamp values in reads. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriterWithVariantShredding.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriterWithVariantShredding.scala new file mode 100644 index 0000000000000..80c4fbac7ae7f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriterWithVariantShredding.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.mapreduce._ +import org.apache.parquet.hadoop.metadata.ParquetMetadata + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.OutputWriter +import org.apache.spark.sql.execution.datasources.parquet.InferVariantShreddingSchema +import org.apache.spark.sql.execution.datasources.parquet.ParquetOutputWriter +import org.apache.spark.sql.execution.datasources.parquet.ParquetWriteSupport + +/** + * A wrapper around ParquetOutputWriter that defers creating the writer until some number of input + * rows have been buffered and analyzed to determine an appropriate schema for shredded Variant. + * @inferrenceHelper An object that is aware of the input schema. It takes the buffered rows, and + * returns a schema that replaces VariantType with its shredded representation. + * If `shreddingSchemaForced` is true, shredding schema inference is not performed because a forced + * schema is assumed to be used. + */ +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +class ParquetOutputWriterWithVariantShredding( + val path: String, + context: TaskAttemptContext, + inferenceHelper: InferVariantShreddingSchema, + val isShreddingSchemaForced: Boolean) + extends OutputWriter { + + private var parquetOutputWriter: Option[ParquetOutputWriter] = None + private var latestFooter: Option[ParquetMetadata] = None + + private val rows = scala.collection.mutable.Buffer[InternalRow]() + private var bufferedSize = 0L + private lazy val toUnsafeRow = UnsafeProjection.create(inferenceHelper.schema) + + private val maxBufferSize = 64L * 1024L * 1024L + private val maxBufferRows = 4096 + + private def finalizeSchemaAndFlush(): Unit = { + if (!isShreddingSchemaForced) { + val finalSchema = inferenceHelper.inferSchema(rows.toSeq) + ParquetWriteSupport.setShreddingSchema(finalSchema, context.getConfiguration) + } + parquetOutputWriter = Some(new ParquetOutputWriter(path, context)) + rows.foreach(row => parquetOutputWriter.get.write(row)) + rows.clear() + bufferedSize = 0 + } + + def getLatestFooterOpt: Option[ParquetMetadata] = latestFooter + + // Check if adding the current row will exceed the thresholds for buffering + // As a side effect, updates `bufferedSize` on the assumption that if we return false, the row + // will be added to the list of buffered rows. + private def stopBuffering(row: UnsafeRow): Boolean = { + // Buffer the first `maxBufferRows` rows. + // Use +1 to account for the current row, which will be added by the caller if we return true + // here. + if (rows.size + 1 >= maxBufferRows) { + true + } else { + bufferedSize += row.getSizeInBytes + bufferedSize > maxBufferSize + } + } + + override def write(row: InternalRow): Unit = { + if (parquetOutputWriter.isEmpty) { + // We need an UnsafeRow in order to determine the memory usage. Normally, the row should + // already be UnsafeRow, but if it isn't, make one. + val unsafeRow = row match { + case u: UnsafeRow => u + case _ => toUnsafeRow(row) + } + if (stopBuffering(unsafeRow)) { + // We're ready to pick a schema. + // Copy the last row by reference, since we'll clear it after finalizeSchemaAndFlush. + rows += unsafeRow + finalizeSchemaAndFlush() + } else { + rows += unsafeRow.copy() + } + } else { + // We've already picked a schema, and can write directly. + parquetOutputWriter.get.write(row) + } + } + + override def close(): Unit = { + try { + if (parquetOutputWriter.isEmpty) { + // We haven't written any rows yet. Pick a schema, and write them all. + finalizeSchemaAndFlush() + } + } finally { + parquetOutputWriter.foreach { writer => + writer.close() + } + parquetOutputWriter = None + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index 5db3b0671db9f..65a77322549f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -38,6 +38,7 @@ import org.apache.spark.internal.LogKeys.{CLASS_NAME, CONFIG} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.JoinedRow +import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, OutputWriter, OutputWriterFactory} @@ -435,6 +436,21 @@ object ParquetUtils extends Logging { StructType(newFields) } + private def needShreddingInference( + parquetOptions: ParquetOptions, + schema: StructType): Boolean = { + // If shredding inference is enabled, use the shredding writer, but only if the schema + // actually contains at least one Variant column. + if (parquetOptions.inferShreddingForVariant) { + // This will return true even if the type contains Variant as an array or map element. + // Even though we don't try to infer a schema right now, there isn't much downside, and + // we may want to allow schema inference for these cases in the future. + VariantExpressionEvalUtils.typeContainsVariant(schema) + } else { + false + } + } + def prepareWrite( sqlConf: SQLConf, job: Job, @@ -522,12 +538,22 @@ object ParquetUtils extends Logging { log"${MDC(CONFIG, ParquetOutputFormat.JOB_SUMMARY_LEVEL)} to NONE.") } + val useVariantShreddingWriter = needShreddingInference(parquetOptions, dataSchema) + val shreddingSchemaForced = + sqlConf.getConf(SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST).nonEmpty + new OutputWriterFactory { override def newInstance( path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new ParquetOutputWriter(path, context) + if (useVariantShreddingWriter) { + val inferenceHelper = new InferVariantShreddingSchema(dataSchema) + new ParquetOutputWriterWithVariantShredding(path, context, inferenceHelper, + shreddingSchemaForced) + } else { + new ParquetOutputWriter(path, context) + } } override def getFileExtension(context: TaskAttemptContext): String = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala new file mode 100644 index 0000000000000..cdaf6c488dc2a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala @@ -0,0 +1,634 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.classic.Dataset +import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetTest, ParquetToSparkSchemaConverter, SparkShreddingUtils} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ +import org.apache.spark.types.variant._ +import org.apache.spark.unsafe.types.VariantVal + +// Unit tests for variant shredding inference. +class VariantInferShreddingSuite extends QueryTest with SharedSparkSession with ParquetTest { + override def sparkConf: SparkConf = { + super.sparkConf.set(SQLConf.PUSH_VARIANT_INTO_SCAN.key, "true") + .set(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key, "true") + .set(SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key, "true") + } + + private def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + + private def testWithTempDir(name: String)(block: File => Unit): Unit = test(name) { + withTempDir { dir => + block(dir) + } + } + + def getFooters(dir: File): Seq[org.apache.parquet.hadoop.Footer] = { + val fs = FileSystem.get(spark.sessionState.newHadoopConf()) + val fileStatuses = fs.listStatus(new Path(dir.getPath)) + .filter(_.getPath.toString.endsWith(".parquet")) + .toIndexedSeq + ParquetFileFormat.readParquetFootersInParallel( + spark.sessionState.newHadoopConf(), fileStatuses, ignoreCorruptFiles = false) + } + + // Checks that exactly one parquet file exists at the provided path, and returns its schema. + def getFileSchema(dir: File): StructType = { + val footers = getFooters(dir) + assert(footers.size == 1) + new ParquetToSparkSchemaConverter() + .convert(footers(0).getParquetMetadata.getFileMetaData.getSchema) + } + + // Check that the parquet file at the given path contains exactly one field, "v", + // with the expected schema. + def checkFileSchema(expectedSimpleSchema: DataType, dir: File): Unit = { + val expected = SparkShreddingUtils.variantShreddingSchema(expectedSimpleSchema) + val actual = getFileSchema(dir) + // Depending on the data, Spark may be able to infer that the top-level column is + // non-nullable, so accept either one. + assert(actual == StructType(Seq(StructField("v", expected, nullable = false))) || + actual == StructType(Seq(StructField("v", expected, nullable = true)))) + } + + // Given a DF with a field named "v", check that the string representation and schema_of_variant + // match. We can't check the binary values directly, because they can change before and after + // shredding. + def checkStringAndSchema(dir: File, expected: DataFrame, field: String = "v"): Unit = { + checkAnswer( + spark.read.parquet(dir.getAbsolutePath).selectExpr(s"$field::string", + s"schema_of_variant($field)"), + expected.selectExpr(s"$field::string", s"schema_of_variant($field)").collect() + ) + } + + testWithTempDir("infer shredding schema basic") { dir => + // Check that we can write and read normally when shredding is enabled if + // we don't provide a shredding schema. + val df = spark.sql( + """ + | select parse_json('{"a": ' || id || ', "b": "' || id || '"}') as v + | from range(0, 3, 1, 1) + |""".stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + // Inferred integer columns are always widened to long + val expected = DataType.fromDDL("struct") + checkFileSchema(expected, dir) + checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect()) + } + + test("infer shredding does not infer rare rows") { + Seq(2, 9, 10, 11, 19, 20, 21, 100).foreach { inverseFreq => + withTempDir { dir => + // Test two infrequent array-of-object cases: + // 1) rareArray: If an array is rarely present, it is dropped from the schema, even if a + // field in the array appears many times when the array is present. + // 2) rareArray2: If an array is usually present, then its fields are weighted by how many + // times they occur in total. E.g. a field that is only present in 1% of rows, but + // occurs in more than 10 elements in that row will be included in the schema, since it + // appears "on average" in 10% of rows. + // These rules are a bit arbitrary, and the second one especially is mainly done to + // keep the algorithm simple, but validate it here, and we can consider revisiting it + // in the future. + val df = spark.sql( + s""" + | select case when id % $inverseFreq = 0 then + | parse_json('{"a": ' || id || + | ', "rareArray": [{"x": 1}, {"x": 2}, {"y": 3}]' || + | ', "rareArray2": [{"x": 1}, {"x": 2}, {"y": 3}]' || + | ', "rareField" : "xyz"}') + | else + | parse_json('{"a": ' || id || + | ', "rareArray2": []' || + | ', "b": "' || id || '"}') + | end as v + | from range(0, 10000, 1, 1) + |""".stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + // "a" appears in all rows, and "b" appears in at least 50% of rows, so they should always + // show up in the schema. The other fields depend on how frequently they appear. + val expected = if (inverseFreq > 20) { + // rareArray2 is always present, but none of its fields are frequent enough to + // include in the schema, so the element shows up as `variant`. + DataType.fromDDL("struct>") + } else if (inverseFreq > 10) { + // The "x" field in rareArray2 appears twice in each of the 5-10% of rows, so on average + // it appears in over 10% of rows. + DataType.fromDDL("struct>>") + } else { + // All fields appear in at least 10% of rows, so should be in the inferred schema. + DataType.fromDDL( "struct>, " + + "rareArray2 array>, " + + "rareField string>") + } + checkFileSchema(expected, dir) + checkStringAndSchema(dir, df) + } + } + } + + test("infer shredding does not infer wide schemas") { + Seq(50, 60, 70).foreach { topLevelFields => + // If this changes, we should change the test, or set it explicitly. + assert(SQLConf.VARIANT_SHREDDING_MAX_SCHEMA_WIDTH.defaultValue.get == 300) + withTempDir { dir => + + // Each field is a 2-element object. the leaf fields will consume 4 columns, and + // the value column for the field will consume one, for a total of 5, so we should + // hit the limit at around 60 columns. + val bigObject = (0 until topLevelFields).map { i => + s""" "col_$i": {"x": $i, "y": "${i + 1}"} """ + }.mkString(start = "{", sep = ",", end = "}") + val df = spark.sql( + // In addition to the large object, add a smaller one. Once we hit the limit, we should + // not shred that one, becaused it comes later in the schema. + s"""select parse_json('$bigObject') as v, + parse_json('{"x": ' || id || ', "y": 2}') as v2 + from range(0, 100, 1, 1) """) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + val footers = getFooters(dir) + assert(footers.size == 1) + // Checking the exact schema is a hassle, so just check that we get the expected number + // of column chunks. + val cols = footers(0).getParquetMetadata.getFileMetaData.getSchema.getColumns + if (topLevelFields == 50) { + // Columns should be 50 * 5 for the nested fields, plus a top-level value and metadata + // The second object should shred with 6 columns, using a top-level value and metadata, + // and a value and typed_value for each of "x" and "y". + assert(cols.size() == 252 + 6) + } else { + // Limit is 300, and metadata does not count towards the limit. Once we hit the limit on + // the first column, we'll use unshredded for the second, which adds two more columns. + assert(cols.size() == 301 + 2) + } + // Binary should be identical, so we can call checkAnswer directly. + checkStringAndSchema(dir, df) + } + } + } + + testWithTempDir("infer shredding key as data") { dir => + // The first 10 fields in each object include the row ID in the field name, so they'll be + // unique. Because we impose a 1000-field limit when building up the schema, we'll end up + // dropping all but the first 1000, so we won't include the non-unique fields in the schema. + // Since the unique names are below the count threshold, we'll end up with an unshredded + // schema. + // In the future, we could consider trying to improve this by dropping the least-common fields + // when we hit the limit of 1000. + val bigObject = (0 until 100).map { i => + if (i < 50) { + s""" "first_${i}_' || id || '": {"x": $i, "y": "${i + 1}"} """ + } else { + s""" "last_${i}": {"x": $i, "y": "${i + 1}"} """ + } + }.mkString(start = "{", sep = ",", end = "}") + val df = spark.sql( + // In addition to the large object, add a smaller one. It should be shredded correctly. + s"""select parse_json('$bigObject') as v, + parse_json('{"x": ' || id || ', "y": 2}') as v2 + from range(0, 100, 1, 1) """) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + val footers = getFooters(dir) + assert(footers.size == 1) + + // We can't call checkFileSchema, because it only handles the case of one Variant column in + // the file. + val largeExpected = SparkShreddingUtils.variantShreddingSchema(DataType.fromDDL("variant")) + val smallExpected = SparkShreddingUtils.variantShreddingSchema( + DataType.fromDDL("struct")) + val actual = getFileSchema(dir) + assert(actual == StructType(Seq( + StructField("v", largeExpected, nullable = false), + StructField("v2", smallExpected, nullable = false)))) + checkStringAndSchema(dir, df) + } + + testWithTempDir("infer shredding from sparse data") { dir => + // Infer a schema when there is only one row per batch. + // The second case only starts at row 4096 * 2048, but we should still see it when + // we infer a schema, since there is only one active row per batch. + val df = spark.sql( + """ + | select + | case when floor(id / (4096 * 2048)) = 0 then + | parse_json('{"a": ' || id || ', "c": "' || id || '"}') + | else parse_json('{"d": ' || id || ', "b": "' || id || '"}') + | end as v + | from range(0, 4096 * 4096, 1, 1) + | where id % 4096 = 0 + |""".stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + val expected = DataType.fromDDL("struct") + checkFileSchema(expected, dir) + checkStringAndSchema(dir, df) + } + + testWithTempDir("infer shredding non-null after null") { dir => + // When the first batch is less than max batch size, the writer buffers it, and eventually + // infers a schema based on the buffered data and non-buffered data. Ensure that we behave + // correctly if either batch is all-null. + val df = spark.sql( + """ + | select + | case when id >= 4096 then + | parse_json('{"a": ' || id || ', "b": "' || id || '"}') + | else null + | end as v + | from range(0, 10000, 1, 1) + | -- Filter out one row per batch so that we buffer the first batch rather than + | -- immediately finalizing the schema. + | where id % 4096 != 0 + |""".stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + val expected = DataType.fromDDL("struct") + checkFileSchema(expected, dir) + checkStringAndSchema(dir, df) + } + + testWithTempDir("infer shredding null after non-null") { dir => + // Same as the previous test, but the first batch is non-null, and the second is all-null. + val df = spark.sql( + """ + | select + | case when id < 4096 then + | parse_json('{"a": ' || id || ', "b": "' || id || '"}') + | else null + | end as v + | from range(0, 10000, 1, 1) + | -- Filter out one row per batch so that we buffer the first batch rather than + | -- immediately finalizing the schema. + | where id % 4096 != 0 + |""".stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + val expected = DataType.fromDDL("struct") + checkFileSchema(expected, dir) + checkStringAndSchema(dir, df) + } + + testWithTempDir("infer shredding with empty file") { dir => + // When there is no data, we shoul produce a sane schema. + val df = spark.sql( + """ + | select + | case when id < 4096 then + | parse_json('{"a": ' || id || ', "b": "' || id || '"}') + | else null + | end as v + | from range(0, 10000, 1, 1) + | where cast(id * id as string) = '123' + |""".stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + val expected = DataType.fromDDL("variant") + checkFileSchema(expected, dir) + checkStringAndSchema(dir, df) + } + + testWithTempDir("infer a simple schema when data is null") { dir => + // The second case only starts at row 4096, so we won't use it in the shredding schema. + val df = spark.sql( + """ + | select + | case when floor(id / 4096) = 0 then null + | else parse_json('{"a": ' || id || ', "b": "' || id || '"}') + | end as v + | from range(0, 20000, 1, 1) + |""".stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + checkFileSchema(VariantType, dir) + checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect()) + } + + testWithTempDir("infer a schema when data is mostly null") { dir => + // Even if there is only one non-null row, use it to infer a schema. + val df = spark.sql( + """ + | select + | case when id % 4096 != 123 then null + | else parse_json('{"a": ' || id || ', "b": "' || id || '"}') + | end as v + | from range(0, 20000, 1, 1) + |""".stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + val expected = DataType.fromDDL("struct") + checkFileSchema(expected, dir) + checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect()) + } + + testWithTempDir("infer a schema when there is one row") { dir => + // The second case only starts at row 4096, so we won't use it in the shredding schema. + val df = spark.sql( + """ select parse_json('{"a": ' || id || ', "b": "' || id || '"}') as v + | from range(0, 1, 1, 1) + |""".stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + val expected = DataType.fromDDL("struct") + checkFileSchema(expected, dir) + checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect()) + } + + testWithTempDir("Nested variant values") { dir => + val df = spark.sql( + """ + | select + | struct( + | struct( + | id, + | parse_json('{"a": ' || id || ', "b": "' || id || '"}') as v, + | array(parse_json("1"), parse_json("2")) as a + | ) as s1, + | array(struct(parse_json('{"a": 1, "b": 2}') as v2)) as a1 + | ) as s + | from range(0, 3, 1, 1) + |""".stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + // The struct field that is not in an array should be shredded. + val variantSchema = DataType.fromDDL("struct") + val shreddedSchema = SparkShreddingUtils.variantShreddingSchema(variantSchema) + // Unshredded variant prints "value" before "metadata". According to the current wording + // of the spec "The Parquet columns used to store variant metadata and values must be accessed + // by name, not by position", so this should be okay, but should we do anything to be + // consistent between the shredded and unshredded versions? + val unshreddedSchema = StructType(Seq( + StructField("value", BinaryType, nullable = false), + StructField("metadata", BinaryType, nullable = false))) + // Only the nested struct field should be shredded, none of the fields that are in an array. + val expected = StructType(Seq(StructField("s", StructType(Seq( + StructField("s1", StructType(Seq( + StructField("id", LongType, nullable = false), + StructField("v", shreddedSchema, nullable = false), + StructField("a", ArrayType(unshreddedSchema, containsNull = false), nullable = false))), + nullable = false), + StructField("a1", ArrayType(StructType(Seq( + StructField("v2", unshreddedSchema, nullable = false))), containsNull = false), + nullable = false))), + nullable = false))) + val actual = getFileSchema(dir) + assert(actual == expected) + // I think the binary should be identical, but right now the reader doesn't support reading + // a full struct that contains a shredded Variant, so we need to check the individual fields. + // TODO(cashmand) re-enable once we have support. + // checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect()) + checkAnswer( + spark.read.parquet(dir.getAbsolutePath).selectExpr( + "s.s1.id", "s.s1.v", "s.s1.a", "s.a1" + ), + df.selectExpr( + "s.s1.id", "s.s1.v", "s.s1.a", "s.a1" + ).collect()) + } + + test("infer shredding with mixed scale") { + withTempDir { dir => + // a: Test combining large positive/negative integers with decimal to produce a decimal + // that doesn't fit in precision 18. + // b: Mix of numeric and string, won't shred. + // c: Similar to a, different ordering of decimal and integer. + // d: Test that Long.MinValue is handled correctly, merges appropriately with decimals. + // e: Test that an integer (127) that fits in int8_t but not Decimal(2, 0) is correctly + // merged with an integer that is represented as a Decimal(18, 0) + val df = spark.sql( + s""" + | select + | case when id % 3 = 0 then + | parse_json('{"a": -123456789012345, "b": ' || id || ', "c": 0.1, "d": -123, + | "e": 127}') + | when id % 3 = 1 then + | parse_json('{"a": 0.03, "b": 1.23, "c": -0.123456789012345678, "d": 0.1234567890, + | "e": 123456789012345678}') + | when id % 3 = 2 then + | parse_json('{"a": -1.10000, "b": "' || id || '", "c": 0.12, + | "d": ${Long.MinValue}, "e": 0.123}') + | end as v + | from range(0, 3, 1, 1) + |""".stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + // We don't narrow in the presence of trailing zeros, so scale of a should be 5. We + // always use a scale of at least 18, to leave room for larger values. + val expected = DataType.fromDDL( + "struct") + + checkFileSchema(expected, dir) + // We can't call checkStringAndSchema, because the schema_of_variant doesn't match: the + // first "a" value is reported as a Decimal(1, 0) after shredding. I think the right way + // to fix this is to change schema_of_variant to return BIGINT for Decimal(N, 0) when N is + // <= 18. There are a number of related behaviour changes that we probably need for + // schema_of_variant. + checkAnswer( + spark.read.parquet(dir.getAbsolutePath).selectExpr("v::string"), + df.selectExpr("v::string").collect()) + + // Check that the values were actually shredded into typed_value. + val fullSchema = StructType(Seq(StructField("v", + SparkShreddingUtils.variantShreddingSchema(expected)))) + val shreddedDf = spark.read.schema(fullSchema).parquet(dir.getAbsolutePath) + checkAnswer(shreddedDf.selectExpr("v.typed_value.a.value"), + Seq(Row(null), Row(null), Row(null))) + checkAnswer(shreddedDf.selectExpr("v.typed_value.c.value"), + Seq(Row(null), Row(null), Row(null))) + checkAnswer(shreddedDf.selectExpr("v.typed_value.d.value"), + Seq(Row(null), Row(null), Row(null))) + checkAnswer(shreddedDf.selectExpr("v.typed_value.e.value"), + Seq(Row(null), Row(null), Row(null))) + checkAnswer(shreddedDf.selectExpr("v.typed_value.a.typed_value"), + Seq( + Row(BigDecimal("-123456789012345")), + Row(BigDecimal("0.03000")), + Row(BigDecimal("-1.10000")))) + checkAnswer(shreddedDf.selectExpr("v.typed_value.c.typed_value"), + Seq( + Row(BigDecimal("0.100000000000000000")), + Row(BigDecimal("-0.123456789012345678")), + Row(BigDecimal("0.120000000000000000")))) + checkAnswer(shreddedDf.selectExpr("v.typed_value.d.typed_value"), + Seq( + Row(BigDecimal("-123.0000000000")), + Row(BigDecimal("0.1234567890")), + Row(BigDecimal("-9223372036854775808.0000000000")))) + checkAnswer(shreddedDf.selectExpr("v.typed_value.e.typed_value"), + Seq( + Row(BigDecimal("127.000")), + Row(BigDecimal("123456789012345678.000")), + Row(BigDecimal("0.123")))) + } + } + + // Test with a few values of maxRecordsPerFile. It is the other situation besides partitioning + // where we write multiple files within a task. 0 means no limit. + Seq((0, 100), + (0, 50000), + (23, 200), + (9950, 50000)).foreach { case (maxRecordsPerFile, numRows) => + Seq(false, true).foreach { useSort => + val sortStr = if (useSort) "sorted" else "clustered" + testWithTempDir( + s"infer shredding with partitions: $numRows $sortStr rows, " + + s"$maxRecordsPerFile per file") { dir => + withSQLConf(SQLConf.MAX_RECORDS_PER_FILE.key -> maxRecordsPerFile.toString) { + val sortClause = if (useSort) "sort by p, v:a::string" else "" + val df = spark.sql( + s""" + | select + | id % 5 as p, + | case + | when id % 5 = 0 then parse_json('{"a": 1, "b": "hello"}') + | when id % 5 = 1 then parse_json('{"a": 1.2, "b": "world"}') + | else parse_json('{"a": "not a number", "b": "goodbye"}') + | end as v + | from range(0, $numRows, 1, 1) + | $sortClause + |""".stripMargin) + df.write.mode("overwrite").partitionBy("p").parquet(dir.getAbsolutePath) + // Depending on the data, Spark may be able to infer that the top-level column is + // non-nullable, so accept either one. + val possibleSchemas = Seq( + "struct", + "struct", + "struct") + .map(DataType.fromDDL) + .map(SparkShreddingUtils.variantShreddingSchema(_)) + .map(shreddedType => StructType(Seq(StructField("v", shreddedType, nullable = false)))) + // Each partition is stored in a sub-directory + dir.listFiles().filter(_.isDirectory).foreach { subdir => + // We compute a new shredding schema for every partition. Check that each schema we see + // is in the list of possibliities. + val footers = getFooters(subdir) + footers.foreach { footer => + val actual = new ParquetToSparkSchemaConverter() + .convert(footer.getParquetMetadata.getFileMetaData.getSchema) + assert(possibleSchemas.contains(actual)) + } + } + checkAnswer(spark.read.parquet(dir.getAbsolutePath).selectExpr("p", "v"), df.collect()) + } + } + } + } + + // Spark hits JSON parsing limits at depth 1000. Ensure that we can shred until fairly close + // to that limit. If we tried to shred the full schema, we'd hit other Spark limits on schema + // size, but shredding inference should limit the depth at which we shred. + // At depth 40, we should shred the full schema. + Seq(40, 900).foreach { depth => + test(s"infer shredding on deep schemas - depth=$depth ") { + withTempDir { dir => + val deepArray = "[" * depth + "1" + "]" * depth + val deepStruct = """{"a": """ * depth + "1" + "}" * depth + val deepMixed = """[{"a": """ * (depth / 2) + "1" + "}]" * (depth / 2) + val df = spark.sql( + s"""select parse_json('$deepArray') as a, + parse_json('$deepStruct') as s, + parse_json('$deepMixed') as m + from range(0, 100, 1, 1) """) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + val footers = getFooters(dir) + assert(footers.size == 1) + // Checking the exact schema is a hassle, so just check that we get the expected number of + // column chunks. + val cols = footers(0).getParquetMetadata.getFileMetaData.getSchema.getColumns.size() + if (depth == 40) { + // There are 40 value columns associated with each array/struct level, and a value and + // typed_value column at the leaf, plus metadata, for a total of 43 * 3 columns. + assert(cols == 129) + } else { + // Each level produces a `value` column, and the max depth is 50, so about 100 across both + // fields. Don't be too picky about exactly how deep the inferred schema is; the test is + // mainly meant to ensure correctness/stability. + assert(cols > 150 && cols < 160) + } + checkStringAndSchema(dir, df, field = "a") + checkStringAndSchema(dir, df, field = "s") + checkStringAndSchema(dir, df, field = "m") + } + } + } + + testWithTempDir("non-json types") { dir => + // Ensure that we infer a correct schema for types that do not appear in JSON. + val df = spark.sql( + """ + | -- Note: field names must be alphabetically ordered, or binary details will differ and + | -- cause `checkAnswer` to fail. + | select + | to_variant_object(struct( + | cast('abc' as binary) as _bin, + | cast(id as boolean) as _bool, + | date_from_unix_date(id) as _date, + | cast(id as float) as _float, + | cast(id as timestamp) as _time, + | cast(cast(id as timestamp) as timestamp_ntz) as _time_ntz + | )) + | as v + | from range(0, 3, 1, 1) + |""".stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + val expected = DataType.fromDDL( + "struct<_bin binary, _bool boolean, _date date, _float float, _time timestamp, " + + "_time_ntz timestamp_ntz>") + checkFileSchema(expected, dir) + checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect()) + } + + testWithTempDir("uuid") { dir => + // There's no way to generate a UUID Variant value in Spark, so we need to do it manually. + val numRows = 10 + val rdd = spark.sparkContext.parallelize[InternalRow](Nil, numSlices = 1).mapPartitions { _ => + val uuid = java.util.UUID.fromString("01020304-0506-0708-090a-0b0c0d0e0f10") + val builder = new VariantBuilder(false) + val start = builder.getWritePos + val fields = new java.util.ArrayList[VariantBuilder.FieldEntry](2) + val a_id = builder.addKey("a") + fields.add(new VariantBuilder.FieldEntry("a", a_id, builder.getWritePos - start)) + builder.appendUuid(uuid) + // Add an extra field to make Variant reconstruction a bit more interesting. + val b_id = builder.addKey("b") + fields.add(new VariantBuilder.FieldEntry("b", b_id, builder.getWritePos - start)) + builder.appendLong(123) + builder.finishWritingObject(start, fields) + val v = builder.result() + Iterator.tabulate(numRows) { _ => + InternalRow(new VariantVal(v.getValue, v.getMetadata)) + } + } + // Ensure that we infer a correct schema for types that do not appear in JSON. + val writeSchema = DataType.fromDDL("struct").asInstanceOf[StructType] + val df = Dataset.ofRows(spark, LogicalRDD(DataTypeUtils.toAttributes(writeSchema), rdd)(spark)) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + // The field should not be shredded. + val expected = DataType.fromDDL("struct") + checkFileSchema(expected, dir) + checkAnswer(spark.read.parquet(dir.getAbsolutePath), df.collect()) + } +}