Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a0c7d4a
extract BaseMegatronMapper
lostkevin Sep 29, 2025
f305c7f
rename MegatronMapper
lostkevin Sep 29, 2025
63c74ac
make update_mapping only called in '_inner_map_*'
lostkevin Sep 29, 2025
3983fa1
extract Megatron Mapper for VLM
lostkevin Sep 29, 2025
852df21
clean docstring
lostkevin Sep 29, 2025
508d5c7
fix pylint
lostkevin Sep 29, 2025
ff21e81
fix pylint
lostkevin Sep 29, 2025
a622940
remove src_arch
lostkevin Sep 29, 2025
46fe4cf
make some mapping functions be fully configurable
lostkevin Sep 29, 2025
e6f7aa2
init commit
lostkevin Sep 30, 2025
1d4ed89
test qwen2_5_vl
lostkevin Oct 9, 2025
c262631
fix pylint
lostkevin Oct 9, 2025
663dc6b
Merge branch 'dev/rm_policy_model' into dev/add_qwen3_next
lostkevin Oct 9, 2025
cd83710
fix issues when PP > 1
lostkevin Oct 9, 2025
e4c6e1f
Merge remote-tracking branch 'origin/main' into dev/rm_policy_model
lostkevin Oct 9, 2025
632712d
Merge branch 'dev/rm_policy_model' into dev/add_qwen3_next
lostkevin Oct 9, 2025
44801fc
passing a copy to avoid inplace modification on fp32 logits
lostkevin Oct 9, 2025
b56aa20
fix issue
lostkevin Oct 9, 2025
b271a7f
Merge branch 'dev/rm_policy_model' into dev/add_qwen3_next
lostkevin Oct 9, 2025
5e4008d
Merge remote-tracking branch 'origin/main' into dev/rm_policy_model
lostkevin Oct 10, 2025
c088e38
Merge branch 'dev/rm_policy_model' into dev/add_qwen3_next
lostkevin Oct 10, 2025
682a144
add draft version of qwen3-vl
lostkevin Oct 13, 2025
8ab696d
Merge remote-tracking branch 'origin/main' into dev/add_qwen3_next
lostkevin Oct 13, 2025
0fb9e2b
add how to build image for qwen3-next
jerryli1981 Oct 14, 2025
900dcac
fix param_sync
lostkevin Oct 14, 2025
805a016
Add SGLANG PATCH to README
jerryli1981 Oct 14, 2025
da34f05
fix readme and scripts
jerryli1981 Oct 15, 2025
6f6f3a8
fix memory_pool.py overwrite in readme
jerryli1981 Oct 16, 2025
942adc4
Merge branch 'dev/add_qwen3_next' of github.com:lostkevin/ChatLearn i…
jerryli1981 Oct 16, 2025
5509f13
demo
lostkevin Oct 20, 2025
9cd419a
demo update
lostkevin Oct 20, 2025
65f0195
fix wandb logging
jerryli1981 Oct 20, 2025
0bcaeae
demo update
lostkevin Oct 20, 2025
4431113
fix convergence issue
jerryli1981 Oct 21, 2025
bde6361
fix convergence issue
jerryli1981 Oct 21, 2025
8a20d07
update readme
jerryli1981 Oct 21, 2025
d12316f
fix pylint
lostkevin Oct 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion chatlearn/algorithm/grpo_utils/megatron_policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def train_step(self, data_list: List[Dict[str, Any]], **kwargs):
num_zeros_in_grad,
self.stats,
{},
"policy_trainer",
"",
self._metric_list,
)

