Property inheritence for a custom environment when wrapped as TransformedEnv #2522
-
I'm not totally sure if this is a desired behaviour or bug, so for now, I wanted to ask the question... I'm writing a custom environment, let's say MyEnv, and I want to use the EnvBase.rollout() method but possibly change a few options like "_simple_done_value" or "_step_mdp_value" etc. Lastly, MyEnv is wrapped into a TransformedEnv(MyEnv) with a couple of transforms. If I want to override any method in EnvBase it's fine until it's turned into a TransformedEnv, i.e., you can't override any EnvBase attributes or methods for example, a custom rollout method, or a new _step_mdp property. Is this a desired behaviour or a bug? To ReproduceSuper simple example: import torch
from torchrl.envs import TransformedEnv
from torchrl.envs.common import EnvBase
class MyEnv(EnvBase):
def __init__(self):
super().__init__()
def _reset(self):
print("Resetting")
def _set_seed(self, seed: int):
print(f"Seeding {seed}")
def _step(self):
print("Stepping")
def rollout(self, max_steps): # Override rollout here
print(f"Rolling for {max_steps}")
env = MyEnv()
env.rollout(5) # -> Rolling for 5
tenv = TransformedEnv(env)
tenv.rollout(5) # -> TypeError: MyEnv._reset() takes 1 positional argument but 2 were given () Where the second error is due to running EnvBase.rollout() i.e., MyEnv.rollout() doesn't override EnvBase.rollout(). Why this happens?This is because when you instantiate a TransformedEnv it inherits from EnvBase and not the MyEnv, so any overridden methods/attributes are lost. super().__init__(device=None, allow_done_after_reset=None, **kwargs) DiscussionThis behaviour seems strange to me, is it meant to be like this? Even if I try to set _step_mdp directly on the TransformedEnv, I can't as it has no setter. I would expect to neatly describe all aspects of the environment in a custom MyEnv class, override EnvBase attributes/methods if necessary (or even provide custom arguments to e.g., _step_mdp), and have a TransformedEnv inherit these. If it is the desired behaviour, how do I set properties like "_step_mdp_value" without them being reset to None every step of the TransformedEnv? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hey thanks for posting this. This are indeed sort of "meant to be like this". You've probably realized by now that To answer specifically to your queries:
Note that this kind of things isn't just about torchrl: in pytorch, if you overwrite |
Beta Was this translation helpful? Give feedback.
Hey thanks for posting this.
I see what the issue is but it's going to be hard to fix consistently, maybe the best would be to provide some doc regarding what can and can't be done and how to do fancy stuff without breaking the class.
This are indeed sort of "meant to be like this". You've probably realized by now that
TransformedEnv
is rather complex, because it needs to do a lot of things duringstep
: call the inverse transform, execute thebase_env._step
, make some quality checks, execute the forward transforms, aggregate done states etc.During reset, for some transforms we need to know what the input tensordict was (remember that
_reset
signature takes a tensordict as input too), so …