1313import subprocess
1414import sys
1515import time
16+ from contextlib import nullcontext
1617from unittest .mock import patch
1718
1819import numpy as np
@@ -1487,12 +1488,14 @@ def env_fn(seed):
14871488 assert_allclose_td (data10 , data20 )
14881489
14891490 @pytest .mark .parametrize ("use_async" , [False , True ])
1490- @pytest .mark .parametrize ("cudagraph" , [False , True ])
1491+ @pytest .mark .parametrize (
1492+ "cudagraph" , [False , True ] if torch .cuda .is_available () else [False ]
1493+ )
14911494 @pytest .mark .parametrize (
14921495 "weight_sync_scheme" ,
14931496 [None , MultiProcessWeightSyncScheme , SharedMemWeightSyncScheme ],
14941497 )
1495- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "no cuda device found" )
1498+ # @pytest.mark.skipif(not torch.cuda.is_available() and not torch.mps.is_available() , reason="no cuda/mps device found")
14961499 def test_update_weights (self , use_async , cudagraph , weight_sync_scheme ):
14971500 def create_env ():
14981501 return ContinuousActionVecMockEnv ()
@@ -1509,11 +1512,12 @@ def create_env():
15091512 kwargs = {}
15101513 if weight_sync_scheme is not None :
15111514 kwargs ["weight_sync_schemes" ] = {"policy" : weight_sync_scheme ()}
1515+ device = "cuda:0" if torch .cuda .is_available () else "cpu"
15121516 collector = collector_class (
15131517 [create_env ] * 3 ,
15141518 policy = policy ,
1515- device = [torch .device ("cuda:0" )] * 3 ,
1516- storing_device = [torch .device ("cuda:0" )] * 3 ,
1519+ device = [torch .device (device )] * 3 ,
1520+ storing_device = [torch .device (device )] * 3 ,
15171521 frames_per_batch = 20 ,
15181522 cat_results = "stack" ,
15191523 cudagraph_policy = cudagraph ,
@@ -1544,7 +1548,9 @@ def create_env():
15441548 # check they don't match
15451549 for worker in range (3 ):
15461550 for k in state_dict [f"worker{ worker } " ]["policy_state_dict" ]:
1547- with pytest .raises (AssertionError ):
1551+ with pytest .raises (
1552+ AssertionError
1553+ ) if torch .cuda .is_available () else nullcontext ():
15481554 torch .testing .assert_close (
15491555 state_dict [f"worker{ worker } " ]["policy_state_dict" ][k ],
15501556 policy_state_dict [k ].cpu (),
@@ -2401,7 +2407,9 @@ def test_auto_wrap_error(self, collector_class, env_maker, num_envs):
24012407 policy = UnwrappablePolicy (out_features = env_maker ().action_spec .shape [- 1 ])
24022408 with pytest .raises (
24032409 TypeError ,
2404- match = ("Arguments to policy.forward are incompatible with entries in" ),
2410+ match = (
2411+ "Arguments to policy.forward are incompatible with entries in|Failed to wrap the policy. If the policy needs to be trusted, set trust_policy=True."
2412+ ),
24052413 ):
24062414 collector_class (
24072415 ** self ._create_collector_kwargs (
0 commit comments