Skip to content

Commit 215ed76

Browse files
committed
[Feature] transformers policy
ghstack-source-id: b333d2e Pull Request resolved: #2825
1 parent 8c1271f commit 215ed76

File tree

11 files changed

+530
-150
lines changed

11 files changed

+530
-150
lines changed

test/test_env.py

+113-46
Original file line numberDiff line numberDiff line change
@@ -4581,11 +4581,13 @@ def __next__(self):
45814581
@pytest.mark.parametrize("batch_size", [0, 4])
45824582
@pytest.mark.parametrize("device", [None, "cpu"])
45834583
def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
4584-
env = LLMEnv(str2str=str2str, device=device)
4584+
env = LLMEnv(
4585+
str2str=str2str, device=device, has_attention=False, no_stack=False
4586+
)
45854587
if str2str:
45864588
primer = DataLoadingPrimer(
45874589
dataloader=self.DummyDataLoader(batch_size=batch_size),
4588-
data_keys=["observation"],
4590+
data_keys=[LLMEnv._DEFAULT_STR_KEY],
45894591
example_data="a string!",
45904592
)
45914593
else:
@@ -4595,7 +4597,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
45954597
dataloader=self.DummyTensorDataLoader(
45964598
batch_size=batch_size, padding=True
45974599
),
4598-
data_keys=["observation"],
4600+
data_keys=[LLMEnv._DEFAULT_TOKEN_KEY],
45994601
data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
46004602
stack_method=stack_method,
46014603
)
@@ -4605,7 +4607,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46054607
if batched:
46064608
td = env.reset(TensorDict(batch_size=[3]))
46074609
env.check_env_specs(break_when_any_done="both", tensordict=td)
4608-
r = env.rollout(10, tensordict=TensorDict(batch_size=[3]))
4610+
env.rollout(10, tensordict=TensorDict(batch_size=[3]))
46094611
else:
46104612
env.check_env_specs(break_when_any_done="both")
46114613

