Skip to content

Commit 2e73a09

Browse files
committed
[Test] Fix flaky tests
1 parent 3d5dd1a commit 2e73a09

File tree

6 files changed

+353
-148
lines changed

6 files changed

+353
-148
lines changed

test/services/test_python_executor_service.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_service_execution(self, ray_init):
7373
result = x + y
7474
print(f"Result: {result}")
7575
"""
76-
result = ray.get(executor.execute.remote(code), timeout=2)
76+
result = ray.get(executor.execute.remote(code), timeout=10)
7777

7878
assert result["success"] is True
7979
assert "Result: 30" in result["stdout"]
@@ -101,7 +101,7 @@ def test_service_execution_error(self, ray_init):
101101

102102
# Execute code with an error
103103
code = "raise ValueError('Test error')"
104-
result = ray.get(executor.execute.remote(code), timeout=2)
104+
result = ray.get(executor.execute.remote(code), timeout=10)
105105

106106
assert result["success"] is False
107107
assert "ValueError: Test error" in result["stderr"]
@@ -119,7 +119,7 @@ def test_multiple_executions(self, ray_init):
119119
"python_executor",
120120
PythonExecutorService,
121121
pool_size=4,
122-
timeout=5.0,
122+
timeout=10.0,
123123
num_cpus=4,
124124
max_concurrency=4,
125125
)
@@ -132,14 +132,16 @@ def test_multiple_executions(self, ray_init):
132132
code = f"print('Execution {i}')"
133133
futures.append(executor.execute.remote(code))
134134

135-
# Wait for all to complete
136-
results = ray.get(futures, timeout=5)
135+
# Wait for all to complete with longer timeout
136+
results = ray.get(futures, timeout=30)
137137

138138
# All should succeed
139139
assert len(results) == 8
140140
for i, result in enumerate(results):
141-
assert result["success"] is True
142-
assert f"Execution {i}" in result["stdout"]
141+
assert result["success"] is True, f"Execution {i} failed: {result}"
142+
assert (
143+
f"Execution {i}" in result["stdout"]
144+
), f"Expected 'Execution {i}' in stdout, got: {result['stdout']!r}"
143145

144146
finally:
145147
services.reset()

test/test_collector.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import subprocess
1414
import sys
1515
import time
16+
from contextlib import nullcontext
1617
from unittest.mock import patch
1718

1819
import numpy as np
@@ -1487,12 +1488,14 @@ def env_fn(seed):
14871488
assert_allclose_td(data10, data20)
14881489

14891490
@pytest.mark.parametrize("use_async", [False, True])
1490-
@pytest.mark.parametrize("cudagraph", [False, True])
1491+
@pytest.mark.parametrize(
1492+
"cudagraph", [False, True] if torch.cuda.is_available() else [False]
1493+
)
14911494
@pytest.mark.parametrize(
14921495
"weight_sync_scheme",
14931496
[None, MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme],
14941497
)
1495-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device found")
1498+
# @pytest.mark.skipif(not torch.cuda.is_available() and not torch.mps.is_available(), reason="no cuda/mps device found")
14961499
def test_update_weights(self, use_async, cudagraph, weight_sync_scheme):
14971500
def create_env():
14981501
return ContinuousActionVecMockEnv()
@@ -1509,11 +1512,12 @@ def create_env():
15091512
kwargs = {}
15101513
if weight_sync_scheme is not None:
15111514
kwargs["weight_sync_schemes"] = {"policy": weight_sync_scheme()}
1515+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
15121516
collector = collector_class(
15131517
[create_env] * 3,
15141518
policy=policy,
1515-
device=[torch.device("cuda:0")] * 3,
1516-
storing_device=[torch.device("cuda:0")] * 3,
1519+
device=[torch.device(device)] * 3,
1520+
storing_device=[torch.device(device)] * 3,
15171521
frames_per_batch=20,
15181522
cat_results="stack",
15191523
cudagraph_policy=cudagraph,
@@ -1544,7 +1548,9 @@ def create_env():
15441548
# check they don't match
15451549
for worker in range(3):
15461550
for k in state_dict[f"worker{worker}"]["policy_state_dict"]:
1547-
with pytest.raises(AssertionError):
1551+
with pytest.raises(
1552+
AssertionError
1553+
) if torch.cuda.is_available() else nullcontext():
15481554
torch.testing.assert_close(
15491555
state_dict[f"worker{worker}"]["policy_state_dict"][k],
15501556
policy_state_dict[k].cpu(),
@@ -2401,7 +2407,9 @@ def test_auto_wrap_error(self, collector_class, env_maker, num_envs):
24012407
policy = UnwrappablePolicy(out_features=env_maker().action_spec.shape[-1])
24022408
with pytest.raises(
24032409
TypeError,
2404-
match=("Arguments to policy.forward are incompatible with entries in"),
2410+
match=(
2411+
"Arguments to policy.forward are incompatible with entries in|Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True."
2412+
),
24052413
):
24062414
collector_class(
24072415
**self._create_collector_kwargs(

0 commit comments

Comments
 (0)