Skip to content

Commit

Permalink
Shape hints
Browse files Browse the repository at this point in the history
Signed-off-by: Iaroslav Igoshev <[email protected]>
  • Loading branch information
YarShev committed Mar 12, 2024
1 parent fe3a229 commit 4b1efd2
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 60 deletions.
39 changes: 34 additions & 5 deletions modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,15 +1474,21 @@ def to_datetime(self, *args, **kwargs):

# Abstract map partitions operations
# These operations are operations that apply a function to every partition.
def abs(self):
def abs(self, **kwargs):
"""
Get absolute numeric value of each element.
Parameters
----------
**kwargs : dict
Serves the compatibility purpose. Does not affect the result.
Returns
-------
BaseQueryCompiler
QueryCompiler with absolute numeric value of each element.
"""
_ = kwargs.pop("shape_hint", None)
return DataFrameDefault.register(pandas.DataFrame.abs)(self)

def map(self, func, *args, **kwargs):
Expand All @@ -1501,6 +1507,7 @@ def map(self, func, *args, **kwargs):
BaseQueryCompiler
Transformed QueryCompiler.
"""
_ = kwargs.pop("shape_hint", None)
return DataFrameDefault.register(pandas.DataFrame.map)(
self, func, *args, **kwargs
)
Expand Down Expand Up @@ -1578,16 +1585,22 @@ def isin(self, values, ignore_indices=False, **kwargs): # noqa: PR02
self, values, **kwargs
)

def isna(self):
def isna(self, **kwargs):
"""
Check for each element of self whether it's NaN.
Parameters
----------
**kwargs : dict
Serves the compatibility purpose. Does not affect the result.
Returns
-------
BaseQueryCompiler
Boolean mask for self of whether an element at the corresponding
position is NaN.
"""
_ = kwargs.pop("shape_hint", None)
return DataFrameDefault.register(pandas.DataFrame.isna)(self)

# FIXME: this method is not supposed to take any parameters (Modin issue #3108).
Expand All @@ -1608,12 +1621,18 @@ def negative(self, **kwargs):
-----
Be aware, that all QueryCompiler values have to be numeric.
"""
_ = kwargs.pop("shape_hint", None)
return DataFrameDefault.register(pandas.DataFrame.__neg__)(self, **kwargs)

def notna(self):
def notna(self, **kwargs):
"""
Check for each element of `self` whether it's existing (non-missing) value.
Parameters
----------
**kwargs : dict
Serves the compatibility purpose. Does not affect the result.
Returns
-------
BaseQueryCompiler
Expand All @@ -1639,6 +1658,7 @@ def round(self, **kwargs): # noqa: PR02
BaseQueryCompiler
QueryCompiler with rounded values.
"""
_ = kwargs.pop("shape_hint", None)
return DataFrameDefault.register(pandas.DataFrame.round)(self, **kwargs)

# FIXME:
Expand Down Expand Up @@ -1666,6 +1686,7 @@ def replace(self, **kwargs): # noqa: PR02
BaseQueryCompiler
QueryCompiler with all `to_replace` values replaced by `value`.
"""
_ = kwargs.pop("shape_hint", None)
return DataFrameDefault.register(pandas.DataFrame.replace)(self, **kwargs)

@doc_utils.add_refer_to("Series.argsort")
Expand Down Expand Up @@ -1830,7 +1851,7 @@ def stack(self, level, dropna):
)

# Abstract map partitions across select indices
def astype(self, col_dtypes, errors: str = "raise"): # noqa: PR02
def astype(self, col_dtypes, errors: str = "raise", **kwargs): # noqa: PR02
"""
Convert columns dtypes to given dtypes.
Expand All @@ -1842,6 +1863,8 @@ def astype(self, col_dtypes, errors: str = "raise"): # noqa: PR02
Control raising of exceptions on invalid data for provided dtype.
- raise : allow exceptions to be raised
- ignore : suppress exceptions. On error return original object.
**kwargs : dict
Serves the compatibility purpose. Does not affect the result.
Returns
-------
Expand Down Expand Up @@ -6649,15 +6672,21 @@ def struct_explode(self):

# DataFrame methods

def invert(self):
def invert(self, **kwargs):
"""
Apply bitwise inversion for each element of the QueryCompiler.
Parameters
----------
**kwargs : dict
Serves the compatibility purpose. Does not affect the result.
Returns
-------
BaseQueryCompiler
New QueryCompiler containing bitwise inversion for each value.
"""
_ = kwargs.pop("shape_hint", None)
return DataFrameDefault.register(pandas.DataFrame.__invert__)(self)

