@@ -768,8 +768,7 @@ def __init__(
768768 self .set_truncated = set_truncated
769769
770770 self ._make_shuttle ()
771- if self ._use_buffers :
772- self ._make_final_rollout ()
771+ self ._maybe_make_final_rollout (make_rollout = self ._use_buffers )
773772 self ._set_truncated_keys ()
774773
775774 if split_trajs is None :
@@ -806,28 +805,30 @@ def _make_shuttle(self):
806805 traj_ids ,
807806 )
808807
809- def _make_final_rollout (self ):
810- with torch .no_grad ():
811- self ._final_rollout = self .env .fake_tensordict ()
812-
813- # If storing device is not None, we use this to cast the storage.
814- # If it is None and the env and policy are on the same device,
815- # the storing device is already the same as those, so we don't need
816- # to consider this use case.
817- # In all other cases, we can't really put a device on the storage,
818- # since at least one data source has a device that is not clear.
819- if self .storing_device :
820- self ._final_rollout = self ._final_rollout .to (
821- self .storing_device , non_blocking = True
822- )
823- else :
824- # erase all devices
825- self ._final_rollout .clear_device_ ()
808+ def _maybe_make_final_rollout (self , make_rollout : bool ):
809+ if make_rollout :
810+ with torch .no_grad ():
811+ self ._final_rollout = self .env .fake_tensordict ()
812+
813+ # If storing device is not None, we use this to cast the storage.
814+ # If it is None and the env and policy are on the same device,
815+ # the storing device is already the same as those, so we don't need
816+ # to consider this use case.
817+ # In all other cases, we can't really put a device on the storage,
818+ # since at least one data source has a device that is not clear.
819+ if self .storing_device :
820+ self ._final_rollout = self ._final_rollout .to (
821+ self .storing_device , non_blocking = True
822+ )
823+ else :
824+ # erase all devices
825+ self ._final_rollout .clear_device_ ()
826826
827827 # If the policy has a valid spec, we use it
828828 self ._policy_output_keys = set ()
829829 if (
830- hasattr (self .policy , "spec" )
830+ make_rollout
831+ and hasattr (self .policy , "spec" )
831832 and self .policy .spec is not None
832833 and all (v is not None for v in self .policy .spec .values (True , True ))
833834 ):
@@ -846,14 +847,20 @@ def _make_final_rollout(self):
846847 if key in self ._final_rollout .keys (True ):
847848 continue
848849 self ._final_rollout .set (key , spec .zero ())
849-
850+ elif (
851+ not make_rollout
852+ and hasattr (self .policy , "out_keys" )
853+ and self .policy .out_keys
854+ ):
855+ self ._policy_output_keys = list (self .policy .out_keys )
850856 else :
851- # otherwise, we perform a small number of steps with the policy to
852- # determine the relevant keys with which to pre-populate _final_rollout.
853- # This is the safest thing to do if the spec has None fields or if there is
854- # no spec at all.
855- # See #505 for additional context.
856- self ._final_rollout .update (self ._shuttle .copy ())
857+ if make_rollout :
858+ # otherwise, we perform a small number of steps with the policy to
859+ # determine the relevant keys with which to pre-populate _final_rollout.
860+ # This is the safest thing to do if the spec has None fields or if there is
861+ # no spec at all.
862+ # See #505 for additional context.
863+ self ._final_rollout .update (self ._shuttle .copy ())
857864 with torch .no_grad ():
858865 policy_input = self ._shuttle .copy ()
859866 if self .policy_device :
@@ -911,33 +918,35 @@ def filter_policy(name, value_output, value_input, value_input_clone):
911918 set (filtered_policy_output .keys (True , True ))
912919 )
913920 )
914- self ._final_rollout .update (
915- policy_output .select (* self ._policy_output_keys )
916- )
921+ if make_rollout :
922+ self ._final_rollout .update (
923+ policy_output .select (* self ._policy_output_keys )
924+ )
917925 del filtered_policy_output , policy_output , policy_input
918926
919927 _env_output_keys = []
920928 for spec in ["full_observation_spec" , "full_done_spec" , "full_reward_spec" ]:
921929 _env_output_keys += list (self .env .output_spec [spec ].keys (True , True ))
922930 self ._env_output_keys = _env_output_keys
923- self ._final_rollout = (
924- self ._final_rollout .unsqueeze (- 1 )
925- .expand (* self .env .batch_size , self .frames_per_batch )
926- .clone ()
927- .zero_ ()
928- )
931+ if make_rollout :
932+ self ._final_rollout = (
933+ self ._final_rollout .unsqueeze (- 1 )
934+ .expand (* self .env .batch_size , self .frames_per_batch )
935+ .clone ()
936+ .zero_ ()
937+ )
929938
930- # in addition to outputs of the policy, we add traj_ids to
931- # _final_rollout which will be collected during rollout
932- self ._final_rollout .set (
933- ("collector" , "traj_ids" ),
934- torch .zeros (
935- * self ._final_rollout .batch_size ,
936- dtype = torch .int64 ,
937- device = self .storing_device ,
938- ),
939- )
940- self ._final_rollout .refine_names (..., "time" )
939+ # in addition to outputs of the policy, we add traj_ids to
940+ # _final_rollout which will be collected during rollout
941+ self ._final_rollout .set (
942+ ("collector" , "traj_ids" ),
943+ torch .zeros (
944+ * self ._final_rollout .batch_size ,
945+ dtype = torch .int64 ,
946+ device = self .storing_device ,
947+ ),
948+ )
949+ self ._final_rollout .refine_names (..., "time" )
941950
942951 def _set_truncated_keys (self ):
943952 self ._truncated_keys = []
0 commit comments