diff --git a/scope_rl/dataset/synthetic.py b/scope_rl/dataset/synthetic.py index d4199f2..a0f5f19 100644 --- a/scope_rl/dataset/synthetic.py +++ b/scope_rl/dataset/synthetic.py @@ -691,10 +691,10 @@ def _obtain_steps( actions = np.zeros(n_trajectories * step_per_trajectory, dtype=int) action_probs = np.zeros(n_trajectories * step_per_trajectory, dtype=int) else: - actions = np.zeros(n_trajectories * step_per_trajectory, self.action_dim) - action_probs = np.zeros( + actions = np.zeros((n_trajectories * step_per_trajectory, self.action_dim)) + action_probs = np.zeros(( n_trajectories * step_per_trajectory, self.action_dim - ) + )) rewards = np.zeros(n_trajectories * step_per_trajectory) dones = np.zeros(n_trajectories * step_per_trajectory) @@ -709,7 +709,7 @@ def _obtain_steps( idx, step = 0, 0 done = False state, info_ = self.env.reset() - + next_state = None for i in tqdm( np.arange(n_trajectories), desc="[obtain_trajectories]", @@ -720,7 +720,7 @@ def _obtain_steps( if not obtain_trajectories_from_single_interaction: done = True - for rollout_step in rollout_lengths[i]: + for rollout_step in range(rollout_lengths[i]): if done: state, info_ = self.env.reset() step = 0 @@ -750,7 +750,7 @@ def _obtain_steps( ( action, action_prob, - ) = self.behavior_policy.sample_action_and_output_pscore_online(state) + ) = behavior_policy.sample_action_and_output_pscore_online(state) next_state, reward, done, truncated, info_ = self.env.step(action) states[idx] = state @@ -1268,7 +1268,6 @@ def obtain_steps( path=path, save_relative_path=save_relative_path, ) - for j in tqdm( np.arange(len(behavior_policies)), desc="[obtain_datasets: behavior_policy]",