Skip to content

Commit 19226b1

Browse files
committed
upmerge
1 parent 118e6ad commit 19226b1

File tree

2 files changed

+102
-84
lines changed

2 files changed

+102
-84
lines changed

native/core/src/execution/planner.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2828,6 +2828,7 @@ fn literal_to_array_ref(
28282828

28292829
// Build offsets and collect non-null child arrays
28302830
let mut offsets = Vec::with_capacity(list_literal.list_values.len() + 1);
2831+
// Offsets starts with 0
28312832
offsets.push(0i32);
28322833
let mut child_arrays: Vec<ArrayRef> = Vec::new();
28332834

spark/src/main/scala/org/apache/comet/serde/literals.scala

Lines changed: 101 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
package org.apache.comet.serde.literals
2121

22+
import java.lang
23+
2224
import org.apache.spark.internal.Logging
2325
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
2426
import org.apache.spark.sql.catalyst.util.GenericArrayData
@@ -27,7 +29,6 @@ import org.apache.spark.unsafe.types.UTF8String
2729

2830
import com.google.protobuf.ByteString
2931

30-
import org.apache.comet.CometConf
3132
import org.apache.comet.CometSparkSessionExtensions.withInfo
3233
import org.apache.comet.DataTypeSupport.isComplexType
3334
import org.apache.comet.serde.{CometExpressionSerde, Compatible, ExprOuterClass, LiteralOuterClass, SupportLevel, Unsupported}
@@ -41,13 +42,15 @@ object CometLiteral extends CometExpressionSerde[Literal] with Logging {
4142
if (supportedDataType(
4243
expr.dataType,
4344
allowComplex = expr.value == null ||
45+
4446
// Nested literal support for native reader
4547
// can be tracked https://github.com/apache/datafusion-comet/issues/1937
46-
// now supports only Array of primitive
47-
(Seq(CometConf.SCAN_NATIVE_ICEBERG_COMPAT, CometConf.SCAN_NATIVE_DATAFUSION)
48-
.contains(CometConf.COMET_NATIVE_SCAN_IMPL.get()) && expr.dataType
49-
.isInstanceOf[ArrayType]) && !isComplexType(
50-
expr.dataType.asInstanceOf[ArrayType].elementType))) {
48+
(expr.dataType
49+
.isInstanceOf[ArrayType] && (!isComplexType(
50+
expr.dataType.asInstanceOf[ArrayType].elementType) || expr.dataType
51+
.asInstanceOf[ArrayType]
52+
.elementType
53+
.isInstanceOf[ArrayType])))) {
5154
Compatible(None)
5255
} else {
5356
Unsupported(Some(s"Unsupported data type ${expr.dataType}"))
@@ -86,84 +89,10 @@ object CometLiteral extends CometExpressionSerde[Literal] with Logging {
8689
val byteStr =
8790
com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]])
8891
exprBuilder.setBytesVal(byteStr)
89-
case a: ArrayType =>
90-
val listLiteralBuilder = ListLiteral.newBuilder()
91-
val array = value.asInstanceOf[GenericArrayData].array
92-
a.elementType match {
93-
case NullType =>
94-
array.foreach(_ => listLiteralBuilder.addNullMask(true))
95-
case BooleanType =>
96-
array.foreach(v => {
97-
val casted = v.asInstanceOf[java.lang.Boolean]
98-
listLiteralBuilder.addBooleanValues(casted)
99-
listLiteralBuilder.addNullMask(casted != null)
100-
})
101-
case ByteType =>
102-
array.foreach(v => {
103-
val casted = v.asInstanceOf[java.lang.Integer]
104-
listLiteralBuilder.addByteValues(casted)
105-
listLiteralBuilder.addNullMask(casted != null)
106-
})
107-
case ShortType =>
108-
array.foreach(v => {
109-
val casted = v.asInstanceOf[java.lang.Short]
110-
listLiteralBuilder.addShortValues(
111-
if (casted != null) casted.intValue()
112-
else null.asInstanceOf[java.lang.Integer])
113-
listLiteralBuilder.addNullMask(casted != null)
114-
})
115-
case IntegerType | DateType =>
116-
array.foreach(v => {
117-
val casted = v.asInstanceOf[java.lang.Integer]
118-
listLiteralBuilder.addIntValues(casted)
119-
listLiteralBuilder.addNullMask(casted != null)
120-
})
121-
case LongType | TimestampType | TimestampNTZType =>
122-
array.foreach(v => {
123-
val casted = v.asInstanceOf[java.lang.Long]
124-
listLiteralBuilder.addLongValues(casted)
125-
listLiteralBuilder.addNullMask(casted != null)
126-
})
127-
case FloatType =>
128-
array.foreach(v => {
129-
val casted = v.asInstanceOf[java.lang.Float]
130-
listLiteralBuilder.addFloatValues(casted)
131-
listLiteralBuilder.addNullMask(casted != null)
132-
})
133-
case DoubleType =>
134-
array.foreach(v => {
135-
val casted = v.asInstanceOf[java.lang.Double]
136-
listLiteralBuilder.addDoubleValues(casted)
137-
listLiteralBuilder.addNullMask(casted != null)
138-
})
139-
case StringType =>
140-
array.foreach(v => {
141-
val casted = v.asInstanceOf[org.apache.spark.unsafe.types.UTF8String]
142-
listLiteralBuilder.addStringValues(if (casted != null) casted.toString else "")
143-
listLiteralBuilder.addNullMask(casted != null)
144-
})
145-
case _: DecimalType =>
146-
array
147-
.foreach(v => {
148-
val casted =
149-
v.asInstanceOf[Decimal]
150-
listLiteralBuilder.addDecimalValues(if (casted != null) {
151-
com.google.protobuf.ByteString
152-
.copyFrom(casted.toBigDecimal.underlying.unscaledValue.toByteArray)
153-
} else ByteString.EMPTY)
154-
listLiteralBuilder.addNullMask(casted != null)
155-
})
156-
case _: BinaryType =>
157-
array
158-
.foreach(v => {
159-
val casted =
160-
v.asInstanceOf[Array[Byte]]
161-
listLiteralBuilder.addBytesValues(if (casted != null) {
162-
com.google.protobuf.ByteString.copyFrom(casted)
163-
} else ByteString.EMPTY)
164-
listLiteralBuilder.addNullMask(casted != null)
165-
})
166-
}
92+
93+
case arr: ArrayType =>
94+
val listLiteralBuilder: ListLiteral.Builder =
95+
makeListLiteral(value.asInstanceOf[GenericArrayData].array, arr)
16796
exprBuilder.setListVal(listLiteralBuilder.build())
16897
exprBuilder.setDatatype(serializeDataType(dataType).get)
16998
case dt =>
@@ -188,4 +117,92 @@ object CometLiteral extends CometExpressionSerde[Literal] with Logging {
188117
}
189118

