Skip to content

Commit

Permalink
FEAT-modin-project#7283: Introduce MinRowPartitionSize and MinColumnP…
Browse files Browse the repository at this point in the history
…artitionSize

Signed-off-by: Igoshev, Iaroslav <[email protected]>
  • Loading branch information
YarShev committed May 21, 2024
1 parent 43eff5b commit 53ccafa
Show file tree
Hide file tree
Showing 27 changed files with 228 additions and 81 deletions.
4 changes: 2 additions & 2 deletions docs/img/partitioning_mechanism/partitioning_examples.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 6 additions & 6 deletions docs/usage_guide/benchmarking.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ Consider the following ipython script:
.. code-block:: python
import modin.pandas as pd
from modin.config import MinPartitionSize
from modin.config import MinRowPartitionSize
import time
import ray
# Look at the Ray documentation with respect to the Ray configuration suited to you most.
ray.init()
df = pd.DataFrame(list(range(MinPartitionSize.get() * 2)))
df = pd.DataFrame(list(range(MinRowPartitionSize.get() * 2)))
%time result = df.map(lambda x: time.sleep(0.1) or x)
%time print(result)
Expand Down Expand Up @@ -82,12 +82,12 @@ following script with benchmark mode on:
from io import BytesIO
import modin.pandas as pd
from modin.config import BenchmarkMode, MinPartitionSize
from modin.config import BenchmarkMode, MinRowPartitionSize
BenchmarkMode.put(True)
start = time.time()
df = pd.DataFrame(list(range(MinPartitionSize.get())), columns=['A'])
df = pd.DataFrame(list(range(MinRowPartitionSize.get())), columns=['A'])
result1 = df.map(lambda x: time.sleep(0.2) or x + 1)
result2 = df.map(lambda x: time.sleep(0.2) or x + 2)
result1.to_parquet(BytesIO())
Expand Down Expand Up @@ -136,10 +136,10 @@ That will typically block on any asynchronous computation:
from io import BytesIO
import modin.pandas as pd
from modin.config import MinPartitionSize, NPartitions
from modin.config import MinRowPartitionSize, NPartitions
import modin.utils
MinPartitionSize.put(32)
MinRowPartitionSize.put(32)
NPartitions.put(16)
def slow_add_one(x):
Expand Down
7 changes: 5 additions & 2 deletions docs/usage_guide/optimization_notes/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@ How Modin partitions a dataframe
Modin uses a partitioning scheme that partitions a dataframe along both axes, resulting in a matrix
of partitions. The row and column chunk sizes are computed independently based
on the length of the appropriate axis and Modin's special :doc:`configuration variables </flow/modin/config>`
(``NPartitions`` and ``MinPartitionSize``):
(``NPartitions``, ``MinRowPartitionSize`` and ``MinColumnPartitionSize``):

