Skip to content

Commit 3dfc1f2

Browse files
author
puyuan
committed
polish(pu): rename step_index to timestep
1 parent a93c0ac commit 3dfc1f2

File tree

7 files changed

+40
-43
lines changed

7 files changed

+40
-43
lines changed

lzero/mcts/buffer/game_buffer_muzero.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def _prepare_reward_value_context(
286286
game_segment_lens = []
287287
# for board games
288288
action_mask_segment, to_play_segment = [], []
289-
# step_index_segment = []
289+
# timestep_segment = []
290290

291291
root_values = []
292292

lzero/mcts/buffer/game_buffer_unizero.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def sample(
6868
batch_size, self._cfg.reanalyze_ratio
6969
)
7070

71-
# current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, step_index_list]
71+
# current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list]
7272

7373
# target reward, target value
7474
batch_rewards, batch_target_values = self._compute_target_reward_value(
@@ -118,7 +118,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
118118
game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data
119119
batch_size = len(batch_index_list)
120120
obs_list, action_list, mask_list = [], [], []
121-
step_index_list = []
121+
timestep_list = []
122122
bootstrap_action_list = []
123123

124124
# prepare the inputs of a batch
@@ -129,7 +129,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
129129
actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
130130
self._cfg.num_unroll_steps].tolist()
131131

132-
step_index_tmp = game.step_index_segment[pos_in_game_segment:pos_in_game_segment +
132+
timestep_tmp = game.timestep_segment[pos_in_game_segment:pos_in_game_segment +
133133
self._cfg.num_unroll_steps].tolist()
134134
# add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid
135135
# mask_tmp = [1. for i in range(len(actions_tmp))]
@@ -146,9 +146,9 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
146146
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
147147
]
148148
# TODO
149-
step_index_tmp += [
149+
timestep_tmp += [
150150
0
151-
for _ in range(self._cfg.num_unroll_steps - len(step_index_tmp))
151+
for _ in range(self._cfg.num_unroll_steps - len(timestep_tmp))
152152
]
153153

154154
# obtain the current observations sequence
@@ -160,7 +160,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
160160
action_list.append(actions_tmp)
161161

162162
mask_list.append(mask_tmp)
163-
step_index_list.append(step_index_tmp)
163+
timestep_list.append(timestep_tmp)
164164

165165
# NOTE: for unizero
166166
bootstrap_action_tmp = game.action_segment[pos_in_game_segment+self._cfg.td_steps:pos_in_game_segment +
@@ -177,7 +177,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
177177
obs_list = prepare_observation(obs_list, self._cfg.model.model_type)
178178

179179
# formalize the inputs of a batch
180-
current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, step_index_list]
180+
current_batch = [obs_list, action_list, bootstrap_action_list, mask_list, batch_index_list, weights_list, make_time_list, timestep_list]
181181
for i in range(len(current_batch)):
182182
current_batch[i] = np.asarray(current_batch[i])
183183

@@ -345,15 +345,15 @@ def _prepare_policy_reanalyzed_context(
345345
rewards, child_visits, game_segment_lens = [], [], []
346346
# for board games
347347
action_mask_segment, to_play_segment = [], []
348-
step_index_segment = []
348+
timestep_segment = []
349349
for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list):
350350
game_segment_len = len(game_segment)
351351
game_segment_lens.append(game_segment_len)
352352
rewards.append(game_segment.reward_segment)
353353
# for board games
354354
action_mask_segment.append(game_segment.action_mask_segment)
355355
to_play_segment.append(game_segment.to_play_segment)
356-
step_index_segment.append(game_segment.step_index_segment)
356+
timestep_segment.append(game_segment.timestep_segment)
357357

358358
child_visits.append(game_segment.child_visit_segment)
359359
# prepare the corresponding observations
@@ -372,7 +372,7 @@ def _prepare_policy_reanalyzed_context(
372372

373373
policy_re_context = [
374374
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens,
375-
action_mask_segment, to_play_segment, step_index_segment
375+
action_mask_segment, to_play_segment, timestep_segment
376376
]
377377
return policy_re_context
378378

@@ -391,11 +391,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
391391

392392
# for board games
393393
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \
394-
to_play_segment, step_index_segment = policy_re_context # noqa
394+
to_play_segment, timestep_segment = policy_re_context # noqa
395395
transition_batch_size = len(policy_obs_list)
396396
game_segment_batch_size = len(pos_in_game_segment_list)
397397

398-
# TODO: step_index_segment
398+
# TODO: timestep_segment
399399
to_play, action_mask = self._preprocess_to_play_and_action_mask(
400400
game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
401401
)
@@ -505,7 +505,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
505505

506506
return batch_target_policies_re
507507

508-
def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, batch_action, step_index_batch) -> Tuple[
508+
def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, batch_action, timestep_batch) -> Tuple[
509509
Any, Any]:
510510
"""
511511
Overview:
@@ -531,7 +531,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
531531
# =============== NOTE: The key difference with MuZero =================
532532
# calculate the bootstrapped value and target value
533533
# NOTE: batch_obs(value_obs_list) is at t+td_steps, batch_action is at timestep t+td_steps
534-
m_output = model.initial_inference(batch_obs, batch_action, start_pos=step_index_batch) # TODO: step_index
534+
m_output = model.initial_inference(batch_obs, batch_action, start_pos=timestep_batch) # TODO: timestep
535535
# ======================================================================
536536

537537
# if not in training, obtain the scalars of the value/reward

lzero/model/unizero_world_models/transformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues
137137
# 如果使用 RoPE,则对 freqs_cis 进行切片
138138
if self.config.rotary_emb:
139139
# 修复:如果 start_pos 是标量,则将其扩展为当前 batch 大小的相同数值
140-
# *2是由于step_index只是统计了obs,但是序列是obs act
140+
# *2是由于timestep只是统计了obs,但是序列是obs act
141141
if isinstance(start_pos, int) or isinstance(start_pos, float):
142142
start_pos_tensor = torch.full((sequences.shape[0],), int(start_pos), device=sequences.device) * 2
143143
else:

lzero/model/unizero_world_models/world_model.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int,
11281128

11291129
def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None,
11301130
**kwargs: Any) -> LossWithIntermediateLosses:
1131-
start_pos = batch['step_index']
1131+
start_pos = batch['timestep']
11321132
# Encode observations into latent state representations
11331133
obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'])
11341134

@@ -1345,9 +1345,9 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
13451345
first_step_losses[loss_name] = loss_tmp[:, 0][first_step_mask].mean()
13461346

13471347
# Middle step loss
1348-
middle_step_index = seq_len // 2
1349-
middle_step_mask = mask_padding[:, middle_step_index]
1350-
middle_step_losses[loss_name] = loss_tmp[:, middle_step_index][middle_step_mask].mean()
1348+
middle_timestep = seq_len // 2
1349+
middle_step_mask = mask_padding[:, middle_timestep]
1350+
middle_step_losses[loss_name] = loss_tmp[:, middle_timestep][middle_step_mask].mean()
13511351

13521352
# Last step loss
13531353
last_step_mask = mask_padding[:, -1]

lzero/policy/unizero.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
353353
self._target_model.train()
354354

355355
current_batch, target_batch, _ = data
356-
obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, step_index_batch = current_batch
356+
obs_batch_ori, action_batch, target_action_batch, mask_batch, indices, weights, make_time, timestep_batch = current_batch
357357
target_reward, target_value, target_policy = target_batch
358358

359359
# Prepare observations based on frame stack number
@@ -371,7 +371,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
371371
# Prepare action batch and convert to torch tensor
372372
action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(
373373
-1).long() # For discrete action space
374-
step_index_batch = torch.from_numpy(step_index_batch).to(self._cfg.device).unsqueeze(
374+
timestep_batch = torch.from_numpy(timestep_batch).to(self._cfg.device).unsqueeze(
375375
-1).long() # TODO: only for discrete action space
376376
data_list = [mask_batch, target_reward, target_value, target_policy, weights]
377377
mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list,
@@ -397,7 +397,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
397397
self._cfg.batch_size, -1, *self._cfg.model.observation_shape)
398398

399399
batch_for_gpt['actions'] = action_batch.squeeze(-1)
400-
batch_for_gpt['step_index'] = step_index_batch.squeeze(-1)
400+
batch_for_gpt['timestep'] = timestep_batch.squeeze(-1)
401401

402402

403403
batch_for_gpt['rewards'] = target_reward_categorical[:, :-1]
@@ -569,7 +569,7 @@ def _forward_collect(
569569
to_play: List = [-1],
570570
epsilon: float = 0.25,
571571
ready_env_id: np.ndarray = None,
572-
step_index: List = [0]
572+
timestep: List = [0]
573573
) -> Dict:
574574
"""
575575
Overview:
@@ -581,7 +581,7 @@ def _forward_collect(
581581
- temperature (:obj:`float`): The temperature of the policy.
582582
- to_play (:obj:`int`): The player to play.
583583
- ready_env_id (:obj:`list`): The id of the env that is ready to collect.
584-
- step_index (:obj:`list`): The step index of the env in one episode
584+
- timestep (:obj:`list`): The step index of the env in one episode
585585
Shape:
586586
- data (:obj:`torch.Tensor`):
587587
- For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \
@@ -591,7 +591,7 @@ def _forward_collect(
591591
- temperature: :math:`(1, )`.
592592
- to_play: :math:`(N, 1)`, where N is the number of collect_env.
593593
- ready_env_id: None
594-
- step_index: :math:`(N, 1)`, where N is the number of collect_env.
594+
- timestep: :math:`(N, 1)`, where N is the number of collect_env.
595595
Returns:
596596
- output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \
597597
``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``.
@@ -606,7 +606,7 @@ def _forward_collect(
606606
output = {i: None for i in ready_env_id}
607607

608608
with torch.no_grad():
609-
network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, step_index)
609+
network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep)
610610
latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output)
611611

612612
pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy()
@@ -627,7 +627,7 @@ def _forward_collect(
627627
roots = MCTSPtree.roots(active_collect_env_num, legal_actions)
628628

629629
roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play)
630-
self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, step_index)
630+
self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, timestep)
631631

