Skip to content

Commit 399f962

Browse files
committed
[SPARK-52444][SQL][Connect] Add support for Variant/Char/Varchar Literal
1 parent 420ac24 commit 399f962

File tree

13 files changed

+339
-71
lines changed

13 files changed

+339
-71
lines changed

python/pyspark/sql/connect/proto/expressions_pb2.py

Lines changed: 69 additions & 63 deletions
Large diffs are not rendered by default.

python/pyspark/sql/connect/proto/expressions_pb2.pyi

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,78 @@ class Expression(google.protobuf.message.Message):
654654
| None
655655
): ...
656656

657+
class Variant(google.protobuf.message.Message):
658+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
659+
660+
VALUE_FIELD_NUMBER: builtins.int
661+
METADATA_FIELD_NUMBER: builtins.int
662+
value: builtins.bytes
663+
metadata: builtins.bytes
664+
def __init__(
665+
self,
666+
*,
667+
value: builtins.bytes = ...,
668+
metadata: builtins.bytes = ...,
669+
) -> None: ...
670+
def ClearField(
671+
self,
672+
field_name: typing_extensions.Literal["metadata", b"metadata", "value", b"value"],
673+
) -> None: ...
674+
675+
class Char(google.protobuf.message.Message):
676+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
677+
678+
VALUE_FIELD_NUMBER: builtins.int
679+
LENGTH_FIELD_NUMBER: builtins.int
680+
value: builtins.str
681+
length: builtins.int
682+
def __init__(
683+
self,
684+
*,
685+
value: builtins.str = ...,
686+
length: builtins.int | None = ...,
687+
) -> None: ...
688+
def HasField(
689+
self,
690+
field_name: typing_extensions.Literal["_length", b"_length", "length", b"length"],
691+
) -> builtins.bool: ...
692+
def ClearField(
693+
self,
694+
field_name: typing_extensions.Literal[
695+
"_length", b"_length", "length", b"length", "value", b"value"
696+
],
697+
) -> None: ...
698+
def WhichOneof(
699+
self, oneof_group: typing_extensions.Literal["_length", b"_length"]
700+
) -> typing_extensions.Literal["length"] | None: ...
701+
702+
class Varchar(google.protobuf.message.Message):
703+
DESCRIPTOR: google.protobuf.descriptor.Descriptor
704+
705+
VALUE_FIELD_NUMBER: builtins.int
706+
LENGTH_FIELD_NUMBER: builtins.int
707+
value: builtins.str
708+
length: builtins.int
709+
def __init__(
710+
self,
711+
*,
712+
value: builtins.str = ...,
713+
length: builtins.int | None = ...,
714+
) -> None: ...
715+
def HasField(
716+
self,
717+
field_name: typing_extensions.Literal["_length", b"_length", "length", b"length"],
718+
) -> builtins.bool: ...
719+
def ClearField(
720+
self,
721+
field_name: typing_extensions.Literal[
722+
"_length", b"_length", "length", b"length", "value", b"value"
723+
],
724+
) -> None: ...
725+
def WhichOneof(
726+
self, oneof_group: typing_extensions.Literal["_length", b"_length"]
727+
) -> typing_extensions.Literal["length"] | None: ...
728+
657729
NULL_FIELD_NUMBER: builtins.int
658730
BINARY_FIELD_NUMBER: builtins.int
659731
BOOLEAN_FIELD_NUMBER: builtins.int
@@ -675,6 +747,9 @@ class Expression(google.protobuf.message.Message):
675747
MAP_FIELD_NUMBER: builtins.int
676748
STRUCT_FIELD_NUMBER: builtins.int
677749
SPECIALIZED_ARRAY_FIELD_NUMBER: builtins.int
750+
VARIANT_FIELD_NUMBER: builtins.int
751+
CHAR_FIELD_NUMBER: builtins.int
752+
VARCHAR_FIELD_NUMBER: builtins.int
678753
@property
679754
def null(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
680755
binary: builtins.bytes
@@ -706,6 +781,12 @@ class Expression(google.protobuf.message.Message):
706781
def struct(self) -> global___Expression.Literal.Struct: ...
707782
@property
708783
def specialized_array(self) -> global___Expression.Literal.SpecializedArray: ...
784+
@property
785+
def variant(self) -> global___Expression.Literal.Variant: ...
786+
@property
787+
def char(self) -> global___Expression.Literal.Char: ...
788+
@property
789+
def varchar(self) -> global___Expression.Literal.Varchar: ...
709790
def __init__(
710791
self,
711792
*,
@@ -730,6 +811,9 @@ class Expression(google.protobuf.message.Message):
730811
map: global___Expression.Literal.Map | None = ...,
731812
struct: global___Expression.Literal.Struct | None = ...,
732813
specialized_array: global___Expression.Literal.SpecializedArray | None = ...,
814+
variant: global___Expression.Literal.Variant | None = ...,
815+
char: global___Expression.Literal.Char | None = ...,
816+
varchar: global___Expression.Literal.Varchar | None = ...,
733817
) -> None: ...
734818
def HasField(
735819
self,
@@ -744,6 +828,8 @@ class Expression(google.protobuf.message.Message):
744828
b"byte",
745829
"calendar_interval",
746830
b"calendar_interval",
831+
"char",
832+
b"char",
747833
"date",
748834
b"date",
749835
"day_time_interval",
@@ -776,6 +862,10 @@ class Expression(google.protobuf.message.Message):
776862
b"timestamp",
777863
"timestamp_ntz",
778864
b"timestamp_ntz",
865+
"varchar",
866+
b"varchar",
867+
"variant",
868+
b"variant",
779869
"year_month_interval",
780870
b"year_month_interval",
781871
],
@@ -793,6 +883,8 @@ class Expression(google.protobuf.message.Message):
793883
b"byte",
794884
"calendar_interval",
795885
b"calendar_interval",
886+
"char",
887+
b"char",
796888
"date",
797889
b"date",
798890
"day_time_interval",
@@ -825,6 +917,10 @@ class Expression(google.protobuf.message.Message):
825917
b"timestamp",
826918
"timestamp_ntz",
827919
b"timestamp_ntz",
920+
"varchar",
921+
b"varchar",
922+
"variant",
923+
b"variant",
828924
"year_month_interval",
829925
b"year_month_interval",
830926
],
@@ -854,6 +950,9 @@ class Expression(google.protobuf.message.Message):
854950
"map",
855951
"struct",
856952
"specialized_array",
953+
"variant",
954+
"char",
955+
"varchar",
857956
]
858957
| None
859958
): ...

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ import org.apache.spark.sql.expressions.Window
4646
import org.apache.spark.sql.functions.lit
4747
import org.apache.spark.sql.protobuf.{functions => pbFn}
4848
import org.apache.spark.sql.types._
49-
import org.apache.spark.unsafe.types.CalendarInterval
49+
import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
5050
import org.apache.spark.util.SparkFileUtils
5151

