From 7e4ea205b70dd6ce098f878e4a301409ca1f3a67 Mon Sep 17 00:00:00 2001 From: Chenhe Gu Date: Tue, 3 Feb 2026 10:03:06 +0800 Subject: [PATCH 1/2] add bridge mode support for distributed weight update --- .../update_weight_from_distributed.py | 83 ++++++++++++++++--- 1 file changed, 72 insertions(+), 11 deletions(-) diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index a8e50e0e4..149fbc9bf 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -13,7 +13,7 @@ from slime.utils.distributed_utils import get_gloo_group, init_process_group -from ..megatron_to_hf import convert_to_hf +from ..megatron_to_hf import convert_to_hf, postprocess_hf_param from .common import all_gather_param, named_params_and_buffers @@ -42,6 +42,14 @@ def __init__( self.weight_version = 0 self._model_update_groups = None + self._bridge = None + if getattr(args, "megatron_to_hf_mode", "raw") == "bridge": + from megatron.bridge import AutoBridge + + import slime_plugins.megatron_bridge # noqa: F401 + + self._bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) + def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], rollout_engine_lock: ActorHandle ) -> None: @@ -90,6 +98,28 @@ def update_weights(self) -> None: ) dist.barrier(group=get_gloo_group()) + if getattr(self.args, "megatron_to_hf_mode", "raw") == "bridge": + self._update_weights_with_bridge() + else: + self._update_weights_raw() + + dist.barrier(group=get_gloo_group()) + if dist.get_rank() == 0: + # int4/fp4 post_process + if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: + post_process_weights( + restore_weights_before_load=False, + post_process_quantization=True, + rollout_engines=self.rollout_engines, + ) + ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) + dist.barrier(group=get_gloo_group()) + + def _update_weights_raw(self) -> None: + """ + manual TP gather + convert_to_hf. + Non-expert (TP) → expert (EP) separately. + """ buffer_size = 0 converted_named_tensors = [] # non expert params @@ -119,17 +149,48 @@ def update_weights(self) -> None: if named_tensors: self._update_expert_bucket_weights_from_distributed(named_tensors, pbar=pbar) - dist.barrier(group=get_gloo_group()) - if dist.get_rank() == 0: - # int4/fp4 post_process - if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]: - post_process_weights( - restore_weights_before_load=False, - post_process_quantization=True, - rollout_engines=self.rollout_engines, + def _update_weights_with_bridge(self) -> None: + """ + Bridge mode: let Bridge handle PP/TP/EP gather + conversion. + Only PP source rank (DP=TP=0) broadcasts to rollout engines. + """ + from slime.utils import megatron_bridge_utils + + pbar = tqdm(desc=f"[{self._group_name}] Update weights (bridge)") if self._is_pp_src_rank else None + + buffer_size = 0 + converted_named_tensors = [] + + with megatron_bridge_utils.patch_megatron_model(self.model): + # Iterate through weights - all ranks participate in each iteration + for hf_name, tensor, megatron_name in self._bridge.export_hf_weights( + self.model, + cpu=False, + show_progress=False, + ): + # Only PP source rank accumulates and broadcasts + if not self._is_pp_src_rank: + continue + + tensor = postprocess_hf_param( + args=self.args, + megatron_param_name=megatron_name, + hf_param_name=hf_name, + param=tensor, ) - ray.get([engine.continue_generation.remote() for engine in self.rollout_engines]) - dist.barrier(group=get_gloo_group()) + + tensor_size = tensor.numel() * tensor.element_size() + + if buffer_size + tensor_size > self.args.update_weight_buffer_size and converted_named_tensors: + self._update_bucket_weights_from_distributed(converted_named_tensors, pbar=pbar) + converted_named_tensors = [] + buffer_size = 0 + + converted_named_tensors.append((hf_name, tensor)) + buffer_size += tensor_size + + if self._is_pp_src_rank and converted_named_tensors: + self._update_bucket_weights_from_distributed(converted_named_tensors, pbar=pbar) def _update_weight_from_distributed( self, From 865939fcb224769211847871aa193e3b2b61ed3e Mon Sep 17 00:00:00 2001 From: Chenhe Gu Date: Fri, 6 Feb 2026 16:33:44 +0800 Subject: [PATCH 2/2] update --- slime/backends/megatron_utils/checkpoint.py | 6 +----- slime/backends/megatron_utils/model.py | 5 ++--- slime/backends/megatron_utils/model_provider.py | 4 ++-- .../update_weight/hf_weight_iterator_bridge.py | 11 +++-------- .../update_weight/update_weight_from_distributed.py | 11 ++--------- slime/utils/megatron_bridge_utils.py | 10 ++++++++++ 6 files changed, 20 insertions(+), 27 deletions(-) diff --git a/slime/backends/megatron_utils/checkpoint.py b/slime/backends/megatron_utils/checkpoint.py index 8c7cd5317..12ed10c8d 100644 --- a/slime/backends/megatron_utils/checkpoint.py +++ b/slime/backends/megatron_utils/checkpoint.py @@ -128,14 +128,10 @@ def _is_megatron_checkpoint(path: str | Path) -> bool: def _load_checkpoint_hf(ddp_model, optimizer, args, load_path: str): assert args.megatron_to_hf_mode == "bridge", "Only bridge mode is supported for loading HF checkpoint" - from megatron.bridge import AutoBridge - - import slime_plugins.megatron_bridge # noqa: F401 - logger.info(f"Load checkpoint from HuggingFace model into Megatron (path={load_path})") with megatron_bridge_utils.patch_megatron_model(ddp_model): - bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) + bridge = megatron_bridge_utils.get_bridge(args.hf_checkpoint) bridge.load_hf_weights(ddp_model) # Copied from Megatron-core :: load_checkpoint (with simplifications) diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index fc497046e..b3c8c8116 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -725,15 +725,14 @@ def save_hf_model(args, rollout_id: int, model: Sequence[DDP]) -> None: ) try: - from megatron.bridge import AutoBridge - from slime.utils.megatron_bridge_utils import patch_megatron_model + from slime.utils.megatron_bridge_utils import get_bridge, patch_megatron_model path = Path(args.save_hf.format(rollout_id=rollout_id)) if should_log: logger.info(f"Saving model in HuggingFace format to {path}") - bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) + bridge = get_bridge(args.hf_checkpoint) path.mkdir(parents=True, exist_ok=True) diff --git a/slime/backends/megatron_utils/model_provider.py b/slime/backends/megatron_utils/model_provider.py index 31db8b0da..d54f310ca 100644 --- a/slime/backends/megatron_utils/model_provider.py +++ b/slime/backends/megatron_utils/model_provider.py @@ -81,9 +81,9 @@ def wrapped_model_provider( return wrapped_model_provider if args.megatron_to_hf_mode == "bridge": - from megatron.bridge import AutoBridge + from slime.utils.megatron_bridge_utils import get_bridge - bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) + bridge = get_bridge(args.hf_checkpoint) provider = bridge.to_megatron_provider(load_weights=False) # TODO: we should not manually set this... provider.tensor_model_parallel_size = args.tensor_model_parallel_size diff --git a/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py b/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py index 7c6ac6401..17913685b 100644 --- a/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py +++ b/slime/backends/megatron_utils/update_weight/hf_weight_iterator_bridge.py @@ -12,20 +12,15 @@ class HfWeightIteratorBridge(HfWeightIteratorBase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - from megatron.bridge import AutoBridge - - import slime_plugins.megatron_bridge # noqa: F401 - - self._bridge = AutoBridge.from_hf_pretrained(self.args.hf_checkpoint, trust_remote_code=True) - def get_hf_weight_chunks(self, megatron_local_weights): # TODO support quantization (e.g. modify megatron-bridge to provide megatron param name) + bridge = megatron_bridge_utils.get_bridge(self.args.hf_checkpoint) renamed_megatron_local_weights = {strip_param_name_prefix(k): v for k, v in megatron_local_weights.items()} with megatron_bridge_utils.patch_megatron_model(self.model): - conversion_tasks = self._bridge.get_conversion_tasks(self.model) + conversion_tasks = bridge.get_conversion_tasks(self.model) conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights) - named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks) + named_weights = bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks) named_weights = ( ( diff --git a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py index 149fbc9bf..e00847b7b 100644 --- a/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py +++ b/slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py @@ -42,14 +42,6 @@ def __init__( self.weight_version = 0 self._model_update_groups = None - self._bridge = None - if getattr(args, "megatron_to_hf_mode", "raw") == "bridge": - from megatron.bridge import AutoBridge - - import slime_plugins.megatron_bridge # noqa: F401 - - self._bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True) - def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], rollout_engine_lock: ActorHandle ) -> None: @@ -161,9 +153,10 @@ def _update_weights_with_bridge(self) -> None: buffer_size = 0 converted_named_tensors = [] + bridge = megatron_bridge_utils.get_bridge(self.args.hf_checkpoint) with megatron_bridge_utils.patch_megatron_model(self.model): # Iterate through weights - all ranks participate in each iteration - for hf_name, tensor, megatron_name in self._bridge.export_hf_weights( + for hf_name, tensor, megatron_name in bridge.export_hf_weights( self.model, cpu=False, show_progress=False, diff --git a/slime/utils/megatron_bridge_utils.py b/slime/utils/megatron_bridge_utils.py index 9e5f065cd..533e1588a 100644 --- a/slime/utils/megatron_bridge_utils.py +++ b/slime/utils/megatron_bridge_utils.py @@ -1,4 +1,5 @@ from contextlib import contextmanager +from functools import lru_cache try: from megatron.core.utils import unwrap_model @@ -6,6 +7,15 @@ unwrap_model = None +@lru_cache(maxsize=1) +def get_bridge(hf_checkpoint: str): + """Create or return cached AutoBridge instance. Bridge is stateless (only holds + architecture metadata), so a single cached instance per hf_checkpoint is safe.""" + from megatron.bridge import AutoBridge + + return AutoBridge.from_hf_pretrained(hf_checkpoint, trust_remote_code=True) + + @contextmanager def patch_megatron_model(model): unwrapped_model = unwrap_model(model)[0]