diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index a6f83c06f..e9e55a073 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -308,14 +308,14 @@ def _generate_data_iterator(rollout_data, micro_batch_size, micro_batch_indices= data_iterator = _generate_data_iterator(rollout_data, args.micro_batch_size) else: assert args.max_tokens_per_gpu is not None - # calculate the number of mirobatches for each step - samples = rollout_data["total_lengths"] - assert len(samples) == num_local_samples + # calculate the number of microbatches for each step + seq_lens = rollout_data["total_lengths"] + assert len(seq_lens) == num_local_samples num_microbatches = [] for i in range(num_steps_per_rollout): start, end = i * num_local_gbs, (i + 1) * num_local_gbs num_microbatches.append( - get_minimum_num_micro_batch_size(samples[start:end], args.max_tokens_per_gpu * cp_size) + get_minimum_num_micro_batch_size(seq_lens[start:end], args.max_tokens_per_gpu * cp_size) ) num_microbatches = torch.tensor(num_microbatches, dtype=torch.int, device=torch.cuda.current_device()) @@ -330,14 +330,12 @@ def _generate_data_iterator(rollout_data, micro_batch_size, micro_batch_indices= num_microbatches = num_microbatches.tolist() - # balance the each micro batch - samples = rollout_data["total_lengths"] - # balance the number of mirobatches across steps + # balance the number of microbatches across steps micro_batch_indices = [] for i, num_mbs in enumerate(num_microbatches): start, end = i * num_local_gbs, (i + 1) * num_local_gbs - samples = rollout_data["total_lengths"][start:end] - partitions = get_seqlen_balanced_partitions(samples, num_mbs, equal_size=False) + seq_lens = rollout_data["total_lengths"][start:end] + partitions = get_seqlen_balanced_partitions(seq_lens, num_mbs, equal_size=False) for j in range(num_mbs): for k in range(len(partitions[j])): partitions[j][k] += start