Skip to content

Commit 9bedca3

Browse files
committed
[Feature] ConditionalPolicySwitch transform
ghstack-source-id: 5fcf89c Pull Request resolved: #2711
1 parent 0ba2317 commit 9bedca3

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

torchrl/envs/transforms/transforms.py

+18
Original file line numberDiff line numberDiff line change
@@ -9974,3 +9974,21 @@ def _apply_transform(self, reward: Tensor) -> TensorDictBase:
99749974
)
99759975

99769976
return (self.weights * reward).sum(dim=-1)
9977+
9978+
9979+
class ConditionalPolicySwitch(Transform):
9980+
def __init__(self, policy: Callable[[TensorDictBase], TensorDictBase], condition: Callable[[TensorDictBase], bool]):
9981+
super().__init__([], [])
9982+
self.__dict__["policy"] = policy
9983+
self.condition = condition
9984+
def _step(
9985+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
9986+
) -> TensorDictBase:
9987+
if self.condition(tensordict):
9988+
parent: TransformedEnv = self.parent
9989+
tensordict = parent.step(tensordict)
9990+
tensordict_ = parent.step_mdp(tensordict)
9991+
tensordict_ = self.policy(tensordict_)
9992+
return parent.step(tensordict_)
9993+
return tensordict
9994+
return

0 commit comments

Comments
 (0)