@doc_utils.doc_reduce_agg(
Expand Down
99 changes: 66 additions & 33 deletions modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,60 +1779,91 @@ def isin_func(df, values):
series_view = Map.register(
lambda df, *args, **kwargs: pandas.DataFrame(
df.squeeze(axis=1).view(*args, **kwargs)
)
),
shape_hint="column",
)
to_numeric = Map.register(
lambda df, *args, **kwargs: pandas.DataFrame(
pandas.to_numeric(df.squeeze(axis=1), *args, **kwargs)
)
),
shape_hint="column",
)
to_timedelta = Map.register(
lambda s, *args, **kwargs: pandas.to_timedelta(
s.squeeze(axis=1), *args, **kwargs
).to_frame(),
dtypes="timedelta64[ns]",
shape_hint="column",
)

# END Map partitions operations

# String map partitions operations

str_capitalize = Map.register(_str_map("capitalize"), dtypes="copy")
str_center = Map.register(_str_map("center"), dtypes="copy")
str_contains = Map.register(_str_map("contains"), dtypes=np.bool_)
str_count = Map.register(_str_map("count"), dtypes=int)
str_endswith = Map.register(_str_map("endswith"), dtypes=np.bool_)
str_find = Map.register(_str_map("find"), dtypes=np.int64)
str_findall = Map.register(_str_map("findall"), dtypes="copy")
str_get = Map.register(_str_map("get"), dtypes="copy")
str_index = Map.register(_str_map("index"), dtypes=np.int64)
str_isalnum = Map.register(_str_map("isalnum"), dtypes=np.bool_)
str_isalpha = Map.register(_str_map("isalpha"), dtypes=np.bool_)
str_isdecimal = Map.register(_str_map("isdecimal"), dtypes=np.bool_)
str_isdigit = Map.register(_str_map("isdigit"), dtypes=np.bool_)
str_islower = Map.register(_str_map("islower"), dtypes=np.bool_)
str_isnumeric = Map.register(_str_map("isnumeric"), dtypes=np.bool_)
str_isspace = Map.register(_str_map("isspace"), dtypes=np.bool_)
str_istitle = Map.register(_str_map("istitle"), dtypes=np.bool_)
str_isupper = Map.register(_str_map("isupper"), dtypes=np.bool_)
str_join = Map.register(_str_map("join"), dtypes="copy")
str_len = Map.register(_str_map("len"), dtypes=int)
str_ljust = Map.register(_str_map("ljust"), dtypes="copy")
str_lower = Map.register(_str_map("lower"), dtypes="copy")
str_lstrip = Map.register(_str_map("lstrip"), dtypes="copy")
str_match = Map.register(_str_map("match"), dtypes="copy")
str_normalize = Map.register(_str_map("normalize"), dtypes="copy")
str_pad = Map.register(_str_map("pad"), dtypes="copy")
_str_partition = Map.register(_str_map("partition"), dtypes="copy")
str_capitalize = Map.register(
_str_map("capitalize"), dtypes="copy", shape_hint="column"
)
str_center = Map.register(_str_map("center"), dtypes="copy", shape_hint="column")
str_contains = Map.register(
_str_map("contains"), dtypes=np.bool_, shape_hint="column"
)
str_count = Map.register(_str_map("count"), dtypes=int, shape_hint="column")
str_endswith = Map.register(
_str_map("endswith"), dtypes=np.bool_, shape_hint="column"
)
str_find = Map.register(_str_map("find"), dtypes=np.int64, shape_hint="column")
str_findall = Map.register(_str_map("findall"), dtypes="copy", shape_hint="column")
str_get = Map.register(_str_map("get"), dtypes="copy", shape_hint="column")
str_index = Map.register(_str_map("index"), dtypes=np.int64, shape_hint="column")
str_isalnum = Map.register(
_str_map("isalnum"), dtypes=np.bool_, shape_hint="column"
)
str_isalpha = Map.register(
_str_map("isalpha"), dtypes=np.bool_, shape_hint="column"
)
str_isdecimal = Map.register(
_str_map("isdecimal"), dtypes=np.bool_, shape_hint="column"
)
str_isdigit = Map.register(
_str_map("isdigit"), dtypes=np.bool_, shape_hint="column"
)
str_islower = Map.register(
_str_map("islower"), dtypes=np.bool_, shape_hint="column"
)
str_isnumeric = Map.register(
_str_map("isnumeric"), dtypes=np.bool_, shape_hint="column"
)
str_isspace = Map.register(
_str_map("isspace"), dtypes=np.bool_, shape_hint="column"
)
str_istitle = Map.register(
_str_map("istitle"), dtypes=np.bool_, shape_hint="column"
)
str_isupper = Map.register(
_str_map("isupper"), dtypes=np.bool_, shape_hint="column"
)
str_join = Map.register(_str_map("join"), dtypes="copy", shape_hint="column")
str_len = Map.register(_str_map("len"), dtypes=int, shape_hint="column")
str_ljust = Map.register(_str_map("ljust"), dtypes="copy", shape_hint="column")
str_lower = Map.register(_str_map("lower"), dtypes="copy", shape_hint="column")
str_lstrip = Map.register(_str_map("lstrip"), dtypes="copy", shape_hint="column")
str_match = Map.register(_str_map("match"), dtypes="copy", shape_hint="column")
str_normalize = Map.register(
_str_map("normalize"), dtypes="copy", shape_hint="column"
)
str_pad = Map.register(_str_map("pad"), dtypes="copy", shape_hint="column")
_str_partition = Map.register(
_str_map("partition"), dtypes="copy", shape_hint="column"
)