Expand Down
22 changes: 12 additions & 10 deletions chatlearn/algorithm/grpo_utils/megatron_utils/train_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,32 +113,32 @@ def training_log(
if is_last_rank():

for key in loss_dict:
iter_dict[f"{name}/{key}"] = loss_dict[key]
consumed_train_samples_dict[f"{name}/" + key + " vs samples"] = loss_dict[
iter_dict[f"{key}"] = loss_dict[key]
consumed_train_samples_dict[key + " vs samples"] = loss_dict[
key
]

if grad_norm is not None:
iter_dict[f"{name}/" + "grad_norm"] = grad_norm
consumed_train_samples_dict[f"{name}/" + "grad-norm vs samples"] = grad_norm
iter_dict["grad_norm"] = grad_norm
consumed_train_samples_dict["grad-norm vs samples"] = grad_norm

if more_grad_norm is not None:
for k in more_grad_norm:
iter_dict[f"{name}/{k}" + " grad_norm"] = more_grad_norm[k]
consumed_train_samples_dict[f"{name}/{k}" + " grad-norm vs samples"] = (
iter_dict[f"{k}" + " grad_norm"] = more_grad_norm[k]
consumed_train_samples_dict[f"{k}" + " grad-norm vs samples"] = (
more_grad_norm[k]
)

if params_norm is not None:
iter_dict[f"{name}/" + "params-norm"] = params_norm
consumed_train_samples_dict[f"{name}/" + "params-norm vs samples"] = (
iter_dict["params-norm"] = params_norm
consumed_train_samples_dict["params-norm vs samples"] = (
params_norm
)

elapsed_time = 0
elapsed_time_per_iteration = elapsed_time / total_iterations
if args.log_timers_to_tensorboard:
iter_dict[f"{name}/" + "iteration-time"] = elapsed_time_per_iteration
iter_dict["iteration-time"] = elapsed_time_per_iteration

log_string = " iteration {:8d}/infinity |".format(iteration)
log_string += " consumed samples: {:12d} |".format(args.consumed_train_samples)
Expand Down Expand Up @@ -561,9 +561,11 @@ def forward_step(data_iterator, model, *, is_training: bool=False, is_packing: b
'input_ids': inputs["all_tokens"],
'position_ids': inputs["all_token_position_ids"],
'labels': inputs["labels"] if not is_training else None,
'packed_seq_params': inputs['packed_seq_params'] if is_packing else None
}

if is_packing:
kwargs.update({'packed_seq_params': inputs['packed_seq_params']})

if 'pixel_values' in inputs:
kwargs.update({
'vision_data': inputs["pixel_values"],
Expand Down
13 changes: 13 additions & 0 deletions chatlearn/configs/megatron_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class MegatronModelArchitectureConfig(BaseConfig):
default=1000000,
metadata={"help": "Base to use for rotary positional embeddings"},
)
rotary_percent: float = 1.0
group_query_attention: bool = field(
default=False, metadata={"help": "Use group-query attention."}
)
Expand Down Expand Up @@ -245,6 +246,11 @@ class MegatronModelArchitectureConfig(BaseConfig):
freeze_VP: bool = field(
default=False, metadata={"help": "Freeze vision projection layers"}
)

hybrid_override_pattern: Optional[str] = None
is_hybrid_model: bool = False
apply_layernorm_1p: bool = False

def _post_init_impl(self):
if self.moe_aux_loss_coeff == 0:
self.moe_router_load_balancing_type = 'none'
Expand Down Expand Up @@ -329,6 +335,12 @@ class MegatronConfig(BaseConfig):
}
)

use_expandable_segments: bool = field(
default=False, metadata={"help": "Whether to use expandable_segments in PYTORCH_CUDA_ALLOC_CONF, \
avoid big reseverd memory in ref and policy trainer worker, expandable_segments should be False \
while in parameter sync for efficiency"}
)

def _validate_impl(self):
assert self.num_gpu > 0, "Megatron-Core requires at least one GPU"
assert self.num_gpu % self.num_replica == 0, \
Expand Down Expand Up @@ -443,6 +455,7 @@ class MegatronPolicyTrainerConfig(PolicyTrainerConfig, MegatronConfig):
"help": "Load model for finetuning. Do not load optimizer or rng state from checkpoint and set iteration to 0."
},
)
distributed_timeout_minutes: int = 10

def _validate_impl(self):
assert self.calculate_per_token_loss, "Per-Token-Loss is required for Training."
14 changes: 4 additions & 10 deletions chatlearn/models/megatron_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import re
from dataclasses import fields

import inspect
import torch

try:
Expand Down Expand Up @@ -123,6 +122,8 @@ def model_setup(self):
"""
:meta private:
"""
if self.module_args.use_expandable_segments:
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
super().model_setup()

# TODO: we may need to let setup return model, optimizer and opt_param_scheduler
Expand Down Expand Up @@ -255,17 +256,10 @@ def map_local_param_name_to_global(self):
self.global_name_to_local_name = {}
# NOTE: this regex is for model with TEGroupedGEMM
# SequentialMLP or GroupedMLP is not supported
regex = re.compile(r"(.*)decoder.layers\.(\d+)\.([a-z0-9_.]+)([\._])([a-z]+)([0-9]*)")
regex = re.compile(r"(.*)decoder.layers\.(\d+)\.([a-zA-Z0-9_.]+)([\._])([a-zA-Z]+)([0-9]*)")
for vp_stage, model_chunk in enumerate(self.model):
model_config = unwrap_model(model_chunk).config
if 'vp_stage' in inspect.signature(get_transformer_layer_offset).parameters:
offset = get_transformer_layer_offset(model_config, vp_stage=vp_stage)
else:
if len(self.model) > 1:
mpu.set_virtual_pipeline_model_parallel_rank(vp_stage)
offset = get_transformer_layer_offset(model_config)
if len(self.model) > 1:
mpu.set_virtual_pipeline_model_parallel_rank(None)
offset = get_transformer_layer_offset(model_config, vp_stage=vp_stage)
if model_config.num_moe_experts is not None:
ep_rank = mpu.get_expert_model_parallel_rank()
ep_size = mpu.get_expert_model_parallel_world_size()
Expand Down
12 changes: 12 additions & 0 deletions chatlearn/models/sglang_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,12 @@ def generate(self, query: List[Dict], is_eval: bool) -> List[Dict]:
self.flush_cache()
return outputs

def dump_parameters(self, dump_path_root):
os.makedirs(dump_path_root, exist_ok=True)
self.onload()
self.llm.save_sharded_model(path=dump_path_root, pattern=None, max_size=None)
self.offload()

def update_weights_from_ipc_handles(self, reduce_data):
gathered_data = None
if self.is_engine():
Expand Down Expand Up @@ -725,6 +731,12 @@ async def generate(self, query: List[Dict], is_eval: bool) -> List[Dict]:
)
return outputs

