Skip to content

Commit ebb476a

Browse files
committed
[SPARK-52239][PYTHON][CONNECT] Support register an arrow UDF
### What changes were proposed in this pull request? Support register an arrow UDF ### Why are the changes needed? make sure arrow UDF works with: - `spark.catalog.registerFunction` - `spark.udf.register` ### Does this PR introduce _any_ user-facing change? no, arrow UDF is not exposed to end users for now ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #50964 from zhengruifeng/py_arrow_udf_test_1. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 9df43ac commit ebb476a

File tree

1 file changed

+82
-11
lines changed

1 file changed

+82
-11
lines changed

python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py

Lines changed: 82 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -549,20 +549,92 @@ def f(x):
549549
)
550550
self.assertEqual(df.collect(), res.collect())
551551

552-
def test_register_nondeterministic_arrow_udf(self):
552+
def test_udf_register_arrow_udf_basic(self):
553553
import pyarrow as pa
554554

555-
random_pandas_udf = arrow_udf(
555+
scalar_original_add = arrow_udf(
556+
lambda x, y: pa.compute.add(x, y).cast(pa.int32()), IntegerType()
557+
)
558+
self.assertEqual(scalar_original_add.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
559+
self.assertEqual(scalar_original_add.deterministic, True)
560+
561+
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS add1")
562+
new_add = self.spark.udf.register("add1", scalar_original_add)
563+
564+
self.assertEqual(new_add.deterministic, True)
565+
self.assertEqual(new_add.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
566+
567+
df = self.spark.range(10).select(
568+
F.col("id").cast("int").alias("a"), F.col("id").cast("int").alias("b")
569+
)
570+
res1 = df.select(new_add(F.col("a"), F.col("b")))
571+
res2 = self.spark.sql(
572+
"SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t"
573+
)
574+
expected = df.select(F.expr("a + b"))
575+
self.assertEqual(expected.collect(), res1.collect())
576+
self.assertEqual(expected.collect(), res2.collect())
577+
578+
def test_catalog_register_arrow_udf_basic(self):
579+
import pyarrow as pa
580+
581+
scalar_original_add = arrow_udf(
582+
lambda x, y: pa.compute.add(x, y).cast(pa.int32()), IntegerType()
583+
)
584+
self.assertEqual(scalar_original_add.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
585+
self.assertEqual(scalar_original_add.deterministic, True)
586+
587+
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS add1")
588+
new_add = self.spark.catalog.registerFunction("add1", scalar_original_add)
589+
590+
self.assertEqual(new_add.deterministic, True)
591+
self.assertEqual(new_add.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
592+
593+
df = self.spark.range(10).select(
594+
F.col("id").cast("int").alias("a"), F.col("id").cast("int").alias("b")
595+
)
596+
res1 = df.select(new_add(F.col("a"), F.col("b")))
597+
res2 = self.spark.sql(
598+
"SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t"
599+
)
600+
expected = df.select(F.expr("a + b"))
601+
self.assertEqual(expected.collect(), res1.collect())
602+
self.assertEqual(expected.collect(), res2.collect())
603+
604+
def test_udf_register_nondeterministic_arrow_udf(self):
605+
import pyarrow as pa
606+
607+
random_arrow_udf = arrow_udf(
556608
lambda x: pa.compute.add(x, random.randint(6, 6)), LongType()
557609
).asNondeterministic()
558-
self.assertEqual(random_pandas_udf.deterministic, False)
559-
self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
560-
nondeterministic_pandas_udf = self.spark.catalog.registerFunction(
561-
"randomPandasUDF", random_pandas_udf
562-
)
563-
self.assertEqual(nondeterministic_pandas_udf.deterministic, False)
564-
self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
565-
[row] = self.spark.sql("SELECT randomPandasUDF(1)").collect()
610+
self.assertEqual(random_arrow_udf.deterministic, False)
611+
self.assertEqual(random_arrow_udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
612+
613+
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS randomArrowUDF")
614+
nondeterministic_arrow_udf = self.spark.udf.register("randomArrowUDF", random_arrow_udf)
615+
616+
self.assertEqual(nondeterministic_arrow_udf.deterministic, False)
617+
self.assertEqual(nondeterministic_arrow_udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
618+
[row] = self.spark.sql("SELECT randomArrowUDF(1)").collect()
619+
self.assertEqual(row[0], 7)
620+
621+
def test_catalog_register_nondeterministic_arrow_udf(self):
622+
import pyarrow as pa
623+
624+
random_arrow_udf = arrow_udf(
625+
lambda x: pa.compute.add(x, random.randint(6, 6)), LongType()
626+
).asNondeterministic()
627+
self.assertEqual(random_arrow_udf.deterministic, False)
628+
self.assertEqual(random_arrow_udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
629+
630+
self.spark.sql("DROP TEMPORARY FUNCTION IF EXISTS randomArrowUDF")
631+
nondeterministic_arrow_udf = self.spark.catalog.registerFunction(
632+
"randomArrowUDF", random_arrow_udf
633+
)
634+
635+
self.assertEqual(nondeterministic_arrow_udf.deterministic, False)
636+
self.assertEqual(nondeterministic_arrow_udf.evalType, PythonEvalType.SQL_SCALAR_ARROW_UDF)
637+
[row] = self.spark.sql("SELECT randomArrowUDF(1)").collect()
566638
self.assertEqual(row[0], 7)
567639

568640
def test_nondeterministic_arrow_udf(self):
@@ -599,7 +671,6 @@ def test_nondeterministic_arrow_udf_in_aggregate(self):
599671
with self.assertRaisesRegex(AnalysisException, "Non-deterministic"):
600672
df.agg(F.sum(random_udf("id"))).collect()
601673

602-
# TODO: add tests for registering Arrow UDF
603674
# TODO: add tests for chained Arrow UDFs
604675
# TODO: add tests for named arguments
605676

0 commit comments

Comments
 (0)