diff --git a/docs/sql-ref-datatypes.md b/docs/sql-ref-datatypes.md index 25f847508ce01..fd3e0b7a3164f 100644 --- a/docs/sql-ref-datatypes.md +++ b/docs/sql-ref-datatypes.md @@ -131,7 +131,7 @@ from pyspark.sql.types import * |**StringType**|str|StringType()| |**CharType(length)**|str|CharType(length)| |**VarcharType(length)**|str|VarcharType(length)| -|**BinaryType**|bytearray|BinaryType()| +|**BinaryType**|bytes|BinaryType()| |**BooleanType**|bool|BooleanType()| |**TimestampType**|datetime.datetime|TimestampType()| |**TimestampNTZType**|datetime.datetime|TimestampNTZType()| diff --git a/python/docs/source/tutorial/sql/type_conversions.rst b/python/docs/source/tutorial/sql/type_conversions.rst index 2f13701995ef2..8fea54cd3eae9 100644 --- a/python/docs/source/tutorial/sql/type_conversions.rst +++ b/python/docs/source/tutorial/sql/type_conversions.rst @@ -105,7 +105,7 @@ All Conversions - string - StringType() * - **BinaryType** - - bytearray + - bytes - BinaryType() * - **BooleanType** - bool diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 821e142304aa9..66415e3ba4556 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -545,6 +545,44 @@ def __repr__(self): return "CompressedSerializer(%s)" % self.serializer +class BinaryConvertingSerializer(Serializer): + """ + Converts bytearray to bytes for binary data when binary_as_bytes is enabled + """ + + def __init__(self, serializer, binary_as_bytes=False): + self.serializer = serializer + self.binary_as_bytes = binary_as_bytes + + def _convert_binary(self, obj): + """Recursively convert bytearray to bytes in data structures""" + if not self.binary_as_bytes: + return obj + + if isinstance(obj, bytearray): + return bytes(obj) + elif isinstance(obj, (list, tuple)): + converted = [self._convert_binary(item) for item in obj] + return type(obj)(converted) + elif isinstance(obj, dict): + return {key: self._convert_binary(value) for key, value in obj.items()} + else: + return obj + + def dump_stream(self, iterator, stream): + self.serializer.dump_stream(iterator, stream) + + def load_stream(self, stream): + for obj in self.serializer.load_stream(stream): + yield self._convert_binary(obj) + + def __repr__(self): + return "BinaryConvertingSerializer(%s, binary_as_bytes=%s)" % ( + str(self.serializer), + self.binary_as_bytes, + ) + + class UTF8Deserializer(Serializer): """ diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py index 5c4be6570cd6a..e1b70c5fd3597 100644 --- a/python/pyspark/sql/avro/functions.py +++ b/python/pyspark/sql/avro/functions.py @@ -69,7 +69,7 @@ def from_avro( >>> df = spark.createDataFrame(data, ("key", "value")) >>> avroDf = df.select(to_avro(df.value).alias("avro")) >>> avroDf.collect() - [Row(avro=bytearray(b'\\x00\\x00\\x04\\x00\\nAlice'))] + [Row(avro=b'\\x00\\x00\\x04\\x00\\nAlice')] >>> jsonFormatSchema = '''{"type":"record","name":"topLevelRecord","fields": ... [{"name":"avro","type":[{"type":"record","name":"value","namespace":"topLevelRecord", @@ -141,12 +141,12 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column: >>> data = ['SPADES'] >>> df = spark.createDataFrame(data, "string") >>> df.select(to_avro(df.value).alias("suite")).collect() - [Row(suite=bytearray(b'\\x00\\x0cSPADES'))] + [Row(suite=b'\\x00\\x0cSPADES')] >>> jsonFormatSchema = '''["null", {"type": "enum", "name": "value", ... "symbols": ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]}]''' >>> df.select(to_avro(df.value, jsonFormatSchema).alias("suite")).collect() - [Row(suite=bytearray(b'\\x02\\x00'))] + [Row(suite=b'\\x02\\x00')] """ from py4j.java_gateway import JVMView from pyspark.sql.classic.column import _to_java_column diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index f1aa55c2039ac..d4f43fff62e69 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -46,6 +46,21 @@ import pyarrow as pa +def _should_use_bytes_for_binary(binary_as_bytes: Optional[bool] = None) -> bool: + """Check if BINARY type should be converted to bytes instead of bytearray.""" + if binary_as_bytes is not None: + return binary_as_bytes + + from pyspark.sql import SparkSession + + spark = SparkSession.getActiveSession() + if spark is not None: + v = spark.conf.get("spark.sql.execution.pyspark.binaryAsBytes.enabled", "true") + return str(v).lower() == "true" + + return True + + class LocalDataToArrowConversion: """ Conversion from local data (except pandas DataFrame and numpy ndarray) to Arrow. @@ -518,13 +533,16 @@ def _create_converter(dataType: DataType) -> Callable: @overload @staticmethod def _create_converter( - dataType: DataType, *, none_on_identity: bool = True + dataType: DataType, *, none_on_identity: bool = True, binary_as_bytes: Optional[bool] = None ) -> Optional[Callable]: pass @staticmethod def _create_converter( - dataType: DataType, *, none_on_identity: bool = False + dataType: DataType, + *, + none_on_identity: bool = False, + binary_as_bytes: Optional[bool] = None, ) -> Optional[Callable]: assert dataType is not None and isinstance(dataType, DataType) @@ -542,7 +560,9 @@ def _create_converter( dedup_field_names = _dedup_names(field_names) field_convs = [ - ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True) + ArrowTableToRowsConversion._create_converter( + f.dataType, none_on_identity=True, binary_as_bytes=binary_as_bytes + ) for f in dataType.fields ] @@ -564,7 +584,7 @@ def convert_struct(value: Any) -> Any: elif isinstance(dataType, ArrayType): element_conv = ArrowTableToRowsConversion._create_converter( - dataType.elementType, none_on_identity=True + dataType.elementType, none_on_identity=True, binary_as_bytes=binary_as_bytes ) if element_conv is None: @@ -589,10 +609,10 @@ def convert_array(value: Any) -> Any: elif isinstance(dataType, MapType): key_conv = ArrowTableToRowsConversion._create_converter( - dataType.keyType, none_on_identity=True + dataType.keyType, none_on_identity=True, binary_as_bytes=binary_as_bytes ) value_conv = ArrowTableToRowsConversion._create_converter( - dataType.valueType, none_on_identity=True + dataType.valueType, none_on_identity=True, binary_as_bytes=binary_as_bytes ) if key_conv is None: @@ -646,7 +666,10 @@ def convert_binary(value: Any) -> Any: return None else: assert isinstance(value, bytes) - return bytearray(value) + if _should_use_bytes_for_binary(binary_as_bytes): + return value + else: + return bytearray(value) return convert_binary @@ -676,7 +699,7 @@ def convert_timestample_ntz(value: Any) -> Any: udt: UserDefinedType = dataType conv = ArrowTableToRowsConversion._create_converter( - udt.sqlType(), none_on_identity=True + udt.sqlType(), none_on_identity=True, binary_as_bytes=binary_as_bytes ) if conv is None: @@ -722,20 +745,28 @@ def convert_variant(value: Any) -> Any: @overload @staticmethod def convert( # type: ignore[overload-overlap] - table: "pa.Table", schema: StructType + table: "pa.Table", schema: StructType, *, binary_as_bytes: Optional[bool] = None ) -> List[Row]: pass @overload @staticmethod def convert( - table: "pa.Table", schema: StructType, *, return_as_tuples: bool = True + table: "pa.Table", + schema: StructType, + *, + return_as_tuples: bool = True, + binary_as_bytes: Optional[bool] = None, ) -> List[tuple]: pass @staticmethod # type: ignore[misc] def convert( - table: "pa.Table", schema: StructType, *, return_as_tuples: bool = False + table: "pa.Table", + schema: StructType, + *, + return_as_tuples: bool = False, + binary_as_bytes: Optional[bool] = None, ) -> List[Union[Row, tuple]]: require_minimum_pyarrow_version() import pyarrow as pa @@ -748,7 +779,9 @@ def convert( if len(fields) > 0: field_converters = [ - ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True) + ArrowTableToRowsConversion._create_converter( + f.dataType, none_on_identity=True, binary_as_bytes=binary_as_bytes + ) for f in schema.fields ] diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index b09713e0c289e..79c6475a3d28d 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -16591,15 +16591,15 @@ def to_binary(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) -> C >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([("abc",)], ["e"]) - >>> df.select(sf.try_to_binary(df.e, sf.lit("utf-8")).alias('r')).collect() - [Row(r=bytearray(b'abc'))] + >>> df.select(sf.to_binary(df.e, sf.lit("utf-8")).alias('r')).collect() + [Row(r=b'abc')] Example 2: Convert string to a timestamp without encoding specified >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([("414243",)], ["e"]) - >>> df.select(sf.try_to_binary(df.e).alias('r')).collect() - [Row(r=bytearray(b'ABC'))] + >>> df.select(sf.to_binary(df.e).alias('r')).collect() + [Row(r=b'ABC')] """ if format is not None: return _invoke_function_over_columns("to_binary", col, format) @@ -17615,14 +17615,14 @@ def try_to_binary(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([("abc",)], ["e"]) >>> df.select(sf.try_to_binary(df.e, sf.lit("utf-8")).alias('r')).collect() - [Row(r=bytearray(b'abc'))] + [Row(r=b'abc')] Example 2: Convert string to a timestamp without encoding specified >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([("414243",)], ["e"]) >>> df.select(sf.try_to_binary(df.e).alias('r')).collect() - [Row(r=bytearray(b'ABC'))] + [Row(r=b'ABC')] Example 3: Converion failure results in NULL when ANSI mode is on diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index d1bdfa9e8d01e..b975217be50e1 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -350,11 +350,12 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): This has performance penalties. """ - def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled): + def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled, binary_as_bytes=None): super(ArrowStreamPandasSerializer, self).__init__() self._timezone = timezone self._safecheck = safecheck self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled + self._binary_as_bytes = binary_as_bytes self._converter_cache = {} @staticmethod @@ -583,9 +584,10 @@ def __init__( arrow_cast=False, input_types=None, int_to_decimal_coercion_enabled=False, + binary_as_bytes=None, ): super(ArrowStreamPandasUDFSerializer, self).__init__( - timezone, safecheck, int_to_decimal_coercion_enabled + timezone, safecheck, int_to_decimal_coercion_enabled, binary_as_bytes ) self._assign_cols_by_name = assign_cols_by_name self._df_for_struct = df_for_struct @@ -782,12 +784,14 @@ def __init__( safecheck, assign_cols_by_name, arrow_cast, + binary_as_bytes=None, ): super(ArrowStreamArrowUDFSerializer, self).__init__() self._timezone = timezone self._safecheck = safecheck self._assign_cols_by_name = assign_cols_by_name self._arrow_cast = arrow_cast + self._binary_as_bytes = binary_as_bytes def _create_array(self, arr, arrow_type, arrow_cast): import pyarrow as pa @@ -862,12 +866,14 @@ def __init__( safecheck, input_types, int_to_decimal_coercion_enabled=False, + binary_as_bytes=None, ): super().__init__( timezone=timezone, safecheck=safecheck, assign_cols_by_name=False, arrow_cast=True, + binary_as_bytes=binary_as_bytes, ) self._input_types = input_types self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled @@ -887,7 +893,9 @@ def load_stream(self, stream): List of columns containing list of Python values. """ converters = [ - ArrowTableToRowsConversion._create_converter(dt, none_on_identity=True) + ArrowTableToRowsConversion._create_converter( + dt, none_on_identity=True, binary_as_bytes=self._binary_as_bytes + ) for dt in self._input_types ] @@ -949,7 +957,14 @@ class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer): Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs. """ - def __init__(self, timezone, safecheck, input_types, int_to_decimal_coercion_enabled): + def __init__( + self, + timezone, + safecheck, + input_types, + int_to_decimal_coercion_enabled, + binary_as_bytes=None, + ): super(ArrowStreamPandasUDTFSerializer, self).__init__( timezone=timezone, safecheck=safecheck, @@ -972,6 +987,7 @@ def __init__(self, timezone, safecheck, input_types, int_to_decimal_coercion_ena input_types=input_types, # Enable additional coercions for UDTF serialization int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + binary_as_bytes=binary_as_bytes, ) self._converter_map = dict() diff --git a/python/pyspark/sql/tests/arrow/test_arrow.py b/python/pyspark/sql/tests/arrow/test_arrow.py index 819639c63a2cb..3a7e6dc8c35f4 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow.py +++ b/python/pyspark/sql/tests/arrow/test_arrow.py @@ -1781,6 +1781,47 @@ def test_createDataFrame_arrow_fixed_size_binary(self): df = self.spark.createDataFrame(t) self.assertIsInstance(df.schema["fsb"].dataType, BinaryType) + def test_binary_type_default_bytes_behavior(self): + """Test that binary values are returned as bytes by default""" + df = self.spark.createDataFrame([(bytearray(b"test"),)], ["binary_col"]) + collected = df.collect() + self.assertIsInstance(collected[0].binary_col, bytes) + self.assertEqual(collected[0].binary_col, b"test") + + def test_binary_type_config_enabled_bytes(self): + """Test that binary values are returned as bytes when config is enabled""" + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes.enabled": "true"}): + df = self.spark.createDataFrame([(bytearray(b"test"),)], ["binary_col"]) + collected = df.collect() + self.assertIsInstance(collected[0].binary_col, bytes) + self.assertEqual(collected[0].binary_col, b"test") + + def test_binary_type_config_disabled_bytearray(self): + """Test that binary values are returned as bytearray when config is disabled""" + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes.enabled": "false"}): + df = self.spark.createDataFrame([(bytearray(b"test"),)], ["binary_col"]) + collected = df.collect() + self.assertIsInstance(collected[0].binary_col, bytearray) + self.assertEqual(collected[0].binary_col, bytearray(b"test")) + + def test_binary_type_to_local_iterator_bytes_mode(self): + """Test binary type with toLocalIterator when bytes mode is enabled""" + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes.enabled": "true"}): + df = self.spark.createDataFrame([(b"test1",), (b"test2",)], ["binary_col"]) + local_iter = df.toLocalIterator() + rows = list(local_iter) + for row in rows: + self.assertIsInstance(row.binary_col, bytes) + + def test_binary_type_to_local_iterator_bytearray_mode(self): + """Test binary type with toLocalIterator when bytearray mode is enabled""" + with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes.enabled": "false"}): + df = self.spark.createDataFrame([(b"test1",), (b"test2",)], ["binary_col"]) + local_iter = df.toLocalIterator() + rows = list(local_iter) + for row in rows: + self.assertIsInstance(row.binary_col, bytearray) + def test_createDataFrame_arrow_fixed_size_list(self): a = pa.array([[-1, 3]] * 5, type=pa.list_(pa.int32(), 2)) t = pa.table([a], ["fsl"]) diff --git a/python/pyspark/sql/tests/test_serde.py b/python/pyspark/sql/tests/test_serde.py index eab1ad043ef33..63f494751bb5b 100644 --- a/python/pyspark/sql/tests/test_serde.py +++ b/python/pyspark/sql/tests/test_serde.py @@ -128,9 +128,9 @@ def test_BinaryType_serialization(self): # The empty bytearray is test for SPARK-21534. schema = StructType([StructField("mybytes", BinaryType())]) data = [ - [bytearray(b"here is my data")], - [bytearray(b"and here is some more")], - [bytearray(b"")], + [bytes(b"here is my data")], + [bytes(b"and here is some more")], + [bytes(b"")], ] df = self.spark.createDataFrame(data, schema=schema) df.collect() @@ -143,7 +143,7 @@ def test_int_array_serialization(self): def test_bytes_as_binary_type(self): df = self.spark.createDataFrame([[b"abcd"]], "col binary") - self.assertEqual(df.first().col, bytearray(b"abcd")) + self.assertEqual(df.first().col, bytes(b"abcd")) class SerdeTests(SerdeTestsMixin, ReusedSQLTestCase): diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_input_types.txt b/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_input_types.txt index d21e7f2eb24a1..2b87384586237 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_input_types.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_input_types.txt @@ -17,8 +17,8 @@ |decimal_null |decimal(3,2) |[None, Decimal('9.99')] |['object', 'object'] |[None, Decimal('9.99')] | |string_values |string |['abc', '', 'hello'] |['object', 'object', 'object'] |['abc', '', 'hello'] | |string_null |string |[None, 'test'] |['object', 'object'] |[None, 'test'] | -|binary_values |binary |[bytearray(b'abc'), bytearray(b''), bytearray(b'ABC')] |['object', 'object', 'object'] |[bytearray(b'abc'), bytearray(b''), bytearray(b'ABC')] | -|binary_null |binary |[None, bytearray(b'test')] |['object', 'object'] |[None, bytearray(b'test')] | +|binary_values |binary |[b'abc', b'', b'ABC'] |['object', 'object', 'object'] |[b'abc', b'', b'ABC'] | +|binary_null |binary |[None, b'test'] |['object', 'object'] |[None, b'test'] | |boolean_values |boolean |[True, False] |['bool', 'bool'] |[True, False] | |boolean_null |boolean |[None, True] |['object', 'object'] |[None, True] | |date_values |date |[datetime.date(2020, 2, 2), datetime.date(1970, 1, 1)] |['object', 'object'] |[datetime.date(2020, 2, 2), datetime.date(1970, 1, 1)] | diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_return_type_coercion.txt b/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_return_type_coercion.txt index 7719d1805d9e9..ad9af8add2e02 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_return_type_coercion.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_pandas_udf_return_type_coercion.txt @@ -12,7 +12,7 @@ |float |[None, None] |[1.0, 0.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |X |X |X |X |[12.0, 34.0] |X |[1.0, 2.0] |X |X |X |X |X |X | |double |[None, None] |[1.0, 0.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |[1.0, 2.0] |X |X |X |X |[12.0, 34.0] |X |[1.0, 2.0] |X |X |X |X |X |X | |array |[None, None] |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |[[1, 2], [3, 4]] |[[1, 2, 3], [1, 2, 3]] |X |X |X |X |X |X |X | -|binary |[None, None] |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b'\x01'), bytearray |[bytearray(b''), bytearray(b'' |[bytearray(b''), bytearray(b'' |[bytearray(b''), bytearray(b'' |[bytearray(b''), bytearray(b'' |[bytearray(b''), bytearray(b'' |[bytearray(b''), bytearray(b'' |[bytearray(b'a'), bytearray(b' |[bytearray(b'12'), bytearray(b |X |X |[bytearray(b''), bytearray(b'' |[bytearray(b''), bytearray(b'' |[bytearray(b''), bytearray(b'' |[bytearray(b'A'), bytearray(b' |X |X | +|binary |[None, None] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'\x01', b'\x02'] |[b'', b''] |[b'', b''] |[b'', b''] |[b'', b''] |[b'', b''] |[b'', b''] |[b'a', b'b'] |[b'12', b'34'] |X |X |[b'', b''] |[b'', b''] |[b'', b''] |[b'A', b'B'] |X |X | |decimal(10,0) |[None, None] |X |[Decimal('1'), Decimal('2')] |[Decimal('1'), Decimal('2')] |[Decimal('1'), Decimal('2')] |X |[Decimal('1'), Decimal('2')] |[Decimal('1'), Decimal('2')] |[Decimal('1'), Decimal('2')] |X |X |[Decimal('1'), Decimal('2')] |[Decimal('1'), Decimal('2')] |X |X |X |X |X |X |[Decimal('1'), Decimal('2')] |X |X |X |X |X |X | |map |[None, None] |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |[{'a': 1}, {'b': 2}] | |struct<_1:int> |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |X |[Row(_1=1), Row(_1=2)] |X | diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_disabled.txt b/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_disabled.txt index 2572d48dbec7c..a3727dfd5d6b7 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_disabled.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_disabled.txt @@ -17,8 +17,8 @@ |decimal_null |decimal(3,2) |[None, Decimal('9.99')] |['NoneType', 'Decimal'] |['None', '9.99'] | |string_values |string |['abc', '', 'hello'] |['str', 'str', 'str'] |['abc', '', 'hello'] | |string_null |string |[None, 'test'] |['NoneType', 'str'] |['None', 'test'] | -|binary_values |binary |[bytearray(b'abc'), bytearray(b''), bytearray(b'ABC')] |['bytearray', 'bytearray', 'bytearray'] |["bytearray(b'abc')", "bytearray(b'')", "bytearray(b'ABC')"] | -|binary_null |binary |[None, bytearray(b'test')] |['NoneType', 'bytearray'] |['None', "bytearray(b'test')"] | +|binary_values |binary |[b'abc', b'', b'ABC'] |['bytes', 'bytes', 'bytes'] |["b'abc'", "b''", "b'ABC'"] | +|binary_null |binary |[None, b'test'] |['NoneType', 'bytes'] |['None', "b'test'"] | |boolean_values |boolean |[True, False] |['bool', 'bool'] |['True', 'False'] | |boolean_null |boolean |[None, True] |['NoneType', 'bool'] |['None', 'True'] | |date_values |date |[datetime.date(2020, 2, 2), datetime.date(1970, 1, 1)] |['date', 'date'] |['2020-02-02', '1970-01-01'] | diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_enabled.txt b/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_enabled.txt index 2572d48dbec7c..a3727dfd5d6b7 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_enabled.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_enabled.txt @@ -17,8 +17,8 @@ |decimal_null |decimal(3,2) |[None, Decimal('9.99')] |['NoneType', 'Decimal'] |['None', '9.99'] | |string_values |string |['abc', '', 'hello'] |['str', 'str', 'str'] |['abc', '', 'hello'] | |string_null |string |[None, 'test'] |['NoneType', 'str'] |['None', 'test'] | -|binary_values |binary |[bytearray(b'abc'), bytearray(b''), bytearray(b'ABC')] |['bytearray', 'bytearray', 'bytearray'] |["bytearray(b'abc')", "bytearray(b'')", "bytearray(b'ABC')"] | -|binary_null |binary |[None, bytearray(b'test')] |['NoneType', 'bytearray'] |['None', "bytearray(b'test')"] | +|binary_values |binary |[b'abc', b'', b'ABC'] |['bytes', 'bytes', 'bytes'] |["b'abc'", "b''", "b'ABC'"] | +|binary_null |binary |[None, b'test'] |['NoneType', 'bytes'] |['None', "b'test'"] | |boolean_values |boolean |[True, False] |['bool', 'bool'] |['True', 'False'] | |boolean_null |boolean |[None, True] |['NoneType', 'bool'] |['None', 'True'] | |date_values |date |[datetime.date(2020, 2, 2), datetime.date(1970, 1, 1)] |['date', 'date'] |['2020-02-02', '1970-01-01'] | diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_legacy_pandas.txt b/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_legacy_pandas.txt index 92f8af100e743..576af5f12102d 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_legacy_pandas.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_udf_input_types_arrow_legacy_pandas.txt @@ -17,8 +17,8 @@ |decimal_null |decimal(3,2) |[None, Decimal('9.99')] |['NoneType', 'Decimal'] |['None', '9.99'] | |string_values |string |['abc', '', 'hello'] |['str', 'str', 'str'] |['abc', '', 'hello'] | |string_null |string |[None, 'test'] |['NoneType', 'str'] |['None', 'test'] | -|binary_values |binary |[bytearray(b'abc'), bytearray(b''), bytearray(b'ABC')] |['bytes', 'bytes', 'bytes'] |["b'abc'", "b''", "b'ABC'"] | -|binary_null |binary |[None, bytearray(b'test')] |['NoneType', 'bytes'] |['None', "b'test'"] | +|binary_values |binary |[b'abc', b'', b'ABC'] |['bytes', 'bytes', 'bytes'] |["b'abc'", "b''", "b'ABC'"] | +|binary_null |binary |[None, b'test'] |['NoneType', 'bytes'] |['None', "b'test'"] | |boolean_values |boolean |[True, False] |['bool', 'bool'] |['True', 'False'] | |boolean_null |boolean |[None, True] |['NoneType', 'bool'] |['None', 'True'] | |date_values |date |[datetime.date(2020, 2, 2), datetime.date(1970, 1, 1)] |['date', 'date'] |['2020-02-02', '1970-01-01'] | diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_enabled.txt b/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_enabled.txt index c117113369e56..0fb6301735668 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_enabled.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_enabled.txt @@ -1,18 +1,18 @@ +-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+ -|SQL Type \ Python Value(Type) |None(NoneType) |True(bool) |1(int) |a(str) |1970-01-01(date) |1970-01-01 00:00:00(datetime) |1.0(float) |array('i', [1])(array) |[1](list) |(1,)(tuple) |bytearray(b'ABC')(bytearray) |1(Decimal) |{'a': 1}(dict) |Row(kwargs=1)(Row) |Row(namedtuple=1)(Row) | +|SQL Type \ Python Value(Type) |None(NoneType) |True(bool) |1(int) |a(str) |1970-01-01(date) |1970-01-01 00:00:00(datetime) |1.0(float) |array('i', [1])(array) |[1](list) |(1,)(tuple) |b'ABC'(bytes) |1(Decimal) |{'a': 1}(dict) |Row(kwargs=1)(Row) |Row(namedtuple=1)(Row) | +-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+ |boolean |None |True |True |X |X |X |True |X |X |X |X |X |X |X |X | |tinyint |None |X |1 |X |X |X |1 |X |X |X |X |1 |X |X |X | |smallint |None |X |1 |X |X |X |1 |X |X |X |X |1 |X |X |X | |int |None |X |1 |X |0 |X |1 |X |X |X |X |1 |X |X |X | |bigint |None |X |1 |X |X |0 |1 |X |X |X |X |1 |X |X |X | -|string |None |'true' |'1' |'a' |'1970-01-01' |'1970-01-01 00:00:00' |'1.0' |"array('i', [1])" |'[1]' |'(1,)' |"bytearray(b'ABC')" |'1' |"{'a': 1}" |'Row(kwargs=1)' |'Row(namedtuple=1)' | +|string |None |'true' |'1' |'a' |'1970-01-01' |'1970-01-01 00:00:00' |'1.0' |"array('i', [1])" |'[1]' |'(1,)' |"b'ABC'" |'1' |"{'a': 1}" |'Row(kwargs=1)' |'Row(namedtuple=1)' | |date |None |X |datetime.date(1970, 1, 2) |X |datetime.date(1970, 1, 1) |datetime.date(1970, 1, 1) |datetime.date(1970, 1, 2) |X |X |X |X |datetime.date(1970, 1, 2) |X |X |X | |timestamp |None |X |X |X |X |datetime.datetime(1970, 1, 1, |X |X |X |X |X |X |X |X |X | |float |None |1.0 |1.0 |X |X |X |1.0 |X |X |X |X |1.0 |X |X |X | |double |None |1.0 |1.0 |X |X |X |1.0 |X |X |X |X |1.0 |X |X |X | |array |None |X |X |X |X |X |X |[1] |[1] |[1] |[65, 66, 67] |X |X |[1] |[1] | -|binary |None |X |X |X |X |X |X |X |X |X |bytearray(b'ABC') |X |X |X |X | +|binary |None |X |X |X |X |X |X |X |X |X |b'ABC' |X |X |X |X | |decimal(10,0) |None |X |X |X |X |X |X |X |X |X |X |Decimal('1') |X |X |X | |map |None |X |X |X |X |X |X |X |X |X |X |X |{'a': 1} |X |X | |struct<_1:int> |None |X |X |X |X |X |X |X |X |Row(_1=1) |X |X |Row(_1=None) |Row(_1=1) |Row(_1=1) | diff --git a/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_legacy_pandas.txt b/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_legacy_pandas.txt index a1809dfa9aab6..2cbbe1f18aa51 100644 --- a/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_legacy_pandas.txt +++ b/python/pyspark/sql/tests/udf_type_tests/golden_udf_return_type_coercion_arrow_legacy_pandas.txt @@ -1,18 +1,18 @@ +-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+ -|SQL Type \ Python Value(Type) |None(NoneType) |True(bool) |1(int) |a(str) |1970-01-01(date) |1970-01-01 00:00:00(datetime) |1.0(float) |array('i', [1])(array) |[1](list) |(1,)(tuple) |bytearray(b'ABC')(bytearray) |1(Decimal) |{'a': 1}(dict) |Row(kwargs=1)(Row) |Row(namedtuple=1)(Row) | +|SQL Type \ Python Value(Type) |None(NoneType) |True(bool) |1(int) |a(str) |1970-01-01(date) |1970-01-01 00:00:00(datetime) |1.0(float) |array('i', [1])(array) |[1](list) |(1,)(tuple) |b'ABC'(bytes) |1(Decimal) |{'a': 1}(dict) |Row(kwargs=1)(Row) |Row(namedtuple=1)(Row) | +-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+-------------------------------+ |boolean |None |True |True |X |X |X |True |X |X |X |X |X |X |X |X | |tinyint |None |1 |1 |X |X |X |1 |X |X |X |X |1 |X |X |X | |smallint |None |1 |1 |X |X |X |1 |X |X |X |X |1 |X |X |X | |int |None |1 |1 |X |0 |X |1 |X |X |X |X |1 |X |X |X | |bigint |None |1 |1 |X |X |0 |1 |X |X |X |X |1 |X |X |X | -|string |None |'True' |'1' |'a' |'1970-01-01' |'1970-01-01 00:00:00' |'1.0' |"array('i', [1])" |'[1]' |'(1,)' |"bytearray(b'ABC')" |'1' |"{'a': 1}" |'Row(kwargs=1)' |'Row(namedtuple=1)' | +|string |None |'True' |'1' |'a' |'1970-01-01' |'1970-01-01 00:00:00' |'1.0' |"array('i', [1])" |'[1]' |'(1,)' |"b'ABC'" |'1' |"{'a': 1}" |'Row(kwargs=1)' |'Row(namedtuple=1)' | |date |None |X |X |X |datetime.date(1970, 1, 1) |datetime.date(1970, 1, 1) |X |X |X |X |X |datetime.date(1970, 1, 2) |X |X |X | |timestamp |None |X |datetime.datetime(1970, 1, 1, |X |X |datetime.datetime(1970, 1, 1, |X |X |X |X |X |datetime.datetime(1970, 1, 1, |X |X |X | |float |None |1.0 |1.0 |X |X |X |1.0 |X |X |X |X |1.0 |X |X |X | |double |None |1.0 |1.0 |X |X |X |1.0 |X |X |X |X |1.0 |X |X |X | |array |None |X |X |X |X |X |X |[1] |[1] |[1] |[65, 66, 67] |X |X |[1] |[1] | -|binary |None |bytearray(b'\x00') |bytearray(b'\x00') |X |X |X |X |bytearray(b'\x01\x00\x00\x00') |bytearray(b'\x01') |bytearray(b'\x01') |bytearray(b'ABC') |X |X |bytearray(b'\x01') |bytearray(b'\x01') | +|binary |None |b'\x00' |b'\x00' |X |X |X |X |b'\x01\x00\x00\x00' |b'\x01' |b'\x01' |b'ABC' |X |X |b'\x01' |b'\x01' | |decimal(10,0) |None |X |X |X |X |X |Decimal('1') |X |X |X |X |Decimal('1') |X |X |X | |map |None |X |X |X |X |X |X |X |X |X |X |X |{'a': 1} |X |X | |struct<_1:int> |None |X |X |X |X |X |X |Row(_1=1) |Row(_1=1) |Row(_1=1) |Row(_1=65) |X |Row(_1=None) |Row(_1=1) |Row(_1=1) | diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index db162e8b1c521..13d773b7df714 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -357,7 +357,22 @@ def __repr__(self) -> str: class BinaryType(AtomicType, metaclass=DataTypeSingleton): """Binary (byte array) data type.""" - pass + def needConversion(self) -> bool: + return True + + def fromInternal(self, obj: Any) -> Any: + if obj is None: + return None + from pyspark.sql.conversion import _should_use_bytes_for_binary + + if _should_use_bytes_for_binary(): + if isinstance(obj, bytearray): + return bytes(obj) + return obj + else: + if isinstance(obj, bytes): + return bytearray(obj) + return obj class BooleanType(AtomicType, metaclass=DataTypeSingleton): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e73f16464aaa1..4c258f0ed5498 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -47,6 +47,7 @@ SpecialLengths, CPickleSerializer, BatchedSerializer, + BinaryConvertingSerializer, ) from pyspark.sql.conversion import LocalDataToArrowConversion, ArrowTableToRowsConversion from pyspark.sql.functions import SkipRestOfInputTableException @@ -113,7 +114,7 @@ def chain(f, g): return lambda *a: g(f(*a)) -def wrap_udf(f, args_offsets, kwargs_offsets, return_type): +def wrap_udf(f, args_offsets, kwargs_offsets, return_type, runner_conf=None): func, args_kwargs_offsets = wrap_kwargs_support(f, args_offsets, kwargs_offsets) if return_type.needConversion(): @@ -1248,7 +1249,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil func, args_offsets, kwargs_offsets, return_type, runner_conf, udf_index ) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: - return wrap_udf(func, args_offsets, kwargs_offsets, return_type) + return wrap_udf(func, args_offsets, kwargs_offsets, return_type, runner_conf) else: raise ValueError("Unknown eval type: {}".format(eval_type)) @@ -1271,6 +1272,13 @@ def use_large_var_types(runner_conf): return runner_conf.get("spark.sql.execution.arrow.useLargeVarTypes", "false").lower() == "true" +def use_bytes_for_binary(runner_conf): + return ( + runner_conf.get("spark.sql.execution.pyspark.binaryAsBytes.enabled", "true").lower() + == "true" + ) + + def use_legacy_pandas_udf_conversion(runner_conf): return ( runner_conf.get( @@ -1297,6 +1305,7 @@ def read_udtf(pickleSer, infile, eval_type): v = utf8_deserializer.loads(infile) runner_conf[k] = v prefers_large_var_types = use_large_var_types(runner_conf) + binary_as_bytes = use_bytes_for_binary(runner_conf) legacy_pandas_conversion = ( runner_conf.get( "spark.sql.legacy.execution.pythonUDTF.pandas.conversion.enabled", "false" @@ -1326,6 +1335,7 @@ def read_udtf(pickleSer, infile, eval_type): safecheck, input_types=input_types, int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + binary_as_bytes=binary_as_bytes, ) else: ser = ArrowStreamUDTFSerializer() @@ -1337,6 +1347,7 @@ def read_udtf(pickleSer, infile, eval_type): v = utf8_deserializer.loads(infile) runner_conf[k] = v prefers_large_var_types = use_large_var_types(runner_conf) + binary_as_bytes = use_bytes_for_binary(runner_conf) # Read the table argument offsets num_table_arg_offsets = read_int(infile) table_arg_offsets = [read_int(infile) for _ in range(num_table_arg_offsets)] @@ -2184,7 +2195,7 @@ def read_udfs(pickleSer, infile, eval_type): PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF, PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF, ): - # Load conf used for pandas_udf evaluation + # Load conf used for all UDF evaluations num_conf = read_int(infile) for i in range(num_conf): k = utf8_deserializer.loads(infile) @@ -2208,6 +2219,7 @@ def read_udfs(pickleSer, infile, eval_type): # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) prefers_large_var_types = use_large_var_types(runner_conf) + binary_as_bytes = use_bytes_for_binary(runner_conf) safecheck = ( runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower() == "true" @@ -2296,7 +2308,9 @@ def read_udfs(pickleSer, infile, eval_type): PythonEvalType.SQL_WINDOW_AGG_ARROW_UDF, ): # Arrow cast and safe check are always enabled - ser = ArrowStreamArrowUDFSerializer(timezone, True, _assign_cols_by_name, True) + ser = ArrowStreamArrowUDFSerializer( + timezone, True, _assign_cols_by_name, True, binary_as_bytes + ) elif ( eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF and not use_legacy_pandas_udf_conversion(runner_conf) @@ -2305,7 +2319,7 @@ def read_udfs(pickleSer, infile, eval_type): f.dataType for f in _parse_datatype_json_string(utf8_deserializer.loads(infile)) ] ser = ArrowBatchUDFSerializer( - timezone, safecheck, input_types, int_to_decimal_coercion_enabled + timezone, safecheck, input_types, int_to_decimal_coercion_enabled, binary_as_bytes ) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of @@ -2337,10 +2351,19 @@ def read_udfs(pickleSer, infile, eval_type): True, input_types, int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled, + binary_as_bytes=binary_as_bytes, ) - else: + elif eval_type == PythonEvalType.SQL_BATCHED_UDF: + num_conf = read_int(infile) + for i in range(num_conf): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + runner_conf[k] = v + batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100")) - ser = BatchedSerializer(CPickleSerializer(), batch_size) + binary_as_bytes = use_bytes_for_binary(runner_conf) + batched_ser = BatchedSerializer(CPickleSerializer(), batch_size) + ser = BinaryConvertingSerializer(batched_ser, binary_as_bytes) is_profiling = read_bool(infile) if is_profiling: @@ -2723,6 +2746,7 @@ def mapper(a): def mapper(a): result = tuple(f(*[a[o] for o in arg_offsets]) for arg_offsets, f in udfs) + # In the special case of a single UDF this will return a single result rather # than a tuple of results; this is the format that the JVM side expects. if len(result) == 1: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 17b8dd493cf80..19ce1f1197c21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3913,6 +3913,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PYSPARK_USE_BYTES_FOR_BINARY_ENABLED = + buildConf("spark.sql.execution.pyspark.binaryAsBytes.enabled") + .doc("When true, BINARY type values are returned as bytes objects in PySpark. " + + "When false, BINARY type values are returned as bytearray objects. ") + .version("4.1.0") + .booleanConf + .createWithDefault(true) + val PYTHON_UDF_MAX_RECORDS_PER_BATCH = buildConf("spark.sql.execution.python.udf.maxRecordsPerBatch") .doc("When using Python UDFs, limit the maximum number of records that can be batched " + @@ -7101,6 +7109,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def arrowPySparkSelfDestructEnabled: Boolean = getConf(ARROW_PYSPARK_SELF_DESTRUCT_ENABLED) + def PySparkUseBytesForBinaryEnabled: Boolean = getConf(PYSPARK_USE_BYTES_FOR_BINARY_ENABLED) + def pysparkJVMStacktraceEnabled: Boolean = getConf(PYSPARK_JVM_STACKTRACE_ENABLED) def pythonUDFProfiler: Option[String] = getConf(PYTHON_UDF_PROFILER) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 609fa218f1288..27e2dcdc2f5fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -102,8 +102,14 @@ class ArrowPythonRunner( funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf, pythonMetrics, jobArtifactUUID) { - override protected def writeUDF(dataOut: DataOutputStream): Unit = + override protected def writeUDF(dataOut: DataOutputStream): Unit = { + dataOut.writeInt(workerConf.size) + for ((k, v) <- workerConf) { + PythonWorkerUtils.writeUTF(k, dataOut) + PythonWorkerUtils.writeUTF(v, dataOut) + } PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler) + } } /** @@ -129,6 +135,11 @@ class ArrowPythonWithNamedArgumentRunner( if (evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF) { PythonWorkerUtils.writeUTF(schema.json, dataOut) } + dataOut.writeInt(workerConf.size) + for ((k, v) <- workerConf) { + PythonWorkerUtils.writeUTF(k, dataOut) + PythonWorkerUtils.writeUTF(v, dataOut) + } PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas, profiler) } } @@ -155,9 +166,13 @@ object ArrowPythonRunner { val intToDecimalCoercion = Seq( SQLConf.PYTHON_UDF_PANDAS_INT_TO_DECIMAL_COERCION_ENABLED.key -> conf.getConf(SQLConf.PYTHON_UDF_PANDAS_INT_TO_DECIMAL_COERCION_ENABLED, false).toString) + val binaryAsBytes = Seq( + SQLConf.PYSPARK_USE_BYTES_FOR_BINARY_ENABLED.key -> + conf.getConf(SQLConf.PYSPARK_USE_BYTES_FOR_BINARY_ENABLED, defaultValue = true).toString) Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck ++ arrowAyncParallelism ++ useLargeVarTypes ++ intToDecimalCoercion ++ + binaryAsBytes ++ legacyPandasConversion ++ legacyPandasConversionUDF: _*) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 28318a319b088..da410053b2e52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -76,10 +76,20 @@ class BatchEvalPythonEvaluatorFactory( // Input iterator to Python. val inputIterator = BatchEvalPythonExec.getInputIterator(iter, schema, batchSize) + val binaryAsBytes = SQLConf.get.getConf(SQLConf.PYSPARK_USE_BYTES_FOR_BINARY_ENABLED) + val workerConf = Map( + SQLConf.PYSPARK_USE_BYTES_FOR_BINARY_ENABLED.key -> binaryAsBytes.toString) + // Output iterator for results from Python. val outputIterator = new PythonUDFWithNamedArgumentsRunner( - funcs, PythonEvalType.SQL_BATCHED_UDF, argMetas, pythonMetrics, jobArtifactUUID, profiler) + funcs, + PythonEvalType.SQL_BATCHED_UDF, + argMetas, + workerConf, + pythonMetrics, + jobArtifactUUID, + profiler) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 8ff7e57d9421e..2f334291ca114 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -65,9 +65,7 @@ abstract class BasePythonUDFRunner( context: TaskContext): Writer = { new Writer(env, worker, inputIterator, partitionIndex, context) { - protected override def writeCommand(dataOut: DataOutputStream): Unit = { - writeUDF(dataOut) - } + protected override def writeCommand(dataOut: DataOutputStream): Unit = writeUDF(dataOut) override def writeNextInputToStream(dataOut: DataOutputStream): Boolean = { val startData = dataOut.size() @@ -127,7 +125,11 @@ class PythonUDFRunner( pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], profiler: Option[String]) - extends BasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics, jobArtifactUUID) { + extends BasePythonUDFRunner(funcs, + evalType, + argOffsets, + pythonMetrics, + jobArtifactUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = { PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler) @@ -138,13 +140,23 @@ class PythonUDFWithNamedArgumentsRunner( funcs: Seq[(ChainedPythonFunctions, Long)], evalType: Int, argMetas: Array[Array[ArgumentMetadata]], + workerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], profiler: Option[String]) extends BasePythonUDFRunner( - funcs, evalType, argMetas.map(_.map(_.offset)), pythonMetrics, jobArtifactUUID) { + funcs, + evalType, + argMetas.map(_.map(_.offset)), + pythonMetrics, + jobArtifactUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = { + dataOut.writeInt(workerConf.size) + for ((k, v) <- workerConf) { + PythonWorkerUtils.writeUTF(k, dataOut) + PythonWorkerUtils.writeUTF(v, dataOut) + } PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas, profiler) } } @@ -212,4 +224,5 @@ object PythonUDFRunner { } } } + }