async def dump_parameters(self, dump_path_root):
os.makedirs(dump_path_root, exist_ok=True)
await self.onload()
self.llm.save_sharded_model(path=dump_path_root, pattern=None, max_size=None)
await self.offload()

async def generate_per_request(self, query: Dict, is_eval: bool) -> Dict:
outputs = None
if self.is_engine():
Expand Down
2 changes: 1 addition & 1 deletion chatlearn/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def _resume_from_data_checkpoint(self):
def dump_parameters(self, dump_path):
for _, model in enumerate(self.models):
replic_0 = model.replicas[0]
if isinstance(replic_0, DistVLLMActor):
if isinstance(replic_0, (DistVLLMActor, DistSGLangActor)):
future.wait(replic_0.engine.dump_parameters.remote(dump_path))

def save_checkpoint(self, episode_id):
Expand Down
30 changes: 20 additions & 10 deletions chatlearn/synchronizer/mappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,30 @@
def get_mapper_name(src_model: 'DistModel', dst_model: 'DistModel'):
src_type = src_model.runtime_args.train_backend
dst_type = dst_model.runtime_args.rollout_backend
if src_type == 'megatron' and dst_type == 'vllm':
return "MegatronVLLMMapper"
elif src_type == 'megatron' and dst_type == 'sglang':
return "MegatronSGLangMapper"
else:
raise NotImplementedError(f"Unsupported src/dst model combination: {src_type}-{dst_type}")
model_type = src_model.runtime_args.model_type # llm or vlm

mapping = {
'llm-megatron-vllm': "MegatronVLLMMapper-LLM",
'llm-megatron-sglang': "MegatronSGLangMapper-LLM",
'vlm-megatron-vllm': "MegatronVLLMMapper-VLM",
'vlm-megatron-sglang': "MegatronSGLangMapper-VLM",
}
key = f'{model_type}-{src_type}-{dst_type}'
if key not in mapping:
raise NotImplementedError(f"Unsupported src/dst model combination: {key}")
return mapping[key]


def name_to_mapper_cls(mapper_name: str):
# pylint: disable=import-outside-toplevel
from .mapping_helpers import VLLM_HELPERS, HF_HELPERS
if mapper_name in ["MegatronVLLMMapper", "MegatronSGLangMapper"]:
from .mapper import MegatronMapper
helper_mappings = {"MegatronVLLMMapper": VLLM_HELPERS, "MegatronSGLangMapper": HF_HELPERS}
return partial(MegatronMapper, mapper_config=helper_mappings[mapper_name])
if mapper_name in ["MegatronVLLMMapper-LLM", "MegatronSGLangMapper-LLM"]:
from .megatron_llm_mapper import MegatronLLMMapper
helper_mappings = {"MegatronVLLMMapper-LLM": VLLM_HELPERS, "MegatronSGLangMapper-LLM": HF_HELPERS}
return partial(MegatronLLMMapper, mapper_config=helper_mappings[mapper_name])
elif mapper_name in ["MegatronVLLMMapper-VLM", "MegatronSGLangMapper-VLM"]:
from .megatron_vlm_mapper import MegatronVLMMapper
helper_mappings = {"MegatronVLLMMapper-VLM": VLLM_HELPERS, "MegatronSGLangMapper-VLM": HF_HELPERS}
return partial(MegatronVLMMapper, mapper_config=helper_mappings[mapper_name])
else:
raise ValueError(f"Unrecognized Mapper {mapper_name}")
Loading