Skip to content

Commit 29f7971

Browse files
committed
[Feature] ConditionalPolicySwitch transform
ghstack-source-id: defb61a Pull Request resolved: #2711
1 parent dd2bf20 commit 29f7971

File tree

7 files changed

+472
-21
lines changed

7 files changed

+472
-21
lines changed

docs/source/reference/envs.rst

+1
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,7 @@ to be able to create this other composition:
816816
CenterCrop
817817
ClipTransform
818818
Compose
819+
ConditionalPolicySwitch
819820
Crop
820821
DTypeCastTransform
821822
DeviceCastTransform

examples/agents/ppo-chess.py

+54-18
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,24 @@
55
import tensordict.nn
66
import torch
77
import tqdm
8-
from tensordict.nn import TensorDictSequential as TDSeq, TensorDictModule as TDMod, \
9-
ProbabilisticTensorDictModule as TDProb, ProbabilisticTensorDictSequential as TDProbSeq
8+
from tensordict.nn import (
9+
ProbabilisticTensorDictModule as TDProb,
10+
ProbabilisticTensorDictSequential as TDProbSeq,
11+
TensorDictModule as TDMod,
12+
TensorDictSequential as TDSeq,
13+
)
1014
from torch import nn
1115
from torch.nn.utils import clip_grad_norm_
1216
from torch.optim import Adam
1317

1418
from torchrl.collectors import SyncDataCollector
19+
from torchrl.data import LazyTensorStorage, ReplayBuffer, SamplerWithoutReplacement
1520

1621
from torchrl.envs import ChessEnv, Tokenizer
1722
from torchrl.modules import MLP
1823
from torchrl.modules.distributions import MaskedCategorical
1924
from torchrl.objectives import ClipPPOLoss
2025
from torchrl.objectives.value import GAE
21-
from torchrl.data import ReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement
2226

2327
tensordict.nn.set_composite_lp_aggregate(False)
2428

@@ -39,7 +43,9 @@
3943
embedding_moves = nn.Embedding(num_embeddings=n + 1, embedding_dim=64)
4044

4145
# Embedding for the fen
42-
embedding_fen = nn.Embedding(num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64)
46+
embedding_fen = nn.Embedding(
47+
num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64
48+
)
4349

4450
backbone = MLP(out_features=512, num_cells=[512] * 8, activation_class=nn.ReLU)
4551

@@ -49,20 +55,30 @@
4955
critic_head = nn.Linear(512, 1)
5056
critic_head.bias.data.fill_(0)
5157

52-
prob = TDProb(in_keys=["logits", "mask"], out_keys=["action"], distribution_class=MaskedCategorical, return_log_prob=True)
58+
prob = TDProb(
59+
in_keys=["logits", "mask"],
60+
out_keys=["action"],
61+
distribution_class=MaskedCategorical,
62+
return_log_prob=True,
63+
)
64+
5365

5466
def make_mask(idx):
5567
mask = idx.new_zeros((*idx.shape[:-1], n + 1), dtype=torch.bool)
5668
return mask.scatter_(-1, idx, torch.ones_like(idx, dtype=torch.bool))[..., :-1]
5769

