diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py index 6987729ba8f..f79a64c2126 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py @@ -1164,7 +1164,7 @@ def reset_mixed_precision_policy(self, mixed_precision_policy: MixedPrecisionPol self.mp_policy = mp_policy_reset self.param_and_grad_buffer.mp_policy = mp_policy_reset - def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False): + def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False, sync_and_return: bool = False): """ Initiates param sync (all-gather) communication operations for all model parameters. @@ -1177,6 +1177,10 @@ def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bo other settings. force_dispatch (bool, optional): force dispatch regardless of other settings. """ + if sync_and_return: + self.synchronize_param_gather() + return + self._replace_param_with_raw_if_needed() if not force_sync and self.ddp_config.overlap_param_gather: diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 24ec25e5150..a0b9a7cef3d 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -734,6 +734,15 @@ def forward_backward_no_pipelining( if getattr(config, 'fine_grained_activation_offloading', False): off_interface.reset() + # Reset all_gather_pipeline bucket status before next validation iteration + if forward_only: + for model_chunk in [model]: + if ( + model_chunk.ddp_config.overlap_param_gather + and model_chunk.ddp_config.use_megatron_fsdp + ): + model_chunk.start_param_sync(sync_and_return=True) + if config.timers is not None: config.timers('forward-backward').stop() @@ -1954,6 +1963,15 @@ def pp_post_backward(input_tensor_grad, vp_stage=None): if getattr(config, 'fine_grained_activation_offloading', False): off_interface.reset() + # Reset all_gather_pipeline bucket status before next validation iteration + if forward_only: + for model_chunk in model: + if ( + model_chunk.ddp_config.overlap_param_gather + and model_chunk.ddp_config.use_megatron_fsdp + ): + model_chunk.start_param_sync(sync_and_return=True) + # Restore config.grad_sync_func and config.param_sync_func. if forward_only: config.grad_sync_func, config.param_sync_func = grad_sync_func, param_sync_func @@ -2361,6 +2379,14 @@ def enable_grad_sync(): if getattr(config, 'fine_grained_activation_offloading', False): off_interface.reset() + # Reset all_gather_pipeline bucket status before next validation iteration + if forward_only: + for model_chunk in [model]: + if ( + model_chunk.ddp_config.overlap_param_gather + and model_chunk.ddp_config.use_megatron_fsdp + ): + model_chunk.start_param_sync(sync_and_return=True) if config.timers is not None: config.timers('forward-backward').stop()