Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/ray/air/_internal/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def concat_tensors_to_device(

Args:
tensor_sequence: Sequence of tensors to stack
device: The device to move tensors to
device: The device to move tensors to. If None, tensors are not moved.
non_blocking: If True, perform device transfer without forcing a
synchronization.

Expand All @@ -436,6 +436,12 @@ def concat_tensors_to_device(
[type(t) for t in tensor_sequence if not isinstance(t, torch.Tensor)]
)

# If there is only one tensor and its device already matches, return it directly.
if len(tensor_sequence) == 1 and (
device is None or tensor_sequence[0].device == torch.device(device)
):
return tensor_sequence[0]

first_dtype = tensor_sequence[0].dtype
assert all(t.dtype == first_dtype for t in tensor_sequence), (
"All tensors must have the same dtype. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def __init__(
batch_format: Optional[BatchFormat] = None,
zero_copy_batch: bool = True,
output_block_size_option: Optional[OutputBlockSizeOption] = None,
disable_block_shaping: bool = False,
):
super().__init__(
input_type=MapTransformFnDataType.Batch,
Expand All @@ -333,10 +334,11 @@ def __init__(
self._batch_format = batch_format
self._zero_copy_batch = zero_copy_batch
self._ensure_copy = not zero_copy_batch and batch_size is not None
self._disable_block_shaping = disable_block_shaping

self._batch_fn = batch_fn

def _pre_process(self, blocks: Iterable[Block]) -> Iterable[MapTransformFnData]:
def _pre_process(self, blocks: Iterable[Block]) -> Iterable[DataBatch]:
# TODO make batch-udf zero-copy by default
ensure_copy = not self._zero_copy_batch and self._batch_size is not None

Expand All @@ -349,12 +351,16 @@ def _pre_process(self, blocks: Iterable[Block]) -> Iterable[MapTransformFnData]:
)

def _apply_transform(
self, ctx: TaskContext, batches: Iterable[MapTransformFnData]
) -> Iterable[MapTransformFnData]:
self, ctx: TaskContext, batches: Iterable[DataBatch]
) -> Iterable[DataBatch]:
yield from self._batch_fn(batches, ctx)

def _post_process(self, results: Iterable[MapTransformFnData]) -> Iterable[Block]:
return self._shape_blocks(results)
def _post_process(self, results: Iterable[DataBatch]) -> Iterable[Block]:
if self._disable_block_shaping:
for batch in results:
yield BlockAccessor.batch_to_block(batch)
else:
yield from self._shape_blocks(results)

def __repr__(self) -> str:
return f"BatchMapTransformFn({self._batch_fn=}, {self._batch_format=}, {self._batch_size=}, {self._zero_copy_batch=})"
Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/_internal/logical/operators/map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,14 @@ def __init__(
self,
input_op: LogicalOperator,
target_num_rows_per_block: int,
supports_fusion: bool = True,
):
super().__init__(
f"StreamingRepartition[num_rows_per_block={target_num_rows_per_block}]",
input_op,
)
self._target_num_rows_per_block = target_num_rows_per_block
self._supports_fusion = supports_fusion

@property
def target_num_rows_per_block(self) -> int:
Expand Down
122 changes: 111 additions & 11 deletions python/ray/data/_internal/logical/rules/operator_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
from ray.data._internal.execution.operators.base_physical_operator import (
AllToAllOperator,
)
from ray.data._internal.execution.operators.map_operator import MapOperator
from ray.data._internal.execution.operators.map_operator import (
BaseRefBundler,
BlockRefBundler,
MapOperator,
)
from ray.data._internal.execution.operators.task_pool_map_operator import (
TaskPoolMapOperator,
)
Expand All @@ -35,6 +39,7 @@
AbstractMap,
AbstractUDFMap,
)
from ray.data._internal.streaming_repartition import StreamingRepartitionRefBundler
from ray.util.annotations import DeveloperAPI

# Scheduling strategy can be inherited from upstream operator if not specified.
Expand Down Expand Up @@ -146,6 +151,16 @@ def _can_fuse(self, down_op: PhysicalOperator, up_op: PhysicalOperator) -> bool:
if not up_op.supports_fusion() or not down_op.supports_fusion():
return False

if (
isinstance(down_op, MapOperator)
and isinstance(up_op, MapOperator)
and self._get_compatible_ref_bundler(
up_op._block_ref_bundler, down_op._block_ref_bundler
)
is None
):
return False