632632
# list of list, shape: ``{list: batch_size} -> {list: action_space_size}``
633633
roots_visit_count_distributions = roots.get_distributions()
@@ -669,7 +669,7 @@ def _forward_collect(
669669
'searched_value': value,
670670
'predicted_value': pred_values[i],
671671
'predicted_policy_logits': policy_logits[i],
672-
'step_index': step_index[i]
672+
'timestep': timestep[i]
673673
}
674674
batch_action.append(action)
675675

@@ -706,7 +706,7 @@ def _init_eval(self) -> None:
706706
self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)]
707707

708708
def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1,
709-
ready_env_id: np.array = None, step_index: int = 0) -> Dict:
709+
ready_env_id: np.array = None, timestep: int = 0) -> Dict:
710710
"""
711711
Overview:
712712
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search.
@@ -734,7 +734,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
734734
ready_env_id = np.arange(active_eval_env_num)
735735
output = {i: None for i in ready_env_id}
736736
with torch.no_grad():
737-
network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, step_index)
737+
network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, timestep)
738738
latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output)
739739

740740
# if not in training, obtain the scalars of the value/reward
@@ -750,7 +750,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
750750
# python mcts_tree
751751
roots = MCTSPtree.roots(active_eval_env_num, legal_actions)
752752
roots.prepare_no_noise(reward_roots, policy_logits, to_play)
753-
self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, step_index)
753+
self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, timestep)
754754

755755
# list of list, shape: ``{list: batch_size} -> {list: action_space_size}``
756756
roots_visit_count_distributions = roots.get_distributions()
@@ -780,7 +780,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
780780
'searched_value': value,
781781
'predicted_value': pred_values[i],
782782
'predicted_policy_logits': policy_logits[i],
783-
'step_index': step_index[i]
783+
'timestep': timestep[i]
784784
}
785785
batch_action.append(action)
786786

zoo/atari/envs/atari_lightzero_env.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(self, cfg: EasyDict) -> None:
9999
self.channel_last = cfg.channel_last
100100
self.clip_rewards = cfg.clip_rewards
101101
self.episode_life = cfg.episode_life
102-
self.step_index = 0
102+
self.timestep = 0
103103

104104
def reset(self) -> dict:
105105
"""
@@ -134,9 +134,7 @@ def reset(self) -> dict:
134134
self.obs = to_ndarray(obs)
135135
self._eval_episode_return = 0.
136136
self.timestep = 0
137-
138137
obs = self.observe()
139-
self.step_index = 0
140138
return obs
141139

