Skip to content

Commit 0bc8ed6

Browse files
committed
[Test] Fix flaky tests
1 parent 3d5dd1a commit 0bc8ed6

File tree

5 files changed

+60
-38
lines changed

5 files changed

+60
-38
lines changed

test/services/test_python_executor_service.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_service_execution(self, ray_init):
7373
result = x + y
7474
print(f"Result: {result}")
7575
"""
76-
result = ray.get(executor.execute.remote(code), timeout=2)
76+
result = ray.get(executor.execute.remote(code), timeout=10)
7777

7878
assert result["success"] is True
7979
assert "Result: 30" in result["stdout"]
@@ -101,7 +101,7 @@ def test_service_execution_error(self, ray_init):
101101

102102
# Execute code with an error
103103
code = "raise ValueError('Test error')"
104-
result = ray.get(executor.execute.remote(code), timeout=2)
104+
result = ray.get(executor.execute.remote(code), timeout=10)
105105

106106
assert result["success"] is False
107107
assert "ValueError: Test error" in result["stderr"]
@@ -119,7 +119,7 @@ def test_multiple_executions(self, ray_init):
119119
"python_executor",
120120
PythonExecutorService,
121121
pool_size=4,
122-
timeout=5.0,
122+
timeout=10.0,
123123
num_cpus=4,
124124
max_concurrency=4,
125125
)
@@ -132,14 +132,16 @@ def test_multiple_executions(self, ray_init):
132132
code = f"print('Execution {i}')"
133133
futures.append(executor.execute.remote(code))
134134

135-
# Wait for all to complete
136-
results = ray.get(futures, timeout=5)
135+
# Wait for all to complete with longer timeout
136+
results = ray.get(futures, timeout=30)
137137

138138
# All should succeed
139139
assert len(results) == 8
140140
for i, result in enumerate(results):
141-
assert result["success"] is True
142-
assert f"Execution {i}" in result["stdout"]
141+
assert result["success"] is True, f"Execution {i} failed: {result}"
142+
assert (
143+
f"Execution {i}" in result["stdout"]
144+
), f"Expected 'Execution {i}' in stdout, got: {result['stdout']!r}"
143145

144146
finally:
145147
services.reset()

test/test_collector.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import subprocess
1414
import sys
1515
import time
16+
from contextlib import nullcontext
1617
from unittest.mock import patch
1718

1819
import numpy as np
@@ -1487,12 +1488,14 @@ def env_fn(seed):
14871488
assert_allclose_td(data10, data20)
14881489

14891490
@pytest.mark.parametrize("use_async", [False, True])
1490-
@pytest.mark.parametrize("cudagraph", [False, True])
1491+
@pytest.mark.parametrize(
1492+
"cudagraph", [False, True] if torch.cuda.is_available() else [False]
1493+
)
14911494
@pytest.mark.parametrize(
14921495
"weight_sync_scheme",
14931496
[None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme],
14941497
)
1495-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found")
1498+
# @pytest.mark.skipif(not torch.cuda.is_available() and not torch.mps.is_available(), reason="no cuda/mps device found")
14961499
def test_update_weights(self, use_async, cudagraph, weight_sync_scheme):
14971500
def create_env():
14981501
return ContinuousActionVecMockEnv()
@@ -1509,11 +1512,12 @@ def create_env():
15091512
kwargs = {}
15101513
if weight_sync_scheme is not None:
15111514
kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()}
1515+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
15121516
collector = collector_class(
15131517
[create_env] * 3,
15141518
policy=policy,
1515-
device=[torch.device("cuda:0")] * 3,
1516-
storing_device=[torch.device("cuda:0")] * 3,
1519+
device=[torch.device(device)] * 3,
1520+
storing_device=[torch.device(device)] * 3,
15171521
frames_per_batch=20,
15181522
cat_results="stack",
15191523
cudagraph_policy=cudagraph,
@@ -1544,7 +1548,9 @@ def create_env():
15441548
# check they don't match
15451549
for worker in range(3):
15461550
for k in state_dict[f"worker{worker}"]["policy_state_dict"]:
1547-
with pytest.raises(AssertionError):
1551+
with pytest.raises(
1552+
AssertionError
1553+
) if torch.cuda.is_available() else nullcontext():
15481554
torch.testing.assert_close(
15491555
state_dict[f"worker{worker}"]["policy_state_dict"][k],
15501556
policy_state_dict[k].cpu(),
@@ -2401,7 +2407,9 @@ def test_auto_wrap_error(self, collector_class, env_maker, num_envs):
24012407
policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1])
24022408
with pytest.raises(
24032409
TypeError,
2404-
match=("Arguments to policy.forward are incompatible with entries in"),
2410+
match=(
2411+
"Arguments to policy.forward are incompatible with entries in|Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True."
2412+
),
24052413
):
24062414
collector_class(
24072415
**self._create_collector_kwargs(

torchrl/collectors/collectors.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -135,18 +135,19 @@ class _InterruptorManager(SyncManager):
135135
_InterruptorManager.register("_Interruptor", _Interruptor)
136136

137137

138-
def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict:
139-
"""Maps the tensors to CPU through a nested dictionary."""
140-
return OrderedDict(
141-
**{
142-
k: recursive_map_to_cpu(item)
143-
if isinstance(item, OrderedDict)
144-
else item.cpu()
145-
if isinstance(item, torch.Tensor)
146-
else item
147-
for k, item in dictionary.items()
148-
}
149-
)
138+
def _map_to_cpu_if_needed(x):
139+
"""Map tensors on exotic devices (MPS, NPU, etc.) to CPU.
140+
141+
CPU and CUDA tensors are kept as-is since they can be shared across processes.
142+
Only exotic devices that don't support multiprocessing are mapped to CPU.
143+
"""
144+
if isinstance(x, torch.Tensor):
145+
# CPU and CUDA can be shared across processes
146+
if x.device.type in ("cpu", "cuda"):
147+
return x
148+
# Exotic devices (MPS, NPU, etc.) need to be mapped to CPU
149+
return x.cpu()
150+
return x
150151

151152

152153
class DataCollectorBase(IterableDataset, metaclass=abc.ABCMeta):
@@ -1149,7 +1150,7 @@ def _setup_policy_and_weights(self, policy: TensorDictModule | Callable) -> None
11491150
)
11501151
except (TypeError, AttributeError, ValueError) as err:
11511152
raise TypeError(
1152-
"Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True."
1153+
"Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True. Scroll up for more details."
11531154
) from err
11541155
self._wrapped_policy = wrapped_policy
11551156
else:
@@ -4880,9 +4881,12 @@ def cast_tensor(x, MPS_ERROR=MPS_ERROR):
48804881
continue
48814882

48824883
elif msg == "state_dict":
4884+
from torch.utils._pytree import tree_map
4885+
48834886
state_dict = inner_collector.state_dict()
4884-
# send state_dict to cpu first
4885-
state_dict = recursive_map_to_cpu(state_dict)
4887+
# Map exotic devices (MPS, NPU, etc.) to CPU for multiprocessing compatibility
4888+
# CPU and CUDA tensors are already shareable and don't need conversion
4889+
state_dict = tree_map(_map_to_cpu_if_needed, state_dict)
48864890
pipe_child.send((state_dict, "state_dict"))
48874891
has_timed_out = False
48884892
continue

torchrl/envs/batched_envs.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2701,7 +2701,6 @@ def _run_worker_pipe_direct(
27012701
if event is not None:
27022702
event.record()
27032703
event.synchronize()
2704-
mp_event.set()
27052704
if consolidate:
27062705
try:
27072706
child_pipe.send(
@@ -2713,6 +2712,9 @@ def _run_worker_pipe_direct(
27132712
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
27142713
else:
27152714
child_pipe.send(cur_td)
2715+
# Set event after successfully sending through pipe to avoid race condition
2716+
# where event is set but pipe send fails (BrokenPipeError)
2717+
mp_event.set()
27162718

27172719
del cur_td
27182720

@@ -2726,7 +2728,6 @@ def _run_worker_pipe_direct(
27262728
if event is not None:
27272729
event.record()
27282730
event.synchronize()
2729-
mp_event.set()
27302731
if consolidate:
27312732
try:
27322733
next_td = next_td.consolidate(
@@ -2735,6 +2736,9 @@ def _run_worker_pipe_direct(
27352736
except Exception as err:
27362737
raise RuntimeError(_CONSOLIDATE_ERR_CAPTURE) from err
27372738
child_pipe.send(next_td)
2739+
# Set event after successfully sending through pipe to avoid race condition
2740+
# where event is set but pipe send fails (BrokenPipeError)
2741+
mp_event.set()
27382742

27392743
del next_td
27402744

torchrl/envs/llm/transforms/tools.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -906,9 +906,9 @@ def execute(self, prompt: str) -> dict[str, Any]:
906906
except queue.Empty:
907907
pass
908908

909-
if not start_found:
910-
timeout_val -= 0.1
911-
time.sleep(0.1)
909+
# Always sleep a bit to avoid busy-waiting and give subprocess time
910+
timeout_val -= 0.01
911+
time.sleep(0.01)
912912

913913
except Exception as e:
914914
return {
@@ -1007,8 +1007,10 @@ def __init__(self, pool_size: int = 32, timeout: float = 10.0):
10071007
self.processes = [
10081008
PersistentPythonProcess(timeout=timeout) for _ in range(pool_size)
10091009
]
1010+
# Create a lock for each process to prevent concurrent access
1011+
self.process_locks = [threading.Lock() for _ in range(pool_size)]
10101012
self.next_idx = 0
1011-
self._lock = threading.Lock()
1013+
self._selection_lock = threading.Lock()
10121014

10131015
def execute(self, code: str) -> dict:
10141016
"""Execute Python code using next available process (round-robin).
@@ -1019,12 +1021,14 @@ def execute(self, code: str) -> dict:
10191021
Returns:
10201022
dict: Execution result with keys 'success', 'stdout', 'stderr', 'returncode'.
10211023
"""
1022-
# Simple round-robin - Ray handles the queuing via max_concurrency
1023-
with self._lock:
1024-
process = self.processes[self.next_idx]
1024+
# Select a process using round-robin
1025+
with self._selection_lock:
1026+
process_idx = self.next_idx
10251027
self.next_idx = (self.next_idx + 1) % self.pool_size
10261028

1027-
return process.execute(code)
1029+
# Lock the selected process for the duration of execution
1030+
with self.process_locks[process_idx]:
1031+
return self.processes[process_idx].execute(code)
10281032

10291033
def cleanup(self):
10301034
"""Cleanup all processes in the pool."""

0 commit comments

Comments
 (0)