5252
// scalastyle:off
@@ -3319,7 +3319,9 @@ class PlanGenerationTestSuite
33193319
fn.lit(java.sql.Date.valueOf("2023-02-23")),
33203320
fn.lit(java.time.Duration.ofSeconds(200L)),
33213321
fn.lit(java.time.Period.ofDays(100)),
3322-
fn.lit(new CalendarInterval(2, 20, 100L)))
3322+
fn.lit(new CalendarInterval(2, 20, 100L)),
3323+
fn.lit(new VariantVal(Array[Byte](1), Array[Byte](1)))
3324+
)
33233325
}
33243326

33253327
test("function lit array") {
@@ -3390,6 +3392,7 @@ class PlanGenerationTestSuite
33903392
fn.typedLit(java.time.Duration.ofSeconds(200L)),
33913393
fn.typedLit(java.time.Period.ofDays(100)),
33923394
fn.typedLit(new CalendarInterval(2, 20, 100L)),
3395+
fn.typedLit(new VariantVal(Array[Byte](1), Array[Byte](1))),
33933396

33943397
// Handle parameterized scala types e.g.: List, Seq and Map.
33953398
fn.typedLit(Some(1)),

sql/connect/common/src/main/protobuf/spark/connect/expressions.proto

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ message Expression {
195195
Struct struct = 24;
196196

197197
SpecializedArray specialized_array = 25;
198+
Variant variant = 26;
199+
Char char = 27;
200+
Varchar varchar = 28;
198201
}
199202

200203
message Decimal {
@@ -240,6 +243,21 @@ message Expression {
240243
Strings strings = 6;
241244
}
242245
}
246+
247+
message Variant {
248+
bytes value = 1;
249+
bytes metadata = 2;
250+
}
251+
252+
message Char {
253+
string value = 1;
254+
optional int32 length = 2;
255+
}
256+
257+
message Varchar {
258+
string value = 1;
259+
optional int32 length = 2;
260+
}
243261
}
244262

245263
// An unresolved attribute that is not explicitly bound to a specific column, but the column

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
3535
import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils}
3636
import org.apache.spark.sql.connect.common.DataTypeProtoConverter._
3737
import org.apache.spark.sql.types._
38-
import org.apache.spark.unsafe.types.CalendarInterval
38+
import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}
3939
import org.apache.spark.util.SparkClassUtils
4040