- ``NPartitions`` is the maximum number of splits along an axis; by default, it equals to the number of cores
on your local machine or cluster of nodes.
- ``MinPartitionSize`` is the minimum number of rows/columns to do a split. For instance, if ``MinPartitionSize``
- ``MinRowPartitionSize`` is the minimum number of rows to do a split. For instance, if ``MinRowPartitionSize``
is 32, the row axis will not be split unless the amount of rows is greater than 32. If it is is greater, for example, 34,
then the row axis is sliced into two partitions: containing 32 and 2 rows accordingly.
- ``MinColumnPartitionSize`` is the minimum number of columns to do a split. For instance, if ``MinColumnPartitionSize``
is 32, the column axis will not be split unless the amount of columns is greater than 32. If it is is greater, for example, 34,
then the column axis is sliced into two partitions: containing 32 and 2 columns accordingly.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ by reading its source code:
else:
raise NotImplementedError(dtype)
pd.DataFrame(np.arange(cfg.NPartitions.get() * cfg.MinPartitionSize.get())).to_numpy()
pd.DataFrame(np.arange(cfg.NPartitions.get() * cfg.MinRowPartitionSize.get())).to_numpy()
nrows = [1_000_000, 5_000_000, 10_000_000, 25_000_000, 50_000_000, 100_000_000]
duplicate_rate = [0, 0.1, 0.5, 0.95]
Expand Down Expand Up @@ -192,7 +192,7 @@ micro-benchmark by reading its source code:
cfg.CpuCount.put(16)
pd.DataFrame(np.arange(cfg.NPartitions.get() * cfg.MinPartitionSize.get())).to_numpy()
pd.DataFrame(np.arange(cfg.NPartitions.get() * cfg.MinRowPartitionSize.get())).to_numpy()
nrows = [1_000_000, 5_000_000, 10_000_000, 25_000_000]
duplicate_rate = [0, 0.1, 0.5, 0.95]
Expand Down Expand Up @@ -312,7 +312,7 @@ by reading its source code:
else:
raise NotImplementedError(dtype)
pd.DataFrame(np.arange(cfg.NPartitions.get() * cfg.MinPartitionSize.get())).to_numpy()
pd.DataFrame(np.arange(cfg.NPartitions.get() * cfg.MinRowPartitionSize.get())).to_numpy()
nrows = [1_000_000, 5_000_000, 10_000_000, 25_000_000, 50_000_000, 100_000_000]
duplicate_rate = [0, 0.1, 0.5, 0.95]
Expand Down
4 changes: 4 additions & 0 deletions modin/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
LogMemoryInterval,
LogMode,
Memory,
MinColumnPartitionSize,
MinPartitionSize,
MinRowPartitionSize,
ModinNumpy,
NPartitions,
PersistentPickle,
Expand Down Expand Up @@ -82,6 +84,8 @@
# Partitioning
"NPartitions",
"MinPartitionSize",
"MinRowPartitionSize",
"MinColumnPartitionSize",
# ASV specific
"TestDatasetSize",
"AsvImplementation",
Expand Down
97 changes: 97 additions & 0 deletions modin/config/envvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,8 @@ def put(cls, value: int) -> None:
if value <= 0:
raise ValueError(f"Min partition size should be > 0, passed value {value}")
super().put(value)
MinRowPartitionSize.put(value)
MinColumnPartitionSize.put(value)

