Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions slime/backends/megatron_utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,10 @@ def _is_megatron_checkpoint(path: str | Path) -> bool:

def _load_checkpoint_hf(ddp_model, optimizer, args, load_path: str):
assert args.megatron_to_hf_mode == "bridge", "Only bridge mode is supported for loading HF checkpoint"
from megatron.bridge import AutoBridge

import slime_plugins.megatron_bridge # noqa: F401

logger.info(f"Load checkpoint from HuggingFace model into Megatron (path={load_path})")

with megatron_bridge_utils.patch_megatron_model(ddp_model):
bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True)
bridge = megatron_bridge_utils.get_bridge(args.hf_checkpoint)
bridge.load_hf_weights(ddp_model)

# Copied from Megatron-core :: load_checkpoint (with simplifications)
Expand Down
5 changes: 2 additions & 3 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,15 +725,14 @@ def save_hf_model(args, rollout_id: int, model: Sequence[DDP]) -> None:
)

try:
from megatron.bridge import AutoBridge
from slime.utils.megatron_bridge_utils import patch_megatron_model
from slime.utils.megatron_bridge_utils import get_bridge, patch_megatron_model

path = Path(args.save_hf.format(rollout_id=rollout_id))

if should_log:
logger.info(f"Saving model in HuggingFace format to {path}")

bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True)
bridge = get_bridge(args.hf_checkpoint)

path.mkdir(parents=True, exist_ok=True)

Expand Down
4 changes: 2 additions & 2 deletions slime/backends/megatron_utils/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def wrapped_model_provider(
return wrapped_model_provider

if args.megatron_to_hf_mode == "bridge":
from megatron.bridge import AutoBridge
from slime.utils.megatron_bridge_utils import get_bridge

bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True)
bridge = get_bridge(args.hf_checkpoint)
provider = bridge.to_megatron_provider(load_weights=False)
# TODO: we should not manually set this...
provider.tensor_model_parallel_size = args.tensor_model_parallel_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,15 @@ class HfWeightIteratorBridge(HfWeightIteratorBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

from megatron.bridge import AutoBridge

import slime_plugins.megatron_bridge # noqa: F401

self._bridge = AutoBridge.from_hf_pretrained(self.args.hf_checkpoint, trust_remote_code=True)

def get_hf_weight_chunks(self, megatron_local_weights):
# TODO support quantization (e.g. modify megatron-bridge to provide megatron param name)
bridge = megatron_bridge_utils.get_bridge(self.args.hf_checkpoint)
renamed_megatron_local_weights = {strip_param_name_prefix(k): v for k, v in megatron_local_weights.items()}
with megatron_bridge_utils.patch_megatron_model(self.model):
conversion_tasks = self._bridge.get_conversion_tasks(self.model)
conversion_tasks = bridge.get_conversion_tasks(self.model)
conversion_tasks = _process_conversion_tasks(conversion_tasks, renamed_megatron_local_weights)

named_weights = self._bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks)
named_weights = bridge.export_hf_weights(self.model, cpu=False, conversion_tasks=conversion_tasks)

named_weights = (
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from slime.utils.distributed_utils import get_gloo_group, init_process_group

from ..megatron_to_hf import convert_to_hf
from ..megatron_to_hf import convert_to_hf, postprocess_hf_param
from .common import all_gather_param, named_params_and_buffers


Expand Down Expand Up @@ -90,6 +90,28 @@ def update_weights(self) -> None:
)
dist.barrier(group=get_gloo_group())

if getattr(self.args, "megatron_to_hf_mode", "raw") == "bridge":
self._update_weights_with_bridge()
else:
self._update_weights_raw()

dist.barrier(group=get_gloo_group())
if dist.get_rank() == 0:
# int4/fp4 post_process
if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]:
post_process_weights(
restore_weights_before_load=False,
post_process_quantization=True,
rollout_engines=self.rollout_engines,
)
ray.get([engine.continue_generation.remote() for engine in self.rollout_engines])
dist.barrier(group=get_gloo_group())

def _update_weights_raw(self) -> None:
"""
manual TP gather + convert_to_hf.
Non-expert (TP) → expert (EP) separately.
"""
buffer_size = 0
converted_named_tensors = []
# non expert params
Expand Down Expand Up @@ -119,17 +141,49 @@ def update_weights(self) -> None:
if named_tensors:
self._update_expert_bucket_weights_from_distributed(named_tensors, pbar=pbar)

dist.barrier(group=get_gloo_group())
if dist.get_rank() == 0:
# int4/fp4 post_process
if self.quantization_config and self.quantization_config["quant_method"] in ["compressed-tensors"]:
post_process_weights(
restore_weights_before_load=False,
post_process_quantization=True,
rollout_engines=self.rollout_engines,
def _update_weights_with_bridge(self) -> None:
"""
Bridge mode: let Bridge handle PP/TP/EP gather + conversion.
Only PP source rank (DP=TP=0) broadcasts to rollout engines.
"""
from slime.utils import megatron_bridge_utils

pbar = tqdm(desc=f"[{self._group_name}] Update weights (bridge)") if self._is_pp_src_rank else None

buffer_size = 0
converted_named_tensors = []

bridge = megatron_bridge_utils.get_bridge(self.args.hf_checkpoint)
with megatron_bridge_utils.patch_megatron_model(self.model):
# Iterate through weights - all ranks participate in each iteration
for hf_name, tensor, megatron_name in bridge.export_hf_weights(
self.model,
cpu=False,
show_progress=False,
):
# Only PP source rank accumulates and broadcasts
if not self._is_pp_src_rank:
continue

tensor = postprocess_hf_param(
args=self.args,
megatron_param_name=megatron_name,
hf_param_name=hf_name,
param=tensor,
)
ray.get([engine.continue_generation.remote() for engine in self.rollout_engines])
dist.barrier(group=get_gloo_group())

tensor_size = tensor.numel() * tensor.element_size()

if buffer_size + tensor_size > self.args.update_weight_buffer_size and converted_named_tensors:
self._update_bucket_weights_from_distributed(converted_named_tensors, pbar=pbar)
converted_named_tensors = []
buffer_size = 0

converted_named_tensors.append((hf_name, tensor))
buffer_size += tensor_size

if self._is_pp_src_rank and converted_named_tensors:
self._update_bucket_weights_from_distributed(converted_named_tensors, pbar=pbar)

def _update_weight_from_distributed(
self,
Expand Down
10 changes: 10 additions & 0 deletions slime/utils/megatron_bridge_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
from contextlib import contextmanager
from functools import lru_cache

try:
from megatron.core.utils import unwrap_model
except ImportError:
unwrap_model = None


@lru_cache(maxsize=1)
def get_bridge(hf_checkpoint: str):
"""Create or return cached AutoBridge instance. Bridge is stateless (only holds
architecture metadata), so a single cached instance per hf_checkpoint is safe."""
from megatron.bridge import AutoBridge

return AutoBridge.from_hf_pretrained(hf_checkpoint, trust_remote_code=True)


@contextmanager
def patch_megatron_model(model):
unwrapped_model = unwrap_model(model)[0]
Expand Down