Skip to content

Commit 19ea6ff

Browse files
stefankandiccloud-fan
authored andcommitted
[SPARK-53130][SQL][PYTHON] Fix toJson behavior of collated string types
### What changes were proposed in this pull request? Changing the behavior of collated string types to return their collation in the `toJson` methods and to still keep backwards compatibility with older engine versions reading tables with collations by propagating this fix upstream in `StructField` where the collation will be removed from the type but still kept in the metadata. ### Why are the changes needed? Old way of handling `toJson` meant that collated string types will not be able to be serialized and deserialized correctly unless they are a part of `StructField`. Initially, we thought that this is not a big deal, but then later we faced some issues regarding this, especially in pyspark which uses json primarily to parse types back and forth. This could avoid hacky changes in future like the one in #51688 without changing any behavior for how tables/schemas work. ### Does this PR introduce _any_ user-facing change? Technically yes, but it is a small change that should not impact any queries, just how StringType is represented when not in a StructField object. ### How was this patch tested? New and existing unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51850 from stefankandic/fixStringJson. Authored-by: Stefan Kandic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent f6cd385 commit 19ea6ff

File tree

6 files changed

+84
-13
lines changed

6 files changed

+84
-13
lines changed

python/pyspark/sql/tests/test_types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,17 @@ def test_schema_with_collations_json_ser_de(self):
647647
from pyspark.sql.types import _parse_datatype_json_string
648648

649649
unicode_collation = "UNICODE"
650+
utf8_lcase_collation = "UTF8_LCASE"
651+
652+
standalone_string = StringType(unicode_collation)
653+
654+
standalone_array = ArrayType(StringType(unicode_collation))
655+
656+
standalone_map = MapType(StringType(utf8_lcase_collation), StringType(unicode_collation))
657+
658+
standalone_nested = ArrayType(
659+
MapType(StringType(utf8_lcase_collation), ArrayType(StringType(unicode_collation)))
660+
)
650661

651662
simple_struct = StructType([StructField("c1", StringType(unicode_collation))])
652663

@@ -718,6 +729,10 @@ def test_schema_with_collations_json_ser_de(self):
718729
)
719730