4141
object LiteralValueProtoConverter {
@@ -99,6 +99,10 @@ object LiteralValueProtoConverter {
9999
case v: Array[_] => builder.setArray(arrayBuilder(v))
100100
case v: CalendarInterval =>
101101
builder.setCalendarInterval(calendarIntervalBuilder(v.months, v.days, v.microseconds))
102+
case v: VariantVal =>
103+
builder.setVariant(builder.getVariantBuilder
104+
.setValue(ByteString.copyFrom(v.getValue))
105+
.setMetadata(ByteString.copyFrom(v.getMetadata)))
102106
case null => builder.setNull(ProtoDataTypes.NullType)
103107
case _ => throw new UnsupportedOperationException(s"literal $literal not supported (yet).")
104108
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Project [id#0L, id#0L, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, 2023-02-23 AS DATE '2023-02-23'#0, INTERVAL '0 00:03:20' DAY TO SECOND AS INTERVAL '0 00:03:20' DAY TO SECOND#0, ... 2 more fields]
1+
Project [id#0L, id#0L, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, 2023-02-23 AS DATE '2023-02-23'#0, INTERVAL '0 00:03:20' DAY TO SECOND AS INTERVAL '0 00:03:20' DAY TO SECOND#0, ... 3 more fields]
22
+- LocalRelation <empty>, [id#0L, a#0, b#0]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, ... 18 more fields]
1+
Project [id#0L, id#0L, 1 AS 1#0, null AS NULL#0, true AS true#0, 68 AS 68#0, 9872 AS 9872#0, -8726532 AS -8726532#0, 7834609328726532 AS 7834609328726532#0L, 2.718281828459045 AS 2.718281828459045#0, -0.8 AS -0.8#0, 89.97620 AS 89.97620#0, 89889.7667231 AS 89889.7667231#0, connect! AS connect!#0, T AS T#0, ABCDEFGHIJ AS ABCDEFGHIJ#0, 0x78797A7B7C7D7E7F808182838485868788898A8B8C8D8E AS X'78797A7B7C7D7E7F808182838485868788898A8B8C8D8E'#0, 0x0806 AS X'0806'#0, [8,6] AS ARRAY(8, 6)#0, null AS NULL#0, 2020-10-10 AS DATE '2020-10-10'#0, 8.997620 AS 8.997620#0, 2023-02-23 04:31:59.808 AS TIMESTAMP '2023-02-23 04:31:59.808'#0, 1969-12-31 16:00:12.345 AS TIMESTAMP '1969-12-31 16:00:12.345'#0, 2023-02-23 20:36:00 AS TIMESTAMP_NTZ '2023-02-23 20:36:00'#0, ... 19 more fields]
22
+- LocalRelation <empty>, [id#0L, a#0, b#0]

sql/connect/common/src/test/resources/query-tests/queries/function_lit.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,30 @@
607607
}
608608
}
609609
}
610+
}, {
611+
"literal": {
612+
"variant": {
613+
"value": "AQ==",
614+
"metadata": "AQ=="
615+
}
616+
},
617+
"common": {
618+
"origin": {
619+
"jvmOrigin": {
620+
"stackTrace": [{
621+
"classLoaderName": "app",
622+
"declaringClass": "org.apache.spark.sql.functions$",
623+
"methodName": "lit",
624+
"fileName": "functions.scala"
625+
}, {
626+
"classLoaderName": "app",
627+
"declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
628+
"methodName": "~~trimmed~anonfun~~",
629+
"fileName": "PlanGenerationTestSuite.scala"
630+
}]
631+
}
632+
}
633+
}
610634
}]
611635
}
612636
}
Binary file not shown.

sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,30 @@
652652
}
653653
}
654654
}
655+
}, {
656+
"literal": {
657+
"variant": {
658+
"value": "AQ==",
659+
"metadata": "AQ=="
660+
}
661+
},
662+
"common": {
663+
"origin": {
664+
"jvmOrigin": {
665+
"stackTrace": [{
666+
"classLoaderName": "app",
667+
"declaringClass": "org.apache.spark.sql.functions$",
668+
"methodName": "typedLit",
669+
"fileName": "functions.scala"
670+
}, {
671+
"classLoaderName": "app",
672+
"declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite",
673+
"methodName": "~~trimmed~anonfun~~",
674+
"fileName": "PlanGenerationTestSuite.scala"
675+
}]
676+
}
677+
}
678+
}
655679
}, {
656680
"literal": {
657681
"integer": 1

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.connect.proto
2121
import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters}
2222
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, LiteralValueProtoConverter}
2323
import org.apache.spark.sql.types._
24-
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
24+
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}
2525

2626
object LiteralExpressionProtoConverter {
2727

@@ -75,6 +75,24 @@ object LiteralExpressionProtoConverter {
7575
case proto.Expression.Literal.LiteralTypeCase.STRING =>
7676
expressions.Literal(UTF8String.fromString(lit.getString), StringType)
7777

78+
case proto.Expression.Literal.LiteralTypeCase.CHAR =>
79+
var length = lit.getChar.getValue.length
80+
if (lit.getChar.hasLength) {
81+
length = lit.getChar.getLength
82+
}
83+
expressions.Literal(UTF8String.fromString(lit.getChar.getValue), CharType(length))
84+
85+
case proto.Expression.Literal.LiteralTypeCase.VARCHAR =>
86+
var length = lit.getVarchar.getValue.length
87+
if (lit.getVarchar.hasLength) {
88+
length = lit.getVarchar.getLength
89+
}
90+
expressions.Literal(UTF8String.fromString(lit.getVarchar.getValue), VarcharType(length))
91+
92+
case proto.Expression.Literal.LiteralTypeCase.VARIANT =>
93+
expressions.Literal(new VariantVal(
94+
lit.getVariant.getValue.toByteArray, lit.getVariant.getMetadata.toByteArray), VariantType)
95+
7896
case proto.Expression.Literal.LiteralTypeCase.DATE =>
7997
expressions.Literal(lit.getDate, DateType)
8098

0 commit comments

Comments
 (0)