# We currently only support fusing for the following cases:
# - TaskPoolMapOperator -> TaskPoolMapOperator/ActorPoolMapOperator
# - TaskPoolMapOperator -> AllToAllOperator
Expand Down Expand Up @@ -288,9 +303,10 @@ def _get_fused_map_operator(
assert isinstance(up_logical_op, AbstractMap)

# Derive min num rows per input bundle
min_rows_per_bundled_input = self._derive_bundle_min_num_rows(
down_logical_op, up_logical_op
ref_bundler = self._get_compatible_ref_bundler(
up_op._block_ref_bundler, down_op._block_ref_bundler
)
assert ref_bundler is not None

target_max_block_size = self._get_merged_target_max_block_size(
up_op.target_max_block_size_override, down_op.target_max_block_size_override
Expand Down Expand Up @@ -322,7 +338,7 @@ def _get_fused_map_operator(
target_max_block_size_override=target_max_block_size,
name=name,
compute_strategy=compute,
min_rows_per_bundle=min_rows_per_bundled_input,
ref_bundler=ref_bundler,
map_task_kwargs=map_task_kwargs,
ray_remote_args=ray_remote_args,
ray_remote_args_fn=ray_remote_args_fn,
Expand All @@ -333,6 +349,11 @@ def _get_fused_map_operator(
):
op.add_map_task_kwargs_fn(map_task_kwargs_fn)

min_rows_per_bundled_input = (
ref_bundler._min_rows_per_bundle
if isinstance(ref_bundler, BlockRefBundler)
else ref_bundler._target_num_rows
)
# Build a map logical operator to be used as a reference for further fusion.
# TODO(Scott): This is hacky, remove this once we push fusion to be purely based
# on a lower-level operator spec.
Expand Down Expand Up @@ -371,22 +392,23 @@ def _get_fused_map_operator(
@classmethod
def _derive_bundle_min_num_rows(
cls,
down_logical_op: AbstractMap,
up_logical_op: AbstractMap,
min_rows_per_bundled_input_up: Optional[int],
min_rows_per_bundled_input_down: Optional[int],
) -> Optional[int]:
us_bundle_min_rows_req = up_logical_op._min_rows_per_bundled_input
ds_bundle_min_rows_req = down_logical_op._min_rows_per_bundled_input

# In case neither of the ops specify `min_rows_per_bundled_input`,
# return None
if us_bundle_min_rows_req is None and ds_bundle_min_rows_req is None:
if (
min_rows_per_bundled_input_up is None
and min_rows_per_bundled_input_down is None
):
return None

# Target min bundle size is selected as max of upstream and downstream ones
# such that it could satisfy both of their requirements
return max(
ds_bundle_min_rows_req or 0,
us_bundle_min_rows_req or 0,
min_rows_per_bundled_input_down or 0,
min_rows_per_bundled_input_up or 0,
)

def _get_fused_all_to_all_operator(
Expand Down Expand Up @@ -505,6 +527,84 @@ def _can_fuse_map_ops(

return True

@classmethod
def _get_compatible_ref_bundler(
cls, up_ref_bundler: BaseRefBundler, down_ref_bundler: BaseRefBundler
) -> Optional[BaseRefBundler]:
"""Determine if two ref bundlers are compatible for operator fusion.

This method checks whether upstream and downstream operators with different
ref bundler configurations can be fused together, and returns the merged
bundler if fusion is possible.

Args:
up_ref_bundler: The ref bundler from the upstream operator.
down_ref_bundler: The ref bundler from the downstream operator.

Returns:
A compatible merged BaseRefBundler if the bundlers can be fused together,
None if they are incompatible and fusion should not occur.

Fusion rules:
- BlockRefBundler + BlockRefBundler: Always compatible. Returns a
BlockRefBundler with min_rows_per_bundle set to the maximum of the two.
- StreamingRepartitionRefBundler + StreamingRepartitionRefBundler: Only
compatible if both have the same target_num_rows_per_block.
- BlockRefBundler + StreamingRepartitionRefBundler (mixed): Compatible
if target_num_rows >= min_rows_per_bundle (or min_rows_per_bundle is None).
Returns a StreamingRepartitionRefBundler.

Note:
This is a naive implementation that should be revisited once
StreamingRepartitionRefBundler is more widely used. There are potential
optimizations such as using the least common multiple for
StreamingRepartitionRefBundlers with different target row counts, or
finding a least multiple that satisfies both min_rows_per_bundle and
target_num_rows_per_block constraints in mixed cases.
"""
if isinstance(up_ref_bundler, BlockRefBundler) and isinstance(
down_ref_bundler, BlockRefBundler
):
return BlockRefBundler(
min_rows_per_bundle=cls._derive_bundle_min_num_rows(
up_ref_bundler._min_rows_per_bundle,
down_ref_bundler._min_rows_per_bundle,
)
)
elif isinstance(up_ref_bundler, StreamingRepartitionRefBundler) and isinstance(
down_ref_bundler, StreamingRepartitionRefBundler
):
if up_ref_bundler._target_num_rows == down_ref_bundler._target_num_rows:
return StreamingRepartitionRefBundler(
target_num_rows_per_block=up_ref_bundler._target_num_rows
)
else:
# TODO(xgui): Explore if we can use least common multiple of the two target_num_rows_per_block
return None
else:
supported_types = (BlockRefBundler, StreamingRepartitionRefBundler)
assert isinstance(up_ref_bundler, supported_types) and isinstance(
down_ref_bundler, supported_types
)
target_num_rows = (
up_ref_bundler._target_num_rows
if isinstance(up_ref_bundler, StreamingRepartitionRefBundler)
else down_ref_bundler._target_num_rows
)
min_rows_per_bundle = (
up_ref_bundler._min_rows_per_bundle
if isinstance(up_ref_bundler, BlockRefBundler)
else down_ref_bundler._min_rows_per_bundle
)
if min_rows_per_bundle is None or target_num_rows >= min_rows_per_bundle:
return StreamingRepartitionRefBundler(
target_num_rows_per_block=target_num_rows
)
else:
# TODO(xgui): Explore if we can use least multiple of target_num_rows_per_block that is greater than min_rows_per_bundle
return None
return None


@DeveloperAPI
def are_remote_args_compatible(
Expand Down
11 changes: 6 additions & 5 deletions python/ray/data/_internal/planner/plan_udf_map_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,12 @@ def plan_streaming_repartition_op(
assert len(physical_children) == 1
input_physical_dag = physical_children[0]
compute = get_compute(op._compute)
transform_fn = BlockMapTransformFn(
lambda blocks, ctx: blocks,
output_block_size_option=OutputBlockSizeOption.of(
target_num_rows_per_block=op.target_num_rows_per_block, # To split n*target_max_block_size row into n blocks
transform_fn = BatchMapTransformFn(
_generate_transform_fn_for_map_batches(
lambda blocks: blocks,
),
batch_size=op.target_num_rows_per_block,
disable_block_shaping=True,
)
map_transformer = MapTransformer([transform_fn])

Expand All @@ -174,7 +175,7 @@ def plan_streaming_repartition_op(
ref_bundler=StreamingRepartitionRefBundler(op.target_num_rows_per_block),
ray_remote_args=op._ray_remote_args,
ray_remote_args_fn=op._ray_remote_args_fn,
supports_fusion=False,
supports_fusion=op._supports_fusion,
)

return operator
Expand Down
7 changes: 1 addition & 6 deletions python/ray/data/_internal/streaming_repartition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@
and add the remaining bundle to the pending bundles for the next iteration.
4. Submit that ready bundle to a remote map task; the task slices each block according to the slice metadata stored
in the RefBundle (the bundle now contains n × target rows for n ≥ 1).
5. We configured the `OutputBlockSizeOption.target_num_rows_per_block` to the target number of rows per block in
plan_streaming_repartition_op so the output buffer further splits the n × target rows into n blocks of exactly
the target size.
Note: the output buffer only splits a bundle when its row count exceeds `target_rows × MAX_SAFE_ROWS_PER_BLOCK_FACTOR`
(default 1.5). Because we split bundles into target-row blocks, `MAX_SAFE_ROWS_PER_BLOCK_FACTOR` must stay < 2 to
output the target-row blocks.
5. We create one BatchMapTransformFn and its batcher will create blocks with exactly the target number of rows.
6. Once upstream input is exhausted, flush any leftover pending bundles and repeat steps 1‑5 for the tail.
7. The resulting blocks have lengths `[target, …, target, (total_rows % target)]`; ordering isn’t guaranteed, but the
remainder block should appear near the end.
Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,7 @@ def repartition(
shuffle: bool = False,
keys: Optional[List[str]] = None,
sort: bool = False,
supports_fusion: bool = True,
) -> "Dataset":
"""Repartition the :class:`Dataset` into exactly this number of
:ref:`blocks <dataset_concept>`.
Expand Down Expand Up @@ -1699,6 +1700,7 @@ def repartition(
op = StreamingRepartition(
self._logical_plan.dag,
target_num_rows_per_block=target_num_rows_per_block,
supports_fusion=supports_fusion,
)
else:
op = Repartition(
Expand Down
Loading