@@ -353,7 +353,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
353
353
self ._target_model .train ()
354
354
355
355
current_batch , target_batch , _ = data
356
- obs_batch_ori , action_batch , target_action_batch , mask_batch , indices , weights , make_time , step_index_batch = current_batch
356
+ obs_batch_ori , action_batch , target_action_batch , mask_batch , indices , weights , make_time , timestep_batch = current_batch
357
357
target_reward , target_value , target_policy = target_batch
358
358
359
359
# Prepare observations based on frame stack number
@@ -371,7 +371,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
371
371
# Prepare action batch and convert to torch tensor
372
372
action_batch = torch .from_numpy (action_batch ).to (self ._cfg .device ).unsqueeze (
373
373
- 1 ).long () # For discrete action space
374
- step_index_batch = torch .from_numpy (step_index_batch ).to (self ._cfg .device ).unsqueeze (
374
+ timestep_batch = torch .from_numpy (timestep_batch ).to (self ._cfg .device ).unsqueeze (
375
375
- 1 ).long () # TODO: only for discrete action space
376
376
data_list = [mask_batch , target_reward , target_value , target_policy , weights ]
377
377
mask_batch , target_reward , target_value , target_policy , weights = to_torch_float_tensor (data_list ,
@@ -397,7 +397,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
397
397
self ._cfg .batch_size , - 1 , * self ._cfg .model .observation_shape )
398
398
399
399
batch_for_gpt ['actions' ] = action_batch .squeeze (- 1 )
400
- batch_for_gpt ['step_index ' ] = step_index_batch .squeeze (- 1 )
400
+ batch_for_gpt ['timestep ' ] = timestep_batch .squeeze (- 1 )
401
401
402
402
403
403
batch_for_gpt ['rewards' ] = target_reward_categorical [:, :- 1 ]
@@ -569,7 +569,7 @@ def _forward_collect(
569
569
to_play : List = [- 1 ],
570
570
epsilon : float = 0.25 ,
571
571
ready_env_id : np .ndarray = None ,
572
- step_index : List = [0 ]
572
+ timestep : List = [0 ]
573
573
) -> Dict :
574
574
"""
575
575
Overview:
@@ -581,7 +581,7 @@ def _forward_collect(
581
581
- temperature (:obj:`float`): The temperature of the policy.
582
582
- to_play (:obj:`int`): The player to play.
583
583
- ready_env_id (:obj:`list`): The id of the env that is ready to collect.
584
- - step_index (:obj:`list`): The step index of the env in one episode
584
+ - timestep (:obj:`list`): The step index of the env in one episode
585
585
Shape:
586
586
- data (:obj:`torch.Tensor`):
587
587
- For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \
@@ -591,7 +591,7 @@ def _forward_collect(
591
591
- temperature: :math:`(1, )`.
592
592
- to_play: :math:`(N, 1)`, where N is the number of collect_env.
593
593
- ready_env_id: None
594
- - step_index : :math:`(N, 1)`, where N is the number of collect_env.
594
+ - timestep : :math:`(N, 1)`, where N is the number of collect_env.
595
595
Returns:
596
596
- output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \
597
597
``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``.
@@ -606,7 +606,7 @@ def _forward_collect(
606
606
output = {i : None for i in ready_env_id }
607
607
608
608
with torch .no_grad ():
609
- network_output = self ._collect_model .initial_inference (self .last_batch_obs , self .last_batch_action , data , step_index )
609
+ network_output = self ._collect_model .initial_inference (self .last_batch_obs , self .last_batch_action , data , timestep )
610
610
latent_state_roots , reward_roots , pred_values , policy_logits = mz_network_output_unpack (network_output )
611
611
612
612
pred_values = self .inverse_scalar_transform_handle (pred_values ).detach ().cpu ().numpy ()
@@ -627,7 +627,7 @@ def _forward_collect(
627
627
roots = MCTSPtree .roots (active_collect_env_num , legal_actions )
628
628
629
629
roots .prepare (self ._cfg .root_noise_weight , noises , reward_roots , policy_logits , to_play )
630
- self ._mcts_collect .search (roots , self ._collect_model , latent_state_roots , to_play , step_index )
630
+ self ._mcts_collect .search (roots , self ._collect_model , latent_state_roots , to_play , timestep )
631
631
632
632
# list of list, shape: ``{list: batch_size} -> {list: action_space_size}``
633
633
roots_visit_count_distributions = roots .get_distributions ()
@@ -669,7 +669,7 @@ def _forward_collect(
669
669
'searched_value' : value ,
670
670
'predicted_value' : pred_values [i ],
671
671
'predicted_policy_logits' : policy_logits [i ],
672
- 'step_index ' : step_index [i ]
672
+ 'timestep ' : timestep [i ]
673
673
}
674
674
batch_action .append (action )
675
675
@@ -706,7 +706,7 @@ def _init_eval(self) -> None:
706
706
self .last_batch_action = [- 1 for _ in range (self .evaluator_env_num )]
707
707
708
708
def _forward_eval (self , data : torch .Tensor , action_mask : list , to_play : int = - 1 ,
709
- ready_env_id : np .array = None , step_index : int = 0 ) -> Dict :
709
+ ready_env_id : np .array = None , timestep : int = 0 ) -> Dict :
710
710
"""
711
711
Overview:
712
712
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search.
@@ -734,7 +734,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
734
734
ready_env_id = np .arange (active_eval_env_num )
735
735
output = {i : None for i in ready_env_id }
736
736
with torch .no_grad ():
737
- network_output = self ._eval_model .initial_inference (self .last_batch_obs , self .last_batch_action , data , step_index )
737
+ network_output = self ._eval_model .initial_inference (self .last_batch_obs , self .last_batch_action , data , timestep )
738
738
latent_state_roots , reward_roots , pred_values , policy_logits = mz_network_output_unpack (network_output )
739
739
740
740
# if not in training, obtain the scalars of the value/reward
@@ -750,7 +750,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
750
750
# python mcts_tree
751
751
roots = MCTSPtree .roots (active_eval_env_num , legal_actions )
752
752
roots .prepare_no_noise (reward_roots , policy_logits , to_play )
753
- self ._mcts_eval .search (roots , self ._eval_model , latent_state_roots , to_play , step_index )
753
+ self ._mcts_eval .search (roots , self ._eval_model , latent_state_roots , to_play , timestep )
754
754
755
755
# list of list, shape: ``{list: batch_size} -> {list: action_space_size}``
756
756
roots_visit_count_distributions = roots .get_distributions ()
@@ -780,7 +780,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
780
780
'searched_value' : value ,
781
781
'predicted_value' : pred_values [i ],
782
782
'predicted_policy_logits' : policy_logits [i ],
783
- 'step_index ' : step_index [i ]
783
+ 'timestep ' : timestep [i ]
784
784
}
785
785
batch_action .append (action )
786
786
0 commit comments