Skip to content

Commit 24c9d74

Browse files
jeguzziVincent Moens
and
Vincent Moens
committed
[Environment] Complete PettingZooWrapper state support (#2953)
Co-authored-by: Vincent Moens <[email protected]> (cherry picked from commit d882ea2)
1 parent 55f6074 commit 24c9d74

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

test/test_libs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3944,6 +3944,20 @@ def __call__(self, td):
39443944
td[-1]["next", "player", "reward"] == torch.tensor([[-1], [1]])
39453945
).all()
39463946

3947+
@pytest.mark.parametrize("task", ["simple_v3"])
3948+
def test_return_state(self, task):
3949+
env = PettingZooEnv(
3950+
task=task,
3951+
parallel=True,
3952+
seed=0,
3953+
use_mask=False,
3954+
return_state=True,
3955+
)
3956+
check_env_specs(env)
3957+
r = env.rollout(10)
3958+
assert (r["state"] != 0).any()
3959+
assert (r["next", "state"] != 0).any()
3960+
39473961
@pytest.mark.parametrize(
39483962
"task",
39493963
[

torchrl/envs/libs/pettingzoo.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,10 @@ def _reset(
584584
value, device=self.device
585585
)
586586

587+
if self.return_state:
588+
state = torch.as_tensor(self.state(), device=self.device)
589+
tensordict_out.set("state", state)
590+
587591
return tensordict_out
588592

589593
def _reset_aec(self, **kwargs) -> tuple[dict, dict]:
@@ -702,6 +706,11 @@ def _step(
702706
tensordict_out.set("done", done)
703707
tensordict_out.set("terminated", terminated)
704708
tensordict_out.set("truncated", truncated)
709+
710+
if self.return_state:
711+
state = torch.as_tensor(self.state(), device=self.device)
712+
tensordict_out.set("state", state)
713+
705714
return tensordict_out
706715

707716
def _aggregate_done(self, tensordict_out, use_any):

0 commit comments

Comments
 (0)