@@ -4628,7 +4630,7 @@ def test_llm_from_dataloader(
46284630
if str2str:
46294631
kwargs = {
46304632
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4631-
"data_keys": ["observation"],
4633+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
46324634
"example_data": "a string!",
46334635
}
46344636
else:
@@ -4638,11 +4640,18 @@ def test_llm_from_dataloader(
46384640
"dataloader": self.DummyTensorDataLoader(
46394641
padding=True, batch_size=batch_size
46404642
),
4641-
"data_keys": ["observation"],
4643+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
46424644
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
46434645
"stack_method": stack_method,
46444646
}
4645-
kwargs.update({"str2str": str2str, "device": device})
4647+
kwargs.update(
4648+
{
4649+
"str2str": str2str,
4650+
"device": device,
4651+
"has_attention": False,
4652+
"no_stack": False,
4653+
}
4654+
)
46464655
env = LLMEnv.from_dataloader(**kwargs)
46474656
assert not env.batch_locked
46484657
if batched:
@@ -4655,46 +4664,64 @@ def test_llm_from_dataloader(
46554664
def policy(td):
46564665
if str2str:
46574666
if not td.shape:
4658-
td["action"] = "<nothing>"
4667+
td[LLMEnv._DEFAULT_ACTION_KEY] = "<nothing>"
46594668
else:
4660-
td["action"] = NonTensorStack(
4669+
td[LLMEnv._DEFAULT_ACTION_KEY] = NonTensorStack(
46614670
*["<nothing>" for _ in range(td.shape[0])]
46624671
)
46634672
else:
4664-
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4673+
td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones(
4674+
td.shape + (1,), dtype=torch.int64
4675+
)
46654676
return td
46664677

46674678
if batched:
46684679
# Tell the env that we want 3 sub-envs
46694680
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3]))
46704681
assert r.ndim == 2
46714682
if str2str:
4672-
assert isinstance(r[0, 0]["observation"], str)
4673-
assert isinstance(r[0, 1]["observation"], str)
4683+
assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str)
4684+
assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str)
46744685
assert (
4675-
r[0, 0]["observation"]
4676-
== r[0, 1]["observation"][: -len(r[0, 0]["action"])]
4686+
r[0, 0][LLMEnv._DEFAULT_STR_KEY]
4687+
== r[0, 1][LLMEnv._DEFAULT_STR_KEY][
4688+
: -len(r[0, 0][LLMEnv._DEFAULT_ACTION_KEY])
4689+
]
46774690
)
46784691
assert (
4679-
r[0, 1]["observation"]
4680-
== r[0, 2]["observation"][: -len(r[0, 1]["action"])]
4692+
r[0, 1][LLMEnv._DEFAULT_STR_KEY]
4693+
== r[0, 2][LLMEnv._DEFAULT_STR_KEY][
4694+
: -len(r[0, 1][LLMEnv._DEFAULT_ACTION_KEY])
4695+
]
46814696
)
46824697
assert (
4683-
r[-1, 0]["observation"]
4684-
== r[-1, 1]["observation"][: -len(r[-1, 0]["action"])]
4698+
r[-1, 0][LLMEnv._DEFAULT_STR_KEY]
4699+
== r[-1, 1][LLMEnv._DEFAULT_STR_KEY][
4700+
: -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_KEY])
4701+
]
46854702
)
46864703
assert (
4687-
r[-1, 1]["observation"]
4688-
== r[-1, 2]["observation"][: -len(r[-1, 1]["action"])]
4704+
r[-1, 1][LLMEnv._DEFAULT_STR_KEY]
4705+
== r[-1, 2][LLMEnv._DEFAULT_STR_KEY][
4706+
: -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_KEY])
4707+
]
46894708
)
46904709
else:
4691-
assert (r[0, 0]["observation"] == r[0, 1]["observation"][:-1]).all()
4692-
assert (r[0, 1]["observation"] == r[0, 2]["observation"][:-1]).all()
46934710
assert (
4694-
r[-1, 0]["observation"] == r[-1, 1]["observation"][:-1]
4711+
r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4712+
== r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
4713+
).all()
4714+
assert (
4715+
r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY]
4716+
== r[0, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
46954717
).all()
46964718
assert (
4697-
r[-1, 1]["observation"] == r[-1, 2]["observation"][:-1]
4719+
r[-1, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4720+
== r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
4721+
).all()
4722+
assert (
4723+
r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY]
4724+
== r[-1, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
46984725
).all()
46994726
else:
47004727
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
@@ -4720,7 +4747,7 @@ def test_llm_from_dataloader_repeats(
47204747
if str2str:
47214748
kwargs = {
47224749
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4723-
"data_keys": ["observation"],
4750+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
47244751
"example_data": "a string!",
47254752
"repeats": repeats,
47264753
}
@@ -4731,12 +4758,19 @@ def test_llm_from_dataloader_repeats(
47314758
"dataloader": self.DummyTensorDataLoader(
47324759
padding=True, batch_size=batch_size
47334760
),
4734-
"data_keys": ["observation"],
4761+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
47354762
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
47364763
"stack_method": stack_method,
47374764
"repeats": repeats,
47384765
}
4739-
kwargs.update({"str2str": str2str, "device": device})
4766+
kwargs.update(
4767+
{
4768+
"str2str": str2str,
4769+
"device": device,
4770+
"has_attention": False,
4771+
"no_stack": False,
4772+
}
4773+
)
47404774
env = LLMEnv.from_dataloader(**kwargs)
47414775
assert env.transform.repeats == repeats
47424776

@@ -4746,13 +4780,15 @@ def test_llm_from_dataloader_repeats(
47464780
def policy(td):
47474781
if str2str:
47484782
if not td.shape:
4749-
td["action"] = "<nothing>"
4783+
td[LLMEnv._DEFAULT_ACTION_KEY] = "<nothing>"
47504784
else:
4751-
td["action"] = NonTensorStack(
4785+
td[LLMEnv._DEFAULT_ACTION_KEY] = NonTensorStack(
47524786
*["<nothing>" for _ in range(td.shape[0])]
47534787
)
47544788
else:
4755-
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4789+
td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones(
4790+
td.shape + (1,), dtype=torch.int64
4791+
)
47564792
return td
47574793

47584794
if batched:
@@ -4768,34 +4804,58 @@ def policy(td):
47684804
r_reset = r[..., ::max_steps]
47694805
if not batched:
47704806
if str2str:
4771-
assert r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4772-
assert r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4773-
assert r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4807+
assert (
4808+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4809+
== r_reset[..., 1][LLMEnv._DEFAULT_STR_KEY]
4810+
)
4811+
assert (
4812+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4813+
== r_reset[..., 2][LLMEnv._DEFAULT_STR_KEY]
4814+
)
4815+
assert (
4816+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4817+
!= r_reset[..., 3][LLMEnv._DEFAULT_STR_KEY]
4818+
)
47744819
else:
47754820
assert (
4776-
r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4821+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4822+
== r_reset[..., 1][LLMEnv._DEFAULT_TOKEN_KEY]
47774823
).all()
47784824
assert (
4779-
r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4825+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4826+
== r_reset[..., 2][LLMEnv._DEFAULT_TOKEN_KEY]
47804827
).all()
47814828
assert (
4782-
r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4829+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4830+
!= r_reset[..., 3][LLMEnv._DEFAULT_TOKEN_KEY]
47834831
).any()
47844832
else:
47854833
# When batched, each block contains the 3 reset packs
47864834
if str2str:
4787-
assert r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4788-
assert r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4789-
assert r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4835+
assert (
4836+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4837+
== r_reset[1, 0][LLMEnv._DEFAULT_STR_KEY]
4838+
)
4839+
assert (
4840+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4841+
== r_reset[2, 0][LLMEnv._DEFAULT_STR_KEY]
4842+
)
4843+
assert (
4844+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4845+
!= r_reset[0, 1][LLMEnv._DEFAULT_STR_KEY]
4846+
)
47904847
else:
47914848
assert (
4792-
r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4849+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4850+
== r_reset[1, 0][LLMEnv._DEFAULT_TOKEN_KEY]
47934851
).all()
47944852
assert (
4795-
r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4853+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4854+
== r_reset[2, 0][LLMEnv._DEFAULT_TOKEN_KEY]
47964855
).all()
47974856
assert (
4798-
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4857+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4858+
!= r_reset[0, 1][LLMEnv._DEFAULT_TOKEN_KEY]
47994859
).any()
48004860

48014861
@pytest.mark.parametrize(
@@ -4829,7 +4889,7 @@ def test_done_and_reward(
48294889
if str2str:
48304890
kwargs = {
48314891
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4832-
"data_keys": ["observation"],
4892+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
48334893
"example_data": "a string!",
48344894
"repeats": repeats,
48354895
"assign_reward": assign_reward,
@@ -4842,20 +4902,27 @@ def test_done_and_reward(
48424902
"dataloader": self.DummyTensorDataLoader(
48434903
padding=True, batch_size=batch_size
48444904
),
4845-
"data_keys": ["observation"],
4905+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
48464906
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
48474907
"stack_method": stack_method,
48484908
"repeats": repeats,
48494909
"assign_reward": assign_reward,
48504910
"assign_done": assign_done,
48514911
}
4852-
kwargs.update({"str2str": str2str, "device": device})
4912+
kwargs.update(
4913+
{
4914+
"str2str": str2str,
4915+
"device": device,
4916+
"has_attention": False,
4917+
"no_stack": False,
4918+
}
4919+
)
48534920
env = LLMEnv.from_dataloader(**kwargs)
48544921
# We want to make sure that transforms that rely on the done state work appropriately
48554922
env.append_transform(StepCounter(max_steps=10))
48564923

48574924
def policy(td):
4858-
td["action"] = torch.ones(
4925+
td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones(
48594926
td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64
48604927
)
48614928
return td

torchrl/data/postprocs/postprocs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch import nn
1313

1414

15-
1615
def _get_reward(
1716
gamma: float,
1817
reward: torch.Tensor,
@@ -367,6 +366,7 @@ def __init__(
367366
discount: float = 1.0,
368367
):
369368
from torchrl.objectives.value.functional import reward2go
369+
370370
super().__init__()
371371
self.in_keys = [unravel_key(reward_key), unravel_key(done_key)]
372372
if reward_key_out is None:

torchrl/data/replay_buffers/storages.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1536,10 +1536,10 @@ def _collate_id(x):
15361536

15371537

15381538
def _get_default_collate(storage, _is_tensordict=False):
1539-
if isinstance(storage, ListStorage):
1540-
return _stack_anything
1541-
elif isinstance(storage, TensorStorage):
1539+
if isinstance(storage, LazyStackStorage) or isinstance(storage, TensorStorage):
15421540
return _collate_id
1541+
elif isinstance(storage, ListStorage):
1542+
return _stack_anything
15431543
else:
15441544
raise NotImplementedError(
15451545
f"Could not find a default collate_fn for storage {type(storage)}."

0 commit comments

Comments
 (0)