142140
def step(self, action: int) -> BaseEnvTimestep:
@@ -155,7 +153,6 @@ def step(self, action: int) -> BaseEnvTimestep:
155153
self.timestep += 1
156154
# print(f'self.timestep: {self.timestep}')
157155
observation = self.observe()
158-
self.step_index += 1
159156
if done:
160157
info['eval_episode_return'] = self._eval_episode_return
161158
return BaseEnvTimestep(observation, self.reward, done, info)
@@ -175,7 +172,7 @@ def observe(self) -> dict:
175172
observation = np.transpose(observation, (2, 0, 1))
176173

177174
action_mask = np.ones(self._action_space.n, 'int8')
178-
return {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'step_index': self.step_index}
175+
return {'observation': observation, 'action_mask': action_mask, 'to_play': -1, 'timestep': self.timestep}
179176

180177
@property
181178
def legal_actions(self):

zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self, cfg: dict = {}) -> None:
6868
)
6969
self._action_space = gym.spaces.Discrete(2)
7070
self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32)
71-
self.step_index = 0
71+
self.timestep = 0
7272

7373

7474
def reset(self) -> Dict[str, np.ndarray]:
@@ -93,9 +93,9 @@ def reset(self) -> Dict[str, np.ndarray]:
9393

9494
# Initialize the action mask and return the observation.
9595
action_mask = np.ones(self.action_space.n, 'int8')
96-
self.step_index = 0
96+
self.timestep = 0
9797

98-
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1, 'step_index': self.step_index}
98+
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1, 'timestep': self.timestep}
9999

100100
# this is to artificially introduce randomness in order to evaluate the performance of
101101
# stochastic_muzero on state input
@@ -143,9 +143,9 @@ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
143143
self.save_gif_replay()
144144

145145
action_mask = np.ones(self.action_space.n, 'int8')
146-
self.step_index += 1
146+
self.timestep += 1
147147

148-
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1, 'step_index': self.step_index}
148+
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1, 'timestep': self.timestep}
149149

150150
# this is to artificially introduce randomness in order to evaluate the performance of
151151
# stochastic_muzero on state input

0 commit comments

Comments
 (0)