720731
schemas = [
732+
standalone_string,
733+
standalone_array,
734+
standalone_map,
735+
standalone_nested,
721736
simple_struct,
722737
nested_struct,
723738
array_in_schema,

python/pyspark/sql/types.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -298,11 +298,8 @@ def simpleString(self) -> str:
298298

299299
return f"string collate {self.collation}"
300300

301-
# For backwards compatibility and compatibility with other readers all string types
302-
# are serialized in json as regular strings and the collation info is written to
303-
# struct field metadata
304301
def jsonValue(self) -> str:
305-
return "string"
302+
return self.simpleString()
306303

307304
def __repr__(self) -> str:
308305
return (
@@ -1058,11 +1055,39 @@ def jsonValue(self) -> Dict[str, Any]:
10581055

10591056
return {
10601057
"name": self.name,
1061-
"type": self.dataType.jsonValue(),
1058+
"type": self._dataTypeJsonValue(collationMetadata),
10621059
"nullable": self.nullable,
10631060
"metadata": metadata,
10641061
}
10651062

1063+
def _dataTypeJsonValue(self, collationMetadata: Dict[str, str]) -> Union[str, Dict[str, Any]]:
1064+
if not collationMetadata:
1065+
return self.dataType.jsonValue()
1066+
1067+
def removeCollations(dt: DataType) -> DataType:
1068+
# Only recurse into map and array types as any child struct type
1069+
# will have already been processed.
1070+
if isinstance(dt, ArrayType):
1071+
return ArrayType(removeCollations(dt.elementType), dt.containsNull)
1072+
elif isinstance(dt, MapType):
1073+
return MapType(
1074+
removeCollations(dt.keyType),
1075+
removeCollations(dt.valueType),
1076+
dt.valueContainsNull,
1077+
)
1078+
elif isinstance(dt, StringType):
1079+
return StringType()
1080+
elif isinstance(dt, VarcharType):
1081+
return VarcharType(dt.length)
1082+
elif isinstance(dt, CharType):
1083+
return CharType(dt.length)
1084+
else:
1085+
return dt
1086+
1087+
# As we want to be backwards compatible we should remove all collations information from the
1088+
# json and only keep that information in the metadata.
1089+
return removeCollations(self.dataType).jsonValue()
1090+
10661091
@classmethod
10671092
def fromJson(cls, json: Dict[str, Any]) -> "StructField":
10681093
metadata = json.get("metadata")
@@ -1891,6 +1916,7 @@ def parseJson(cls, json_str: str) -> "VariantVal":
18911916

18921917
_LENGTH_CHAR = re.compile(r"char\(\s*(\d+)\s*\)")
18931918
_LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)")
1919+
_STRING_WITH_COLLATION = re.compile(r"string\s+collate\s+(\w+)")
18941920
_FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)")
18951921
_INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?")
18961922
_INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?")
@@ -2055,6 +2081,9 @@ def _parse_datatype_json_value(
20552081
if first_field is not None and second_field is None:
20562082
return YearMonthIntervalType(first_field)
20572083
return YearMonthIntervalType(first_field, second_field)
2084+
elif _STRING_WITH_COLLATION.match(json_value):
2085+
m = _STRING_WITH_COLLATION.match(json_value)
2086+
return StringType(m.group(1)) # type: ignore[union-attr]
20582087
elif _LENGTH_CHAR.match(json_value):
20592088
m = _LENGTH_CHAR.match(json_value)
20602089
return CharType(int(m.group(1))) # type: ignore[union-attr]

sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ object DataType {
126126
private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r
127127
private val CHAR_TYPE = """char\(\s*(\d+)\s*\)""".r
128128
private val VARCHAR_TYPE = """varchar\(\s*(\d+)\s*\)""".r
129+
private val STRING_WITH_COLLATION = """string\s+collate\s+(\w+)""".r
129130

130131
val COLLATIONS_METADATA_KEY = "__COLLATIONS"
131132

@@ -215,6 +216,7 @@ object DataType {
215216
case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
216217
case CHAR_TYPE(length) => CharType(length.toInt)
217218
case VARCHAR_TYPE(length) => VarcharType(length.toInt)
219+
case STRING_WITH_COLLATION(collation) => StringType(collation)
218220
// For backwards compatibility, previously the type name of NullType is "null"
219221
case "null" => NullType
220222
case "timestamp_ltz" => TimestampType

sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql.types
1919

20-
import org.json4s.JsonAST.{JString, JValue}
21-
2220
import org.apache.spark.annotation.Stable
2321
import org.apache.spark.sql.catalyst.util.CollationFactory
2422
import org.apache.spark.sql.internal.SqlApiConf
@@ -90,11 +88,6 @@ class StringType private[sql] (
9088
private[sql] def collationName: String =
9189
CollationFactory.fetchCollation(collationId).collationName
9290

93-
// Due to backwards compatibility and compatibility with other readers
94-
// all string types are serialized in json as regular strings and
95-
// the collation information is written to struct field metadata
96-
override def jsonValue: JValue = JString("string")
97-
9891
override def equals(obj: Any): Boolean = {
9992
obj match {
10093
case s: StringType => s.collationId == collationId && s.constraint == constraint

sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,30 @@ case class StructField(
7070

7171
private[sql] def jsonValue: JValue = {
7272
("name" -> name) ~
73-
("type" -> dataType.jsonValue) ~
73+
("type" -> dataTypeJsonValue) ~
7474
("nullable" -> nullable) ~
7575
("metadata" -> metadataJson)
7676
}
7777

78+
private[sql] def dataTypeJsonValue: JValue = {
79+
if (collationMetadata.isEmpty) return dataType.jsonValue
80+
81+
def removeCollations(dt: DataType): DataType = dt match {
82+
// Only recurse into map and array types as any child struct type
83+
// will have already been processed.
84+
case ArrayType(et, nullable) =>
85+
ArrayType(removeCollations(et), nullable)
86+
case MapType(kt, vt, nullable) =>
87+
MapType(removeCollations(kt), removeCollations(vt), nullable)
88+
case st: StringType => StringHelper.removeCollation(st)
89+
case _ => dt
90+
}
91+
92+
// As we want to be backwards compatible we should remove all collations information from the
93+
// json and only keep that information in the metadata.
94+
removeCollations(dataType).jsonValue
95+
}
96+
7897
private def metadataJson: JValue = {
7998
val metadataJsonValue = metadata.jsonValue
8099
metadataJsonValue match {

sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearM
3030
class DataTypeSuite extends SparkFunSuite {
3131

3232
private val UNICODE_COLLATION_ID = CollationFactory.collationNameToId("UNICODE")
33+
private val UTF8_LCASE_COLLATION_ID = CollationFactory.collationNameToId("UTF8_LCASE")
3334

3435
test("construct an ArrayType") {
3536
val array = ArrayType(StringType)
@@ -1145,6 +1146,17 @@ class DataTypeSuite extends SparkFunSuite {
11451146
}
11461147

11471148
test("schema with collation should not change during ser/de") {
1149+
val standaloneString = StringType(UNICODE_COLLATION_ID)
1150+
1151+
val standaloneArray = ArrayType(StringType(UNICODE_COLLATION_ID))
1152+
1153+
val standaloneMap = MapType(StringType(UNICODE_COLLATION_ID),
1154+
StringType(UTF8_LCASE_COLLATION_ID))
1155+
1156+
val standaloneNested = ArrayType(MapType(
1157+
StringType(UNICODE_COLLATION_ID),
1158+
ArrayType(StringType(UTF8_LCASE_COLLATION_ID))))
1159+
11481160
val simpleStruct = StructType(
11491161
StructField("c1", StringType(UNICODE_COLLATION_ID)) :: Nil)
11501162

@@ -1185,6 +1197,7 @@ class DataTypeSuite extends SparkFunSuite {
11851197
mapWithKeyInNameInSchema ++ arrayInMapInNestedSchema.fields ++ nestedArrayInMap.fields)
11861198

11871199
Seq(
1200+
standaloneString, standaloneArray, standaloneMap, standaloneNested,
11881201
simpleStruct, caseInsensitiveNames, specialCharsInName, nestedStruct, arrayInSchema,
11891202
mapInSchema, mapWithKeyInNameInSchema, nestedArrayInMap, arrayInMapInNestedSchema,
11901203
schemaWithMultipleFields)

0 commit comments

Comments
 (0)