190119
}
120+
121+
private def makeListLiteral(array: Array[Any], arrayType: ArrayType): ListLiteral.Builder = {
122+
val listLiteralBuilder = ListLiteral.newBuilder()
123+
arrayType.elementType match {
124+
case NullType =>
125+
array.foreach(_ => listLiteralBuilder.addNullMask(true))
126+
case BooleanType =>
127+
array.foreach(v => {
128+
val casted = v.asInstanceOf[lang.Boolean]
129+
listLiteralBuilder.addBooleanValues(casted)
130+
listLiteralBuilder.addNullMask(casted != null)
131+
})
132+
case ByteType =>
133+
array.foreach(v => {
134+
val casted = v.asInstanceOf[Integer]
135+
listLiteralBuilder.addByteValues(casted)
136+
listLiteralBuilder.addNullMask(casted != null)
137+
})
138+
case ShortType =>
139+
array.foreach(v => {
140+
val casted = v.asInstanceOf[lang.Short]
141+
listLiteralBuilder.addShortValues(
142+
if (casted != null) casted.intValue()
143+
else null.asInstanceOf[Integer])
144+
listLiteralBuilder.addNullMask(casted != null)
145+
})
146+
case IntegerType | DateType =>
147+
array.foreach(v => {
148+
val casted = v.asInstanceOf[Integer]
149+
listLiteralBuilder.addIntValues(casted)
150+
listLiteralBuilder.addNullMask(casted != null)
151+
})
152+
case LongType | TimestampType | TimestampNTZType =>
153+
array.foreach(v => {
154+
val casted = v.asInstanceOf[lang.Long]
155+
listLiteralBuilder.addLongValues(casted)
156+
listLiteralBuilder.addNullMask(casted != null)
157+
})
158+
case FloatType =>
159+
array.foreach(v => {
160+
val casted = v.asInstanceOf[lang.Float]
161+
listLiteralBuilder.addFloatValues(casted)
162+
listLiteralBuilder.addNullMask(casted != null)
163+
})
164+
case DoubleType =>
165+
array.foreach(v => {
166+
val casted = v.asInstanceOf[lang.Double]
167+
listLiteralBuilder.addDoubleValues(casted)
168+
listLiteralBuilder.addNullMask(casted != null)
169+
})
170+
case StringType =>
171+
array.foreach(v => {
172+
val casted = v.asInstanceOf[UTF8String]
173+
listLiteralBuilder.addStringValues(if (casted != null) casted.toString else "")
174+
listLiteralBuilder.addNullMask(casted != null)
175+
})
176+
case _: DecimalType =>
177+
array
178+
.foreach(v => {
179+
val casted =
180+
v.asInstanceOf[Decimal]
181+
listLiteralBuilder.addDecimalValues(if (casted != null) {
182+
com.google.protobuf.ByteString
183+
.copyFrom(casted.toBigDecimal.underlying.unscaledValue.toByteArray)
184+
} else ByteString.EMPTY)
185+
listLiteralBuilder.addNullMask(casted != null)
186+
})
187+
case _: BinaryType =>
188+
array
189+
.foreach(v => {
190+
val casted =
191+
v.asInstanceOf[Array[Byte]]
192+
listLiteralBuilder.addBytesValues(if (casted != null) {
193+
com.google.protobuf.ByteString.copyFrom(casted)
194+
} else ByteString.EMPTY)
195+
listLiteralBuilder.addNullMask(casted != null)
196+
})
197+
case a: ArrayType =>
198+
array.foreach(v => {
199+
val casted = v.asInstanceOf[GenericArrayData]
200+
listLiteralBuilder.addListValues(if (casted != null) {
201+
makeListLiteral(casted.array, a)
202+
} else ListLiteral.newBuilder())
203+
listLiteralBuilder.addNullMask(casted != null)
204+
})
205+
}
206+
listLiteralBuilder
207+
}
191208
}

0 commit comments

Comments
 (0)