diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst
index d75a0e67c54..da9c2114161 100644
--- a/docs/source/reference/envs.rst
+++ b/docs/source/reference/envs.rst
@@ -867,6 +867,7 @@ to be able to create this other composition:
     CenterCrop
     ClipTransform
     Compose
+    ConditionalPolicySwitch
     Crop
     DTypeCastTransform
     DeviceCastTransform
diff --git a/examples/agents/ppo-chess.py b/examples/agents/ppo-chess.py
index f9527339e2a..6c3a7886ee5 100644
--- a/examples/agents/ppo-chess.py
+++ b/examples/agents/ppo-chess.py
@@ -5,20 +5,24 @@
 import tensordict.nn
 import torch
 import tqdm
-from tensordict.nn import TensorDictSequential as TDSeq, TensorDictModule as TDMod, \
-    ProbabilisticTensorDictModule as TDProb, ProbabilisticTensorDictSequential as TDProbSeq
+from tensordict.nn import (
+    ProbabilisticTensorDictModule as TDProb,
+    ProbabilisticTensorDictSequential as TDProbSeq,
+    TensorDictModule as TDMod,
+    TensorDictSequential as TDSeq,
+)
 from torch import nn
 from torch.nn.utils import clip_grad_norm_
 from torch.optim import Adam
 
 from torchrl.collectors import SyncDataCollector
+from torchrl.data import LazyTensorStorage, ReplayBuffer, SamplerWithoutReplacement
 
 from torchrl.envs import ChessEnv, Tokenizer
 from torchrl.modules import MLP
 from torchrl.modules.distributions import MaskedCategorical
 from torchrl.objectives import ClipPPOLoss
 from torchrl.objectives.value import GAE
-from torchrl.data import ReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement
 
 tensordict.nn.set_composite_lp_aggregate(False)
 
@@ -39,7 +43,9 @@
 embedding_moves = nn.Embedding(num_embeddings=n + 1, embedding_dim=64)
 
 # Embedding for the fen
-embedding_fen = nn.Embedding(num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64)
+embedding_fen = nn.Embedding(
+    num_embeddings=transform.tokenizer.vocab_size, embedding_dim=64
+)
 
 backbone = MLP(out_features=512, num_cells=[512] * 8, activation_class=nn.ReLU)
 
@@ -49,20 +55,30 @@
 critic_head = nn.Linear(512, 1)
 critic_head.bias.data.fill_(0)
 
-prob = TDProb(in_keys=["logits", "mask"], out_keys=["action"], distribution_class=MaskedCategorical, return_log_prob=True)
+prob = TDProb(
+    in_keys=["logits", "mask"],
+    out_keys=["action"],
+    distribution_class=MaskedCategorical,
+    return_log_prob=True,
+)
+
 
 def make_mask(idx):
     mask = idx.new_zeros((*idx.shape[:-1], n + 1), dtype=torch.bool)
     return mask.scatter_(-1, idx, torch.ones_like(idx, dtype=torch.bool))[..., :-1]
 
