diff --git a/chatlearn/algorithm/grpo_utils/megatron_policy_trainer.py b/chatlearn/algorithm/grpo_utils/megatron_policy_trainer.py index 2defaca7..7f72abff 100644 --- a/chatlearn/algorithm/grpo_utils/megatron_policy_trainer.py +++ b/chatlearn/algorithm/grpo_utils/megatron_policy_trainer.py @@ -249,7 +249,7 @@ def train_step(self, data_list: List[Dict[str, Any]], **kwargs): num_zeros_in_grad, self.stats, {}, - "policy_trainer", + "", self._metric_list, ) diff --git a/chatlearn/algorithm/grpo_utils/megatron_utils/train_helper.py b/chatlearn/algorithm/grpo_utils/megatron_utils/train_helper.py index f2a33677..d99ad5ad 100644 --- a/chatlearn/algorithm/grpo_utils/megatron_utils/train_helper.py +++ b/chatlearn/algorithm/grpo_utils/megatron_utils/train_helper.py @@ -113,32 +113,32 @@ def training_log( if is_last_rank(): for key in loss_dict: - iter_dict[f"{name}/{key}"] = loss_dict[key] - consumed_train_samples_dict[f"{name}/" + key + " vs samples"] = loss_dict[ + iter_dict[f"{key}"] = loss_dict[key] + consumed_train_samples_dict[key + " vs samples"] = loss_dict[ key ] if grad_norm is not None: - iter_dict[f"{name}/" + "grad_norm"] = grad_norm - consumed_train_samples_dict[f"{name}/" + "grad-norm vs samples"] = grad_norm + iter_dict["grad_norm"] = grad_norm + consumed_train_samples_dict["grad-norm vs samples"] = grad_norm if more_grad_norm is not None: for k in more_grad_norm: - iter_dict[f"{name}/{k}" + " grad_norm"] = more_grad_norm[k] - consumed_train_samples_dict[f"{name}/{k}" + " grad-norm vs samples"] = ( + iter_dict[f"{k}" + " grad_norm"] = more_grad_norm[k] + consumed_train_samples_dict[f"{k}" + " grad-norm vs samples"] = ( more_grad_norm[k] ) if params_norm is not None: - iter_dict[f"{name}/" + "params-norm"] = params_norm - consumed_train_samples_dict[f"{name}/" + "params-norm vs samples"] = ( + iter_dict["params-norm"] = params_norm + consumed_train_samples_dict["params-norm vs samples"] = ( params_norm ) elapsed_time = 0 elapsed_time_per_iteration = elapsed_time / total_iterations if args.log_timers_to_tensorboard: - iter_dict[f"{name}/" + "iteration-time"] = elapsed_time_per_iteration + iter_dict["iteration-time"] = elapsed_time_per_iteration log_string = " iteration {:8d}/infinity |".format(iteration) log_string += " consumed samples: {:12d} |".format(args.consumed_train_samples) @@ -561,9 +561,11 @@ def forward_step(data_iterator, model, *, is_training: bool=False, is_packing: b 'input_ids': inputs["all_tokens"], 'position_ids': inputs["all_token_position_ids"], 'labels': inputs["labels"] if not is_training else None, - 'packed_seq_params': inputs['packed_seq_params'] if is_packing else None } + if is_packing: + kwargs.update({'packed_seq_params': inputs['packed_seq_params']}) + if 'pixel_values' in inputs: kwargs.update({ 'vision_data': inputs["pixel_values"], diff --git a/chatlearn/configs/megatron_config.py b/chatlearn/configs/megatron_config.py index ff40b33b..11c57aac 100644 --- a/chatlearn/configs/megatron_config.py +++ b/chatlearn/configs/megatron_config.py @@ -70,6 +70,7 @@ class MegatronModelArchitectureConfig(BaseConfig): default=1000000, metadata={"help": "Base to use for rotary positional embeddings"}, ) + rotary_percent: float = 1.0 group_query_attention: bool = field( default=False, metadata={"help": "Use group-query attention."} ) @@ -245,6 +246,11 @@ class MegatronModelArchitectureConfig(BaseConfig): freeze_VP: bool = field( default=False, metadata={"help": "Freeze vision projection layers"} ) + + hybrid_override_pattern: Optional[str] = None + is_hybrid_model: bool = False + apply_layernorm_1p: bool = False + def _post_init_impl(self): if self.moe_aux_loss_coeff == 0: self.moe_router_load_balancing_type = 'none' @@ -329,6 +335,12 @@ class MegatronConfig(BaseConfig): } ) + use_expandable_segments: bool = field( + default=False, metadata={"help": "Whether to use expandable_segments in PYTORCH_CUDA_ALLOC_CONF, \ + avoid big reseverd memory in ref and policy trainer worker, expandable_segments should be False \ + while in parameter sync for efficiency"} + ) + def _validate_impl(self): assert self.num_gpu > 0, "Megatron-Core requires at least one GPU" assert self.num_gpu % self.num_replica == 0, \ @@ -443,6 +455,7 @@ class MegatronPolicyTrainerConfig(PolicyTrainerConfig, MegatronConfig): "help": "Load model for finetuning. Do not load optimizer or rng state from checkpoint and set iteration to 0." }, ) + distributed_timeout_minutes: int = 10 def _validate_impl(self): assert self.calculate_per_token_loss, "Per-Token-Loss is required for Training." diff --git a/chatlearn/models/megatron_module.py b/chatlearn/models/megatron_module.py index e8ba6d40..2854ebdb 100644 --- a/chatlearn/models/megatron_module.py +++ b/chatlearn/models/megatron_module.py @@ -16,7 +16,6 @@ import re from dataclasses import fields -import inspect import torch try: @@ -123,6 +122,8 @@ def model_setup(self): """ :meta private: """ + if self.module_args.use_expandable_segments: + torch.cuda.memory._set_allocator_settings("expandable_segments:True") super().model_setup() # TODO: we may need to let setup return model, optimizer and opt_param_scheduler @@ -255,17 +256,10 @@ def map_local_param_name_to_global(self): self.global_name_to_local_name = {} # NOTE: this regex is for model with TEGroupedGEMM # SequentialMLP or GroupedMLP is not supported - regex = re.compile(r"(.*)decoder.layers\.(\d+)\.([a-z0-9_.]+)([\._])([a-z]+)([0-9]*)") + regex = re.compile(r"(.*)decoder.layers\.(\d+)\.([a-zA-Z0-9_.]+)([\._])([a-zA-Z]+)([0-9]*)") for vp_stage, model_chunk in enumerate(self.model): model_config = unwrap_model(model_chunk).config - if 'vp_stage' in inspect.signature(get_transformer_layer_offset).parameters: - offset = get_transformer_layer_offset(model_config, vp_stage=vp_stage) - else: - if len(self.model) > 1: - mpu.set_virtual_pipeline_model_parallel_rank(vp_stage) - offset = get_transformer_layer_offset(model_config) - if len(self.model) > 1: - mpu.set_virtual_pipeline_model_parallel_rank(None) + offset = get_transformer_layer_offset(model_config, vp_stage=vp_stage) if model_config.num_moe_experts is not None: ep_rank = mpu.get_expert_model_parallel_rank() ep_size = mpu.get_expert_model_parallel_world_size() diff --git a/chatlearn/models/sglang_module.py b/chatlearn/models/sglang_module.py index dfb91970..37ef4895 100644 --- a/chatlearn/models/sglang_module.py +++ b/chatlearn/models/sglang_module.py @@ -412,6 +412,12 @@ def generate(self, query: List[Dict], is_eval: bool) -> List[Dict]: self.flush_cache() return outputs + def dump_parameters(self, dump_path_root): + os.makedirs(dump_path_root, exist_ok=True) + self.onload() + self.llm.save_sharded_model(path=dump_path_root, pattern=None, max_size=None) + self.offload() + def update_weights_from_ipc_handles(self, reduce_data): gathered_data = None if self.is_engine(): @@ -725,6 +731,12 @@ async def generate(self, query: List[Dict], is_eval: bool) -> List[Dict]: ) return outputs + async def dump_parameters(self, dump_path_root): + os.makedirs(dump_path_root, exist_ok=True) + await self.onload() + self.llm.save_sharded_model(path=dump_path_root, pattern=None, max_size=None) + await self.offload() + async def generate_per_request(self, query: Dict, is_eval: bool) -> Dict: outputs = None if self.is_engine(): diff --git a/chatlearn/runtime/engine.py b/chatlearn/runtime/engine.py index 59100992..87d3968e 100644 --- a/chatlearn/runtime/engine.py +++ b/chatlearn/runtime/engine.py @@ -556,7 +556,7 @@ def _resume_from_data_checkpoint(self): def dump_parameters(self, dump_path): for _, model in enumerate(self.models): replic_0 = model.replicas[0] - if isinstance(replic_0, DistVLLMActor): + if isinstance(replic_0, (DistVLLMActor, DistSGLangActor)): future.wait(replic_0.engine.dump_parameters.remote(dump_path)) def save_checkpoint(self, episode_id): diff --git a/chatlearn/synchronizer/mappers/__init__.py b/chatlearn/synchronizer/mappers/__init__.py index d041c331..ff438970 100644 --- a/chatlearn/synchronizer/mappers/__init__.py +++ b/chatlearn/synchronizer/mappers/__init__.py @@ -22,20 +22,30 @@ def get_mapper_name(src_model: 'DistModel', dst_model: 'DistModel'): src_type = src_model.runtime_args.train_backend dst_type = dst_model.runtime_args.rollout_backend - if src_type == 'megatron' and dst_type == 'vllm': - return "MegatronVLLMMapper" - elif src_type == 'megatron' and dst_type == 'sglang': - return "MegatronSGLangMapper" - else: - raise NotImplementedError(f"Unsupported src/dst model combination: {src_type}-{dst_type}") + model_type = src_model.runtime_args.model_type # llm or vlm + + mapping = { + 'llm-megatron-vllm': "MegatronVLLMMapper-LLM", + 'llm-megatron-sglang': "MegatronSGLangMapper-LLM", + 'vlm-megatron-vllm': "MegatronVLLMMapper-VLM", + 'vlm-megatron-sglang': "MegatronSGLangMapper-VLM", + } + key = f'{model_type}-{src_type}-{dst_type}' + if key not in mapping: + raise NotImplementedError(f"Unsupported src/dst model combination: {key}") + return mapping[key] def name_to_mapper_cls(mapper_name: str): # pylint: disable=import-outside-toplevel from .mapping_helpers import VLLM_HELPERS, HF_HELPERS - if mapper_name in ["MegatronVLLMMapper", "MegatronSGLangMapper"]: - from .mapper import MegatronMapper - helper_mappings = {"MegatronVLLMMapper": VLLM_HELPERS, "MegatronSGLangMapper": HF_HELPERS} - return partial(MegatronMapper, mapper_config=helper_mappings[mapper_name]) + if mapper_name in ["MegatronVLLMMapper-LLM", "MegatronSGLangMapper-LLM"]: + from .megatron_llm_mapper import MegatronLLMMapper + helper_mappings = {"MegatronVLLMMapper-LLM": VLLM_HELPERS, "MegatronSGLangMapper-LLM": HF_HELPERS} + return partial(MegatronLLMMapper, mapper_config=helper_mappings[mapper_name]) + elif mapper_name in ["MegatronVLLMMapper-VLM", "MegatronSGLangMapper-VLM"]: + from .megatron_vlm_mapper import MegatronVLMMapper + helper_mappings = {"MegatronVLLMMapper-VLM": VLLM_HELPERS, "MegatronSGLangMapper-VLM": HF_HELPERS} + return partial(MegatronVLMMapper, mapper_config=helper_mappings[mapper_name]) else: raise ValueError(f"Unrecognized Mapper {mapper_name}") diff --git a/chatlearn/synchronizer/mappers/base_megatron_mapper.py b/chatlearn/synchronizer/mappers/base_megatron_mapper.py new file mode 100644 index 00000000..40d00a45 --- /dev/null +++ b/chatlearn/synchronizer/mappers/base_megatron_mapper.py @@ -0,0 +1,278 @@ +# Copyright 2025 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Basic Mapper for Megatron to rollout framework""" +from collections import defaultdict +from typing import List, Dict, TYPE_CHECKING, Union, Tuple + +from megatron.training.utils import unwrap_model + +from chatlearn.configs import PolicyConfig +from chatlearn.configs.megatron_config import MegatronPolicyTrainerConfig +from chatlearn.utils.mappings import ShardedTensorInfo + +from .mapping_helpers import ( + process_normal_tensor, + process_gate_up_tensor, + process_qkv_tensor, + process_merged_linear_tensor, + process_linear_attn_tensor, + VLLM_HELPERS, + HF_HELPERS +) + +if TYPE_CHECKING: + from megatron.core.transformer.module import MegatronModule as MCoreModule + from chatlearn.models.megatron_module import MegatronModule + +class BaseMegatronMapper: + """BaseMegatronMapper""" + def __init__( + self, + dst_model_config: PolicyConfig, + model: 'MegatronModule', + *, + mapper_config: Union[VLLM_HELPERS, HF_HELPERS] = VLLM_HELPERS, + ): + """The Base Mapper for Megatron sync. In each remote Megatron Actor, + the method of this class is called to generate the parameter mapping + between src and dst. + + Args: + dst_model_config (PolicyConfig): The config of target model to + be sychronized + model (MegatronModule): The source Megatron Module + mapper_config (Union[VLLM_HELPERS, HF_HELPERS]): The mapping mode. + """ + self.model: List['MCoreModule'] = unwrap_model(model.model) + self._src_model_config: MegatronPolicyTrainerConfig = model.module_args + self._dst_model_config = dst_model_config + self._mapper_config = mapper_config + self._dst_tp_size = 1 if mapper_config.force_full_model else self._dst_model_config.tensor_model_parallel_size + self._src_name_to_metadata: Dict[str, ShardedTensorInfo] = model.get_parameter_metadata(key_type='local_name') + self._dst_name_to_metadata: Dict[str, ShardedTensorInfo] = None + self._mapping = None + + def generate_sync_mapping( + self, + dst_name_to_metadata: Dict[str, ShardedTensorInfo] + ) -> Dict[ShardedTensorInfo, List[ShardedTensorInfo]]: + """ Generate the synchronization mapping of this local rank. + + Args: + dst_name_to_metadata (Dict[str, ShardedTensorInfo]): mapping a global + parameter name to the corresponding ShardedTensorInfo. + + Returns: + Dict[ShardedTensorInfo, List[ShardedTensorInfo]]: The return + dict is the plan including all local parameters to be synchronized. The + mapper will ensure that the key of mapping for each mapping type is + non-overlapping and can merge into the full state dict of this rank. + For most cases, the length of dst shards list is 1, except for GQA with + large TP. + """ + self._dst_name_to_metadata = dst_name_to_metadata + return self._map_model() + + def dump_sync_mapping(self, folder_path: str, sync_mapping: Dict): + """dump the generayed sync mapping to the given folder path in JSON format. + Currently do nothing. + + Args: + folder_path (str): The folder path to dump the sync mapping. + sync_mapping (Dict): The sync mapping to be saved. + """ + + def _map_model(self): + """Mapping the local name of src model to global name of dst model + """ + raise NotImplementedError() + + # NOTE: the following function implements the tensor-wise sync mapping + def _inner_map_for_tensor_parallel( + self, + src_key: str, + dst_key: str, + *, + global_expert_id: int=None, + num_experts: int=None, + mapping_type: str='column' + ): + AXES = {'column': 0, 'row': 1} + src_info = self._src_name_to_metadata[src_key] + # NOTE: we should do nothing to bias of RowParallel, call full shape mapping. + if src_info.ndim == 1 and mapping_type == 'row': + return self._inner_map_for_full_shape(src_key, dst_key) + + dst_info = self._dst_name_to_metadata[dst_key] + mapping = {} + for src_meta, dst_meta in process_normal_tensor( + src_info, + self._dst_tp_size, + axis=AXES[mapping_type] + ): + src_meta.param_id, dst_meta.param_id = src_info.param_id, dst_info.param_id + src_meta.dtype, dst_meta.dtype = src_info.dtype, dst_info.dtype + if global_expert_id is not None: + dst_meta = ( + dst_meta + .unsqueeze(offset=global_expert_id, length=num_experts, axis=0) + .refragment(1, axis=0) # 1 is dst EP + ) + mapping[src_meta] = [dst_meta] + self._update_mapping(mapping) + return mapping + + def _inner_map_for_full_shape( + self, + src_key: str, + dst_key: str + ): + src_info = self._src_name_to_metadata[src_key] + dst_info = self._dst_name_to_metadata[dst_key] + results = {src_info.copy(): [dst_info.copy()]} + self._update_mapping(results) + return results + + def _inner_map_for_gate_up_proj(self, src_key: str, dst_key: str, proj_type: str, *, global_expert_id: int=None, num_experts: int=None): + src_info = self._src_name_to_metadata[src_key] + dst_info = self._dst_name_to_metadata[dst_key] + mapping = {} + for src_meta, dst_meta in process_gate_up_tensor( + src_info, + self._dst_tp_size, + proj_type=proj_type + ): + src_meta.param_id, dst_meta.param_id = src_info.param_id, dst_info.param_id + src_meta.dtype, dst_meta.dtype = src_info.dtype, dst_info.dtype + if global_expert_id is not None: + dst_meta = ( + dst_meta + .unsqueeze(offset=global_expert_id, length=num_experts, axis=0) + .refragment(1, axis=0) # 1 is dst EP + ) + mapping[src_meta] = [dst_meta] + self._update_mapping(mapping) + return mapping + + def _inner_map_for_qkv_proj( + self, + src_key: str, + dst_key: str, + proj_type: str, + num_attention_heads: int, + num_query_groups: int, + is_gated_attention: bool=False + ): + src_info = self._src_name_to_metadata[src_key] + dst_info = self._dst_name_to_metadata[dst_key] + mapping = defaultdict(list) + for src_meta, dst_meta in process_qkv_tensor( + src_info, + num_attention_heads, + num_query_groups, + self._dst_tp_size, + proj_type=proj_type, + is_gated_attention=is_gated_attention + ): + src_meta.param_id, dst_meta.param_id = src_info.param_id, dst_info.param_id + src_meta.dtype, dst_meta.dtype = src_info.dtype, dst_info.dtype + mapping[src_meta].append(dst_meta) + self._update_mapping(mapping) + return mapping + + def _inner_map_for_mla_down_proj(self, src_key: str, dst_key: str): + src_info = self._src_name_to_metadata[src_key] + dst_info = self._dst_name_to_metadata[dst_key] + dst_meta = src_info.refragment(1) + dst_meta.param_id = dst_info.param_id + dst_meta.dtype = dst_info.dtype + results = {src_info.copy(): [dst_meta]} + self._update_mapping(results) + return results + + def _inner_map_for_merged_linear( + self, + src_key: str, + dst_key: str, + src_layout: List[Tuple[str, int]], + required_layout: List[str], + *, + global_expert_id: int=None, + num_experts: int=None, + axis: int = 0 + ): + src_info = self._src_name_to_metadata[src_key] + dst_info = self._dst_name_to_metadata[dst_key] + mapping = {} + for src_meta, dst_meta in process_merged_linear_tensor( + src_info, + self._dst_tp_size, + src_layout=src_layout, + required_layout=required_layout, + axis=axis + ): + src_meta.param_id, dst_meta.param_id = src_info.param_id, dst_info.param_id + src_meta.dtype, dst_meta.dtype = src_info.dtype, dst_info.dtype + if global_expert_id is not None: + dst_meta = ( + dst_meta + .unsqueeze(offset=global_expert_id, length=num_experts, axis=0) + .refragment(1, axis=0) # 1 is dst EP + ) + mapping[src_meta] = [dst_meta] + self._update_mapping(mapping) + return mapping + + def _inner_map_for_linear_attn( + self, + src_key: str, + dst_key: str, + src_layout: List[Tuple[str, int]], + required_layout: List[str], + *, + global_expert_id: int=None, + num_experts: int=None, + axis: int = 0, + n_groups: int = 1 + ): + src_info = self._src_name_to_metadata[src_key] + dst_info = self._dst_name_to_metadata[dst_key] + mapping = {} + for src_meta, dst_meta in process_linear_attn_tensor( + src_info, + self._dst_tp_size, + n_groups=n_groups, + src_layout=src_layout, + required_layout=required_layout, + axis=axis + ): + src_meta.param_id, dst_meta.param_id = src_info.param_id, dst_info.param_id + src_meta.dtype, dst_meta.dtype = src_info.dtype, dst_info.dtype + if global_expert_id is not None: + dst_meta = ( + dst_meta + .unsqueeze(offset=global_expert_id, length=num_experts, axis=0) + .refragment(1, axis=0) # 1 is dst EP + ) + mapping[src_meta] = [dst_meta] + self._update_mapping(mapping) + return mapping + + def _update_mapping(self, results: Dict[ShardedTensorInfo, List[ShardedTensorInfo]]) -> None: + if self._mapping is None: + self._mapping = defaultdict(list) + for src_meta, dst_metas in results.items(): + self._mapping[src_meta] += dst_metas diff --git a/chatlearn/synchronizer/mappers/mapper.py b/chatlearn/synchronizer/mappers/mapper.py deleted file mode 100644 index 9d003858..00000000 --- a/chatlearn/synchronizer/mappers/mapper.py +++ /dev/null @@ -1,702 +0,0 @@ -# Copyright 2025 Alibaba Group Holding Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Mapper for Megatron to vLLM""" -from collections import defaultdict -from typing import List, Dict, Tuple, TYPE_CHECKING, Union - -import inspect -import torch -from torch import nn -from transformers import AutoConfig - -from megatron.core import mpu -from megatron.core.transformer.transformer_layer import get_transformer_layer_offset -from megatron.core.transformer.moe.moe_layer import MoELayer -from megatron.core.transformer.moe.experts import TEGroupedMLP -from megatron.training.utils import unwrap_model - -from chatlearn.configs import PolicyConfig -from chatlearn.configs.megatron_config import MegatronPolicyTrainerConfig -from chatlearn.utils.mappings import ShardedTensorInfo - -from .mapping_helpers import ( - process_normal_tensor, - process_gate_up_tensor, - process_qkv_tensor, - VLLM_HELPERS, - HF_HELPERS -) - -if TYPE_CHECKING: - from megatron.core.models.gpt import GPTModel - from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding - from megatron.core.transformer.transformer_layer import TransformerLayer - from megatron.core.tensor_parallel import ColumnParallelLinear - from megatron.core.transformer.mlp import MLP - from megatron.core.transformer.multi_latent_attention import MLASelfAttention - from megatron.core.transformer.attention import SelfAttention - from chatlearn.models.megatron_module import MegatronModule - -class MegatronMapper: - """MegatronMapper""" - def __init__( - self, - dst_model_config: PolicyConfig, - model: 'MegatronModule', - *, - mapper_config: Union[VLLM_HELPERS, HF_HELPERS] = VLLM_HELPERS, - ): - """The Mapper for Megatron sync. In each remote Megatron Actor, - the method of this class is called to generate the parameter mapping - between src and dst. Currently, the mapper supports mapping - MCore Model to vLLM or HF Model. - - WARNING: The mapper assumes that the weights name of same - submodules in different vLLM models are still same. - - Args: - dst_model_config (PolicyConfig): The config of target model to - be sychronized - model (MegatronModule): The source Megatron Module - mapper_config (Union[VLLM_HELPERS, HF_HELPERS]): The mapping mode. - """ - self.model: List['GPTModel'] = unwrap_model(model.model) - self._src_model_config: MegatronPolicyTrainerConfig = model.module_args - self._dst_model_config = dst_model_config - self._mapper_config = mapper_config - self._dst_tp_size = 1 if mapper_config.force_full_model else self._dst_model_config.tensor_model_parallel_size - self._src_name_to_metadata = model.get_parameter_metadata(key_type='local_name') - self._mapping = None - - def generate_sync_mapping( - self, - dst_name_to_metadata: Dict[str, Tuple[int, torch.dtype]] - ) -> Dict[ShardedTensorInfo, List[ShardedTensorInfo]]: - """ Generate the synchronization mapping of this local rank. - - Args: - dst_name_to_metadata (Dict[str, Tuple[int, torch.dtype]]): mapping a global - parameter name to its param_id and datatype - - Returns: - Dict[ShardedTensorInfo, List[ShardedTensorInfo]]: The return - dict is the plan including all local parameters to be synchronized. The - mapper will ensure that the key of mapping for each mapping type is - non-overlapping and can merge into the full state dict of this rank. - For most cases, the length of dst shards list is 1, except for GQA with - large vLLM TP. - """ - self._dst_name_to_metadata = dst_name_to_metadata - return self._map_model() - - def dump_sync_mapping(self, folder_path: str, sync_mapping: Dict): - """dump the generayed sync mapping to the given folder path in JSON format. - - Args: - folder_path (str): The folder path to dump the sync mapping. - sync_mapping (Dict): The sync mapping to be saved. - """ - raise NotImplementedError() - - def _map_vlm_model(self, model: nn.Module, vp_stage: int, layer_offset: int): - dst_language_prefix = self._mapper_config.dst_language_prefix - dst_vision_prefix = self._mapper_config.dst_vision_prefix - dst_lm_head_prefix = self._mapper_config.dst_lm_head_prefix - - if model.pre_process: - self._update_mapping(self._map_preprocess_layer( - model.language_model.embedding, - src_prefix=f"{vp_stage}-language_model.embedding.", - dst_prefix=f"{dst_language_prefix}", - )) - - self._update_mapping(self._inner_map_for_full_shape( - f"{vp_stage}-vision_model.patch_embed.proj.weight", - f"{dst_vision_prefix}patch_embed.proj.weight" - )) - - # vision model decoder - for layer_idx in range(model.vision_model.config.num_layers): - global_layer_id = layer_offset + layer_idx - self._update_mapping(self._map_vision_layer( - model.vision_model.decoder.layers[layer_idx], - src_prefix=f"{vp_stage}-vision_model.decoder.layers.{layer_idx}.", - dst_prefix=f"{dst_vision_prefix}blocks.{global_layer_id}.", - num_attention_heads=model.vision_model.config.num_attention_heads, - num_query_groups=model.vision_model.config.num_query_groups - )) - - # vision model projection - self._update_mapping(self._inner_map_for_full_shape( - f"{vp_stage}-vision_model.decoder.final_layernorm.weight", - f"{dst_vision_prefix}merger.ln_q.weight" - )) - - self._update_mapping(self._inner_map_for_tensor_parallel( - f"{vp_stage}-vision_model.projection.encoder.linear_fc1.weight", - f"{dst_vision_prefix}merger.mlp.0.weight", - mapping_type='column' - )) - - self._update_mapping(self._inner_map_for_tensor_parallel( - f"{vp_stage}-vision_model.projection.encoder.linear_fc1.bias", - f"{dst_vision_prefix}merger.mlp.0.bias", - mapping_type='column' - )) - - self._update_mapping(self._inner_map_for_tensor_parallel( - f"{vp_stage}-vision_model.projection.encoder.linear_fc2.weight", - f"{dst_vision_prefix}merger.mlp.2.weight", - mapping_type='row' - )) - - # bias for row is not slice, so we need to map it to full shape - self._update_mapping(self._inner_map_for_full_shape( - f"{vp_stage}-vision_model.projection.encoder.linear_fc2.bias", - f"{dst_vision_prefix}merger.mlp.2.bias" - )) - - for layer_idx in range(model.language_model.decoder.num_layers_per_pipeline_rank): - global_layer_id = layer_offset + layer_idx - self._update_mapping(self._map_decoder_layer( - model.language_model.decoder.layers[layer_idx], - src_prefix=f"{vp_stage}-language_model.decoder.layers.{layer_idx}.", - dst_prefix=f"{dst_language_prefix}layers.{global_layer_id}.", - )) - - if model.post_process: - self._update_mapping(self._map_norm_layer( - model.language_model.decoder.final_layernorm, - src_prefix=f"{vp_stage}-language_model.decoder.final_layernorm.", - dst_prefix=f"{dst_language_prefix}norm.", - )) - - if model.share_embeddings_and_output_weights and model.pre_process: - self._update_mapping(self._map_postprocess_layer( - model.language_model.embedding, - src_prefix=f"{vp_stage}-language_model.embedding.word_embeddings.", - dst_prefix=f"{dst_lm_head_prefix}", - )) - else: - self._update_mapping(self._map_postprocess_layer( - model.language_model.output_layer, - src_prefix=f"{vp_stage}-language_model.output_layer.", - dst_prefix=f"{dst_lm_head_prefix}", - )) - - def _map_llm_model(self, model: nn.Module, vp_stage: int, layer_offset: int): - if model.pre_process: - self._update_mapping(self._map_preprocess_layer( - model.embedding, - src_prefix=f"{vp_stage}-embedding.", - dst_prefix="model.", - )) - - for layer_idx in range(model.decoder.num_layers_per_pipeline_rank): - global_layer_id = layer_offset + layer_idx - self._update_mapping(self._map_decoder_layer( - model.decoder.layers[layer_idx], - src_prefix=f"{vp_stage}-decoder.layers.{layer_idx}.", - dst_prefix=f"model.layers.{global_layer_id}.", - )) - - if model.post_process: - self._update_mapping(self._map_norm_layer( - model.decoder.final_layernorm, - src_prefix=f"{vp_stage}-decoder.final_layernorm.", - dst_prefix="model.norm.", - )) - - if model.share_embeddings_and_output_weights and model.pre_process: - self._update_mapping(self._map_postprocess_layer( - model.embedding, - src_prefix=f"{vp_stage}-embedding.word_embeddings.", - dst_prefix="", - )) - else: - self._update_mapping(self._map_postprocess_layer( - model.output_layer, - src_prefix=f"{vp_stage}-output_layer.", - dst_prefix="", - )) - # NOTE: the following function implements the module-wise sync mapping - def _map_model(self): - """Mapping the local name of src model to global name of - dst model - """ - for vp_stage, model in enumerate(self.model): - if 'vp_stage' in inspect.signature(get_transformer_layer_offset).parameters: - layer_offset = get_transformer_layer_offset(model.config, vp_stage=vp_stage) - else: - if len(self.model) > 1: - mpu.set_virtual_pipeline_model_parallel_rank(vp_stage) - layer_offset = get_transformer_layer_offset(model.config) - if len(self.model) > 1: - mpu.set_virtual_pipeline_model_parallel_rank(None) - - if hasattr(model, 'vision_model'): - model.mtp_process = False - - if model.mtp_process: - raise NotImplementedError("Currently, the mapper does not support MTP") - - if hasattr(model, 'vision_model'): - self._map_vlm_model(model, vp_stage=vp_stage, layer_offset=layer_offset) - else: - # llm model - self._map_llm_model(model, vp_stage=vp_stage, layer_offset=layer_offset) - - mapping = self._mapping - self._mapping = None - - return mapping - - def _map_norm_layer(self, module: nn.Module, src_prefix: str='', dst_prefix: str='', *, is_norm_layer: bool=True): - """If is_norm_layer is True, try to map on all possible keys, - otherwise only map on `layer_norm_weight` and `layer_norm_bias` - """ - mapping = {} - _keynames = { - 'weight': 'weight', - 'bias': 'bias', - 'layer_norm_weight': 'weight', - 'layer_norm_bias': 'bias' - } - possible_keys = ['layer_norm_weight', 'layer_norm_bias'] - if is_norm_layer: - possible_keys += ['weight', 'bias'] - for item in possible_keys: - if getattr(module, item, None) is None or getattr(module, item).numel() == 0: - continue - self._update_mapping(self._inner_map_for_full_shape( - f"{src_prefix}{item}", - f"{dst_prefix}{_keynames[item]}" - )) - return mapping - - def _map_decoder_layer(self, module: 'TransformerLayer', src_prefix: str='', dst_prefix: str=''): - mapping = {} - if self._src_arch.multi_latent_attention: - map_attn_func = self._map_mla_selfattn - norm_layer = module.input_layernorm - norm_src_key = f"{src_prefix}input_layernorm." - norm_dst_key = f"{dst_prefix}input_layernorm." - is_norm_layer = True - else: - map_attn_func = self._map_selfattn - norm_layer = module.self_attention.linear_qkv - norm_src_key = f"{src_prefix}self_attention.linear_qkv." - norm_dst_key = f"{dst_prefix}input_layernorm." - is_norm_layer = False - self._update_mapping(map_attn_func(module.self_attention, src_prefix=f"{src_prefix}self_attention.", dst_prefix=f"{dst_prefix}self_attn.")) - self._update_mapping(self._map_norm_layer(norm_layer, norm_src_key, norm_dst_key, is_norm_layer=is_norm_layer)) - - if isinstance(module.mlp, MoELayer): - map_mlp_func = self._map_moe_layer - norm_layer = module.pre_mlp_layernorm - norm_src_key = f"{src_prefix}pre_mlp_layernorm." - norm_dst_key = f"{dst_prefix}post_attention_layernorm." - is_norm_layer = True - else: - map_mlp_func = self._map_mlp - norm_layer = module.mlp.linear_fc1 - norm_src_key = f"{src_prefix}mlp.linear_fc1." - norm_dst_key = f"{dst_prefix}post_attention_layernorm." - is_norm_layer = False - self._update_mapping(map_mlp_func(module.mlp, src_prefix=f"{src_prefix}mlp.", dst_prefix=f"{dst_prefix}mlp.")) - self._update_mapping(self._map_norm_layer(norm_layer, norm_src_key, norm_dst_key, is_norm_layer=is_norm_layer)) - return mapping - - def _map_vision_layer( - self, - module: 'TransformerLayer', - src_prefix: str = '', - dst_prefix: str = '', - num_attention_heads: int = None, - num_query_groups: int = None - ): - mapping = {} - - # module.self_attention - # linear_proj - self._update_mapping(self._inner_map_for_tensor_parallel( - f"{src_prefix}self_attention.linear_proj.weight", - f"{dst_prefix}attn.proj.weight", - mapping_type='row' - )) - - # bias for row is not slice, so we need to map it to full shape - self._update_mapping(self._inner_map_for_full_shape( - f"{src_prefix}self_attention.linear_proj.bias", - f"{dst_prefix}attn.proj.bias" - )) - - # linear_qkv - self._update_mapping(self._inner_map_for_qkv_proj( - f"{src_prefix}self_attention.linear_qkv.weight", - f"{dst_prefix}attn.qkv.weight", - proj_type='qkv_proj', - num_attention_heads=num_attention_heads, - num_query_groups=num_query_groups - )) - if self._src_arch.add_qkv_bias: - self._update_mapping(self._inner_map_for_qkv_proj( - f"{src_prefix}self_attention.linear_qkv.bias", - f"{dst_prefix}attn.qkv.bias", - proj_type='qkv_proj', - num_attention_heads=num_attention_heads, - num_query_groups=num_query_groups - )) - - # linear_qkv_norm - self._update_mapping(self._inner_map_for_full_shape( - f"{src_prefix}self_attention.linear_qkv.layer_norm_weight", - f"{dst_prefix}norm1.weight" - )) - - # module.mlp - self._update_mapping(self._map_mlp(module.mlp, src_prefix=f"{src_prefix}mlp.", dst_prefix=f"{dst_prefix}mlp.", is_vision_block=True)) - - # mlp norm - self._update_mapping(self._inner_map_for_full_shape( - f"{src_prefix}mlp.linear_fc1.layer_norm_weight", - f"{dst_prefix}norm2.weight" - )) - return mapping - - def _map_moe_layer(self, module: 'MoELayer', src_prefix='', dst_prefix=''): - mapping = {} - # router - self._update_mapping(self._inner_map_for_full_shape(f"{src_prefix}router.weight", f"{dst_prefix}gate.weight")) - if module.router.enable_expert_bias: - self._update_mapping(self._inner_map_for_full_shape(f"{src_prefix}router.expert_bias", f"{dst_prefix}gate.e_score_correction_bias")) - - if not module.config.moe_grouped_gemm: - raise NotImplementedError("Parameter Sync w/ MoE SequentialMLP is not supported") - if not isinstance(module.experts, TEGroupedMLP): - raise NotImplementedError("Parameter Sync w/ Legacy GroupedMLP is not supported") - - # experts - self._update_mapping(self._map_group_mlp( - module.experts, - src_prefix=f"{src_prefix}experts.", - dst_prefix=f"{dst_prefix}experts." - )) - - # shared experts - if module.shared_experts is not None: - if module.shared_experts.use_shared_expert_gate: - self._update_mapping( - self._inner_map_for_full_shape( - f"{src_prefix}shared_experts.gate_weight", - f"{dst_prefix}shared_expert_gate.weight" - ) - ) - # NOTE: if transformer.config have n_shared_experts, mapping to `shared_experts`, otherwise `shared_expert` - # `shared_experts`: DeepSeek-V2, DeepSeek-V3, etc. - # `shared_expert`: Qwen2-MoE, LLaMA-4, etc. - hf_config = AutoConfig.from_pretrained(self._dst_model_config.load, trust_remote_code=self._dst_model_config.trust_remote_code) - shared_expert_key = 'shared_experts' if hasattr(hf_config, 'n_shared_experts') else 'shared_expert' - self._update_mapping(self._map_mlp( - module.shared_experts, - src_prefix=f"{src_prefix}shared_experts.", - dst_prefix=f"{dst_prefix}{shared_expert_key}." - )) - return mapping - - def _map_mlp(self, module: 'MLP', src_prefix: str='', dst_prefix: str='', is_vision_block=False): - mapping = {} - if not module.config.gated_linear_unit: - raise NotImplementedError("Parameter Sync w/o GatedLinear is not supported") - - dst_names = ['gate_proj', 'up_proj'] - if self._mapper_config.merge_gate_up and not is_vision_block: - dst_names = ['gate_up_proj'] - - for dst_name in dst_names: - self._update_mapping(self._inner_map_for_gate_up_proj( - f"{src_prefix}linear_fc1.weight", - f"{dst_prefix}{dst_name}.weight", - proj_type=dst_name - )) - - if module.config.add_bias_linear: - self._update_mapping(self._inner_map_for_gate_up_proj( - f"{src_prefix}linear_fc1.bias", - f"{dst_prefix}{dst_name}.bias", - proj_type=dst_name - )) - - self._update_mapping(self._inner_map_for_tensor_parallel( - f"{src_prefix}linear_fc2.weight", - f"{dst_prefix}down_proj.weight", - mapping_type='row' - )) - - if module.config.add_bias_linear: - self._update_mapping(self._inner_map_for_full_shape( - f"{src_prefix}linear_fc2.bias", - f"{dst_prefix}down_proj.bias" - )) - return mapping - - def _map_group_mlp(self, module: 'TEGroupedMLP', src_prefix: str='', dst_prefix: str=''): - # pylint: disable=unused-argument - src_ep_rank = mpu.get_expert_model_parallel_rank() - src_ep_size = mpu.get_expert_model_parallel_world_size() - num_experts = self._src_arch.num_experts - global_expert_id_start = num_experts // src_ep_size * src_ep_rank - global_expert_id_end = num_experts // src_ep_size * (src_ep_rank + 1) - mapping = {} - for local_expert_id, global_expert_id in enumerate(range(global_expert_id_start, global_expert_id_end)): - if self._mapper_config.merge_expert: - if not self._mapper_config.merge_gate_up: - raise NotImplementedError("merge_expert w/o merge_gate_up is not implemented.") - self._update_mapping(self._inner_map_for_gate_up_proj( - f"{src_prefix}linear_fc1.weight{local_expert_id}", - f"{dst_prefix}w13_weight", - proj_type='gate_up_proj', - global_expert_id=global_expert_id, - num_experts=num_experts - )) - self._update_mapping(self._inner_map_for_tensor_parallel( - f"{src_prefix}linear_fc2.weight{local_expert_id}", - f"{dst_prefix}w2_weight", - global_expert_id=global_expert_id, - num_experts=num_experts, - mapping_type='row' - )) - else: - if self._mapper_config.merge_gate_up: - raise NotImplementedError("no merge_expert w/ merge_gate_up is not implemented.") - for dst_name in ['gate_proj', 'up_proj']: - self._update_mapping(self._inner_map_for_gate_up_proj( - f"{src_prefix}linear_fc1.weight{local_expert_id}", - f"{dst_prefix}{global_expert_id}.{dst_name}.weight", - proj_type=dst_name, - )) - self._update_mapping(self._inner_map_for_tensor_parallel( - f"{src_prefix}linear_fc2.weight{local_expert_id}", - f"{dst_prefix}{global_expert_id}.down_proj.weight", - mapping_type='row' - )) - return mapping - - def _map_mla_selfattn(self, module: 'MLASelfAttention', src_prefix: str='', dst_prefix: str=''): - mapping = {} - if self._src_arch.q_lora_rank is None: - self._update_mapping(self._inner_map_for_tensor_parallel( - f"{src_prefix}linear_q_proj.weight", - f"{dst_prefix}q_proj.weight", - mapping_type='column' - )) - else: - self._update_mapping(self._inner_map_for_mla_down_proj( - f"{src_prefix}linear_q_down_proj.weight", - f"{dst_prefix}q_a_proj.weight", - )) - self._update_mapping(self._inner_map_for_tensor_parallel( - f"{src_prefix}linear_q_up_proj.weight", - f"{dst_prefix}q_b_proj.weight", - mapping_type='column' - )) - if self._src_arch.qk_layernorm: - self._update_mapping( - self._map_norm_layer( - module.linear_q_up_proj, - f"{src_prefix}linear_q_up_proj.", - f"{dst_prefix}q_a_layernorm.", - is_norm_layer=False - ) - ) - self._update_mapping(self._inner_map_for_mla_down_proj( - f"{src_prefix}linear_kv_down_proj.weight", - f"{dst_prefix}kv_a_proj_with_mqa.weight", - )) - self._update_mapping(self._inner_map_for_tensor_parallel( - f"{src_prefix}linear_kv_up_proj.weight", - f"{dst_prefix}kv_b_proj.weight", - mapping_type='column' - )) - if self._src_arch.qk_layernorm: - self._update_mapping( - self._map_norm_layer( - module.linear_kv_up_proj, - f"{src_prefix}linear_kv_up_proj.", - f"{dst_prefix}kv_a_layernorm.", - is_norm_layer=False - ) - ) - self._update_mapping(self._inner_map_for_tensor_parallel( - f"{src_prefix}linear_proj.weight", - f"{dst_prefix}o_proj.weight", - mapping_type='row' - )) - return mapping - - def _map_selfattn(self, module: 'SelfAttention', src_prefix: str='', dst_prefix: str=''): - mapping = {} - if self._src_arch.qk_layernorm: - self._update_mapping(self._map_norm_layer(module.q_layernorm, f"{src_prefix}q_layernorm.", f"{dst_prefix}q_norm.")) - self._update_mapping(self._map_norm_layer(module.k_layernorm, f"{src_prefix}k_layernorm.", f"{dst_prefix}k_norm.")) - - dst_names = ['q_proj', 'k_proj', 'v_proj'] - if self._mapper_config.merge_qkv: - dst_names = ['qkv_proj'] - - for dst_name in dst_names: - self._update_mapping(self._inner_map_for_qkv_proj( - f"{src_prefix}linear_qkv.weight", - f"{dst_prefix}{dst_name}.weight", - proj_type=dst_name, - num_attention_heads = self._src_arch.num_attention_heads, - num_query_groups = self._src_arch.num_query_groups - )) - if self._src_arch.add_qkv_bias: - self._update_mapping(self._inner_map_for_qkv_proj( - f"{src_prefix}linear_qkv.bias", - f"{dst_prefix}{dst_name}.bias", - proj_type=dst_name, - num_attention_heads = self._src_arch.num_attention_heads, - num_query_groups = self._src_arch.num_query_groups - )) - - self._update_mapping(self._inner_map_for_tensor_parallel( - f"{src_prefix}linear_proj.weight", - f"{dst_prefix}o_proj.weight", - mapping_type='row' - )) - return mapping - - def _map_preprocess_layer(self, module: 'LanguageModelEmbedding', src_prefix='', dst_prefix=''): - if module.add_position_embedding: - raise NotImplementedError("learned_absolute embedding is not supported") - return self._inner_map_for_tensor_parallel( - f"{src_prefix}word_embeddings.weight", - f"{dst_prefix}embed_tokens.weight", - mapping_type='column' - ) - - def _map_postprocess_layer(self, module: 'ColumnParallelLinear', src_prefix='', dst_prefix=''): - # pylint: disable=unused-argument - if ( - not self._src_arch.untie_embeddings_and_output_weights and - f"{dst_prefix}lm_head.weight" not in self._dst_name_to_metadata - ): - return {} - return self._inner_map_for_tensor_parallel( - f"{src_prefix}weight", - f"{dst_prefix}lm_head.weight", - mapping_type='column' - ) - - # NOTE: the following function implements the tensor-wise sync mapping - def _inner_map_for_tensor_parallel( - self, - src_key: str, - dst_key: str, - *, - global_expert_id: int=None, - num_experts: int=None, - mapping_type: str='column' - ): - AXES = {'column': 0, 'row': 1} - src_info = self._src_name_to_metadata[src_key] - dst_info = self._dst_name_to_metadata[dst_key] - mapping = {} - for src_meta, dst_meta in process_normal_tensor( - src_info, - self._dst_tp_size, - axis=AXES[mapping_type] - ): - src_meta.param_id, dst_meta.param_id = src_info.param_id, dst_info.param_id - src_meta.dtype, dst_meta.dtype = src_info.dtype, dst_info.dtype - if global_expert_id is not None: - dst_meta = ( - dst_meta - .unsqueeze(offset=global_expert_id, length=num_experts, axis=0) - .refragment(1, axis=0) # 1 is dst EP - ) - mapping[src_meta] = [dst_meta] - return mapping - - def _inner_map_for_full_shape( - self, - src_key: str, - dst_key: str - ): - src_info = self._src_name_to_metadata[src_key] - dst_info = self._dst_name_to_metadata[dst_key] - return { - src_info.copy(): [dst_info.copy()] - } - - def _inner_map_for_gate_up_proj(self, src_key: str, dst_key: str, proj_type: str, *, global_expert_id: int=None, num_experts: int=None): - src_info = self._src_name_to_metadata[src_key] - dst_info = self._dst_name_to_metadata[dst_key] - mapping = {} - for src_meta, dst_meta in process_gate_up_tensor( - src_info, - self._dst_tp_size, - proj_type=proj_type - ): - src_meta.param_id, dst_meta.param_id = src_info.param_id, dst_info.param_id - src_meta.dtype, dst_meta.dtype = src_info.dtype, dst_info.dtype - if global_expert_id is not None: - dst_meta = ( - dst_meta - .unsqueeze(offset=global_expert_id, length=num_experts, axis=0) - .refragment(1, axis=0) # 1 is dst EP - ) - mapping[src_meta] = [dst_meta] - return mapping - - def _inner_map_for_qkv_proj(self, src_key: str, dst_key: str, proj_type: str, num_attention_heads: int, num_query_groups: int): - src_info = self._src_name_to_metadata[src_key] - dst_info = self._dst_name_to_metadata[dst_key] - mapping = defaultdict(list) - for src_meta, dst_meta in process_qkv_tensor( - src_info, - num_attention_heads, - num_query_groups, - self._dst_tp_size, - proj_type=proj_type - ): - src_meta.param_id, dst_meta.param_id = src_info.param_id, dst_info.param_id - src_meta.dtype, dst_meta.dtype = src_info.dtype, dst_info.dtype - mapping[src_meta].append(dst_meta) - return mapping - - def _inner_map_for_mla_down_proj(self, src_key: str, dst_key: str): - src_info = self._src_name_to_metadata[src_key] - dst_info = self._dst_name_to_metadata[dst_key] - dst_meta = src_info.refragment(1) - dst_meta.param_id = dst_info.param_id - dst_meta.dtype = dst_info.dtype - return { - src_info.copy(): [dst_meta] - } - - @property - def _src_arch(self): - return self._src_model_config.megatron_model_cfg - - def _update_mapping(self, results: Dict[ShardedTensorInfo, List[ShardedTensorInfo]]): - if self._mapping is None: - self._mapping = defaultdict(list) - for src_meta, dst_metas in results.items(): - self._mapping[src_meta] += dst_metas - return self._mapping diff --git a/chatlearn/synchronizer/mappers/mapping_helpers.py b/chatlearn/synchronizer/mappers/mapping_helpers.py index c8244e1c..b0ba385a 100644 --- a/chatlearn/synchronizer/mappers/mapping_helpers.py +++ b/chatlearn/synchronizer/mappers/mapping_helpers.py @@ -19,6 +19,7 @@ from itertools import chain from chatlearn.utils.mappings import ShardedTensorInfo +from chatlearn.utils.utils import slice_data_list_by_index def process_normal_tensor( sharded_info: ShardedTensorInfo, @@ -48,25 +49,6 @@ def process_normal_tensor( ) for tensor_part_info in sharded_info.fragment(dst_tp_size, axis) ] -def _build_gate_up_layout(src_tp_size: int, dst_tp_size: int): - """ - build layout with 2 * lcm(src_tp, dst_tp) chunks - """ - n_chunks = math.lcm(src_tp_size, dst_tp_size) - flatten = lambda x: list(chain.from_iterable(x)) # pylint: disable=unnecessary-lambda-assignment - mcore_layout = flatten([ - [ f"g{c_id + tp_rank * (n_chunks // src_tp_size)}" for c_id in range(n_chunks // src_tp_size) ] + - [ f"u{c_id + tp_rank * (n_chunks // src_tp_size)}" for c_id in range(n_chunks // src_tp_size) ] - for tp_rank in range(src_tp_size) - ]) - - vllm_layout = flatten([ - [ f"g{c_id + tp_rank * (n_chunks // dst_tp_size)}" for c_id in range(n_chunks // dst_tp_size) ] + - [ f"u{c_id + tp_rank * (n_chunks // dst_tp_size)}" for c_id in range(n_chunks // dst_tp_size) ] - for tp_rank in range(dst_tp_size) - ]) - - return mcore_layout, vllm_layout def process_gate_up_tensor( sharded_info: ShardedTensorInfo, @@ -83,74 +65,167 @@ def process_gate_up_tensor( Returns: List[Tuple[ShardedTensorInfo, ...]]: The layout mapping. """ - src_tp_size = sharded_info.axis_fragmentations[0] - - mcore_layout, vllm_layout = _build_gate_up_layout(src_tp_size, dst_tp_size) - mcore_id_to_frags = { - part.global_offset[0]: part.refragment(src_tp_size) - for part in sharded_info.fragment(math.lcm(src_tp_size, dst_tp_size) * 2) - } + gate_up = sharded_info.global_shape[0] if proj_type == 'gate_up_proj': - n_chunks = math.lcm(src_tp_size, dst_tp_size) * 2 - full_dst_info = ShardedTensorInfo.from_global_shape(sharded_info.global_shape) + layout = ['gate', 'up'] + elif proj_type == 'up_proj': + layout = ['up'] else: - n_chunks = math.lcm(src_tp_size, dst_tp_size) - full_dst_info = ShardedTensorInfo.from_global_shape( - (sharded_info.global_shape[0] // 2, ) + sharded_info.global_shape[1:] + layout = ['gate'] + return process_merged_linear_tensor( + sharded_info, + dst_tp_size, + src_layout=[('gate', gate_up // 2), ('up', gate_up // 2)], + required_layout=layout + ) + + +def _build_merged_linear_layout( + layout: List[Tuple[str, int]], + n_chunks: int, + tp_size: int, +) -> List[Tuple[str, int, int]]: + flatten = lambda x: list(chain.from_iterable(x)) # pylint: disable=unnecessary-lambda-assignment + mcore_layout = flatten([ + flatten([ + [ (key, c_id + tp_rank * (n_chunks // tp_size), size // n_chunks) for c_id in range(n_chunks // tp_size) ] + for key, size in layout + ]) + for tp_rank in range(tp_size) + ]) + return mcore_layout + +def process_merged_linear_tensor( + sharded_info: ShardedTensorInfo, + dst_tp_size: int, + src_layout: List[Tuple[str, int]], + required_layout: List[str], + axis: int = 0 +) -> List[Tuple[ShardedTensorInfo, ...]]: + """ + A generalized implementation to resolve mapping on a column-merged linear + """ + src_tp_rank = sharded_info.global_offset[axis] + src_tp_size = sharded_info.axis_fragmentations[axis] + n_chunks = math.lcm(src_tp_size, dst_tp_size) + keyname_to_size = {item[0] : item[1] for item in src_layout} + + src_names = [item[0] for item in src_layout] + if not set(required_layout).issubset(set(src_names)): + raise ValueError(f"Expect all keys of the required layout is the subset of source layout {src_names}, but {required_layout}") + + mcore_layout = slice_data_list_by_index(_build_merged_linear_layout( + src_layout, + n_chunks, + src_tp_size + ), (src_tp_rank, src_tp_size)) + + id_to_frags = { + (item[0], item[1]): part + for item, part in zip( + mcore_layout, + sharded_info.chunk(sections=[item[2] for item in mcore_layout], axis=axis) ) + } + full_dst_size = sum(keyname_to_size[name] for name in required_layout) + full_dst_info = ShardedTensorInfo.from_global_shape( + (full_dst_size, ) + sharded_info.global_shape[1:] + ) + + vllm_layout = _build_merged_linear_layout( + [(name, keyname_to_size[name]) for name in required_layout], + n_chunks, + dst_tp_size + ) results = [] - for chunk_idx, dst_part in enumerate(full_dst_info.fragment(n_chunks)): - if proj_type == 'gate_up_proj': - chunk_name = vllm_layout[chunk_idx] - else: - chunk_name = f"{proj_type[:1]}{chunk_idx}" - mcore_idx = mcore_layout.index(chunk_name) - if mcore_idx not in mcore_id_to_frags: + for (name, chunk_id, _), dst_part in zip( + vllm_layout, + full_dst_info.chunk(sections=[item[2] for item in vllm_layout], axis=axis) + ): + if (name, chunk_id) not in id_to_frags: continue results.append(( - mcore_id_to_frags[mcore_idx], - dst_part.refragment(dst_tp_size) + id_to_frags[(name, chunk_id)], + dst_part.refragment(dst_tp_size, axis=axis) )) return __maybe_merge(results) +def process_linear_attn_tensor( + sharded_info: ShardedTensorInfo, + dst_tp_size: int, + n_groups: int, + src_layout: List[Tuple[str, int]], + required_layout: List[str], + axis: int = 0 +) -> List[Tuple[ShardedTensorInfo, ...]]: + if n_groups % dst_tp_size != 0: + raise ValueError("n_groups of linear attn should be divided by tp!") + results = process_merged_linear_tensor( + sharded_info=sharded_info, + dst_tp_size=n_groups, + src_layout=src_layout, + required_layout=required_layout, + axis=axis + ) + return [(item[0], item[1].refragment(dst_tp_size, axis=axis)) for item in results] + def _build_qkv_layout( num_heads: int, num_query_group: int, - dst_tp_size: int + dst_tp_size: int, + is_gated_attention: bool = False ): """Generate a mapping between mcore qkv heads (mix-style qkv) and vllm qkv heads (no mix-style qkv). + is_gated_attention=False: Mcore layout of first dim per tp rank when nh=24, ng=8, tp=4, nq=3: [q q q k v q q q k v], while vLLM: [q q q q q q k k v v] + is_gated_attention=True: + Mcore layout of first dim per tp rank when + nh=48, ng=8, tp=4, nq=3: [q q q g g g k v q q q g g g k v], + while vLLM: [q g q g q g q g q g q g k k v v] + Args: - num_heads (int): The num of attention heads + num_heads (int): The num of attention heads. If is_gated_attention is True, the number should be + the total amount of query heads and gate heads num_query_group (int): The num of query groups dst_tp_size (int): The dst tensor parallel size + is_gated_attention (bool, optional): whether query heads have corresponding gate head """ flatten = lambda x: list(chain.from_iterable(x)) # pylint: disable=unnecessary-lambda-assignment + if is_gated_attention: + num_heads //= 2 nq = num_heads // num_query_group mcore_layout = [] vllm_layout = [] mcore_layout = flatten([ - [ f"q{g_id * nq + q_id}" for q_id in range(nq)] + [f"k{g_id}", f"v{g_id}"] + [ f"q{g_id * nq + q_id}" for q_id in range(nq)] + + ([ f"g{g_id * nq + q_id}" for q_id in range(nq)] if is_gated_attention else []) + + [f"k{g_id}", f"v{g_id}"] for g_id in range(num_query_group) ]) vllm_nq = num_heads // dst_tp_size if dst_tp_size < num_query_group: vllm_layout = flatten([ - [f"q{r_id * vllm_nq + q_id}" for q_id in range(num_heads // dst_tp_size)] + + flatten([ + (f"q{r_id * vllm_nq + q_id}", f"g{r_id * vllm_nq + q_id}") if is_gated_attention else (f"q{r_id * vllm_nq + q_id}", ) + for q_id in range(num_heads // dst_tp_size) + ]) + [f"k{g_id + r_id * (num_query_group // dst_tp_size)}" for g_id in range(num_query_group // dst_tp_size)] + [f"v{g_id + r_id * (num_query_group // dst_tp_size)}" for g_id in range(num_query_group // dst_tp_size)] for r_id in range(dst_tp_size) ]) else: vllm_layout = flatten([ - [f"q{r_id * vllm_nq + q_id}" for q_id in range(num_heads // dst_tp_size)] + + flatten([ + (f"q{r_id * vllm_nq + q_id}", f"g{r_id * vllm_nq + q_id}" if is_gated_attention else (f"q{r_id * vllm_nq + q_id}",)) + for q_id in range(num_heads // dst_tp_size) + ]) + [f"k{r_id * num_query_group // dst_tp_size}", f"v{r_id * num_query_group // dst_tp_size}"] for r_id in range(dst_tp_size) ]) @@ -162,7 +237,8 @@ def process_qkv_tensor( num_heads: int, num_query_groups: Optional[int], dst_tp_size: int, - proj_type: Literal['qkv_proj', 'q_proj', 'k_proj', 'v_proj'] + proj_type: Literal['qkv_proj', 'q_proj', 'k_proj', 'v_proj'], + is_gated_attention: bool = False ) -> List[Tuple[ShardedTensorInfo, ...]]: """Process qkv weight/bias to generate shard mapping. @@ -172,16 +248,19 @@ def process_qkv_tensor( num_query_group (int): The number of query groups dst_tp_size (int): The target tensor parallel size proj_type (Literal['qkv_proj', 'q_proj', 'k_proj', 'v_proj']): the projection type + is_gated_attention (bool, optional): whether query heads have corresponding gate head """ if num_query_groups is None: num_query_groups = num_heads + if is_gated_attention: + num_heads *= 2 if num_query_groups % dst_tp_size != 0 and dst_tp_size % num_query_groups != 0: raise ValueError(f"num_query_groups {num_query_groups} must be divisible or multiple by dst_tp_size {dst_tp_size}") head_dim = sharded_info.global_shape[0] // (num_heads + 2 * num_query_groups) src_tp_size = sharded_info.axis_fragmentations[0] src_global_shape = sharded_info.global_shape - mcore_layout, vllm_layout = _build_qkv_layout(num_heads, num_query_groups, dst_tp_size) + mcore_layout, vllm_layout = _build_qkv_layout(num_heads, num_query_groups, dst_tp_size, is_gated_attention=is_gated_attention) mcore_id_to_frags = { part.global_offset[0]: part.refragment(src_tp_size) for part in sharded_info.fragment(num_query_groups * (2 + num_heads // num_query_groups)) @@ -191,16 +270,14 @@ def process_qkv_tensor( n_heads = num_heads + 2 * num_query_groups * max(1, dst_tp_size // num_query_groups) elif proj_type == 'q_proj': n_heads = num_heads + vllm_layout = [item for item in vllm_layout if 'q' in item or 'g' in item] else: n_heads = num_query_groups * max(1, dst_tp_size // num_query_groups) + vllm_layout = [item for item in vllm_layout if proj_type[:1] in item] full_dst_info = ShardedTensorInfo.from_global_shape((n_heads * head_dim, ) + src_global_shape[1:]) results = [] - for head_idx, dst_part in enumerate(full_dst_info.fragment(n_heads)): - if proj_type == 'qkv_proj': - head_name = vllm_layout[head_idx] - else: - head_name = f"{proj_type[:1]}{head_idx}" # q0 / k1 / v2, etc. + for head_name, dst_part in zip(vllm_layout, full_dst_info.fragment(n_heads)): mcore_idx = mcore_layout.index(head_name) if mcore_idx not in mcore_id_to_frags: continue @@ -241,6 +318,7 @@ def __maybe_merge(mappings: List[Tuple[ShardedTensorInfo, ShardedTensorInfo]], a )) return results +# TODO: deprecate these config classes @dataclass(frozen=True) class VLLM_HELPERS: """The mapper configs for vllm""" diff --git a/chatlearn/synchronizer/mappers/megatron_llm_mapper.py b/chatlearn/synchronizer/mappers/megatron_llm_mapper.py new file mode 100644 index 00000000..475dcad3 --- /dev/null +++ b/chatlearn/synchronizer/mappers/megatron_llm_mapper.py @@ -0,0 +1,573 @@ +# Copyright 2025 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Mapper for Megatron to vLLM""" +from typing import TYPE_CHECKING, Union, Dict + +from torch import nn +from transformers import AutoConfig + +from megatron.core import mpu +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset +from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.experts import TEGroupedMLP +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.ssm.mamba_block import MambaStack +from megatron.core.ssm.mamba_layer import MambaLayer + +from chatlearn.configs import PolicyConfig + +from .mapping_helpers import ( + VLLM_HELPERS, + HF_HELPERS +) +from .metadata import ( + SelfAttnKeyMapping, + MLPKeyMapping, + DecoderLayerKeyMapping, + LanguageModelKeyMapping, + MoELayerKeyMapping, + MLASelfAttnKeyMapping +) +from .base_megatron_mapper import BaseMegatronMapper + +if TYPE_CHECKING: + from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding + from megatron.core.transformer.transformer_layer import TransformerLayer + from megatron.core.tensor_parallel import ColumnParallelLinear + from megatron.core.transformer.mlp import MLP + from megatron.core.transformer.multi_latent_attention import MLASelfAttention + from megatron.core.transformer.attention import SelfAttention + from chatlearn.models.megatron_module import MegatronModule + +class MegatronLLMMapper(BaseMegatronMapper): + """MegatronLLMMapper""" + def __init__( + self, + dst_model_config: PolicyConfig, + model: 'MegatronModule', + *, + mapper_config: Union[VLLM_HELPERS, HF_HELPERS] = VLLM_HELPERS, + ): + """The Mapper for Megatron LLM sync. + + Args: + dst_model_config (PolicyConfig): The config of target model to + be sychronized + model (MegatronModule): The source Megatron Module + mapper_config (Union[VLLM_HELPERS, HF_HELPERS]): The mapping mode. + """ + super().__init__(dst_model_config=dst_model_config, model=model, mapper_config=mapper_config) + + # NOTE: the following function implements the module-wise sync mapping + def _build_layer_index_mapping(self, decoder, vp_stage): + """ + Map the local layer index (ranged from 0 to model.decoder.num_layers_per_pipeline_rank) + to the global huggingface layer index + """ + if isinstance(decoder, TransformerBlock): + layer_offset = get_transformer_layer_offset(decoder.config, vp_stage=vp_stage) + num_layers_per_pipeline_rank = decoder.num_layers_per_pipeline_rank + return {n: layer_offset + n for n in range(num_layers_per_pipeline_rank)} + elif isinstance(decoder, MambaStack): + assert vp_stage == 0, "Mamba do not support VPP" + # NOTE: currently we assume MambaLayer just replaces some of Attention + # layout should be: ((mamba | attn) mlp) x n + num_layers_per_pipeline_rank = decoder.num_layers_per_pipeline_rank + layer_offset = num_layers_per_pipeline_rank * mpu.get_pipeline_model_parallel_rank() + return {n: (n + layer_offset) // 2 for n in range(num_layers_per_pipeline_rank)} + else: + raise ValueError(f"Unexpected decoder type: {type(decoder)}") + + def _map_model(self): + """Mapping the local name of src model to global name of + dst model + """ + cfg = LanguageModelKeyMapping( + decoder_layer_cfg=DecoderLayerKeyMapping( + self_attn_cfg=SelfAttnKeyMapping(use_merged_qkv=self._mapper_config.merge_qkv), + mlp_cfg=MLPKeyMapping(use_merged_gate_up=self._mapper_config.merge_gate_up) + ) + ) + for vp_stage, model in enumerate(self.model): + if getattr(model, 'mtp_process', False): + raise NotImplementedError("Currently, the mapper does not support MTP") + + self._map_llm_model( + model, + cfg, + index_mapping=self._build_layer_index_mapping( + model.decoder, + vp_stage + ), + src_prefix=f"{vp_stage}-", + dst_prefix="" + ) + + mapping = self._mapping + self._mapping = None + return mapping + + def _map_llm_model( + self, + model: nn.Module, + cfg: LanguageModelKeyMapping, + index_mapping: Dict[int, int], + src_prefix: str='', + dst_prefix: str='' + ): + if model.pre_process: + self._map_preprocess_layer( + model.embedding, + src_prefix=f"{src_prefix}embedding.", + dst_prefix=f"{dst_prefix}{cfg.word_embeddings}", + ) + + for layer_idx in range(model.decoder.num_layers_per_pipeline_rank): + global_layer_id = index_mapping[layer_idx] + if isinstance(model.decoder.layers[layer_idx], MambaLayer): + self._map_mamba_layer( + model.decoder.layers[layer_idx], + src_prefix=f"{src_prefix}decoder.layers.{layer_idx}.", + dst_prefix=f"{dst_prefix}{cfg.decoder_layer}{global_layer_id}.", + ) + else: + self._map_transformer_layer( + model.decoder.layers[layer_idx], + cfg=cfg.decoder_layer_cfg, + src_prefix=f"{src_prefix}decoder.layers.{layer_idx}.", + dst_prefix=f"{dst_prefix}{cfg.decoder_layer}{global_layer_id}.", + ) + + if model.post_process: + if isinstance(model.decoder, MambaStack): + self._map_norm_layer( + model.decoder.final_norm, + src_prefix=f"{src_prefix}decoder.final_norm.", + dst_prefix=f"{dst_prefix}{cfg.final_layernorm}", + ) + else: + self._map_norm_layer( + model.decoder.final_layernorm, + src_prefix=f"{src_prefix}decoder.final_layernorm.", + dst_prefix=f"{dst_prefix}{cfg.final_layernorm}", + ) + + if model.share_embeddings_and_output_weights and model.pre_process: + self._map_postprocess_layer( + model.embedding, + src_prefix=f"{src_prefix}embedding.word_embeddings.", + dst_prefix=f"{dst_prefix}{cfg.output_layer}", + ) + else: + self._map_postprocess_layer( + model.output_layer, + src_prefix=f"{src_prefix}output_layer.", + dst_prefix=f"{dst_prefix}{cfg.output_layer}", + ) + + def _map_norm_layer(self, module: nn.Module, src_prefix: str='', dst_prefix: str='', *, is_norm_layer: bool=True): + """If is_norm_layer is True, try to map on all possible keys, + otherwise only map on `layer_norm_weight` and `layer_norm_bias` + """ + _keynames = { + 'weight': 'weight', + 'bias': 'bias', + 'layer_norm_weight': 'weight', + 'layer_norm_bias': 'bias' + } + possible_keys = ['layer_norm_weight', 'layer_norm_bias'] + if is_norm_layer: + possible_keys += ['weight', 'bias'] + for item in possible_keys: + if getattr(module, item, None) is None or getattr(module, item).numel() == 0: + continue + self._inner_map_for_full_shape( + f"{src_prefix}{item}", + f"{dst_prefix}{_keynames[item]}" + ) + + def _map_transformer_layer(self, module: 'TransformerLayer', cfg: DecoderLayerKeyMapping, src_prefix: str='', dst_prefix: str=''): + submodule_config = module.submodules_config + has_self_attention = submodule_config.self_attention is not IdentityOp + has_mlp = submodule_config.mlp is not IdentityOp + assert has_self_attention or has_mlp, "The TransformerLayer should at least contains one of self_attn or mlp!" + + if has_self_attention: + if module.config.multi_latent_attention: + map_attn_func = self._map_mla_selfattn + norm_layer = module.input_layernorm + norm_src_key = f"{src_prefix}input_layernorm." + is_norm_layer = True + else: + map_attn_func = self._map_selfattn + is_gated_attention = hasattr(module.self_attention, 'linear_qgkv') + if is_gated_attention: + norm_layer = module.self_attention.linear_qgkv + norm_src_key = f"{src_prefix}self_attention.linear_qgkv." + else: + norm_layer = module.self_attention.linear_qkv + norm_src_key = f"{src_prefix}self_attention.linear_qkv." + is_norm_layer = False + map_attn_func( + module.self_attention, + cfg=cfg.self_attn_cfg, + src_prefix=f"{src_prefix}self_attention.", + dst_prefix=f"{dst_prefix}{cfg.self_attn}", + ) + self._map_norm_layer( + norm_layer, + norm_src_key, + dst_prefix=f"{dst_prefix}{cfg.input_layernorm}", + is_norm_layer=is_norm_layer + ) + + if has_mlp: + if isinstance(module.mlp, MoELayer): + map_mlp_func = self._map_moe_layer + norm_layer = module.pre_mlp_layernorm + norm_src_key = f"{src_prefix}pre_mlp_layernorm." + is_norm_layer = True + else: + map_mlp_func = self._map_mlp + norm_layer = module.mlp.linear_fc1 + norm_src_key = f"{src_prefix}mlp.linear_fc1." + is_norm_layer = False + map_mlp_func( + module.mlp, + cfg=cfg.mlp_cfg, + src_prefix=f"{src_prefix}mlp.", + dst_prefix=f"{dst_prefix}{cfg.mlp}", + ) + self._map_norm_layer( + norm_layer, + norm_src_key, + dst_prefix=f"{dst_prefix}{cfg.pre_mlp_layernorm}", + is_norm_layer=is_norm_layer + ) + + def _map_mamba_layer(self, module, src_prefix='', dst_prefix=''): + # NOTE: the API is experimental as MambaLayer is not general enough currently + self._map_norm_layer( + module.mixer.in_proj, + f"{src_prefix}mixer.in_proj.", + dst_prefix=f"{dst_prefix}input_layernorm.", + is_norm_layer=False + ) + self._map_mamba_mixer( + module.mixer, + src_prefix=f"{src_prefix}mixer.", + dst_prefix=f"{dst_prefix}linear_attn.", + ) + + def _map_mamba_mixer(self, module, src_prefix='', dst_prefix=''): + Nk, Nv, Dk, Dv = ( + module.ngroups, + module.nheads, + module.d_state, + module.headdim + ) + + # in_proj + src_layout = [ + ('z', Dv * Nv), + ('v', Dv * Nv), + ('q', Dk * Nk), + ('k', Dk * Nk), + ('b', Nv), + ('a', Nv) + ] + self._inner_map_for_linear_attn( + f"{src_prefix}in_proj.weight", + f"{dst_prefix}in_proj_qkvz.weight", + src_layout=src_layout, + required_layout=['q', 'k', 'v', 'z'], + n_groups=Nk + ) + self._inner_map_for_linear_attn( + f"{src_prefix}in_proj.weight", + f"{dst_prefix}in_proj_ba.weight", + src_layout=src_layout, + required_layout=['b', 'a'], + n_groups=Nk + ) + # conv1d + src_layout = [ + ('conv_v', Dv * Nv), + ('conv_q', Dk * Nk), + ('conv_k', Dk * Nk), + ] + self._inner_map_for_merged_linear( + f"{src_prefix}conv1d.weight", + f"{dst_prefix}conv1d.weight", + src_layout=src_layout, + required_layout=['conv_q', 'conv_k', 'conv_v'] + ) + self._inner_map_for_tensor_parallel( + f"{src_prefix}dt_bias", + f"{dst_prefix}dt_bias", + mapping_type='column' + ) + + self._inner_map_for_tensor_parallel( + f"{src_prefix}A_log", + f"{dst_prefix}A_log", + mapping_type='column' + ) + if module.D is not None: + raise NotImplementedError() + + self._map_norm_layer( + module.norm, + f"{src_prefix}norm.", + dst_prefix=f"{dst_prefix}norm.", + is_norm_layer=True + ) + self._inner_map_for_tensor_parallel( + f"{src_prefix}out_proj.weight", + f"{dst_prefix}out_proj.weight", + mapping_type='row' + ) + + def _map_moe_layer(self, module: 'MoELayer', cfg: MoELayerKeyMapping, src_prefix='', dst_prefix=''): + # pylint: disable=unused-argument + mapping = {} + # router + self._inner_map_for_full_shape(f"{src_prefix}router.weight", f"{dst_prefix}gate.weight") + if module.router.enable_expert_bias: + self._inner_map_for_full_shape(f"{src_prefix}router.expert_bias", f"{dst_prefix}gate.e_score_correction_bias") + + if not module.config.moe_grouped_gemm: + raise NotImplementedError("Parameter Sync w/ MoE SequentialMLP is not supported") + if not isinstance(module.experts, TEGroupedMLP): + raise NotImplementedError("Parameter Sync w/ Legacy GroupedMLP is not supported") + + # experts + self._map_group_mlp( + module.experts, + src_prefix=f"{src_prefix}experts.", + dst_prefix=f"{dst_prefix}experts." + ) + + # shared experts + if module.shared_experts is not None: + if module.shared_experts.use_shared_expert_gate: + self._inner_map_for_full_shape( + f"{src_prefix}shared_experts.gate_weight", + f"{dst_prefix}shared_expert_gate.weight" + ) + # NOTE: if transformer.config have n_shared_experts, mapping to `shared_experts`, otherwise `shared_expert` + # `shared_experts`: DeepSeek-V2, DeepSeek-V3, etc. + # `shared_expert`: Qwen2-MoE, LLaMA-4, etc. + hf_config = AutoConfig.from_pretrained(self._dst_model_config.load, trust_remote_code=self._dst_model_config.trust_remote_code) + shared_expert_key = 'shared_experts' if hasattr(hf_config, 'n_shared_experts') else 'shared_expert' + self._map_mlp( + module.shared_experts, + cfg=MLPKeyMapping(use_merged_gate_up=self._mapper_config.merge_gate_up), + src_prefix=f"{src_prefix}shared_experts.", + dst_prefix=f"{dst_prefix}{shared_expert_key}." + ) + return mapping + + def _map_mlp( + self, + module: 'MLP', + cfg: MLPKeyMapping, + src_prefix: str='', + dst_prefix: str='', + ): + param_types = ['weight'] + if module.config.add_bias_linear: + param_types = ['weight', 'bias'] + + for param_type in param_types: + if not module.config.gated_linear_unit: + self._inner_map_for_tensor_parallel( + f"{src_prefix}linear_fc1.{param_type}", + f"{dst_prefix}{cfg.up_proj}{param_type}", + mapping_type='column' + ) + else: + dst_names = {'gate_proj': cfg.gate_proj, 'up_proj': cfg.up_proj} + if cfg.use_merged_gate_up: + dst_names = {'gate_up_proj': cfg.gate_up_proj} + + for dst_type, dst_name in dst_names.items(): + self._inner_map_for_gate_up_proj( + f"{src_prefix}linear_fc1.{param_type}", + f"{dst_prefix}{dst_name}{param_type}", + proj_type=dst_type + ) + self._inner_map_for_tensor_parallel( + f"{src_prefix}linear_fc2.{param_type}", + f"{dst_prefix}{cfg.down_proj}{param_type}", + mapping_type='row' + ) + + def _map_group_mlp(self, module: 'TEGroupedMLP', src_prefix: str='', dst_prefix: str=''): + # pylint: disable=unused-argument + src_ep_rank = mpu.get_expert_model_parallel_rank() + src_ep_size = mpu.get_expert_model_parallel_world_size() + num_experts = module.config.num_moe_experts + global_expert_id_start = num_experts // src_ep_size * src_ep_rank + global_expert_id_end = num_experts // src_ep_size * (src_ep_rank + 1) + for local_expert_id, global_expert_id in enumerate(range(global_expert_id_start, global_expert_id_end)): + if self._mapper_config.merge_expert: + if not self._mapper_config.merge_gate_up: + raise NotImplementedError("merge_expert w/o merge_gate_up is not implemented.") + self._inner_map_for_gate_up_proj( + f"{src_prefix}linear_fc1.weight{local_expert_id}", + f"{dst_prefix}w13_weight", + proj_type='gate_up_proj', + global_expert_id=global_expert_id, + num_experts=num_experts + ) + self._inner_map_for_tensor_parallel( + f"{src_prefix}linear_fc2.weight{local_expert_id}", + f"{dst_prefix}w2_weight", + global_expert_id=global_expert_id, + num_experts=num_experts, + mapping_type='row' + ) + else: + if self._mapper_config.merge_gate_up: + raise NotImplementedError("no merge_expert w/ merge_gate_up is not implemented.") + for dst_name in ['gate_proj', 'up_proj']: + self._inner_map_for_gate_up_proj( + f"{src_prefix}linear_fc1.weight{local_expert_id}", + f"{dst_prefix}{global_expert_id}.{dst_name}.weight", + proj_type=dst_name, + ) + self._inner_map_for_tensor_parallel( + f"{src_prefix}linear_fc2.weight{local_expert_id}", + f"{dst_prefix}{global_expert_id}.down_proj.weight", + mapping_type='row' + ) + + def _map_mla_selfattn(self, module: 'MLASelfAttention', cfg: MLASelfAttnKeyMapping, src_prefix: str='', dst_prefix: str=''): + # pylint: disable=unused-argument + if module.config.q_lora_rank is None: + self._inner_map_for_tensor_parallel( + f"{src_prefix}linear_q_proj.weight", + f"{dst_prefix}q_proj.weight", + mapping_type='column' + ) + else: + self._inner_map_for_mla_down_proj( + f"{src_prefix}linear_q_down_proj.weight", + f"{dst_prefix}q_a_proj.weight", + ) + self._inner_map_for_tensor_parallel( + f"{src_prefix}linear_q_up_proj.weight", + f"{dst_prefix}q_b_proj.weight", + mapping_type='column' + ) + if module.config.qk_layernorm: + self._map_norm_layer( + module.linear_q_up_proj, + f"{src_prefix}linear_q_up_proj.", + f"{dst_prefix}q_a_layernorm.", + is_norm_layer=False + ) + self._inner_map_for_mla_down_proj( + f"{src_prefix}linear_kv_down_proj.weight", + f"{dst_prefix}kv_a_proj_with_mqa.weight", + ) + self._inner_map_for_tensor_parallel( + f"{src_prefix}linear_kv_up_proj.weight", + f"{dst_prefix}kv_b_proj.weight", + mapping_type='column' + ) + if module.config.qk_layernorm: + self._map_norm_layer( + module.linear_kv_up_proj, + f"{src_prefix}linear_kv_up_proj.", + f"{dst_prefix}kv_a_layernorm.", + is_norm_layer=False + ) + self._inner_map_for_tensor_parallel( + f"{src_prefix}linear_proj.weight", + f"{dst_prefix}o_proj.weight", + mapping_type='row' + ) + + def _map_selfattn( + self, + module: 'SelfAttention', + cfg: SelfAttnKeyMapping, + src_prefix: str='', + dst_prefix: str='' + ): + if module.config.qk_layernorm: + self._map_norm_layer(module.q_layernorm, f"{src_prefix}q_layernorm.", f"{dst_prefix}{cfg.q_layernorm}") + self._map_norm_layer(module.k_layernorm, f"{src_prefix}k_layernorm.", f"{dst_prefix}{cfg.k_layernorm}") + + qkv_dst_names = {'qkv_proj': cfg.qkv_proj} + if not cfg.use_merged_qkv: + qkv_dst_names = {'q_proj': cfg.q_proj, 'k_proj': cfg.k_proj, 'v_proj': cfg.v_proj} + + param_types = ['weight'] + if module.config.add_qkv_bias: + param_types = ['weight', 'bias'] + + # TODO: make better condition + is_gated_attention = hasattr(module, 'linear_qgkv') + for param_type in param_types: + for dst_type, dst_name in qkv_dst_names.items(): + if is_gated_attention: + src_key = f"{src_prefix}linear_qgkv.{param_type}" + else: + src_key = f"{src_prefix}linear_qkv.{param_type}" + self._inner_map_for_qkv_proj( + src_key, + f"{dst_prefix}{dst_name}{param_type}", + proj_type=dst_type, + num_attention_heads = module.config.num_attention_heads, + num_query_groups = module.config.num_query_groups, + is_gated_attention = is_gated_attention + ) + + param_types = ['weight'] + if module.config.add_bias_linear: + param_types = ['weight', 'bias'] + for param_type in param_types: + self._inner_map_for_tensor_parallel( + f"{src_prefix}linear_proj.{param_type}", + f"{dst_prefix}{cfg.out_proj}{param_type}", + mapping_type='row' + ) + + def _map_preprocess_layer(self, module: 'LanguageModelEmbedding', src_prefix='', dst_prefix=''): + if module.add_position_embedding: + raise NotImplementedError("learned_absolute embedding is not supported") + self._inner_map_for_tensor_parallel( + f"{src_prefix}word_embeddings.weight", + f"{dst_prefix}embed_tokens.weight", + mapping_type='column' + ) + + def _map_postprocess_layer(self, module: 'ColumnParallelLinear', src_prefix='', dst_prefix=''): + # pylint: disable=unused-argument + if ( + not self._src_model_config.megatron_model_cfg.untie_embeddings_and_output_weights and + f"{dst_prefix}lm_head.weight" not in self._dst_name_to_metadata + ): + return + self._inner_map_for_tensor_parallel( + f"{src_prefix}weight", + f"{dst_prefix}lm_head.weight", + mapping_type='column' + ) diff --git a/chatlearn/synchronizer/mappers/megatron_vlm_mapper.py b/chatlearn/synchronizer/mappers/megatron_vlm_mapper.py new file mode 100644 index 00000000..44457c96 --- /dev/null +++ b/chatlearn/synchronizer/mappers/megatron_vlm_mapper.py @@ -0,0 +1,143 @@ +# Copyright 2025 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Mapper for Megatron to vLLM""" +from typing import TYPE_CHECKING, Union + +from torch import nn + +from chatlearn.configs import PolicyConfig + +from .mapping_helpers import ( + VLLM_HELPERS, + HF_HELPERS +) +from .metadata import ( + SelfAttnKeyMapping, + MLPKeyMapping, + DecoderLayerKeyMapping, + LanguageModelKeyMapping +) + +from .megatron_llm_mapper import MegatronLLMMapper + +if TYPE_CHECKING: + from chatlearn.models.megatron_module import MegatronModule + +class MegatronVLMMapper(MegatronLLMMapper): + """MegatronVLMMapper""" + def __init__( + self, + dst_model_config: PolicyConfig, + model: 'MegatronModule', + *, + mapper_config: Union[VLLM_HELPERS, HF_HELPERS] = VLLM_HELPERS, + ): + """The Mapper for Megatron VLM sync. + + Args: + dst_model_config (PolicyConfig): The config of target model to + be sychronized + model (MegatronModule): The source Megatron Module + mapper_config (Union[VLLM_HELPERS, HF_HELPERS]): The mapping mode. + """ + super().__init__(dst_model_config=dst_model_config, model=model, mapper_config=mapper_config) + + # NOTE: the following function implements the module-wise sync mapping + def _map_model(self): + """Mapping the local name of src model to global name of + dst model + """ + # TODO: clean this config object + cfg = LanguageModelKeyMapping( + word_embeddings=self._mapper_config.dst_language_prefix, + decoder_layer=f"{self._mapper_config.dst_language_prefix}layers.", + decoder_layer_cfg=DecoderLayerKeyMapping( + self_attn_cfg=SelfAttnKeyMapping(use_merged_qkv=self._mapper_config.merge_qkv), + mlp_cfg=MLPKeyMapping(use_merged_gate_up=self._mapper_config.merge_gate_up) + ), + final_layernorm=f"{self._mapper_config.dst_language_prefix}norm.", + output_layer=self._mapper_config.dst_lm_head_prefix + ) + + for vp_stage, model in enumerate(self.model): + if getattr(model, 'mtp_process', False): + raise NotImplementedError("Currently, the mapper does not support MTP") + + if hasattr(model, 'vision_model'): + # assert layer_offset == 0 + self._map_vision_model( + model.vision_model, + src_prefix=f"{vp_stage}-vision_model.", + dst_prefix=self._mapper_config.dst_vision_prefix + ) + + # llm model + self._map_llm_model( + model.language_model, + cfg=cfg, + index_mapping=self._build_layer_index_mapping( + model.language_model.decoder, + vp_stage + ), + src_prefix=f"{vp_stage}-language_model.", + dst_prefix="" + ) + + mapping = self._mapping + self._mapping = None + return mapping + + def _map_vision_model(self, + model: nn.Module, + src_prefix: str = '', + dst_prefix: str = '' + ): + self._inner_map_for_full_shape( + f"{src_prefix}patch_embed.proj.weight", + f"{dst_prefix}patch_embed.proj.weight" + ) + + # vision model decoder + decoder_layer_cfg = DecoderLayerKeyMapping( + input_layernorm='norm1.', + self_attn='attn.', + self_attn_cfg=SelfAttnKeyMapping( + qkv_proj='qkv.', + out_proj='proj.', + use_merged_qkv=True + ), + pre_mlp_layernorm='norm2.' + ) + for layer_idx in range(model.config.num_layers): + self._map_transformer_layer( + model.decoder.layers[layer_idx], + decoder_layer_cfg, + src_prefix=f"{src_prefix}decoder.layers.{layer_idx}.", + dst_prefix=f"{dst_prefix}blocks.{layer_idx}.", + ) + + # vision model projection + self._inner_map_for_full_shape( + f"{src_prefix}decoder.final_layernorm.weight", + f"{dst_prefix}merger.ln_q.weight" + ) + mlp_cfg = MLPKeyMapping(up_proj='0.', down_proj='2.') + self._map_mlp( + model.projection.encoder, + mlp_cfg, + src_prefix=f"{src_prefix}projection.encoder.", + dst_prefix=f"{dst_prefix}merger.mlp." + ) diff --git a/chatlearn/synchronizer/mappers/metadata.py b/chatlearn/synchronizer/mappers/metadata.py new file mode 100644 index 00000000..a205b8dd --- /dev/null +++ b/chatlearn/synchronizer/mappers/metadata.py @@ -0,0 +1,52 @@ +# pylint: disable=missing-module-docstring, missing-class-docstring +from typing import Union +from dataclasses import dataclass, field + + +@dataclass +class SelfAttnKeyMapping: + q_layernorm: str = 'q_norm.' + k_layernorm: str = 'k_norm.' + qkv_proj: str = 'qkv_proj.' + q_proj: str = 'q_proj.' + k_proj: str = 'k_proj.' + v_proj: str = 'v_proj.' + out_proj: str = 'o_proj.' + use_merged_qkv: bool = False + +@dataclass +class MLASelfAttnKeyMapping: + # NOTE: currently not used + pass + +@dataclass +class MLPKeyMapping: + gate_proj: str = 'gate_proj.' + up_proj: str = 'up_proj.' + down_proj: str = 'down_proj.' + gate_up_proj: str = 'gate_up_proj.' + use_merged_gate_up: bool = False + +@dataclass +class MoELayerKeyMapping: + # NOTE: currently not used + pass + + +@dataclass +class DecoderLayerKeyMapping: + input_layernorm: str = 'input_layernorm.' # if is_vision_block, norm1. + self_attn: str = 'self_attn.' + self_attn_cfg: Union[SelfAttnKeyMapping, MLASelfAttnKeyMapping] = field(default=SelfAttnKeyMapping) + pre_mlp_layernorm: str = 'post_attention_layernorm.' # if is_vision_block, norm2. + mlp: str = 'mlp.' + mlp_cfg: Union[MLPKeyMapping, MoELayerKeyMapping] = field(default=MLPKeyMapping) + + +@dataclass +class LanguageModelKeyMapping: + word_embeddings: str = 'model.' + decoder_layer: str = 'model.layers.' + decoder_layer_cfg: DecoderLayerKeyMapping = field(default=DecoderLayerKeyMapping) + final_layernorm: str = 'model.norm.' + output_layer: str = '' diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index 65ea0a75..90d2c59c 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -153,7 +153,7 @@ def generate_global_param_ids(self, model: 'DistModel') -> Dict[str, int]: param_names = set() for res in results: param_names.update(res) - return {name: idx for idx, name in enumerate(param_names)} + return {name: idx for idx, name in enumerate(sorted(param_names))} def validate_sync_mapping( self, diff --git a/chatlearn/synchronizer/planners/tensor_planner.py b/chatlearn/synchronizer/planners/tensor_planner.py index a9fd879e..8ad12c70 100644 --- a/chatlearn/synchronizer/planners/tensor_planner.py +++ b/chatlearn/synchronizer/planners/tensor_planner.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Sync parameters""" +import random from copy import deepcopy from collections import defaultdict from typing import Dict, List, Tuple, TYPE_CHECKING @@ -83,17 +84,17 @@ def build_iteration( ) -> List[Dict[int, List[SyncIteration]]]: """Build iterations from unbucketized plan according to the given memory constraints. - + Args: - unbucketized_plan (Dict[int, Dict[Ranks, List[ShardedTensorInfo]]]): + unbucketized_plan (Dict[int, Dict[Ranks, List[ShardedTensorInfo]]]): The unbucketized comm plan. - src_rank_to_gpu_id (Dict[int, int]): map ranks of source model to + src_rank_to_gpu_id (Dict[int, int]): map ranks of source model to physical GPU ID. - dst_rank_to_gpu_id (Dict[int, int]): map ranks of destination model + dst_rank_to_gpu_id (Dict[int, int]): map ranks of destination model to physical GPU ID. - mem_infos (Dict[int, Tuple[int, int]]): The used memory and + mem_infos (Dict[int, Tuple[int, int]]): The used memory and total memory for each physical GPU. - max_memory_fraction (float, optional): The maximum ratio of planner + max_memory_fraction (float, optional): The maximum ratio of planner could use. Defaults to 0.8. Returns: @@ -115,6 +116,9 @@ def build_iteration( continue is_added.add(dst_param.param_id) dst_param_id_to_src_params[dst_param.param_id].append(src_param) + t = list(dst_param_id_to_src_params.keys()) + random.shuffle(t) + dst_param_id_to_src_params = {k: dst_param_id_to_src_params[k] for k in t} src_shard_to_sender = {} for sender, plan_per_rank in unbucketized_plan.items(): diff --git a/chatlearn/utils/mappings/megatron_helpers.py b/chatlearn/utils/mappings/megatron_helpers.py index c81b3547..e9dd8758 100644 --- a/chatlearn/utils/mappings/megatron_helpers.py +++ b/chatlearn/utils/mappings/megatron_helpers.py @@ -36,6 +36,8 @@ VocabParallelEmbedding, ColumnParallelLinear ) + from megatron.core.transformer.moe.shared_experts import SharedExpertMLP + from megatron.core.ssm.mamba_mixer import MambaMixer HAVE_MEGATRON = True except ImportError: HAVE_MEGATRON = False @@ -181,6 +183,47 @@ def _prepare_metadata(prefix: str, module: nn.Module): results['weight'] = ShardedTensorInfo.from_global_shape( tuple(module.weight.shape), dtype=module.weight.dtype ) + elif isinstance(module, MambaMixer): + tp_rank = mpu.get_tensor_model_parallel_rank() + tp_size = mpu.get_tensor_model_parallel_world_size() + L = module.dt_bias.shape[0] + results['dt_bias'] = ShardedTensorInfo( + dtype=module.dt_bias.dtype, + global_shape=(L * tp_size, ), + axis_fragmentations=(tp_size, ), + global_offset=(tp_rank,) + ) + results['A_log'] = ShardedTensorInfo( + dtype=module.A_log.dtype, + global_shape=(L * tp_size, ), + axis_fragmentations=(tp_size, ), + global_offset=(tp_rank,) + ) + if module.D is not None: + raise NotImplementedError() + if module.rmsnorm: + results['norm.weight'] = ShardedTensorInfo( + dtype=module.norm.weight.dtype, + global_shape=(module.norm.weight.shape[0] * tp_size, ), + axis_fragmentations=(tp_size, ), + global_offset=(tp_rank,) + ) + conv_dim, _, d_conv = module.conv1d.weight.shape + results['conv1d.weight'] = ShardedTensorInfo( + dtype=module.conv1d.weight.dtype, + global_shape=(conv_dim * tp_size, 1, d_conv), + axis_fragmentations=(tp_size, 1, 1), + global_offset=(tp_rank, 0, 0) + ) + elif isinstance(module, SharedExpertMLP): + if module.use_shared_expert_gate: + results['gate_weight'] = ShardedTensorInfo( + dtype=module.gate_weight.dtype, + global_shape=(1, module.gate_weight.shape[1]), + axis_fragmentations=(1, 1), + global_offset=(0, 0) + ) + return results def build_sharded_info_for_mcore_model( diff --git a/chatlearn/utils/mappings/sharded_tensor_info.py b/chatlearn/utils/mappings/sharded_tensor_info.py index ddf0fd93..6d19cc1e 100644 --- a/chatlearn/utils/mappings/sharded_tensor_info.py +++ b/chatlearn/utils/mappings/sharded_tensor_info.py @@ -155,12 +155,12 @@ def unsqueeze(self, offset:int, length: int, axis: int=0) -> 'ShardedTensorInfo' def index(self, tensor: torch.Tensor) -> torch.Tensor: """Indexing tensor with this ShardedTensorInfo. - will check the shape-related information and ignore + will check the shape-related information and ignore inconsistent datatype. - + Args: tensor (torch.Tensor): tensor to be indexed. - + """ tensor_shape = tensor.shape @@ -248,7 +248,7 @@ def concat(shards: List['ShardedTensorInfo'], axis: int=0) -> Optional['ShardedT Returns: - ShardedTensorInfo: The concatenated shard. If the input list is empty, + ShardedTensorInfo: The concatenated shard. If the input list is empty, returns None. """ if len(shards) == 0: @@ -329,3 +329,35 @@ def __contains__(self, other: 'ShardedTensorInfo'): if si > oi or si + sj < oi + oj: return False return True + + def chunk(self, sections: List[int], axis: int=0) -> List['ShardedTensorInfo']: + """ + Chunk the sharded info on the given axis. + + Args: + sections (List[int]): a list of length for chunking, the total length of this + list should be equal to local_shape on the given axis. + axis (int, optional): The axis to be chunked. Defaults to 0. + """ + local_size = self.local_shape[axis] + assert local_size == sum(sections), f"Failed to chunk {self} on axis {axis}, given sections {sections}" + offset = self.local_offset[axis] + + chunks = [] + for section in sections: + result = self.copy() + result.local_shape = result.local_shape[:axis] + (section, ) + result.local_shape[axis + 1:] + result.local_offset = result.local_offset[:axis] + (offset, ) + result.local_offset[axis + 1:] + offset += section + chunks.append(result) + return chunks + + @property + def offset(self): + """Return the offset of this shard in the global tensor""" + return tuple(l + g * s // a for l, g, s, a in zip( + self.local_offset, + self.global_offset, + self.global_shape, + self.axis_fragmentations + )) diff --git a/chatlearn/utils/megatron_utils.py b/chatlearn/utils/megatron_utils.py index c465d08d..6a72b5ad 100644 --- a/chatlearn/utils/megatron_utils.py +++ b/chatlearn/utils/megatron_utils.py @@ -21,6 +21,9 @@ def update_cfg(cfg): hf_transformer_config = AutoConfig.from_pretrained(cfg.models.policy.load) + if hf_transformer_config.architectures[0] == "Qwen3NextForCausalLM": + return update_qwen3_next_cfg(cfg, hf_transformer_config) + # common cfgs cfg.models.policy_trainer.megatron_model_cfg.attention_dropout = hf_transformer_config.attention_dropout cfg.models.policy_trainer.megatron_model_cfg.num_layers = hf_transformer_config.num_hidden_layers @@ -107,3 +110,55 @@ def update_cfg(cfg): cfg.models.ref_policy.megatron_model_cfg = cfg.models.policy_trainer.megatron_model_cfg return cfg + +def update_qwen3_next_cfg(cfg, hf_transformer_config): + cfg.models.policy_trainer.megatron_model_cfg.attention_dropout = hf_transformer_config.attention_dropout + cfg.models.policy_trainer.megatron_model_cfg.num_layers = hf_transformer_config.num_hidden_layers * 2 + + full_attention_interval = hf_transformer_config.full_attention_interval + hybrid_pattern = ['*-' if (i + 1) % full_attention_interval == 0 else 'M-' for i in range(hf_transformer_config.num_hidden_layers)] + cfg.models.policy_trainer.megatron_model_cfg.hybrid_override_pattern = ''.join(hybrid_pattern) + + cfg.models.policy_trainer.megatron_model_cfg.is_hybrid_model = True + + cfg.models.policy_trainer.megatron_model_cfg.hidden_size = hf_transformer_config.hidden_size + cfg.models.policy_trainer.megatron_model_cfg.num_attention_heads = hf_transformer_config.num_attention_heads + cfg.models.policy_trainer.megatron_model_cfg.ffn_hidden_size = hf_transformer_config.intermediate_size + cfg.models.policy_trainer.megatron_model_cfg.max_position_embeddings = hf_transformer_config.max_position_embeddings + cfg.models.policy_trainer.megatron_model_cfg.add_bias_linear = False + cfg.models.policy_trainer.megatron_model_cfg.rotary_base = hf_transformer_config.rope_theta + cfg.models.policy_trainer.megatron_model_cfg.rotary_percent = hf_transformer_config.partial_rotary_factor + cfg.models.policy_trainer.megatron_model_cfg.norm_epsilon = hf_transformer_config.rms_norm_eps + cfg.models.policy_trainer.megatron_model_cfg.untie_embeddings_and_output_weights = not hf_transformer_config.tie_word_embeddings + cfg.models.policy_trainer.megatron_model_cfg.vocab_size = hf_transformer_config.vocab_size + cfg.models.policy_trainer.megatron_model_cfg.qk_layernorm = True + + cfg.models.policy_trainer.megatron_model_cfg.kv_channels = hf_transformer_config.head_dim + cfg.models.policy_trainer.megatron_model_cfg.add_qkv_bias = False + + cfg.models.policy_trainer.megatron_model_cfg.moe_shared_expert_intermediate_size = hf_transformer_config.shared_expert_intermediate_size + + cfg.models.policy_trainer.megatron_model_cfg.group_query_attention = True + cfg.models.policy_trainer.megatron_model_cfg.num_query_groups = hf_transformer_config.num_key_value_heads + + cfg.models.policy_trainer.megatron_model_cfg.moe_grouped_gemm = True + cfg.models.policy_trainer.megatron_model_cfg.moe_token_dispatcher_type = "alltoall" + cfg.models.policy_trainer.megatron_model_cfg.moe_router_topk = hf_transformer_config.num_experts_per_tok + cfg.models.policy_trainer.megatron_model_cfg.moe_ffn_hidden_size = hf_transformer_config.moe_intermediate_size + cfg.models.policy_trainer.megatron_model_cfg.moe_router_dtype= 'fp64' + cfg.models.policy_trainer.megatron_model_cfg.num_experts = hf_transformer_config.num_experts + cfg.models.policy_trainer.megatron_model_cfg.apply_layernorm_1p = True + + cfg.models.policy_trainer.megatron_model_cfg.moe_router_load_balancing_type = "none" + cfg.models.policy_trainer.megatron_model_cfg.moe_aux_loss_coeff = 0 + cfg.models.policy_trainer.megatron_model_cfg.moe_permute_fusion = True + cfg.models.policy_trainer.megatron_model_cfg.moe_router_fusion = False + cfg.models.policy_trainer.megatron_model_cfg.cross_entropy_loss_fusion = True + cfg.models.policy_trainer.megatron_model_cfg.moe_shared_expert_overlap = False + cfg.models.policy_trainer.megatron_model_cfg.gradient_accumulation_fusion = True + cfg.models.policy_trainer.megatron_model_cfg.gradient_accumulation_fusion = True + cfg.models.policy_trainer.megatron_model_cfg.async_tensor_model_parallel_allreduce = True + cfg.models.policy_trainer.distributed_timeout_minutes = 60 + + cfg.models.ref_policy.megatron_model_cfg = cfg.models.policy_trainer.megatron_model_cfg + return cfg diff --git a/chatlearn/utils/utils.py b/chatlearn/utils/utils.py index 1a22152c..7cdacccd 100644 --- a/chatlearn/utils/utils.py +++ b/chatlearn/utils/utils.py @@ -326,7 +326,7 @@ def even_slice(total_sample:int, total_slice:int): slice_index.append(total_sample) return slice_index -def slice_data_list_by_index(batched_input: List[Dict[str, Any]], index): +def slice_data_list_by_index(batched_input: List[Any], index): """ Slice input data_list by slice index """ diff --git a/docs/images/qwen3_next.jpg b/docs/images/qwen3_next.jpg new file mode 100644 index 00000000..73320551 Binary files /dev/null and b/docs/images/qwen3_next.jpg differ diff --git a/docs/zh/tutorial/tutorial_grpo_mcore_qwen3_next.md b/docs/zh/tutorial/tutorial_grpo_mcore_qwen3_next.md new file mode 100644 index 00000000..ba1add59 --- /dev/null +++ b/docs/zh/tutorial/tutorial_grpo_mcore_qwen3_next.md @@ -0,0 +1,111 @@ +# 基于 Mcore 的端到端GRPO训练流程 + +本文档提供使用 ChatLearn、Mcore 和 SGLANG 框架来对Qwen3-next进行GRPO训练的快速开始指南。 + +## 开发环境配置 +建议在PAI平台DSW环境中基于nvcr.io/nvidia/pytorch:24.12-py3来构建镜像。 +```bash + +#安装SGLANG,注意这将移除NGC自带的Pytorch,而自动重新安装pytorch==2.8.0 +pip install --no-cache-dir "sglang[all]==0.5.2" -i https://mirrors.aliyun.com/pypi/simple/ + +#添加SGLANG PATCH。 从https://gist.github.com/lostkevin/9b668c24de6f0e9974c9ad069ef03ed9下载修改后的memory_pool.py文件 +cd /usr/local/lib/python3.12/dist-packages/sglang/srt/mem_cache/ +mv memory_pool.py memory_pool.py.bak +cp memory_pool.py /usr/local/lib/python3.12/dist-packages/sglang/srt/mem_cache/ + + +#安装Chatlearn的依赖包 +pip install transformers==4.57.1 modelscope==1.30.0 tensordict==0.10.0 torchdata==0.11.0 codetiming==1.4.0 blobfile==3.0.0 numpy==1.26.4 accelerate==1.10.0 wandb==0.19.11 datasets==3.6.0 grpcio==1.71.0 omegaconf==2.3.0 hydra-core==1.3.2 msgspec==0.19.0 mathruler==0.1.0 pylatexenc==2.10 langgraph==0.6.6 ray[default]==2.46.0 -i https://mirrors.aliyun.com/pypi/simple/ + +#由于安装VLLM会重新安装pytorch,因此需要重新安装flash attention以及apex +pip uninstall -y flash_attn && pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/flash-attention/torch2.6.0-cu12x/flash_attn-2.4.2-cp312-cp312-linux_x86_64.whl --no-cache-dir -i https://mirrors.aliyun.com/pypi/simple/ + +pip uninstall -y apex && pip install https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/csrc/apex/torch2.6.0-cuda12x/apex-0.1-cp312-cp312-linux_x86_64.whl --no-cache-dir -i https://mirrors.aliyun.com/pypi/simple/ + + +#升级Transformer Engine +pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch +git clone --recursive https://github.com/NVIDIA/TransformerEngine.git +cd TransformerEngine +git submodule update --init --recursive +git checkout release_v2.7 +export CUDNN_PATH=/usr/local/lib/python3.12/dist-packages/nvidia/cudnn/ +cp /usr/local/lib/python3.12/dist-packages/nvidia/cudnn/include/* /usr/local/cuda/include/ +python setup.py bdist_wheel -vvv +cd dist +export NVTE_FRAMEWORK=pytorch +pip install transformer_engine-2.7.0-cp312-cp312-linux_x86_64.whl --no-cache-dir -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.cloud.aliyuncs.com + +#安装mamba-ssm依赖 +pip install --no-build-isolation "mamba-ssm" -i https://mirrors.aliyun.com/pypi/simple/ + +#安装causal-conv1d依赖 +git clone https://github.com/Dao-AILab/causal-conv1d.git +cd causal-conv1d +git checkout v1.5.2 +export CAUSAL_CONV1D_FORCE_BUILD=TRUE +python setup.py bdist_wheel -vvv +cd dist +export NVTE_FRAMEWORK=pytorch +pip install causal_conv1d-1.5.2-cp312-cp312-linux_x86_64.whl --no-cache-dir --no-build-isolation -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.cloud.aliyuncs.com + +# 安装flash-linear-attention +pip install --no-build-isolation flash-linear-attention -i https://mirrors.aliyun.com/pypi/simple/ + +``` +## 代码准备 + +```bash +git clone https://github.com/alibaba/ChatLearn.git +git clone --recurse-submodules https://github.com/alibaba/Pai-Megatron-Patch.git +``` + +## 数据&模型准备 +以[MATH-lighteval](https://www.modelscope.cn/datasets/AI-ModelScope/MATH-lighteval)数据集作为示例. +```bash +cd ChatLearn +# 下载数据集 +mkdir -p dataset +modelscope download --dataset AI-ModelScope/MATH-lighteval --local_dir dataset/MATH-lighteval +# preprocess dataset +python chatlearn/data/data_preprocess/math_lighteval.py --input_dir dataset/MATH-lighteval --local_dir dataset/MATH-lighteval +# download model weight +modelscope download --model Qwen/Qwen3-Next-80B-A3B-Instruct --local_dir Qwen3-Next-80B-A3B-Instruct + +``` + +## 模型转换 +使用下述脚本将Moonlight和DeepSeek-V3的Huggingface格式的模型转换到MCore格式 +```bash +CHATLEARN_ROOT=$(pwd) +cd ../Pai-Megatron-Patch/toolkits/distributed_checkpoints_convertor +bash scripts/qwen3_next/run_8xH20.sh \ +A3B \ +${CHATLEARN_ROOT}/pretrained_models/Qwen3-Next-80B-A3B-Instruct \ +${CHATLEARN_ROOT}/pretrained_models/Qwen3-Next-80B-A3B-Instruct-to-mcore \ +false \ +true \ +bf16 + +``` + +## Qwen3-Next强化学习训练以及训练稳定性指引 +运行以下命令可以对Qwen3-Next进行GRPO训练: + +```bash +cd ${CHATLEARN_ROOT} +bash scripts/mcore_sglang/train_mcore_sglang_qwen3_next_grpo.sh +``` + +在解决了一些训练不稳定的问题后,验证集升的评估指标仍然有提升。 +
+
+