70+
5871
actor = TDProbSeq(
59-
TDMod(
60-
make_mask,
61-
in_keys=["legal_moves"], out_keys=["mask"]),
72+
TDMod(make_mask, in_keys=["legal_moves"], out_keys=["mask"]),
6273
TDMod(embedding_moves, in_keys=["legal_moves"], out_keys=["embedded_legal_moves"]),
6374
TDMod(embedding_fen, in_keys=["fen_tokenized"], out_keys=["embedded_fen"]),
64-
TDMod(lambda *args: torch.cat([arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1), in_keys=["embedded_legal_moves", "embedded_fen"],
65-
out_keys=["features"]),
75+
TDMod(
76+
lambda *args: torch.cat(
77+
[arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1
78+
),
79+
in_keys=["embedded_legal_moves", "embedded_fen"],
80+
out_keys=["features"],
81+
),
6682
TDMod(backbone, in_keys=["features"], out_keys=["hidden"]),
6783
TDMod(actor_head, in_keys=["hidden"], out_keys=["logits"]),
6884
prob,
@@ -78,7 +94,9 @@ def make_mask(idx):
7894

7995
optim = Adam(loss.parameters())
8096

81-
gae = GAE(value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True)
97+
gae = GAE(
98+
value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True
99+
)
82100

83101
# Create a data collector
84102
collector = SyncDataCollector(
@@ -88,12 +106,20 @@ def make_mask(idx):
88106
total_frames=1_000_000,
89107
)
90108

91-
replay_buffer0 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
92-
replay_buffer1 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
109+
replay_buffer0 = ReplayBuffer(
110+
storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2),
111+
batch_size=batch_size,
112+
sampler=SamplerWithoutReplacement(),
113+
)
114+
replay_buffer1 = ReplayBuffer(
115+
storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2),
116+
batch_size=batch_size,
117+
sampler=SamplerWithoutReplacement(),
118+
)
93119

94120
for data in tqdm.tqdm(collector):
95121
data = data.filter_non_tensor_data()
96-
print('data', data[0::2])
122+
print("data", data[0::2])
97123
for i in range(num_epochs):
98124
replay_buffer0.empty()
99125
replay_buffer1.empty()
@@ -103,14 +129,24 @@ def make_mask(idx):
103129
# player 1
104130
data1 = gae(data[1::2])
105131
if i == 0:
106-
print('win rate for 0', data0["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
107-
print('win rate for 1', data1["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
132+
print(
133+
"win rate for 0",
134+
data0["next", "reward"].sum()
135+
/ data["next", "done"].sum().clamp_min(1e-6),
136+
)
137+
print(
138+
"win rate for 1",
139+
data1["next", "reward"].sum()
140+
/ data["next", "done"].sum().clamp_min(1e-6),
141+
)
108142

109143
replay_buffer0.extend(data0)
110144
replay_buffer1.extend(data1)
111145

112-
n_iter = collector.frames_per_batch//(2 * batch_size)
113-
for (d0, d1) in tqdm.tqdm(zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter):
146+
n_iter = collector.frames_per_batch // (2 * batch_size)
147+
for (d0, d1) in tqdm.tqdm(
148+
zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter
149+
):
114150
loss_vals = (loss(d0) + loss(d1)) / 2
115151
loss_vals.sum(reduce=True).backward()
116152
gn = clip_grad_norm_(loss.parameters(), 100.0)

test/test_transforms.py

+171
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import tensordict.tensordict
2222
import torch
23+
from tensordict.nn import WrapModule
2324

2425
from torchrl.collectors import MultiSyncDataCollector
2526

@@ -106,6 +107,7 @@
106107
CenterCrop,
107108
ClipTransform,
108109
Compose,
110+
ConditionalPolicySwitch,
109111
Crop,
110112
DeviceCastTransform,
111113
DiscreteActionProjection,
@@ -13192,6 +13194,175 @@ def test_composite_reward_spec(self) -> None:
1319213194
assert transform.transform_reward_spec(reward_spec) == expected_reward_spec
1319313195

1319413196

13197+
class TestConditionalPolicySwitch(TransformBase):
13198+
def test_single_trans_env_check(self):
13199+
base_env = CountingEnv(max_steps=15)
13200+
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
13201+
# Player 0
13202+
policy_odd = lambda td: td.set("action", env.action_spec.zero())
13203+
policy_even = lambda td: td.set("action", env.action_spec.one())
13204+
transforms = Compose(
13205+
StepCounter(),
13206+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13207+
)
13208+
env = base_env.append_transform(transforms)
13209+
r = env.rollout(1000, policy_odd, break_when_all_done=True)
13210+
assert r.shape[0] == 15
13211+
assert (r["action"] == 0).all()
13212+
assert (
13213+
r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1)
13214+
).all()
13215+
assert r["next", "done"].any()
13216+
13217+
# Player 1
13218+
condition = lambda td: ((td.get("step_count") % 2) == 1).all()
13219+
transforms = Compose(
13220+
StepCounter(),
13221+
ConditionalPolicySwitch(condition=condition, policy=policy_odd),
13222+
)
13223+
env = base_env.append_transform(transforms)
13224+
r = env.rollout(1000, policy_even, break_when_all_done=True)
13225+
assert r.shape[0] == 16
13226+
assert (r["action"] == 1).all()
13227+
assert (
13228+
r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1)
13229+
).all()
13230+
assert r["next", "done"].any()
13231+
13232+
def _create_policy_odd(self, base_env):
13233+
return WrapModule(
13234+
lambda td, base_env=base_env: td.set(
13235+
"action", base_env.action_spec_unbatched.zero(td.shape)
13236+
),
13237+
out_keys=["action"],
13238+
)
13239+
13240+
def _create_policy_even(self, base_env):
13241+
return WrapModule(
13242+
lambda td, base_env=base_env: td.set(
13243+
"action", base_env.action_spec_unbatched.one(td.shape)
13244+
),
13245+
out_keys=["action"],
13246+
)
13247+
13248+
def _create_transforms(self, condition, policy_even):
13249+
return Compose(
13250+
StepCounter(),
13251+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13252+
)
13253+
13254+
def _make_env(self, max_count, env_cls):
13255+
torch.manual_seed(0)
13256+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13257+
base_env = env_cls(max_steps=max_count)
13258+
policy_even = self._create_policy_even(base_env)
13259+
transforms = self._create_transforms(condition, policy_even)
13260+
return base_env.append_transform(transforms)
13261+
13262+
def _test_env(self, env, policy_odd):
13263+
env.check_env_specs()
13264+
env.set_seed(0)
13265+
r = env.rollout(100, policy_odd, break_when_any_done=False)
13266+
# Check results are independent: one reset / step in one env should not impact results in another
13267+
r0, r1, r2 = r.unbind(0)
13268+
r0_split = r0.split(6)
13269+
assert all(((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:]))
13270+
r1_split = r1.split(7)
13271+
assert all(((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:]))
13272+
r2_split = r2.split(8)
13273+
assert all(((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:]))
13274+
13275+
def test_trans_serial_env_check(self):
13276+
torch.manual_seed(0)
13277+
base_env = SerialEnv(
13278+
3,
13279+
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
13280+
batch_locked=False,
13281+
)
13282+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13283+
policy_odd = self._create_policy_odd(base_env)
13284+
policy_even = self._create_policy_even(base_env)
13285+
transforms = self._create_transforms(condition, policy_even)
13286+
env = base_env.append_transform(transforms)
13287+
self._test_env(env, policy_odd)
13288+
13289+
def test_trans_parallel_env_check(self):
13290+
torch.manual_seed(0)
13291+
base_env = ParallelEnv(
13292+
3,
13293+
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
13294+
batch_locked=False,
13295+
mp_start_method=mp_ctx,
13296+
)
13297+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13298+
policy_odd = self._create_policy_odd(base_env)
13299+
policy_even = self._create_policy_even(base_env)
13300+
transforms = self._create_transforms(condition, policy_even)
13301+
env = base_env.append_transform(transforms)
13302+
self._test_env(env, policy_odd)
13303+
13304+
def test_serial_trans_env_check(self):
13305+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13306+
policy_odd = self._create_policy_odd(CountingEnv())
13307+
13308+
def make_env(max_count):
13309+
return partial(self._make_env, max_count, CountingEnv)
13310+
13311+
env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)])
13312+
self._test_env(env, policy_odd)
13313+
13314+
def test_parallel_trans_env_check(self):
13315+
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
13316+
policy_odd = self._create_policy_odd(CountingEnv())
13317+
13318+
def make_env(max_count):
13319+
return partial(self._make_env, max_count, CountingEnv)
13320+
13321+
env = ParallelEnv(
13322+
3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx
13323+
)
13324+
self._test_env(env, policy_odd)
13325+
13326+
def test_transform_no_env(self):
13327+
"""tests the transform on dummy data, without an env."""
13328+
raise NotImplementedError
13329+
13330+
def test_transform_compose(self):
13331+
"""tests the transform on dummy data, without an env but inside a Compose."""
13332+
raise NotImplementedError
13333+
13334+
def test_transform_env(self):
13335+
"""tests the transform on a real env.
13336+
13337+
If possible, do not use a mock env, as bugs may go unnoticed if the dynamic is too
13338+
simplistic. A call to reset() and step() should be tested independently, ie
13339+
a check that reset produces the desired output and that step() does too.
13340+
13341+
"""
13342+
raise NotImplementedError
13343+
13344+
def test_transform_model(self):
13345+
"""tests the transform before an nn.Module that reads the output."""
13346+
raise NotImplementedError
13347+
13348+
def test_transform_rb(self):
13349+
"""tests the transform when used with a replay buffer.
13350+
13351+
If your transform is not supposed to work with a replay buffer, test that
13352+
an error will be raised when called or appended to a RB.
13353+
13354+
"""
13355+
raise NotImplementedError
13356+
13357+
def test_transform_inverse(self):
13358+
"""tests the inverse transform. If not applicable, simply skip this test.
13359+
13360+
If your transform is not supposed to work offline, test that
13361+
an error will be raised when called in a nn.Module.
13362+
"""
13363+
raise NotImplementedError
13364+
13365+
1319513366
if __name__ == "__main__":
1319613367
args, unknown = argparse.ArgumentParser().parse_known_args()
1319713368
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
CenterCrop,
5656
ClipTransform,
5757
Compose,
58+
ConditionalPolicySwitch,
5859
Crop,
5960
DeviceCastTransform,
6061
DiscreteActionProjection,

