@@ -549,20 +549,92 @@ def f(x):
549
549
)
550
550
self .assertEqual (df .collect (), res .collect ())
551
551
552
- def test_register_nondeterministic_arrow_udf (self ):
552
+ def test_udf_register_arrow_udf_basic (self ):
553
553
import pyarrow as pa
554
554
555
- random_pandas_udf = arrow_udf (
555
+ scalar_original_add = arrow_udf (
556
+ lambda x , y : pa .compute .add (x , y ).cast (pa .int32 ()), IntegerType ()
557
+ )
558
+ self .assertEqual (scalar_original_add .evalType , PythonEvalType .SQL_SCALAR_ARROW_UDF )
559
+ self .assertEqual (scalar_original_add .deterministic , True )
560
+
561
+ self .spark .sql ("DROP TEMPORARY FUNCTION IF EXISTS add1" )
562
+ new_add = self .spark .udf .register ("add1" , scalar_original_add )
563
+
564
+ self .assertEqual (new_add .deterministic , True )
565
+ self .assertEqual (new_add .evalType , PythonEvalType .SQL_SCALAR_ARROW_UDF )
566
+
567
+ df = self .spark .range (10 ).select (
568
+ F .col ("id" ).cast ("int" ).alias ("a" ), F .col ("id" ).cast ("int" ).alias ("b" )
569
+ )
570
+ res1 = df .select (new_add (F .col ("a" ), F .col ("b" )))
571
+ res2 = self .spark .sql (
572
+ "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t"
573
+ )
574
+ expected = df .select (F .expr ("a + b" ))
575
+ self .assertEqual (expected .collect (), res1 .collect ())
576
+ self .assertEqual (expected .collect (), res2 .collect ())
577
+
578
+ def test_catalog_register_arrow_udf_basic (self ):
579
+ import pyarrow as pa
580
+
581
+ scalar_original_add = arrow_udf (
582
+ lambda x , y : pa .compute .add (x , y ).cast (pa .int32 ()), IntegerType ()
583
+ )
584
+ self .assertEqual (scalar_original_add .evalType , PythonEvalType .SQL_SCALAR_ARROW_UDF )
585
+ self .assertEqual (scalar_original_add .deterministic , True )
586
+
587
+ self .spark .sql ("DROP TEMPORARY FUNCTION IF EXISTS add1" )
588
+ new_add = self .spark .catalog .registerFunction ("add1" , scalar_original_add )
589
+
590
+ self .assertEqual (new_add .deterministic , True )
591
+ self .assertEqual (new_add .evalType , PythonEvalType .SQL_SCALAR_ARROW_UDF )
592
+
593
+ df = self .spark .range (10 ).select (
594
+ F .col ("id" ).cast ("int" ).alias ("a" ), F .col ("id" ).cast ("int" ).alias ("b" )
595
+ )
596
+ res1 = df .select (new_add (F .col ("a" ), F .col ("b" )))
597
+ res2 = self .spark .sql (
598
+ "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t"
599
+ )
600
+ expected = df .select (F .expr ("a + b" ))
601
+ self .assertEqual (expected .collect (), res1 .collect ())
602
+ self .assertEqual (expected .collect (), res2 .collect ())
603
+
604
+ def test_udf_register_nondeterministic_arrow_udf (self ):
605
+ import pyarrow as pa
606
+
607
+ random_arrow_udf = arrow_udf (
556
608
lambda x : pa .compute .add (x , random .randint (6 , 6 )), LongType ()
557
609
).asNondeterministic ()
558
- self .assertEqual (random_pandas_udf .deterministic , False )
559
- self .assertEqual (random_pandas_udf .evalType , PythonEvalType .SQL_SCALAR_ARROW_UDF )
560
- nondeterministic_pandas_udf = self .spark .catalog .registerFunction (
561
- "randomPandasUDF" , random_pandas_udf
562
- )
563
- self .assertEqual (nondeterministic_pandas_udf .deterministic , False )
564
- self .assertEqual (nondeterministic_pandas_udf .evalType , PythonEvalType .SQL_SCALAR_ARROW_UDF )
565
- [row ] = self .spark .sql ("SELECT randomPandasUDF(1)" ).collect ()
610
+ self .assertEqual (random_arrow_udf .deterministic , False )
611
+ self .assertEqual (random_arrow_udf .evalType , PythonEvalType .SQL_SCALAR_ARROW_UDF )
612
+
613
+ self .spark .sql ("DROP TEMPORARY FUNCTION IF EXISTS randomArrowUDF" )
614
+ nondeterministic_arrow_udf = self .spark .udf .register ("randomArrowUDF" , random_arrow_udf )
615
+
616
+ self .assertEqual (nondeterministic_arrow_udf .deterministic , False )
617
+ self .assertEqual (nondeterministic_arrow_udf .evalType , PythonEvalType .SQL_SCALAR_ARROW_UDF )
618
+ [row ] = self .spark .sql ("SELECT randomArrowUDF(1)" ).collect ()
619
+ self .assertEqual (row [0 ], 7 )
620
+
621
+ def test_catalog_register_nondeterministic_arrow_udf (self ):
622
+ import pyarrow as pa
623
+
624
+ random_arrow_udf = arrow_udf (
625
+ lambda x : pa .compute .add (x , random .randint (6 , 6 )), LongType ()
626
+ ).asNondeterministic ()
627
+ self .assertEqual (random_arrow_udf .deterministic , False )
628
+ self .assertEqual (random_arrow_udf .evalType , PythonEvalType .SQL_SCALAR_ARROW_UDF )
629
+
630
+ self .spark .sql ("DROP TEMPORARY FUNCTION IF EXISTS randomArrowUDF" )
631
+ nondeterministic_arrow_udf = self .spark .catalog .registerFunction (
632
+ "randomArrowUDF" , random_arrow_udf
633
+ )
634
+
635
+ self .assertEqual (nondeterministic_arrow_udf .deterministic , False )
636
+ self .assertEqual (nondeterministic_arrow_udf .evalType , PythonEvalType .SQL_SCALAR_ARROW_UDF )
637
+ [row ] = self .spark .sql ("SELECT randomArrowUDF(1)" ).collect ()
566
638
self .assertEqual (row [0 ], 7 )
567
639
568
640
def test_nondeterministic_arrow_udf (self ):
@@ -599,7 +671,6 @@ def test_nondeterministic_arrow_udf_in_aggregate(self):
599
671
with self .assertRaisesRegex (AnalysisException , "Non-deterministic" ):
600
672
df .agg (F .sum (random_udf ("id" ))).collect ()
601
673
602
- # TODO: add tests for registering Arrow UDF
603
674
# TODO: add tests for chained Arrow UDFs
604
675
# TODO: add tests for named arguments
605
676
0 commit comments