@@ -865,6 +865,8 @@ The inverse process is executed with the output tensordict, where the `in_keys`
865865
866866 Rename transform logic
867867
868+ .. note :: During a call to `inv`, the transforms are executed in reversed order (compared to the forward / step mode).
869+
868870Transforming Tensors and Specs
869871^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
870872
@@ -900,6 +902,74 @@ tensor that should not be generated when using :meth:`~torchrl.envs.EnvBase.rand
900902environment. Instead, `"action_discrete" ` should be generated, and its continuous counterpart obtained from the
901903transform. Therefore, the user should see the `"action_discrete" ` entry being exposed, but not `"action" `.
902904
905+ Designing your own Transform
906+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
907+
908+ To create a basic, custom transform, you need to subclass the `Transform ` class and implement the
909+ :meth: `~torchrl.envs._apply_transform ` method. Here's an example of a simple transform that adds 1 to the observation
910+ tensor:
911+
912+ >>> class AddOneToObs (Transform ):
913+ ... """ A transform that adds 1 to the observation tensor."""
914+ ...
915+ ... def __init__ (self ):
916+ ... super ().__init__ (in_keys = [" observation" ], out_keys = [" observation" ])
917+ ...
918+ ... def _apply_transform (self , obs : torch.Tensor) -> torch.Tensor:
919+ ... return obs + 1
920+
921+
922+ Tips for subclassing `Transform `
923+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
924+
925+ There are various ways of subclassing a transform. The things to take into considerations are:
926+
927+ - Is the transform identical for each tensor / item being transformed? Use
928+ :meth: `~torchrl.envs.Transform._apply_transform ` and :meth: `~torchrl.envs.Transform._inv_apply_transform `.
929+ - The transform needs access to the input data to env.step as well as output? Rewrite
930+ :meth: `~torchrl.envs.Transform._step `.
931+ Otherwise, rewrite :meth: `~torchrl.envs.Transform._call ` (or :meth: `~torchrl.envs.Transform._inv_call `).
932+ - Is the transform to be used within a replay buffer? Overwrite :meth: `~torchrl.envs.Transform.forward `,
933+ :meth: `~torchrl.envs.Transform.inv `, :meth: `~torchrl.envs.Transform._apply_transform ` or
934+ :meth: `~torchrl.envs.Transform._inv_apply_transform `.
935+ - Within a transform, you can access (and make calls to) the parent environment using
936+ :attr: `~torchrl.envs.Transform.parent ` (the base env + all transforms till this one) or
937+ :meth: `~torchrl.envs.Transform.container ` (The object that encapsulates the transform).
938+ - Don't forget to edits the specs if needed: top level: :meth: `~torchrl.envs.Transform.transform_output_spec `,
939+ :meth: `~torchrl.envs.Transform.transform_input_spec `.
940+ Leaf level: :meth: `~torchrl.envs.Transform.transform_observation_spec `,
941+ :meth: `~torchrl.envs.Transform.transform_action_spec `, :meth: `~torchrl.envs.Transform.transform_state_spec `,
942+ :meth: `~torchrl.envs.Transform.transform_reward_spec ` and
943+ :meth: `~torchrl.envs.Transform.transform_reward_spec `.
944+
945+ For practical examples, see the methods listed above.
946+
947+ You can use a transform in an environment by passing it to the TransformedEnv constructor:
948+
949+ >>> env = TransformedEnv(GymEnv(" Pendulum-v1" ), AddOneToObs())
950+
951+ You can compose multiple transforms together using the Compose class:
952+
953+ >>> transform = Compose(AddOneToObs(), RewardSum())
954+ >>> env = TransformedEnv(GymEnv(" Pendulum-v1" ), transform)
955+
956+ Inverse Transforms
957+ ^^^^^^^^^^^^^^^^^^
958+
959+ Some transforms have an inverse transform that can be used to undo the transformation. For example, the AddOneToAction
960+ transform has an inverse transform that subtracts 1 from the action tensor:
961+
962+ >>> class AddOneToAction (Transform ):
963+ ... """ A transform that adds 1 to the action tensor."""
964+ ... def __init__ (self ):
965+ ... super ().__init__ (in_keys = [], out_keys = [], in_keys_inv = [" action" ], out_keys_inv = [" action" ])
966+ ... def _inv_apply_transform (self , action : torch.Tensor) -> torch.Tensor:
967+ ... return action + 1
968+
969+ Using a Transform with a Replay Buffer
970+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
971+
972+ You can use a transform with a replay buffer by passing it to the ReplayBuffer constructor:
903973
904974Cloning transforms
905975~~~~~~~~~~~~~~~~~~
@@ -1000,6 +1070,7 @@ to be able to create this other composition:
10001070 TargetReturn
10011071 TensorDictPrimer
10021072 TimeMaxPool
1073+ Timer
10031074 Tokenizer
10041075 ToTensorImage
10051076 TrajCounter
0 commit comments