def str_partition(self, sep=" ", expand=True):
# For `expand`, need an operator that can create more columns than before
if expand:
return super().str_partition(sep=sep, expand=expand)
return self._str_partition(sep=sep, expand=False)

str_repeat = Map.register(_str_map("repeat"), dtypes="copy")
_str_extract = Map.register(_str_map("extract"), dtypes="copy")
str_repeat = Map.register(_str_map("repeat"), dtypes="copy", shape_hint="column")
_str_extract = Map.register(_str_map("extract"), dtypes="copy", shape_hint="column")

def str_extract(self, pat, flags, expand):
regex = re.compile(pat, flags=flags)
Expand Down Expand Up @@ -1970,12 +2001,14 @@ def searchsorted(df):

# END Dt map partitions operations

def astype(self, col_dtypes, errors: str = "raise"):
def astype(self, col_dtypes, errors: str = "raise", shape_hint=None):
# `errors` parameter needs to be part of the function signature because
# other query compilers may not take care of error handling at the API
# layer. This query compiler assumes there won't be any errors due to
# invalid type keys.
return self.__constructor__(self._modin_frame.astype(col_dtypes, errors=errors))
return self.__constructor__(
self._modin_frame.astype(col_dtypes, errors=errors), shape_hint=shape_hint
)

def infer_objects(self):
return self.__constructor__(self._modin_frame.infer_objects())
Expand Down
20 changes: 12 additions & 8 deletions modin/experimental/core/storage_formats/hdk/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,14 +618,18 @@ def dropna(self, axis=0, how=no_default, thresh=no_default, subset=None):
shape_hint=self._shape_hint,
)

def isna(self):
return self.__constructor__(self._modin_frame.isna(invert=False))
def isna(self, shape_hint=None):
return self.__constructor__(
self._modin_frame.isna(invert=False), shape_hint=shape_hint
)

def notna(self):
return self.__constructor__(self._modin_frame.isna(invert=True))
def notna(self, shape_hint=None):
return self.__constructor__(
self._modin_frame.isna(invert=True), shape_hint=shape_hint
)

def invert(self):
return self.__constructor__(self._modin_frame.invert())
def invert(self, shape_hint=None):
return self.__constructor__(self._modin_frame.invert(), shape_hint=shape_hint)

def dt_year(self):
return self.__constructor__(
Expand Down Expand Up @@ -806,15 +810,15 @@ def reset_index(self, **kwargs):
self._modin_frame.reset_index(drop), shape_hint=shape_hint
)

def astype(self, col_dtypes, errors: str = "raise"):
def astype(self, col_dtypes, errors: str = "raise", shape_hint=None):
if errors != "raise":
raise NotImplementedError(
"This lazy query compiler will always "
+ "raise an error on invalid type keys."
)
return self.__constructor__(
self._modin_frame.astype(col_dtypes),
self._shape_hint,
shape_hint=shape_hint or self._shape_hint,
)

def setitem(self, axis, key, value):
Expand Down
Loading

0 comments on commit 4b1efd2

Please sign in to comment.