+
 actor = TDProbSeq(
-    TDMod(
-        make_mask,
-        in_keys=["legal_moves"], out_keys=["mask"]),
+    TDMod(make_mask, in_keys=["legal_moves"], out_keys=["mask"]),
     TDMod(embedding_moves, in_keys=["legal_moves"], out_keys=["embedded_legal_moves"]),
     TDMod(embedding_fen, in_keys=["fen_tokenized"], out_keys=["embedded_fen"]),
-    TDMod(lambda *args: torch.cat([arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1), in_keys=["embedded_legal_moves", "embedded_fen"],
-          out_keys=["features"]),
+    TDMod(
+        lambda *args: torch.cat(
+            [arg.view(*arg.shape[:-2], -1) for arg in args], dim=-1
+        ),
+        in_keys=["embedded_legal_moves", "embedded_fen"],
+        out_keys=["features"],
+    ),
     TDMod(backbone, in_keys=["features"], out_keys=["hidden"]),
     TDMod(actor_head, in_keys=["hidden"], out_keys=["logits"]),
     prob,
@@ -78,7 +94,9 @@ def make_mask(idx):
 
 optim = Adam(loss.parameters())
 
-gae = GAE(value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True)
+gae = GAE(
+    value_network=TDSeq(*actor[:-2], critic), gamma=0.99, lmbda=0.95, shifted=True
+)
 
 # Create a data collector
 collector = SyncDataCollector(
@@ -88,12 +106,20 @@ def make_mask(idx):
     total_frames=1_000_000,
 )
 
-replay_buffer0 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
-replay_buffer1 = ReplayBuffer(storage=LazyTensorStorage(max_size=collector.frames_per_batch//2), batch_size=batch_size, sampler=SamplerWithoutReplacement())
+replay_buffer0 = ReplayBuffer(
+    storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2),
+    batch_size=batch_size,
+    sampler=SamplerWithoutReplacement(),
+)
+replay_buffer1 = ReplayBuffer(
+    storage=LazyTensorStorage(max_size=collector.frames_per_batch // 2),
+    batch_size=batch_size,
+    sampler=SamplerWithoutReplacement(),
+)
 
 for data in tqdm.tqdm(collector):
     data = data.filter_non_tensor_data()
-    print('data', data[0::2])
+    print("data", data[0::2])
     for i in range(num_epochs):
         replay_buffer0.empty()
         replay_buffer1.empty()
@@ -103,14 +129,24 @@ def make_mask(idx):
             # player 1
             data1 = gae(data[1::2])
             if i == 0:
-                print('win rate for 0', data0["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
-                print('win rate for 1', data1["next", "reward"].sum() / data["next", "done"].sum().clamp_min(1e-6))
+                print(
+                    "win rate for 0",
+                    data0["next", "reward"].sum()
+                    / data["next", "done"].sum().clamp_min(1e-6),
+                )
+                print(
+                    "win rate for 1",
+                    data1["next", "reward"].sum()
+                    / data["next", "done"].sum().clamp_min(1e-6),
+                )
 
             replay_buffer0.extend(data0)
             replay_buffer1.extend(data1)
 
-        n_iter = collector.frames_per_batch//(2 * batch_size)
-        for (d0, d1) in tqdm.tqdm(zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter):
+        n_iter = collector.frames_per_batch // (2 * batch_size)
+        for (d0, d1) in tqdm.tqdm(
+            zip(replay_buffer0, replay_buffer1, strict=True), total=n_iter
+        ):
             loss_vals = (loss(d0) + loss(d1)) / 2
             loss_vals.sum(reduce=True).backward()
             gn = clip_grad_norm_(loss.parameters(), 100.0)
diff --git a/test/test_transforms.py b/test/test_transforms.py
index c480015bf17..adf30a69bac 100644
--- a/test/test_transforms.py
+++ b/test/test_transforms.py
@@ -20,6 +20,8 @@
 
 import tensordict.tensordict
 import torch
+from tensordict.nn import WrapModule
+
 from tensordict import (
     NonTensorData,
     NonTensorStack,
@@ -56,6 +58,7 @@
     CenterCrop,
     ClipTransform,
     Compose,
+    ConditionalPolicySwitch,
     Crop,
     DeviceCastTransform,
     DiscreteActionProjection,
@@ -13338,6 +13341,206 @@ def test_composite_reward_spec(self) -> None:
         assert transform.transform_reward_spec(reward_spec) == expected_reward_spec
 
 
+class TestConditionalPolicySwitch(TransformBase):
+    def test_single_trans_env_check(self):
+        base_env = CountingEnv(max_steps=15)
+        condition = lambda td: ((td.get("step_count") % 2) == 0).all()
+        # Player 0
+        policy_odd = lambda td: td.set("action", env.action_spec.zero())
+        policy_even = lambda td: td.set("action", env.action_spec.one())
+        transforms = Compose(
+            StepCounter(),
+            ConditionalPolicySwitch(condition=condition, policy=policy_even),
+        )
+        env = base_env.append_transform(transforms)
+        env.check_env_specs()
+
+    def _create_policy_odd(self, base_env):
+        return WrapModule(
+            lambda td, base_env=base_env: td.set(
+                "action", base_env.action_spec_unbatched.zero(td.shape)
+            ),
+            out_keys=["action"],
+        )
+
+    def _create_policy_even(self, base_env):
+        return WrapModule(
+            lambda td, base_env=base_env: td.set(
+                "action", base_env.action_spec_unbatched.one(td.shape)
+            ),
+            out_keys=["action"],
+        )
+
+    def _create_transforms(self, condition, policy_even):
+        return Compose(
+            StepCounter(),
+            ConditionalPolicySwitch(condition=condition, policy=policy_even),
+        )
+
+    def _make_env(self, max_count, env_cls):
+        torch.manual_seed(0)
+        condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
+        base_env = env_cls(max_steps=max_count)
+        policy_even = self._create_policy_even(base_env)
+        transforms = self._create_transforms(condition, policy_even)
+        return base_env.append_transform(transforms)
+
+    def _test_env(self, env, policy_odd):
+        env.check_env_specs()
+        env.set_seed(0)
+        r = env.rollout(100, policy_odd, break_when_any_done=False)
+        # Check results are independent: one reset / step in one env should not impact results in another
+        r0, r1, r2 = r.unbind(0)
+        r0_split = r0.split(6)
+        assert all(((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:]))
+        r1_split = r1.split(7)
+        assert all(((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:]))
+        r2_split = r2.split(8)
+        assert all(((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:]))
+
+    def test_trans_serial_env_check(self):
+        torch.manual_seed(0)
+        base_env = SerialEnv(
+            3,
+            [partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
+            batch_locked=False,
+        )
+        condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
+        policy_odd = self._create_policy_odd(base_env)
+        policy_even = self._create_policy_even(base_env)
+        transforms = self._create_transforms(condition, policy_even)
+        env = base_env.append_transform(transforms)
+        self._test_env(env, policy_odd)
+
+    def test_trans_parallel_env_check(self):
+        torch.manual_seed(0)
+        base_env = ParallelEnv(
+            3,
+            [partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
+            batch_locked=False,
+            mp_start_method=mp_ctx,
+        )
+        condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
+        policy_odd = self._create_policy_odd(base_env)
+        policy_even = self._create_policy_even(base_env)
+        transforms = self._create_transforms(condition, policy_even)
+        env = base_env.append_transform(transforms)
+        self._test_env(env, policy_odd)
+
+    def test_serial_trans_env_check(self):
+        condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
+        policy_odd = self._create_policy_odd(CountingEnv())
+
+        def make_env(max_count):
+            return partial(self._make_env, max_count, CountingEnv)
+
+        env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)])
+        self._test_env(env, policy_odd)
+
+    def test_parallel_trans_env_check(self):
+        condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
+        policy_odd = self._create_policy_odd(CountingEnv())
+
+        def make_env(max_count):
+            return partial(self._make_env, max_count, CountingEnv)
+
+        env = ParallelEnv(
+            3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx
+        )
+        self._test_env(env, policy_odd)
+
+    def test_transform_no_env(self):
+        policy_odd = lambda td: td
+        policy_even = lambda td: td
+        condition = lambda td: True
+        transforms = ConditionalPolicySwitch(condition=condition, policy=policy_even)
+        with pytest.raises(
+            RuntimeError,
+            match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
+        ):
+            transforms(TensorDict())
+
+    def test_transform_compose(self):
+        policy_odd = lambda td: td
+        policy_even = lambda td: td
+        condition = lambda td: True
+        transforms = Compose(
+            ConditionalPolicySwitch(condition=condition, policy=policy_even),
+        )
+        with pytest.raises(
+            RuntimeError,
+            match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
+        ):
+            transforms(TensorDict())
+
+    def test_transform_env(self):
+        base_env = CountingEnv(max_steps=15)
+        condition = lambda td: ((td.get("step_count") % 2) == 0).all()
+        # Player 0
+        policy_odd = lambda td: td.set("action", env.action_spec.zero())
+        policy_even = lambda td: td.set("action", env.action_spec.one())
+        transforms = Compose(
+            StepCounter(),
+            ConditionalPolicySwitch(condition=condition, policy=policy_even),
+        )
+        env = base_env.append_transform(transforms)
+        env.check_env_specs()
+        r = env.rollout(1000, policy_odd, break_when_all_done=True)
+        assert r.shape[0] == 15
+        assert (r["action"] == 0).all()
+        assert (
+            r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1)
+        ).all()
+        assert r["next", "done"].any()
+
+        # Player 1
+        condition = lambda td: ((td.get("step_count") % 2) == 1).all()
+        transforms = Compose(
+            StepCounter(),
+            ConditionalPolicySwitch(condition=condition, policy=policy_odd),
+        )
+        env = base_env.append_transform(transforms)
+        r = env.rollout(1000, policy_even, break_when_all_done=True)
+        assert r.shape[0] == 16
+        assert (r["action"] == 1).all()
+        assert (
+            r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1)
+        ).all()
+        assert r["next", "done"].any()
+
+    def test_transform_model(self):
+        policy_odd = lambda td: td
+        policy_even = lambda td: td
+        condition = lambda td: True
+        transforms = nn.Sequential(
+            ConditionalPolicySwitch(condition=condition, policy=policy_even),
+        )
+        with pytest.raises(
+            RuntimeError,
+            match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
+        ):
+            transforms(TensorDict())
+
+    @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
+    def test_transform_rb(self, rbclass):
+        policy_odd = lambda td: td
+        policy_even = lambda td: td
+        condition = lambda td: True
+        rb = rbclass(storage=LazyTensorStorage(10))
+        rb.append_transform(
+            ConditionalPolicySwitch(condition=condition, policy=policy_even)
+        )
+        rb.extend(TensorDict(batch_size=[2]))
+        with pytest.raises(
+            RuntimeError,
+            match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
+        ):
+            rb.sample(2)
+
+    def test_transform_inverse(self):
+        return
+
+
 if __name__ == "__main__":
     args, unknown = argparse.ArgumentParser().parse_known_args()
     pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py
index 52f0cfdbbc5..a103a979174 100644
--- a/torchrl/envs/__init__.py
+++ b/torchrl/envs/__init__.py
@@ -56,6 +56,7 @@
     CenterCrop,
     ClipTransform,
     Compose,
+    ConditionalPolicySwitch,
     Crop,
     DeviceCastTransform,
     DiscreteActionProjection,
diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py
index 51331a86346..50a77a8f557 100644
--- a/torchrl/envs/batched_envs.py
+++ b/torchrl/envs/batched_envs.py
@@ -191,6 +191,8 @@ class BatchedEnvBase(EnvBase):
             one of the environment has dynamic specs.
 
               .. note:: Learn more about dynamic specs and environments :ref:`here <dynamic_envs>`.
+        batch_locked (bool, optional): if provided, will override the ``batch_locked`` attribute of the
+            nested environments. `batch_locked=False` may allow for partial steps.
 
     .. note::
         One can pass keyword arguments to each sub-environments using the following
@@ -305,6 +307,7 @@ def __init__(
         non_blocking: bool = False,
         mp_start_method: str = None,
         use_buffers: bool = None,
+        batch_locked: bool | None = None,
     ):
         super().__init__(device=device)
         self.serial_for_single = serial_for_single
@@ -344,6 +347,7 @@ def __init__(
 
         # if share_individual_td is None, we will assess later if the output can be stacked
         self.share_individual_td = share_individual_td
+        self._batch_locked = batch_locked
         self._share_memory = shared_memory
         self._memmap = memmap
         self.allow_step_when_done = allow_step_when_done
@@ -610,8 +614,8 @@ def map_device(key, value, device_map=device_map):
                 self._env_tensordict.named_apply(
                     map_device, nested_keys=True, filter_empty=True
                 )
-
-            self._batch_locked = meta_data.batch_locked
+            if self._batch_locked is None:
+                self._batch_locked = meta_data.batch_locked
         else:
             self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size])
             devices = set()
@@ -652,7 +656,8 @@ def map_device(key, value, device_map=device_map):
                 self._env_tensordict = torch.stack(
                     [meta_data.tensordict for meta_data in meta_data], 0
                 )
-            self._batch_locked = meta_data[0].batch_locked
+            if self._batch_locked is None:
+                self._batch_locked = meta_data[0].batch_locked
         self.has_lazy_inputs = contains_lazy_spec(self.input_spec)
 
     def state_dict(self) -> OrderedDict:
diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py
index d5b744cfc84..24d4205e5a1 100644
--- a/torchrl/envs/custom/chess.py
+++ b/torchrl/envs/custom/chess.py
@@ -176,9 +176,7 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
             batch_size=torch.Size([96]),
             device=None,
             is_shared=False)
-
-
-    """
+    """  # noqa: D301
 
     _hash_table: Dict[int, str] = {}
     _PGN_RESTART = """[Event "?"]
diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py
index 7ee142fe811..5f661cdee6e 100644
--- a/torchrl/envs/transforms/__init__.py
+++ b/torchrl/envs/transforms/__init__.py
@@ -20,6 +20,7 @@
     CenterCrop,
     ClipTransform,
     Compose,
+    ConditionalPolicySwitch,
     Crop,
     DeviceCastTransform,
     DiscreteActionProjection,
diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py
index 65eda4bc6ec..491cf295a03 100644
--- a/torchrl/envs/transforms/transforms.py
+++ b/torchrl/envs/transforms/transforms.py
@@ -85,6 +85,7 @@
 )
 from torchrl.envs.utils import (
     _sort_keys,
+    _terminated_or_truncated,
     _update_during_reset,
     make_composite_from_td,
     step_mdp,
@@ -10142,3 +10143,243 @@ def _apply_transform(self, reward: Tensor) -> TensorDictBase:
             )
 
         return (self.weights * reward).sum(dim=-1)
+
+
+class ConditionalPolicySwitch(Transform):
+    """A transform that conditionally switches between policies based on a specified condition.
+
+    This transform evaluates a condition on the data returned by the environment's `step` method.
+    If the condition is met, it applies a specified policy to the data. Otherwise, the data is
+    returned unaltered. This is useful for scenarios where different policies need to be applied
+    based on certain criteria, such as alternating turns in a game.
+
+    Args:
+        policy (Callable[[TensorDictBase], TensorDictBase]):
+            The policy to be applied when the condition is met. This should be a callable that
+            takes a `TensorDictBase` and returns a `TensorDictBase`.
+        condition (Callable[[TensorDictBase], bool]):
+            A callable that takes a `TensorDictBase` and returns a boolean or a tensor indicating
+            whether the policy should be applied.
+
+    .. warning:: This transform must have a parent environment.
+
+    .. note:: Ideally, it should be the last transform  in the stack. If the policy requires transformed
+        data (e.g., images), and this transform  is applied before those transformations, the policy will
+        not receive the transformed data.
+
+    Examples:
+        >>> import torch
+        >>> from tensordict.nn import TensorDictModule as Mod
+        >>>
+        >>> from torchrl.envs import GymEnv, ConditionalPolicySwitch, Compose, StepCounter
+        >>> # Create a CartPole environment. We'll be looking at the obs: if the first element of the obs is greater than
+        >>> # 0 (left position) we do a right action (action=0) using the switch policy. Otherwise, we use our main
+        >>> # policy which does a left action.
+        >>> base_env = GymEnv("CartPole-v1", categorical_action_encoding=True)
+        >>>
+        >>> policy = Mod(lambda: torch.ones((), dtype=torch.int64), in_keys=[], out_keys=["action"])
+        >>> policy_switch = Mod(lambda: torch.zeros((), dtype=torch.int64), in_keys=[], out_keys=["action"])
+        >>>
+        >>> cond = lambda td: td.get("observation")[..., 0] >= 0
+        >>>
+        >>> env = base_env.append_transform(
+        ...     Compose(
+        ...         # We use two step counters to show that one counts the global steps, whereas the other
+        ...         # only counts the steps where the main policy is executed
+        ...         StepCounter(step_count_key="step_count_total"),
+        ...         ConditionalPolicySwitch(condition=cond, policy=policy_switch),
+        ...         StepCounter(step_count_key="step_count_main"),
+        ...     )
+        ... )
+        >>>
+        >>> env.set_seed(0)
+        >>> torch.manual_seed(0)
+        >>>
+        >>> r = env.rollout(100, policy=policy)
+        >>> print("action", r["action"])
+        action tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
+        >>> print("obs", r["observation"])
+        obs tensor([[ 0.0322, -0.1540,  0.0111,  0.3190],
+                [ 0.0299, -0.1544,  0.0181,  0.3280],
+                [ 0.0276, -0.1550,  0.0255,  0.3414],
+                [ 0.0253, -0.1558,  0.0334,  0.3596],
+                [ 0.0230, -0.1569,  0.0422,  0.3828],
+                [ 0.0206, -0.1582,  0.0519,  0.4117],
+                [ 0.0181, -0.1598,  0.0629,  0.4469],
+                [ 0.0156, -0.1617,  0.0753,  0.4891],
+                [ 0.0130, -0.1639,  0.0895,  0.5394],
+                [ 0.0104, -0.1665,  0.1058,  0.5987],
+                [ 0.0076, -0.1696,  0.1246,  0.6685],
+                [ 0.0047, -0.1732,  0.1463,  0.7504],
+                [ 0.0016, -0.1774,  0.1715,  0.8459],
+                [-0.0020,  0.0150,  0.1884,  0.6117],
+                [-0.0017,  0.2071,  0.2006,  0.3838]])
+        >>> print("obs'", r["next", "observation"])
+        obs' tensor([[ 0.0299, -0.1544,  0.0181,  0.3280],
+                [ 0.0276, -0.1550,  0.0255,  0.3414],
+                [ 0.0253, -0.1558,  0.0334,  0.3596],
+                [ 0.0230, -0.1569,  0.0422,  0.3828],
+                [ 0.0206, -0.1582,  0.0519,  0.4117],
+                [ 0.0181, -0.1598,  0.0629,  0.4469],
+                [ 0.0156, -0.1617,  0.0753,  0.4891],
+                [ 0.0130, -0.1639,  0.0895,  0.5394],
+                [ 0.0104, -0.1665,  0.1058,  0.5987],
+                [ 0.0076, -0.1696,  0.1246,  0.6685],
+                [ 0.0047, -0.1732,  0.1463,  0.7504],
+                [ 0.0016, -0.1774,  0.1715,  0.8459],
+                [-0.0020,  0.0150,  0.1884,  0.6117],
+                [-0.0017,  0.2071,  0.2006,  0.3838],
+                [ 0.0105,  0.2015,  0.2115,  0.5110]])
+        >>> print("total step count", r["step_count_total"].squeeze())
+        total step count tensor([ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19, 21, 23, 25, 26, 27])
+        >>> print("total step with main policy", r["step_count_main"].squeeze())
+        total step with main policy tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])
+
+    """
+
+    def __init__(
+        self,
+        policy: Callable[[TensorDictBase], TensorDictBase],
+        condition: Callable[[TensorDictBase], bool],
+    ):
+        super().__init__([], [])
+        self.__dict__["policy"] = policy
+        self.condition = condition
+
+    def _step(
+        self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
+    ) -> TensorDictBase:
+        cond = self.condition(next_tensordict)
+        if not isinstance(cond, (bool, torch.Tensor)):
+            raise RuntimeError(
+                "Calling the condition function should return a boolean or a tensor."
+            )
+        elif isinstance(cond, (torch.Tensor,)):
+            if tuple(cond.shape) not in ((1,), (), tuple(tensordict.shape)):
+                raise RuntimeError(
+                    "Tensor outputs must have the shape of the tensordict, or contain a single element."
+                )
+        else:
+            cond = torch.tensor(cond, device=tensordict.device)
+
+        if cond.any():
+            step = tensordict.get("_step", cond)
+            if step.shape != cond.shape:
+                step = step.view_as(cond)
+            cond = cond & step
+
+            parent: TransformedEnv = self.parent
+            any_done, done = self._check_done(next_tensordict)
+            next_td_save = None
+            if any_done:
+                if next_tensordict.numel() == 1 or done.all():
+                    return next_tensordict
+                if parent.base_env.batch_locked:
+                    raise RuntimeError(
+                        "Cannot run partial steps in a batched locked environment. "
+                        "Hint: Parallel and Serial envs can be unlocked through a keyword argument in "
+                        "the constructor."
+                    )
+                done = done.view(next_tensordict.shape)
+                cond = cond & ~done
+            if not cond.all():
+                if parent.base_env.batch_locked:
+                    raise RuntimeError(
+                        "Cannot run partial steps in a batched locked environment. "
+                        "Hint: Parallel and Serial envs can be unlocked through a keyword argument in "
+                        "the constructor."
+                    )
+                next_td_save = next_tensordict
+                next_tensordict = next_tensordict[cond]
+                tensordict = tensordict[cond]
+
+            # policy may be expensive or raise an exception when executed with unadequate data so
+            # we index the td first
+            td = self.policy(
+                parent.step_mdp(tensordict.copy().set("next", next_tensordict))
+            )
+            # Mark the partial steps if needed
+            if next_td_save is not None:
+                td_new = td.new_zeros(cond.shape)
+                # TODO: swap with masked_scatter when avail
+                td_new[cond] = td
+                td = td_new
+                td.set("_step", cond)
+            next_tensordict = parent._step(td)
+            if next_td_save is not None:
+                return torch.where(cond, next_tensordict, next_td_save)
+            return next_tensordict
+        return next_tensordict
+
+    def _check_done(self, tensordict):
+        env = self.parent
+        if env._simple_done:
+            done = tensordict._get_str("done", default=None)
+            if done is not None:
+                any_done = done.any()
+            else:
+                any_done = False
+        else:
+            any_done = _terminated_or_truncated(
+                tensordict,
+                full_done_spec=env.output_spec["full_done_spec"],
+                key="_reset",
+            )
+            done = tensordict.pop("_reset")
+        return any_done, done
+
+    def _reset(
+        self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
+    ) -> TensorDictBase:
+        cond = self.condition(tensordict_reset)
+        # TODO: move to validate
+        if not isinstance(cond, (bool, torch.Tensor)):
+            raise RuntimeError(
+                "Calling the condition function should return a boolean or a tensor."
+            )
+        elif isinstance(cond, (torch.Tensor,)):
+            if tuple(cond.shape) not in ((1,), (), tuple(tensordict.shape)):
+                raise RuntimeError(
+                    "Tensor outputs must have the shape of the tensordict, or contain a single element."
+                )
+        else:
+            cond = torch.tensor(cond, device=tensordict.device)
+
+        if cond.any():
+            reset = tensordict.get("_reset", cond)
+            if reset.shape != cond.shape:
+                reset = reset.view_as(cond)
+            cond = cond & reset
+
+            parent: TransformedEnv = self.parent
+            reset_td_save = None
+            if not cond.all():
+                if parent.base_env.batch_locked:
+                    raise RuntimeError(
+                        "Cannot run partial steps in a batched locked environment. "
+                        "Hint: Parallel and Serial envs can be unlocked through a keyword argument in "
+                        "the constructor."
+                    )
+                reset_td_save = tensordict_reset.copy()
+                tensordict_reset = tensordict_reset[cond]
+                tensordict = tensordict[cond]
+
+            td = self.policy(tensordict_reset)
+            # Mark the partial steps if needed
+            if reset_td_save is not None:
+                td_new = td.new_zeros(cond.shape)
+                # TODO: swap with masked_scatter when avail
+                td_new[cond] = td
+                td = td_new
+                td.set("_step", cond)
+            tensordict_reset = parent._step(td).exclude(*parent.reward_keys)
+            if reset_td_save is not None:
+                return torch.where(cond, tensordict_reset, reset_td_save)
+            return tensordict_reset
+
+        return tensordict_reset
+
+    def forward(self, tensordict: TensorDictBase) -> Any:
+        raise RuntimeError(
+            "ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional."
+        )