Skip to content

Commit 8c1271f

Browse files
committed
[Feature] batch_size, reward, done, attention_key in LLMEnv
ghstack-source-id: 19c1ea1 Pull Request resolved: #2824
1 parent 1bfb9e8 commit 8c1271f

File tree

11 files changed

+243
-37
lines changed

11 files changed

+243
-37
lines changed

test/test_env.py

+76
Original file line numberDiff line numberDiff line change
@@ -4798,6 +4798,82 @@ def policy(td):
47984798
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
47994799
).any()
48004800

4801+
@pytest.mark.parametrize(
4802+
"str2str,stack_method",
4803+
[
4804+
[True, None],
4805+
[False, "as_padded_tensor"],
4806+
],
4807+
)
4808+
@pytest.mark.parametrize("batched", [True])
4809+
@pytest.mark.parametrize("device", [None])
4810+
@pytest.mark.parametrize("batch_size", [4])
4811+
@pytest.mark.parametrize("repeats", [3])
4812+
@pytest.mark.parametrize(
4813+
"assign_reward,assign_done", [[True, False], [True, True], [False, True]]
4814+
)
4815+
def test_done_and_reward(
4816+
self,
4817+
str2str,
4818+
batched,
4819+
stack_method,
4820+
device,
4821+
batch_size,
4822+
repeats,
4823+
assign_reward,
4824+
assign_done,
4825+
):
4826+
with pytest.raises(
4827+
ValueError, match="str2str"
4828+
) if str2str else contextlib.nullcontext():
4829+
if str2str:
4830+
kwargs = {
4831+
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4832+
"data_keys": ["observation"],
4833+
"example_data": "a string!",
4834+
"repeats": repeats,
4835+
"assign_reward": assign_reward,
4836+
"assign_done": assign_done,
4837+
}
4838+
else:
4839+
if stack_method is None:
4840+
stack_method = as_padded_tensor
4841+
kwargs = {
4842+
"dataloader": self.DummyTensorDataLoader(
4843+
padding=True, batch_size=batch_size
4844+
),
4845+
"data_keys": ["observation"],
4846+
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
4847+
"stack_method": stack_method,
4848+
"repeats": repeats,
4849+
"assign_reward": assign_reward,
4850+
"assign_done": assign_done,
4851+
}
4852+
kwargs.update({"str2str": str2str, "device": device})
4853+
env = LLMEnv.from_dataloader(**kwargs)
4854+
# We want to make sure that transforms that rely on the done state work appropriately
4855+
env.append_transform(StepCounter(max_steps=10))
4856+
4857+
def policy(td):
4858+
td["action"] = torch.ones(
4859+
td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64
4860+
)
4861+
return td
4862+
4863+
if batched:
4864+
r = env.rollout(
4865+
100,
4866+
policy,
4867+
tensordict=TensorDict(batch_size=[3]),
4868+
break_when_any_done=False,
4869+
)
4870+
else:
4871+
r = env.rollout(100, policy, break_when_any_done=False)
4872+
if assign_done:
4873+
assert "terminated" in r
4874+
assert "done" in r
4875+
print(r)
4876+
48014877

48024878
if __name__ == "__main__":
48034879
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/data/map/tdstorage.py

-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ def __init__(
128128
self.in_keys = query_module.in_keys
129129
if out_keys is not None:
130130
self.out_keys = out_keys
131-
assert not self._has_lazy_out_keys()
132131

133132
self.query_module = query_module
134133
self.index_key = query_module.index_key

torchrl/data/postprocs/postprocs.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from tensordict.utils import expand_right
1212
from torch import nn
1313

14-
from torchrl.objectives.value.functional import reward2go
1514

1615

1716
def _get_reward(
@@ -367,13 +366,15 @@ def __init__(
367366
time_dim: int = 2,
368367
discount: float = 1.0,
369368
):
369+
from torchrl.objectives.value.functional import reward2go
370370
super().__init__()
371371
self.in_keys = [unravel_key(reward_key), unravel_key(done_key)]
372372
if reward_key_out is None:
373373
reward_key_out = reward_key
374374
self.out_keys = [unravel_key(reward_key_out)]
375375
self.time_dim = time_dim
376376
self.discount = discount
377+
self.reward2go = reward2go
377378

378379
def forward(self, tensordict):
379380
# Get done
@@ -385,6 +386,6 @@ def forward(self, tensordict):
385386
f"reward and done state are expected to have the same shape. Got reard.shape={reward.shape} "
386387
f"and done.shape={done.shape}."
387388
)
388-
reward = reward2go(reward, done, time_dim=-2, gamma=self.discount)
389+
reward = self.reward2go(reward, done, time_dim=-2, gamma=self.discount)
389390
tensordict.set(("next", self.out_keys[0]), reward)
390391
return tensordict

torchrl/envs/common.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -2788,7 +2788,11 @@ def _reset_check_done(self, tensordict, tensordict_reset):
27882788
if reset_value is not None:
27892789
for done_key in done_key_group:
27902790
done_val = tensordict_reset.get(done_key)
2791-
if done_val[reset_value].any() and not self._allow_done_after_reset:
2791+
if (
2792+
done_val.any()
2793+
and done_val[reset_value].any()
2794+
and not self._allow_done_after_reset
2795+
):
27922796
raise RuntimeError(
27932797
f"Env done entry '{done_key}' was (partially) True after reset on specified '_reset' dimensions. This is not allowed."
27942798
)
@@ -3588,7 +3592,7 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
35883592
"""
35893593
any_done = self.any_done(tensordict)
35903594
if any_done:
3591-
return self.reset(tensordict, select_reset_only=True)
3595+
tensordict = self.reset(tensordict, select_reset_only=True)
35923596
return tensordict
35933597

35943598
def empty_cache(self):

0 commit comments

Comments
 (0)