torchrl/envs/batched_envs.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ class BatchedEnvBase(EnvBase):
191191
one of the environment has dynamic specs.
192192
193193
.. note:: Learn more about dynamic specs and environments :ref:`here <dynamic_envs>`.
194+
batch_locked (bool, optional): if provided, will override the ``batch_locked`` attribute of the
195+
nested environments. `batch_locked=False` may allow for partial steps.
194196
195197
.. note::
196198
One can pass keyword arguments to each sub-environments using the following
@@ -305,6 +307,7 @@ def __init__(
305307
non_blocking: bool = False,
306308
mp_start_method: str = None,
307309
use_buffers: bool = None,
310+
batch_locked: bool | None = None,
308311
):
309312
super().__init__(device=device)
310313
self.serial_for_single = serial_for_single
@@ -344,6 +347,7 @@ def __init__(
344347

345348
# if share_individual_td is None, we will assess later if the output can be stacked
346349
self.share_individual_td = share_individual_td
350+
self._batch_locked = batch_locked
347351
self._share_memory = shared_memory
348352
self._memmap = memmap
349353
self.allow_step_when_done = allow_step_when_done
@@ -610,8 +614,8 @@ def map_device(key, value, device_map=device_map):
610614
self._env_tensordict.named_apply(
611615
map_device, nested_keys=True, filter_empty=True
612616
)
613-
614-
self._batch_locked = meta_data.batch_locked
617+
if self._batch_locked is None:
618+
self._batch_locked = meta_data.batch_locked
615619
else:
616620
self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size])
617621
devices = set()
@@ -652,7 +656,8 @@ def map_device(key, value, device_map=device_map):
652656
self._env_tensordict = torch.stack(
653657
[meta_data.tensordict for meta_data in meta_data], 0
654658
)
655-
self._batch_locked = meta_data[0].batch_locked
659+
if self._batch_locked is None:
660+
self._batch_locked = meta_data[0].batch_locked
656661
self.has_lazy_inputs = contains_lazy_spec(self.input_spec)
657662

658663
def state_dict(self) -> OrderedDict:

0 commit comments

Comments
 (0)