5151from megatron .bridge .utils .instantiate_utils import InstantiationMode
5252from megatron .bridge .utils .vocab_utils import calculate_padded_vocab_size
5353from megatron .core import parallel_state
54+ from megatron .core .process_groups_config import ProcessGroupCollection
5455from megatron .core .transformer import MegatronModule
5556from megatron .core .transformer .module import Float16Module
5657from megatron .core .transformer .transformer_config import TransformerConfig
@@ -731,6 +732,8 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
731732 pre_wrap_hook .extend ([composed_peft_hook ])
732733
733734 # Model, optimizer, and learning rate.
735+ pg_collection = ProcessGroupCollection .use_mpu_process_groups ()
736+ setattr (megatron_cfg .model , "_pg_collection" , pg_collection )
734737 model = get_model (
735738 megatron_cfg .model ,
736739 megatron_cfg .ddp ,
@@ -739,6 +742,7 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
739742 data_parallel_random_init = megatron_cfg .rng .data_parallel_random_init ,
740743 pre_wrap_hook = pre_wrap_hook ,
741744 mixed_precision_wrapper = mixed_precision_wrapper ,
745+ pg_collection = pg_collection ,
742746 )
743747 if load_optimizer :
744748 optimizer , scheduler = setup_optimizer (
@@ -872,6 +876,7 @@ def setup_reference_model_state(
872876 overlap_param_gather_with_optimizer_step = megatron_cfg .optimizer .overlap_param_gather_with_optimizer_step ,
873877 pre_wrap_hook = megatron_cfg .rng .data_parallel_random_init ,
874878 mixed_precision_wrapper = ref_mixed_precision_wrapper ,
879+ pg_collection = ProcessGroupCollection .use_mpu_process_groups (),
875880 )
876881
877882 print ("Loading the Reference Model" )
@@ -925,6 +930,7 @@ def finalize_megatron_setup(
925930 megatron_cfg .ddp ,
926931 optimizer ,
927932 align_grad_reduce = megatron_cfg .dist .align_grad_reduce ,
933+ pg_collection = ProcessGroupCollection .use_mpu_process_groups (),
928934 )
929935
930936 tokenizer_config = TokenizerConfig (
0 commit comments