Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions docs/source/python/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,46 @@ For example, the "numpy_gcd" function that we've been using as an example above
function to use in a projection. A "cumulative sum" function would not be a valid function
since the result of each input row depends on the rows that came before. A "drop nulls"
function would also be invalid because it doesn't emit a value for some rows.


Standard Python Operators
=========================

PyArrow supports standard Python operators for element-wise operations for arrays and scalars.
Currently, the support is limited to some of the standard compute functions, i.e.
arithmetic (``+``, ``-``, ``/``, ``%``, ``**``),
bitwise (``&``, ``|``, ``^``, ``>>``, ``<<``) and others.

The aforementioned operators use checked version of underlying kernels wherever possible
and have the same respective constraints, e.g. you cannot add two arrays of strings.

You can use the operators as following:

.. code-block:: python

>>> import pyarrow as pa
>>> arr = pa.array([-1, 2, -3])
>>> val = pa.scalar(42.7)
>>> arr + val
<pyarrow.lib.DoubleArray object at ...>
[
41.7,
44.7,
39.7
]

>>> val ** arr
<pyarrow.lib.DoubleArray object at ...>
[
0.023419203747072598,
1823.2900000000002,
0.000012844475506953143
]

>>> arr << 2
<pyarrow.lib.Int64Array object at ...>
[
-4,
8,
-12
]
73 changes: 73 additions & 0 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def array(object obj, type=None, mask=None, size=None, from_pandas=None,
pa.int16() even if pa.int8() was passed to the function. Note that an
explicit index type will not be demoted even if it is wider than required.

This class supports Python's standard operators
for element-wise operations, i.e. arithmetic (`+`, `-`, `/`, `%`, `**`),
bitwise (`&`, `|`, `^`, `>>`, `<<`) and others.
They can be used directly instead of calling underlying
`pyarrow.compute` functions explicitly.

Examples
--------
>>> import pandas as pd
Expand Down Expand Up @@ -229,6 +235,25 @@ def array(object obj, type=None, mask=None, size=None, from_pandas=None,
>>> arr = pa.array(range(1024), type=pa.dictionary(pa.int8(), pa.int64()))
>>> arr.type.index_type
DataType(int16)

>>> arr1 = pa.array([1, 2, 3], type=pa.int8())
>>> arr2 = pa.array([4, 5, 6], type=pa.int8())
>>> arr1 + arr2
<pyarrow.lib.Int8Array object at ...>
[
5,
7,
9
]

>>> val = pa.scalar(42)
>>> val - arr1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I simply call arr1 - 42 or would that not work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would work, yes. I wanted to show in the docstrings that the users can also use explicit scalars, that's why I went with this.

<pyarrow.lib.Int64Array object at ...>
[
41,
40,
39
]
"""
cdef:
CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
Expand Down Expand Up @@ -2259,6 +2284,54 @@ cdef class Array(_PandasConvertible):
stat.init(sp_stat)
return stat

def __abs__(self):
self._assert_cpu()
return _pc().call_function('abs_checked', [self])

def __add__(self, object other):
self._assert_cpu()
return _pc().call_function('add_checked', [self, other])

def __truediv__(self, object other):
self._assert_cpu()
return _pc().call_function('divide_checked', [self, other])

def __mul__(self, object other):
self._assert_cpu()
return _pc().call_function('multiply_checked', [self, other])

def __neg__(self):
self._assert_cpu()
return _pc().call_function('negate_checked', [self])

def __pow__(self, object other):
self._assert_cpu()
return _pc().call_function('power_checked', [self, other])

def __sub__(self, object other):
self._assert_cpu()
return _pc().call_function('subtract_checked', [self, other])

def __and__(self, object other):
self._assert_cpu()
return _pc().call_function('bit_wise_and', [self, other])

def __or__(self, object other):
self._assert_cpu()
return _pc().call_function('bit_wise_or', [self, other])

def __xor__(self, object other):
self._assert_cpu()
return _pc().call_function('bit_wise_xor', [self, other])

def __lshift__(self, object other):
self._assert_cpu()
return _pc().call_function('shift_left_checked', [self, other])

def __rshift__(self, object other):
self._assert_cpu()
return _pc().call_function('shift_right_checked', [self, other])


cdef _array_like_to_pandas(obj, options, types_mapper):
cdef:
Expand Down
63 changes: 63 additions & 0 deletions python/pyarrow/scalar.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,33 @@ from collections.abc import Sequence, Mapping
cdef class Scalar(_Weakrefable):
"""
The base class for scalars.

Notes
-----
This class supports Python's standard operators
for element-wise operations, i.e. arithmetic (`+`, `-`, `/`, `%`, `**`),
bitwise (`&`, `|`, `^`, `>>`, `<<`) and others.
They can be used directly instead of calling underlying
`pyarrow.compute` functions explicitly.

Examples
--------
>>> import pyarrow as pa
>>> pa.scalar(42) + pa.scalar(17)
<pyarrow.Int64Scalar: 59>

>>> pa.scalar(6) ** 3
<pyarrow.Int64Scalar: 216>

>>> arr = pa.array([1, 2, 3], type=pa.int8())
>>> val = pa.scalar(42)
>>> val - arr
<pyarrow.lib.Int64Array object at ...>
[
41,
40,
39
]
"""

def __init__(self):
Expand Down Expand Up @@ -168,6 +195,42 @@ cdef class Scalar(_Weakrefable):
"""
raise NotImplementedError()

def __abs__(self):
return _pc().call_function('abs_checked', [self])

def __add__(self, object other):
return _pc().call_function('add_checked', [self, other])

def __truediv__(self, object other):
return _pc().call_function('divide_checked', [self, other])

def __mul__(self, object other):
return _pc().call_function('multiply_checked', [self, other])

def __neg__(self):
return _pc().call_function('negate_checked', [self])

def __pow__(self, object other):
return _pc().call_function('power_checked', [self, other])

def __sub__(self, object other):
return _pc().call_function('subtract_checked', [self, other])

def __and__(self, object other):
return _pc().call_function('bit_wise_and', [self, other])

def __or__(self, object other):
return _pc().call_function('bit_wise_or', [self, other])

def __xor__(self, object other):
return _pc().call_function('bit_wise_xor', [self, other])

def __lshift__(self, object other):
return _pc().call_function('shift_left_checked', [self, other])

def __rshift__(self, object other):
return _pc().call_function('shift_right_checked', [self, other])


_NULL = NA = None

Expand Down
70 changes: 70 additions & 0 deletions python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import pyarrow as pa
import pyarrow.tests.strategies as past
from pyarrow.vendored.version import Version
import pyarrow.compute as pc


@pytest.mark.processes
Expand Down Expand Up @@ -4398,3 +4399,72 @@ def test_non_cpu_array():
arr.tolist()
with pytest.raises(NotImplementedError):
arr.validate(full=True)


def test_arithmetic_dunders():
# GH-32007
arr1 = pa.array([-1.1, 2.2, -3.3])
arr2 = pa.array([2.2, 4.4, 5.5])

assert (arr1 + arr2).equals(pc.add_checked(arr1, arr2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we really want to assert the "checked" aspect, then we should also include a case where overflow occurs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall intention of this was to test that the dunder methods have the same output given the same input (kernels as well). Albeit, your suggestion is valid, so added extra tests to cover whether dunder methods indeed overflow.

assert (arr2 / arr1).equals(pc.divide_checked(arr2, arr1))
assert (arr1 * arr2).equals(pc.multiply_checked(arr1, arr2))
assert (-arr1).equals(pc.negate_checked(arr1))
assert (arr1 ** 2).equals(pc.power_checked(arr1, 2))
assert (arr1 - arr2).equals(pc.subtract_checked(arr1, arr2))


def test_bitwise_dunders():
# GH-32007
arr1 = pa.array([-1, 2, -3])
arr2 = pa.array([2, 4, 5])

assert (arr1 & arr2).equals(pc.bit_wise_and(arr1, arr2))
assert (arr1 | arr2).equals(pc.bit_wise_or(arr1, arr2))
assert (arr1 ^ arr2).equals(pc.bit_wise_xor(arr1, arr2))
assert (arr1 << arr2).equals(pc.shift_left_checked(arr1, arr2))
assert (arr1 >> arr2).equals(pc.shift_right_checked(arr1, arr2))


def test_dunders_unmatching_types():
# GH-32007
error_match = r"Function '\w+' has no kernel matching input types"
string_arr = pa.array(["a", "b", "c"])
nested_arr = pa.array([{"x": 1, "y": True}, {"z": 3.4, "x": 4}])
double_arr = pa.array([1.0, 2.0, 3.0])

with pytest.raises(pa.ArrowNotImplementedError, match=error_match):
string_arr + nested_arr
with pytest.raises(pa.ArrowNotImplementedError, match=error_match):
string_arr - double_arr
with pytest.raises(pa.ArrowNotImplementedError, match=error_match):
double_arr * nested_arr


def test_dunders_mixed_types():
# GH-32007
arr = pa.array([11.0, 17.0, 23.0])
val = pa.scalar(3)

assert (arr + val).equals(pc.add_checked(arr, val))
assert (arr - val).equals(pc.subtract_checked(arr, val))
assert (arr / val).equals(pc.divide_checked(arr, val))
assert (arr * val).equals(pc.multiply_checked(arr, val))
assert (arr ** val).equals(pc.power_checked(arr, val))


def test_dunders_checked_overflow():
# GH-32007
arr = pa.array([127, -128], type=pa.int8())
error_match = "overflow"

with pytest.raises(pa.ArrowInvalid, match=error_match):
arr + arr
with pytest.raises(pa.ArrowInvalid, match=error_match):
arr * arr
with pytest.raises(pa.ArrowInvalid, match=error_match):
arr - (-arr)
with pytest.raises(pa.ArrowInvalid, match=error_match):
arr ** pa.scalar(2, type=pa.int8())
with pytest.raises(pa.ArrowInvalid, match=error_match):
arr / (-arr)
56 changes: 56 additions & 0 deletions python/pyarrow/tests/test_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,3 +995,59 @@ def test_map_scalar_with_empty_values():
s = pa.scalar(v, type=map_type)

assert s.as_py(maps_as_pydicts="strict") == v


def test_arithmetic_dunders():
# GH-32007
scl1 = pa.scalar(42)
scl2 = pa.scalar(-17)

assert (scl1 + scl2).equals(pc.add_checked(scl1, scl2))
assert (scl2 / scl1).equals(pc.divide_checked(scl2, scl1))
assert (scl1 * scl2).equals(pc.multiply_checked(scl1, scl2))
assert (-scl1).equals(pc.negate_checked(scl1))
assert (scl1 ** 2).equals(pc.power_checked(scl1, 2))
assert (scl1 - scl2).equals(pc.subtract_checked(scl1, scl2))


def test_bitwise_dunders():
# GH-32007
scl1 = pa.scalar(42)
scl2 = pa.scalar(-17)

assert (scl1 & scl2).equals(pc.bit_wise_and(scl1, scl2))
assert (scl1 | scl2).equals(pc.bit_wise_or(scl1, scl2))
assert (scl1 ^ scl2).equals(pc.bit_wise_xor(scl1, scl2))
assert (scl2 << scl1).equals(pc.shift_left_checked(scl2, scl1))
assert (scl2 >> scl1).equals(pc.shift_right_checked(scl2, scl1))


def test_dunders_unmatching_types():
# GH-32007
error_match = r"Function '\w+' has no kernel matching input types"
string_scl = pa.scalar("abc")
double_scl = pa.scalar(1.23)

with pytest.raises(pa.ArrowNotImplementedError, match=error_match):
string_scl + double_scl
with pytest.raises(pa.ArrowNotImplementedError, match=error_match):
string_scl - double_scl
with pytest.raises(pa.ArrowNotImplementedError, match=error_match):
string_scl / double_scl
with pytest.raises(pa.ArrowNotImplementedError, match=error_match):
string_scl * double_scl


def test_dunders_checked_overflow():
# GH-32007
error_match = "overflow"
scl = pa.scalar(127, type=pa.int8())

with pytest.raises(pa.ArrowInvalid, match=error_match):
scl + scl
with pytest.raises(pa.ArrowInvalid, match=error_match):
scl - (-scl)
with pytest.raises(pa.ArrowInvalid, match=error_match):
scl ** scl
with pytest.raises(pa.ArrowInvalid, match=error_match):
scl * scl
Loading