diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 798c4ea00a..0de351b502 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -364,7 +364,7 @@ def train( ( data_iterator, num_microbatches, - micro_batch_size, + mbs, seq_length, padded_seq_length, ) = get_microbatch_iterator(