Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 122 additions & 2 deletions torchrec/distributed/test_utils/multi_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import unittest
from typing import Any, Callable, Dict, List, Optional
from unittest.mock import patch

import torch
import torch.distributed as dist
Expand All @@ -25,6 +26,65 @@
)


class MultiProcessMock:
"""
Manages cross-process mocks for multi-process testing.

This class maintains a collection of mocks that can be applied across
different processes in distributed testing scenarios.
"""

def __init__(self) -> None:
self.mocks: List[Dict[str, Any]] = []

def add_mock(
self,
target: str,
return_value: Any = None,
side_effect: Any = None,
**kwargs: Any,
) -> None:
"""
Add a new cross-process mock.

Args:
target: The target to mock (e.g., 'module.function')
return_value: The return value for the mock
side_effect: The side effect for the mock
**kwargs: Additional arguments to pass to the mock
"""
mock_config = {
"target": target,
"return_value": return_value,
"side_effect": side_effect,
**kwargs,
}
self.mocks.append(mock_config)

def apply_mocks(self) -> List[Any]:
"""
Apply all registered mocks and return context managers.

Returns:
List of active mock context managers
"""
active_patches = []
for mock_config in self.mocks:
target = mock_config["target"]
return_value = mock_config.get("return_value")
side_effect = mock_config.get("side_effect")

patcher = patch(target, return_value=return_value, side_effect=side_effect)
active_patch = patcher.__enter__()
active_patches.append((patcher, active_patch))

return active_patches

def clear_mocks(self) -> None:
"""Clear all registered mocks."""
self.mocks.clear()


class MultiProcessContext:
def __init__(
self,
Expand Down Expand Up @@ -111,6 +171,32 @@ def __init__(
self._mp_init_mode: str = mp_init_mode
logging.info(f"Using {self._mp_init_mode} for multiprocessing")

# Initialize MultiProcessMock
self._mock_manager = MultiProcessMock()

def add_mock(
self,
target: str,
return_value: Any = None,
side_effect: Any = None,
**kwargs: Any,
) -> None:
"""
Add a new cross-process mock that will be applied during test execution.

Args:
target: The target to mock (e.g., 'module.function')
return_value: The return value for the mock
side_effect: The side effect for the mock
**kwargs: Additional arguments to pass to the mock
"""
self._mock_manager.add_mock(
target=target,
return_value=return_value,
side_effect=side_effect,
**kwargs,
)

@seed_and_log
def setUp(self) -> None:
os.environ["MASTER_ADDR"] = str("localhost")
Expand Down Expand Up @@ -149,8 +235,10 @@ def _run_multi_process_test(
for rank in range(world_size):
kwargs["rank"] = rank
kwargs["world_size"] = world_size
kwargs["_mock_manager"] = self._mock_manager
p = ctx.Process(
target=callable,
target=self._callable_wrapper_with_mocks,
args=(callable,),
kwargs=kwargs,
)
p.start()
Expand All @@ -176,9 +264,11 @@ def _run_multi_process_test_per_rank(
kwargs = {}
kwargs["rank"] = rank
kwargs["world_size"] = world_size
kwargs["_mock_manager"] = self._mock_manager
kwargs.update(kwargs_per_rank[rank])
p = ctx.Process(
target=callable,
target=self._callable_wrapper_with_mocks,
args=(callable,),
kwargs=kwargs,
)
p.start()
Expand All @@ -188,6 +278,36 @@ def _run_multi_process_test_per_rank(
p.join()
self.assertEqual(0, p.exitcode)

@staticmethod
def _callable_wrapper_with_mocks(
callable: Callable[..., None],
_mock_manager: Optional[MultiProcessMock] = None,
**kwargs: Any,
) -> None:
"""
Wrapper that applies mocks before calling the target callable.

Args:
callable: The function to call
_mock_manager: Optional mock manager containing mocks to apply
**kwargs: Additional keyword arguments to pass to the callable
"""
active_patches = []
try:
# Apply mocks if a mock manager is provided
if _mock_manager is not None:
active_patches = _mock_manager.apply_mocks()

# Remove _mock_manager from kwargs before calling the target
kwargs.pop("_mock_manager", None)

# Call the actual test callable
callable(**kwargs)
finally:
# Clean up all patches
for patcher, _ in active_patches:
patcher.__exit__(None, None, None)


def _wrapper_func_for_multiprocessing(args): # pyre-ignore[2, 3]
"""Wrapper function that unpacks arguments and calls the original func"""
Expand Down
Loading
Loading