@@ -360,7 +360,7 @@ def collect(self,
360
360
361
361
action_mask_dict = {i : to_ndarray (init_obs [i ]['action_mask' ]) for i in range (env_nums )}
362
362
to_play_dict = {i : to_ndarray (init_obs [i ]['to_play' ]) for i in range (env_nums )}
363
- step_index_dict = {i : to_ndarray (init_obs [i ]['step_index ' ]) for i in range (env_nums )}
363
+ timestep_dict = {i : to_ndarray (init_obs [i ]['timestep ' ]) for i in range (env_nums )}
364
364
if self .policy_config .use_ture_chance_label_in_chance_encoder :
365
365
chance_dict = {i : to_ndarray (init_obs [i ]['chance' ]) for i in range (env_nums )}
366
366
@@ -421,11 +421,11 @@ def collect(self,
421
421
422
422
action_mask_dict = {env_id : action_mask_dict [env_id ] for env_id in ready_env_id }
423
423
to_play_dict = {env_id : to_play_dict [env_id ] for env_id in ready_env_id }
424
- step_index_dict = {env_id : step_index_dict [env_id ] for env_id in ready_env_id }
424
+ timestep_dict = {env_id : timestep_dict [env_id ] for env_id in ready_env_id }
425
425
426
426
action_mask = [action_mask_dict [env_id ] for env_id in ready_env_id ]
427
427
to_play = [to_play_dict [env_id ] for env_id in ready_env_id ]
428
- step_index = [step_index_dict [env_id ] for env_id in ready_env_id ]
428
+ timestep = [timestep_dict [env_id ] for env_id in ready_env_id ]
429
429
430
430
if self .policy_config .use_ture_chance_label_in_chance_encoder :
431
431
chance_dict = {env_id : chance_dict [env_id ] for env_id in ready_env_id }
@@ -439,13 +439,13 @@ def collect(self,
439
439
# Key policy forward step
440
440
# ==============================================================
441
441
# print(f'ready_env_id:{ready_env_id}')
442
- policy_output = self ._policy .forward (stack_obs , action_mask , temperature , to_play , epsilon , ready_env_id = ready_env_id , step_index = step_index )
442
+ policy_output = self ._policy .forward (stack_obs , action_mask , temperature , to_play , epsilon , ready_env_id = ready_env_id , timestep = timestep )
443
443
444
444
# Extract relevant policy outputs
445
445
actions_with_env_id = {k : v ['action' ] for k , v in policy_output .items ()}
446
446
value_dict_with_env_id = {k : v ['searched_value' ] for k , v in policy_output .items ()}
447
447
pred_value_dict_with_env_id = {k : v ['predicted_value' ] for k , v in policy_output .items ()}
448
- step_index_dict_with_env_id = {k : v ['step_index ' ] for k , v in policy_output .items ()}
448
+ timestep_dict_with_env_id = {k : v ['timestep ' ] for k , v in policy_output .items ()}
449
449
450
450
if self .policy_config .sampled_algo :
451
451
root_sampled_actions_dict_with_env_id = {
@@ -467,7 +467,7 @@ def collect(self,
467
467
actions = {}
468
468
value_dict = {}
469
469
pred_value_dict = {}
470
- step_index_dict = {}
470
+ timestep_dict = {}
471
471
472
472
if not collect_with_pure_policy :
473
473
distributions_dict = {}
@@ -485,7 +485,7 @@ def collect(self,
485
485
actions [env_id ] = actions_with_env_id .pop (env_id )
486
486
value_dict [env_id ] = value_dict_with_env_id .pop (env_id )
487
487
pred_value_dict [env_id ] = pred_value_dict_with_env_id .pop (env_id )
488
- step_index_dict [env_id ] = step_index_dict_with_env_id .pop (env_id )
488
+ timestep_dict [env_id ] = timestep_dict_with_env_id .pop (env_id )
489
489
490
490
if not collect_with_pure_policy :
491
491
distributions_dict [env_id ] = distributions_dict_with_env_id .pop (env_id )
@@ -536,19 +536,19 @@ def collect(self,
536
536
if self .policy_config .use_ture_chance_label_in_chance_encoder :
537
537
game_segments [env_id ].append (
538
538
actions [env_id ], to_ndarray (obs ['observation' ]), reward , action_mask_dict [env_id ],
539
- to_play_dict [env_id ], chance_dict [env_id ], step_index_dict [env_id ]
539
+ to_play_dict [env_id ], chance_dict [env_id ], timestep_dict [env_id ]
540
540
)
541
541
else :
542
542
game_segments [env_id ].append (
543
543
actions [env_id ], to_ndarray (obs ['observation' ]), reward , action_mask_dict [env_id ],
544
- to_play_dict [env_id ], step_index_dict [env_id ]
544
+ to_play_dict [env_id ], timestep_dict [env_id ]
545
545
)
546
546
547
547
# NOTE: the position of code snippet is very important.
548
548
# the obs['action_mask'] and obs['to_play'] are corresponding to the next action
549
549
action_mask_dict [env_id ] = to_ndarray (obs ['action_mask' ])
550
550
to_play_dict [env_id ] = to_ndarray (obs ['to_play' ])
551
- step_index_dict [env_id ] = to_ndarray (obs ['step_index ' ])
551
+ timestep_dict [env_id ] = to_ndarray (obs ['timestep ' ])
552
552
if self .policy_config .use_ture_chance_label_in_chance_encoder :
553
553
chance_dict [env_id ] = to_ndarray (obs ['chance' ])
554
554
@@ -679,7 +679,7 @@ def collect(self,
679
679
680
680
action_mask_dict [env_id ] = to_ndarray (init_obs [env_id ]['action_mask' ])
681
681
to_play_dict [env_id ] = to_ndarray (init_obs [env_id ]['to_play' ])
682
- step_index_dict [env_id ] = to_ndarray (init_obs [env_id ]['step_index ' ])
682
+ timestep_dict [env_id ] = to_ndarray (init_obs [env_id ]['timestep ' ])
683
683
if self .policy_config .use_ture_chance_label_in_chance_encoder :
684
684
chance_dict [env_id ] = to_ndarray (init_obs [env_id ]['chance' ])
685
685
0 commit comments