16
16
#
17
17
18
18
import numbers
19
- from typing import Any , Union
19
+ from typing import Any , Union , Callable
20
20
21
21
import numpy as np
22
22
import pandas as pd
@@ -271,13 +271,22 @@ def floordiv(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
271
271
_sanitize_list_like (right )
272
272
if not is_valid_operand_for_numeric_arithmetic (right ):
273
273
raise TypeError ("Floor division can not be applied to given types." )
274
+ spark_session = left ._internal .spark_frame .sparkSession
275
+ use_try_divide = is_ansi_mode_enabled (spark_session )
276
+
277
+ def fallback_div (x : PySparkColumn , y : PySparkColumn ) -> PySparkColumn :
278
+ return x .__div__ (y )
279
+
280
+ safe_div : Callable [[PySparkColumn , PySparkColumn ], PySparkColumn ] = (
281
+ F .try_divide if use_try_divide else fallback_div
282
+ )
274
283
275
284
def floordiv (left : PySparkColumn , right : Any ) -> PySparkColumn :
276
285
return F .when (F .lit (right is np .nan ), np .nan ).otherwise (
277
286
F .when (
278
287
F .lit (right != 0 ) | F .lit (right ).isNull (),
279
288
F .floor (left .__div__ (right )),
280
- ).otherwise (F .lit (np .inf ). __div__ ( left ))
289
+ ).otherwise (safe_div ( F .lit (np .inf ), left ))
281
290
)
282
291
283
292
right = transform_boolean_operand_to_numeric (right , spark_type = left .spark .data_type )
@@ -369,6 +378,15 @@ def floordiv(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
369
378
_sanitize_list_like (right )
370
379
if not is_valid_operand_for_numeric_arithmetic (right ):
371
380
raise TypeError ("Floor division can not be applied to given types." )
381
+ spark_session = left ._internal .spark_frame .sparkSession
382
+ use_try_divide = is_ansi_mode_enabled (spark_session )
383
+
384
+ def fallback_div (x : PySparkColumn , y : PySparkColumn ) -> PySparkColumn :
385
+ return x .__div__ (y )
386
+
387
+ safe_div : Callable [[PySparkColumn , PySparkColumn ], PySparkColumn ] = (
388
+ F .try_divide if use_try_divide else fallback_div
389
+ )
372
390
373
391
def floordiv (left : PySparkColumn , right : Any ) -> PySparkColumn :
374
392
return F .when (F .lit (right is np .nan ), np .nan ).otherwise (
@@ -377,7 +395,7 @@ def floordiv(left: PySparkColumn, right: Any) -> PySparkColumn:
377
395
F .floor (left .__div__ (right )),
378
396
).otherwise (
379
397
F .when (F .lit (left == np .inf ) | F .lit (left == - np .inf ), left ).otherwise (
380
- F .lit (np .inf ). __div__ ( left )
398
+ safe_div ( F .lit (np .inf ), left )
381
399
)
382
400
)
383
401
)
0 commit comments