Skip to content

Commit a691e76

Browse files
stefankandiccloud-fan
authored andcommitted
[SPARK-53130][SQL][PYTHON] Fix toJson behavior of collated string types
### What changes were proposed in this pull request? Changing the behavior of collated string types to return their collation in the `toJson` methods and to still keep backwards compatibility with older engine versions reading tables with collations by propagating this fix upstream in `StructField` where the collation will be removed from the type but still kept in the metadata. ### Why are the changes needed? Old way of handling `toJson` meant that collated string types will not be able to be serialized and deserialized correctly unless they are a part of `StructField`. Initially, we thought that this is not a big deal, but then later we faced some issues regarding this, especially in pyspark which uses json primarily to parse types back and forth. This could avoid hacky changes in future like the one in #51688 without changing any behavior for how tables/schemas work. ### Does this PR introduce _any_ user-facing change? Technically yes, but it is a small change that should not impact any queries, just how StringType is represented when not in a StructField object. ### How was this patch tested? New and existing unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51850 from stefankandic/fixStringJson. Authored-by: Stefan Kandic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit 19ea6ff) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 03c8d50 commit a691e76

File tree

6 files changed

+84
-13
lines changed

6 files changed

+84
-13
lines changed

python/pyspark/sql/tests/test_types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,17 @@ def test_schema_with_collations_json_ser_de(self):
623623
from pyspark.sql.types import _parse_datatype_json_string
624624

625625
unicode_collation = "UNICODE"
626+
utf8_lcase_collation = "UTF8_LCASE"
627+
628+
standalone_string = StringType(unicode_collation)
629+
630+
standalone_array = ArrayType(StringType(unicode_collation))
631+
632+
standalone_map = MapType(StringType(utf8_lcase_collation), StringType(unicode_collation))
633+
634+
standalone_nested = ArrayType(
635+
MapType(StringType(utf8_lcase_collation), ArrayType(StringType(unicode_collation)))
636+
)
626637

627638
simple_struct = StructType([StructField("c1", StringType(unicode_collation))])
628639

@@ -694,6 +705,10 @@ def test_schema_with_collations_json_ser_de(self):
694705
)
695706