@classmethod
def get(cls) -> int:
Expand All @@ -654,6 +656,13 @@ def get(cls) -> int:
-------
int
"""
from modin.error_message import ErrorMessage

ErrorMessage.single_warning(
"`MinPartitionSize` is deprecated and will be removed in a future version. "
+ "Use `MinRowPartitionSize` and `MinColumnPartitionSize` instead.",
FutureWarning,
)
min_partition_size = super().get()
if min_partition_size <= 0:
raise ValueError(
Expand All @@ -662,6 +671,94 @@ def get(cls) -> int:
return min_partition_size


class MinRowPartitionSize(EnvironmentVariable, type=int):
"""
Minimum number of rows in a single pandas partition split.
Once a partition for a pandas dataframe has more than this many elements,
Modin adds another partition.
"""

varname = "MODIN_MIN_ROW_PARTITION_SIZE"
default = 32

@classmethod
def put(cls, value: int) -> None:
"""
Set ``MinRowPartitionSize`` with extra checks.
Parameters
----------
value : int
Config value to set.
"""
if value <= 0:
raise ValueError(
f"Min row partition size should be > 0, passed value {value}"
)
super().put(value)

@classmethod
def get(cls) -> int:
"""
Get ``MinRowPartitionSize`` with extra checks.
Returns
-------
int
"""
min_row_partition_size = super().get()
if min_row_partition_size <= 0:
raise ValueError(
f"`MinRowPartitionSize` should be > 0; current value: {min_row_partition_size}"
)
return min_row_partition_size


class MinColumnPartitionSize(EnvironmentVariable, type=int):
"""
Minimum number of columns in a single pandas partition split.
Once a partition for a pandas dataframe has more than this many elements,
Modin adds another partition.
"""

varname = "MODIN_MIN_COLUMN_PARTITION_SIZE"
default = 32

@classmethod
def put(cls, value: int) -> None:
"""
Set ``MinColumnPartitionSize`` with extra checks.
Parameters
----------
value : int
Config value to set.
"""
if value <= 0:
raise ValueError(
f"Min column partition size should be > 0, passed value {value}"
)
super().put(value)

@classmethod
def get(cls) -> int:
"""
Get ``MinColumnPartitionSize`` with extra checks.
Returns
-------
int
"""
min_column_partition_size = super().get()
if min_column_partition_size <= 0:
raise ValueError(
f"`MinColumnPartitionSize` should be > 0; current value: {min_column_partition_size}"
)
return min_column_partition_size


class TestReadFromSqlServer(EnvironmentVariable, type=bool):
"""Set to true to test reading from SQL server."""

Expand Down
10 changes: 5 additions & 5 deletions modin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _saving_make_api_url(token, _make_api_url=modin.utils._make_api_url):
BenchmarkMode,
GithubCI,
IsExperimental,
MinPartitionSize,
MinRowPartitionSize,
NPartitions,
)
from modin.core.execution.dispatching.factories import factories # noqa: E402
Expand Down Expand Up @@ -487,11 +487,11 @@ def set_async_read_mode(request):


@pytest.fixture
def set_min_partition_size(request):
old_min_partition_size = MinPartitionSize.get()
MinPartitionSize.put(request.param)
def set_min_row_partition_size(request):
old_min_row_partition_size = MinRowPartitionSize.get()
MinRowPartitionSize.put(request.param)
yield
MinPartitionSize.put(old_min_partition_size)
MinRowPartitionSize.put(old_min_row_partition_size)


ray_client_server = None
Expand Down
22 changes: 14 additions & 8 deletions modin/core/dataframe/pandas/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
from pandas.core.dtypes.common import is_dtype_equal, is_list_like, is_numeric_dtype
from pandas.core.indexes.api import Index, RangeIndex

from modin.config import Engine, IsRayCluster, MinPartitionSize, NPartitions
from modin.config import (
Engine,
IsRayCluster,
MinColumnPartitionSize,
MinRowPartitionSize,
NPartitions,
)
from modin.core.dataframe.base.dataframe.dataframe import ModinDataframe
from modin.core.dataframe.base.dataframe.utils import Axis, JoinType, is_trivial_index
from modin.core.dataframe.pandas.dataframe.utils import (
Expand Down Expand Up @@ -1592,7 +1598,7 @@ def _reorder_labels(self, row_positions=None, col_positions=None):
new_lengths = get_length_list(
axis_len=len(row_idx),
num_splits=ordered_rows.shape[0],
min_block_size=MinPartitionSize.get(),
min_block_size=MinRowPartitionSize.get(),
)
else:
# If the frame's partitioning was preserved then
Expand Down Expand Up @@ -1630,7 +1636,7 @@ def _reorder_labels(self, row_positions=None, col_positions=None):
new_widths = get_length_list(
axis_len=len(col_idx),
num_splits=ordered_cols.shape[1],
min_block_size=MinPartitionSize.get(),
min_block_size=MinColumnPartitionSize.get(),
)
else:
# If the frame's partitioning was preserved then
Expand Down Expand Up @@ -2635,10 +2641,10 @@ def _apply_func_to_range_partitioning(
# algorithm how many partitions we want to end up with, so it samples and finds pivots
# according to that.
if sampling_probability >= 1:
from modin.config import MinPartitionSize
from modin.config import MinRowPartitionSize

ideal_num_new_partitions = round(len(grouper) / MinPartitionSize.get())
if len(grouper) < MinPartitionSize.get() or ideal_num_new_partitions < 2:
ideal_num_new_partitions = round(len(grouper) / MinRowPartitionSize.get())
if len(grouper) < MinRowPartitionSize.get() or ideal_num_new_partitions < 2:
# If the data is too small, we shouldn't try reshuffling/repartitioning but rather
# simply combine all partitions and apply the sorting to the whole dataframe
return grouper.combine_and_apply(func=func)
Expand Down Expand Up @@ -3582,7 +3588,7 @@ def broadcast_apply_full_axis(
kw["row_lengths"] = get_length_list(
axis_len=len(new_index),
num_splits=new_partitions.shape[0],
min_block_size=MinPartitionSize.get(),
min_block_size=MinRowPartitionSize.get(),
)
elif axis == 1:
if self._row_lengths_cache is not None and len(new_index) == sum(
Expand All @@ -3594,7 +3600,7 @@ def broadcast_apply_full_axis(
kw["column_widths"] = get_length_list(
axis_len=len(new_columns),
num_splits=new_partitions.shape[1],
min_block_size=MinPartitionSize.get(),
min_block_size=MinColumnPartitionSize.get(),
)
elif axis == 0:
if self._column_widths_cache is not None and len(
Expand Down
14 changes: 11 additions & 3 deletions modin/core/dataframe/pandas/partitioning/axis_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import pandas

from modin.config import MinPartitionSize
from modin.config import MinColumnPartitionSize, MinRowPartitionSize
from modin.core.dataframe.base.partitioning.axis_partition import (
BaseDataframeAxisPartition,
)
Expand Down Expand Up @@ -277,7 +277,11 @@ def apply(
for part in axis_partition.list_of_blocks
]
),
min_block_size=MinPartitionSize.get(),
min_block_size=(
MinRowPartitionSize.get()
if self.axis == 0
else MinColumnPartitionSize.get()
),
)
)
result = self._wrap_partitions(
Expand All @@ -289,7 +293,11 @@ def apply(
num_splits,
maintain_partitioning,
*self.list_of_blocks,
min_block_size=MinPartitionSize.get(),
min_block_size=(
MinRowPartitionSize.get()
if self.axis == 0
else MinColumnPartitionSize.get()
),
lengths=lengths,
manual_partition=manual_partition,
)
Expand Down
12 changes: 8 additions & 4 deletions modin/core/dataframe/pandas/partitioning/partition_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
BenchmarkMode,
CpuCount,
Engine,
MinPartitionSize,
MinColumnPartitionSize,
MinRowPartitionSize,
NPartitions,
PersistentPickle,
ProgressBar,
Expand Down Expand Up @@ -1024,9 +1025,12 @@ def from_pandas(cls, df, return_dims=False):
A NumPy array with partitions (with dimensions or not).
"""
num_splits = NPartitions.get()
min_block_size = MinPartitionSize.get()
row_chunksize = compute_chunksize(df.shape[0], num_splits, min_block_size)
col_chunksize = compute_chunksize(df.shape[1], num_splits, min_block_size)
min_row_block_size = MinRowPartitionSize.get()
min_column_block_size = MinColumnPartitionSize.get()
row_chunksize = compute_chunksize(df.shape[0], num_splits, min_row_block_size)
col_chunksize = compute_chunksize(
df.shape[1], num_splits, min_column_block_size
)

bar_format = (
"{l_bar}{bar}{r_bar}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
import ray

from modin.config import GpuCount, MinPartitionSize
from modin.config import GpuCount, MinRowPartitionSize
from modin.core.execution.ray.common import RayWrapper
from modin.core.execution.ray.generic.partitioning import (
GenericRayDataframePartitionManager,
Expand Down Expand Up @@ -126,7 +126,7 @@ def from_pandas(cls, df, return_dims=False):
put_func = cls._partition_class.put
# For now, we default to row partitioning
pandas_dfs = split_result_of_axis_func_pandas(
0, num_splits, df, min_block_size=MinPartitionSize.get()
0, num_splits, df, min_block_size=MinRowPartitionSize.get()
)
keys = [
put_func(cls._get_gpu_managers()[i], pandas_dfs[i])
Expand Down
Loading

0 comments on commit 53ccafa

Please sign in to comment.