@@ -209,6 +209,7 @@ class Transform(nn.Module):
209209 """
210210
211211 invertible = False
212+ enable_inv_on_reset = False
212213
213214 def __init__ (
214215 self ,
@@ -293,6 +294,13 @@ def _reset(
293294 """Resets a transform if it is stateful."""
294295 return tensordict_reset
295296
297+ def _reset_env_preprocess (self , tensordict : TensorDictBase ) -> TensorDictBase :
298+ """Inverts the input to :meth:`TransformedEnv._reset`, if needed."""
299+ if self .enable_inv_on_reset :
300+ with _set_missing_tolerance (self , True ):
301+ tensordict = self .inv (tensordict )
302+ return tensordict
303+
296304 def init (self , tensordict ) -> None :
297305 pass
298306
@@ -1018,10 +1026,7 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
10181026 tensordict = tensordict .select (
10191027 * self .reset_keys , * self .state_spec .keys (True , True ), strict = False
10201028 )
1021- # Inputs might be transformed, so need to apply inverse transform
1022- # before passing to the env reset function.
1023- with _set_missing_tolerance (self .transform , True ):
1024- tensordict = self .transform .inv (tensordict )
1029+ tensordict = self .transform ._reset_env_preprocess (tensordict )
10251030 tensordict_reset = self .base_env ._reset (tensordict , ** kwargs )
10261031 if tensordict is None :
10271032 # make sure all transforms see a source tensordict
@@ -1369,6 +1374,11 @@ def _reset(
13691374 tensordict_reset = t ._reset (tensordict , tensordict_reset )
13701375 return tensordict_reset
13711376
1377+ def _reset_env_preprocess (self , tensordict : TensorDictBase ) -> TensorDictBase :
1378+ for t in reversed (self .transforms ):
1379+ tensordict = t ._reset_env_preprocess (tensordict )
1380+ return tensordict
1381+
13721382 def init (self , tensordict : TensorDictBase ) -> None :
13731383 for t in self .transforms :
13741384 t .init (tensordict )
@@ -4725,6 +4735,7 @@ class UnaryTransform(Transform):
47254735 [torchrl][INFO] check_env_specs succeeded!
47264736
47274737 """
4738+ enable_inv_on_reset = True
47284739
47294740 def __init__ (
47304741 self ,
0 commit comments