696707
schemas = [
708+
standalone_string,
709+
standalone_array,
710+
standalone_map,
711+
standalone_nested,
697712
simple_struct,
698713
nested_struct,
699714
array_in_schema,

python/pyspark/sql/types.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,8 @@ def simpleString(self) -> str:
296296

297297
return f"string collate {self.collation}"
298298

299-
# For backwards compatibility and compatibility with other readers all string types
300-
# are serialized in json as regular strings and the collation info is written to
301-
# struct field metadata
302299
def jsonValue(self) -> str:
303-
return "string"
300+
return self.simpleString()
304301

305302
def __repr__(self) -> str:
306303
return (
@@ -1010,11 +1007,39 @@ def jsonValue(self) -> Dict[str, Any]:
10101007

10111008
return {
10121009
"name": self.name,
1013-
"type": self.dataType.jsonValue(),
1010+
"type": self._dataTypeJsonValue(collationMetadata),
10141011
"nullable": self.nullable,
10151012
"metadata": metadata,
10161013
}
10171014

1015+
def _dataTypeJsonValue(self, collationMetadata: Dict[str, str]) -> Union[str, Dict[str, Any]]:
1016+
if not collationMetadata:
1017+
return self.dataType.jsonValue()
1018+
1019+
def removeCollations(dt: DataType) -> DataType:
1020+
# Only recurse into map and array types as any child struct type
1021+
# will have already been processed.
1022+
if isinstance(dt, ArrayType):
1023+
return ArrayType(removeCollations(dt.elementType), dt.containsNull)
1024+
elif isinstance(dt, MapType):
1025+
return MapType(
1026+
removeCollations(dt.keyType),
1027+
removeCollations(dt.valueType),
1028+
dt.valueContainsNull,
1029+
)
1030+
elif isinstance(dt, StringType):
1031+
return StringType()
1032+
elif isinstance(dt, VarcharType):
1033+
return VarcharType(dt.length)
1034+
elif isinstance(dt, CharType):
1035+
return CharType(dt.length)
1036+
else:
1037+
return dt
1038+
1039+
# As we want to be backwards compatible we should remove all collations information from the
1040+
# json and only keep that information in the metadata.
1041+
return removeCollations(self.dataType).jsonValue()
1042+
10181043
@classmethod
10191044
def fromJson(cls, json: Dict[str, Any]) -> "StructField":
10201045
metadata = json.get("metadata")
@@ -1843,6 +1868,7 @@ def parseJson(cls, json_str: str) -> "VariantVal":
18431868

18441869
_LENGTH_CHAR = re.compile(r"char\(\s*(\d+)\s*\)")
18451870
_LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)")
1871+
_STRING_WITH_COLLATION = re.compile(r"string\s+collate\s+(\w+)")
18461872
_FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")
18471873
_INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?")
18481874
_INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?")
@@ -2003,6 +2029,9 @@ def _parse_datatype_json_value(
20032029
if first_field is not None and second_field is None:
20042030
return YearMonthIntervalType(first_field)
20052031
return YearMonthIntervalType(first_field, second_field)
2032+
elif _STRING_WITH_COLLATION.match(json_value):
2033+
m = _STRING_WITH_COLLATION.match(json_value)
2034+
return StringType(m.group(1)) # type: ignore[union-attr]
20062035
elif _LENGTH_CHAR.match(json_value):
20072036
m = _LENGTH_CHAR.match(json_value)
20082037
return CharType(int(m.group(1))) # type: ignore[union-attr]

sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ object DataType {
126126
private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r
127127
private val CHAR_TYPE = """char\(\s*(\d+)\s*\)""".r
128128
private val VARCHAR_TYPE = """varchar\(\s*(\d+)\s*\)""".r
129+
private val STRING_WITH_COLLATION = """string\s+collate\s+(\w+)""".r
129130

130131
val COLLATIONS_METADATA_KEY = "__COLLATIONS"
131132

@@ -214,6 +215,7 @@ object DataType {
214215
case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
215216
case CHAR_TYPE(length) => CharType(length.toInt)
216217
case VARCHAR_TYPE(length) => VarcharType(length.toInt)
218+
case STRING_WITH_COLLATION(collation) => StringType(collation)
217219
// For backwards compatibility, previously the type name of NullType is "null"
218220
case "null" => NullType
219221
case "timestamp_ltz" => TimestampType

sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql.types
1919

20-
import org.json4s.JsonAST.{JString, JValue}
21-
2220
import org.apache.spark.annotation.Stable
2321
import org.apache.spark.sql.catalyst.util.CollationFactory
2422
import org.apache.spark.sql.internal.SqlApiConf
@@ -90,11 +88,6 @@ class StringType private[sql] (
9088
private[sql] def collationName: String =
9189
CollationFactory.fetchCollation(collationId).collationName
9290

93-
// Due to backwards compatibility and compatibility with other readers
94-
// all string types are serialized in json as regular strings and
95-
// the collation information is written to struct field metadata
96-
override def jsonValue: JValue = JString("string")
97-
9891
override def equals(obj: Any): Boolean = {
9992
obj match {
10093
case s: StringType => s.collationId == collationId && s.constraint == constraint

sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,30 @@ case class StructField(
7070

7171
private[sql] def jsonValue: JValue = {
7272
("name" -> name) ~
73-
("type" -> dataType.jsonValue) ~
73+
("type" -> dataTypeJsonValue) ~
7474
("nullable" -> nullable) ~
7575
("metadata" -> metadataJson)
7676
}
7777

78+
private[sql] def dataTypeJsonValue: JValue = {
79+
if (collationMetadata.isEmpty) return dataType.jsonValue
80+
81+
def removeCollations(dt: DataType): DataType = dt match {
82+
// Only recurse into map and array types as any child struct type
83+
// will have already been processed.
84+
case ArrayType(et, nullable) =>
85+
ArrayType(removeCollations(et), nullable)
86+
case MapType(kt, vt, nullable) =>
87+
MapType(removeCollations(kt), removeCollations(vt), nullable)
88+
case st: StringType => StringHelper.removeCollation(st)
89+
case _ => dt
90+
}
91+
92+
// As we want to be backwards compatible we should remove all collations information from the
93+
// json and only keep that information in the metadata.
94+
removeCollations(dataType).jsonValue
95+
}
96+
7897
private def metadataJson: JValue = {
7998
val metadataJsonValue = metadata.jsonValue
8099
metadataJsonValue match {

sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearM
3030
class DataTypeSuite extends SparkFunSuite {
3131

3232
private val UNICODE_COLLATION_ID = CollationFactory.collationNameToId("UNICODE")
33+
private val UTF8_LCASE_COLLATION_ID = CollationFactory.collationNameToId("UTF8_LCASE")
3334

3435
test("construct an ArrayType") {
3536
val array = ArrayType(StringType)
@@ -1143,6 +1144,17 @@ class DataTypeSuite extends SparkFunSuite {
11431144
}
11441145

11451146
test("schema with collation should not change during ser/de") {
1147+
val standaloneString = StringType(UNICODE_COLLATION_ID)
1148+
1149+
val standaloneArray = ArrayType(StringType(UNICODE_COLLATION_ID))
1150+
1151+
val standaloneMap = MapType(StringType(UNICODE_COLLATION_ID),
1152+
StringType(UTF8_LCASE_COLLATION_ID))
1153+
1154+
val standaloneNested = ArrayType(MapType(
1155+
StringType(UNICODE_COLLATION_ID),
1156+
ArrayType(StringType(UTF8_LCASE_COLLATION_ID))))
1157+
11461158
val simpleStruct = StructType(
11471159
StructField("c1", StringType(UNICODE_COLLATION_ID)) :: Nil)
11481160

@@ -1183,6 +1195,7 @@ class DataTypeSuite extends SparkFunSuite {
11831195
mapWithKeyInNameInSchema ++ arrayInMapInNestedSchema.fields ++ nestedArrayInMap.fields)
11841196

11851197
Seq(
1198+
standaloneString, standaloneArray, standaloneMap, standaloneNested,
11861199
simpleStruct, caseInsensitiveNames, specialCharsInName, nestedStruct, arrayInSchema,
11871200
mapInSchema, mapWithKeyInNameInSchema, nestedArrayInMap, arrayInMapInNestedSchema,
11881201
schemaWithMultipleFields)

0 commit comments

Comments
 (0)