|
| 1 | +# ruff: noqa: E402 |
| 2 | +""" |
| 3 | +This module converts |
| 4 | +
|
| 5 | +Command: |
| 6 | +torchrun --nproc_per_node=1 convert_trfrs_to_nanotron.py \ |
| 7 | + --model_name hf-internal-testing/Mixtral-tiny \ |
| 8 | + --save_path ./pretrained/mixtral |
| 9 | +""" |
| 10 | +import argparse |
| 11 | +import sys |
| 12 | +from dataclasses import asdict |
| 13 | +from pathlib import Path |
| 14 | +from typing import Dict, List |
| 15 | + |
| 16 | +import torch |
| 17 | +from config_mixtral import MixtralConfig |
| 18 | + |
| 19 | +sys.path.append(Path(__file__).parent.parent.as_posix()) |
| 20 | + |
| 21 | +import json |
| 22 | + |
| 23 | +import nanotron.distributed as dist |
| 24 | +from config_mixtral_tiny import CONFIG as CONFIG_NANOTRON |
| 25 | +from config_mixtral_tiny import PARALLELISM as PARALLELISM_NANOTRON |
| 26 | +from mixtral import MixtralForTraining |
| 27 | +from nanotron.models import build_model |
| 28 | +from nanotron.parallel.context import ParallelContext |
| 29 | +from nanotron.parallel.parameters import NanotronParameter, sanity_check |
| 30 | +from nanotron.serialize import save |
| 31 | +from nanotron.trainer import mark_tied_parameters |
| 32 | +from transformers import MixtralForCausalLM |
| 33 | + |
| 34 | + |
| 35 | +def get_args(): |
| 36 | + parser = argparse.ArgumentParser(description="Convert transformers weights to nanotron weights") |
| 37 | + parser.add_argument("--model_name", type=str, default="hf-internal-testing/Mixtral-tiny") |
| 38 | + parser.add_argument("--save_path", type=str, default="pretrained/Mixtral-7B-v0.1") |
| 39 | + parser.add_argument("--dp", type=int, default=1) |
| 40 | + parser.add_argument("--pp", type=int, default=1) |
| 41 | + parser.add_argument("--tp", type=int, default=1) |
| 42 | + return parser.parse_args() |
| 43 | + |
| 44 | + |
| 45 | +def permute_for_rotary(tensor, num_heads, per_head_hidden_size, hidden_size): |
| 46 | + return ( |
| 47 | + tensor.view(num_heads, 2, per_head_hidden_size // 2, hidden_size) |
| 48 | + .transpose(1, 2) |
| 49 | + .contiguous() |
| 50 | + .view(num_heads * per_head_hidden_size, hidden_size) |
| 51 | + ) |
| 52 | + |
| 53 | + |
| 54 | +def get_transformers_weight( |
| 55 | + name: str, ref_module_state_dict: Dict[str, torch.Tensor], ref_module: MixtralForCausalLM, get_grad: bool = False |
| 56 | +) -> torch.Tensor: |
| 57 | + """From our nanotron implementation, we get the equivalent tensor in transformers implementation""" |
| 58 | + config = ref_module.config |
| 59 | + nanotron_prefix = "model." |
| 60 | + assert name.startswith(nanotron_prefix) |
| 61 | + name = name[len(nanotron_prefix) :] |
| 62 | + |
| 63 | + path = name.split(".") |
| 64 | + path.remove("pp_block") |
| 65 | + name = ".".join(path) |
| 66 | + |
| 67 | + if get_grad is False: |
| 68 | + |
| 69 | + def get_tensor(path: str): |
| 70 | + return ref_module_state_dict[path] |
| 71 | + |
| 72 | + def get_tensors(path: List[str]): |
| 73 | + return [get_tensor(p) for p in path] |
| 74 | + |
| 75 | + else: |
| 76 | + |
| 77 | + def get_tensor(path: str): |
| 78 | + weight = ref_module.get_parameter(path) |
| 79 | + return weight.grad |
| 80 | + |
| 81 | + def get_tensors(path: List[str]): |
| 82 | + return [get_tensor(p) for p in path] |
| 83 | + |
| 84 | + if name == "token_position_embeddings.token_embedding.weight": |
| 85 | + return get_tensor("model.embed_tokens.weight") |
| 86 | + |
| 87 | + elif name == "lm_head.weight": |
| 88 | + # This only used when weights are not shared |
| 89 | + return get_tensor("lm_head.weight") |
| 90 | + |
| 91 | + elif name == "final_layer_norm.weight": |
| 92 | + return get_tensor("model.norm.weight") |
| 93 | + |
| 94 | + if path[0] == "decoder": |
| 95 | + transformer_path = ["model"] + ["layers"] + [path[1]] |
| 96 | + |
| 97 | + if path[2] == "attn": |
| 98 | + path[2] = "self_attn" |
| 99 | + |
| 100 | + if path[2] == "block_sparse_moe": |
| 101 | + if path[3] == "gate": |
| 102 | + return get_tensor(".".join(transformer_path + path[2:4] + path[5:])) |
| 103 | + |
| 104 | + if path[3] == "experts": |
| 105 | + path.remove("mlp"), path.remove("module") |
| 106 | + tensor_list = [] |
| 107 | + for exp in range(config.num_local_experts): |
| 108 | + weight = get_tensor( |
| 109 | + ".".join(transformer_path + ["block_sparse_moe.experts"] + [str(exp)] + path[4:5] + ["weight"]) |
| 110 | + ) |
| 111 | + tensor_list.append(weight) |
| 112 | + return torch.cat(tensor_list, dim=0).T if "w2" not in name else torch.cat(tensor_list, dim=1).T |
| 113 | + |
| 114 | + if path[3] == "qkv_proj": |
| 115 | + proj_names = ["q_proj", "k_proj", "v_proj"] |
| 116 | + tensor_list = get_tensors( |
| 117 | + [".".join(transformer_path + path[2:3] + [proj_name] + path[4:]) for proj_name in proj_names] |
| 118 | + ) |
| 119 | + # Permute q/k |
| 120 | + per_head_hidden_size = config.hidden_size // config.num_attention_heads |
| 121 | + # Permute q |
| 122 | + print(f"Permuting q {tensor_list[0].shape}") |
| 123 | + tensor_list[0] = permute_for_rotary( |
| 124 | + tensor=tensor_list[0], |
| 125 | + num_heads=config.num_attention_heads, |
| 126 | + per_head_hidden_size=per_head_hidden_size, |
| 127 | + hidden_size=config.hidden_size, |
| 128 | + ) |
| 129 | + # Permute k |
| 130 | + print(f"Permuting k {tensor_list[1].shape}") |
| 131 | + tensor_list[1] = permute_for_rotary( |
| 132 | + tensor=tensor_list[1], |
| 133 | + num_heads=config.num_key_value_heads, |
| 134 | + per_head_hidden_size=per_head_hidden_size, |
| 135 | + hidden_size=config.hidden_size, |
| 136 | + ) |
| 137 | + return torch.cat(tensor_list, dim=0) |
| 138 | + |
| 139 | + return get_tensor(".".join(transformer_path + path[2:])) |
| 140 | + |
| 141 | + else: |
| 142 | + raise ValueError(f"Couldn't find transformer equivalent of {name}") |
| 143 | + |
| 144 | + |
| 145 | +def initialize_nanotron_model(dtype, parallel_context, parallel_config, model_config): |
| 146 | + model = build_model( |
| 147 | + model_builder=lambda: MixtralForTraining( |
| 148 | + config=model_config, |
| 149 | + parallel_context=parallel_context, |
| 150 | + parallel_config=parallel_config, |
| 151 | + random_states=None, |
| 152 | + ), |
| 153 | + dtype=dtype, |
| 154 | + parallel_context=parallel_context, |
| 155 | + device=torch.device("cpu"), |
| 156 | + ) |
| 157 | + return model |
| 158 | + |
| 159 | + |
| 160 | +def fix_device_map_for_pp(model_config_nanotron, model, parallel_context): |
| 161 | + device_map = {} |
| 162 | + current_pp_rank = dist.get_rank(group=parallel_context.pp_pg) |
| 163 | + device_map["model.embed_tokens"] = ( |
| 164 | + model.model.token_position_embeddings.rank |
| 165 | + if current_pp_rank == model.model.token_position_embeddings.rank |
| 166 | + else "meta" |
| 167 | + ) |
| 168 | + for i in range(model_config_nanotron.num_hidden_layers): |
| 169 | + device_map[f"model.layers.{i}"] = ( |
| 170 | + model.model.decoder[i].rank if current_pp_rank == model.model.decoder[i].rank else "meta" |
| 171 | + ) |
| 172 | + device_map["model.norm"] = ( |
| 173 | + model.model.final_layer_norm.rank if current_pp_rank == model.model.final_layer_norm.rank else "meta" |
| 174 | + ) |
| 175 | + device_map["lm_head"] = model.model.lm_head.rank if current_pp_rank == model.model.lm_head.rank else "meta" |
| 176 | + |
| 177 | + |
| 178 | +import lovely_tensors as lt |
| 179 | + |
| 180 | +lt.monkey_patch() # noqa |
| 181 | + |
| 182 | + |
| 183 | +def convert_trfrs_to_nanotron(dp, pp, tp, model_name="huggyllama/llama-7b", save_path="pretrained/llama-7b"): |
| 184 | + # check save_path doesnt exist or is empty |
| 185 | + save_path = Path(save_path) |
| 186 | + if not save_path.exists(): |
| 187 | + save_path.mkdir(parents=True, exist_ok=True) |
| 188 | + assert len(list(save_path.iterdir())) == 0, f"save_path {save_path} is not empty" |
| 189 | + |
| 190 | + parallel_config = PARALLELISM_NANOTRON |
| 191 | + |
| 192 | + parallel_config.dp = dp |
| 193 | + parallel_config.pp = pp |
| 194 | + parallel_config.tp = tp |
| 195 | + parallel_config.expert_parallel_size = 1 |
| 196 | + |
| 197 | + # Initialise all process groups |
| 198 | + parallel_context = ParallelContext( |
| 199 | + data_parallel_size=parallel_config.dp, |
| 200 | + pipeline_parallel_size=parallel_config.pp, |
| 201 | + tensor_parallel_size=parallel_config.tp, |
| 202 | + expert_parallel_size=parallel_config.expert_parallel_size, |
| 203 | + ) |
| 204 | + # params |
| 205 | + dtype = torch.bfloat16 # Flash attention doesn't support fp32 |
| 206 | + |
| 207 | + # Initialise nanotron model |
| 208 | + nanotron_model_config = MixtralConfig.from_hf_config(model_name) |
| 209 | + model = initialize_nanotron_model(dtype, parallel_context, parallel_config, nanotron_model_config) |
| 210 | + |
| 211 | + # Initialise transformers model |
| 212 | + device_map = fix_device_map_for_pp(CONFIG_NANOTRON.model.model_config, model, parallel_context) |
| 213 | + model_ref = MixtralForCausalLM.from_pretrained(model_name, torch_dtype=dtype, device_map=device_map) |
| 214 | + print(model) |
| 215 | + print(model_ref) |
| 216 | + # Copy weights from trfrs to nanotron |
| 217 | + ref_state_dict = model_ref.state_dict() |
| 218 | + for name, param in model.named_parameters(): |
| 219 | + print(f"Syncing {name}") |
| 220 | + ref_param = get_transformers_weight(name=name, ref_module_state_dict=ref_state_dict, ref_module=model_ref) |
| 221 | + |
| 222 | + param_is_tp_sharded = ( |
| 223 | + isinstance(param, NanotronParameter) |
| 224 | + and param.is_sharded |
| 225 | + and parallel_context.world_ranks_to_pg[param.get_sharded_info().global_ranks] == parallel_context.tp_pg |
| 226 | + ) |
| 227 | + |
| 228 | + if param_is_tp_sharded: |
| 229 | + sharded_info = param.get_sharded_info() |
| 230 | + # copy param data (not just the reference) |
| 231 | + with torch.no_grad(): |
| 232 | + for local_global_slices_pair in sharded_info.local_global_slices_pairs: |
| 233 | + local_slices = local_global_slices_pair.local_slices |
| 234 | + global_slices = local_global_slices_pair.global_slices |
| 235 | + param[local_slices].copy_(ref_param[global_slices]) |
| 236 | + else: |
| 237 | + assert ( |
| 238 | + ref_param.shape == param.shape |
| 239 | + ), f"Parameter shape don't match for {name}\n{ref_param.shape} != {param.shape}" |
| 240 | + # copy param data (not just the reference) |
| 241 | + with torch.no_grad(): |
| 242 | + param.copy_(ref_param) |
| 243 | + ref_param = None |
| 244 | + # torch.cuda.empty_cache() |
| 245 | + |
| 246 | + # TODO @nouamanetazi: assert weights are the same |
| 247 | + # Marks parameters as NanotronParameters |
| 248 | + mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) |
| 249 | + |
| 250 | + sanity_check(root_module=model) |
| 251 | + |
| 252 | + checkpoint_metadata = { |
| 253 | + "last_train_step": 0, |
| 254 | + "consumed_train_samples": 0, |
| 255 | + } |
| 256 | + save( |
| 257 | + config=CONFIG_NANOTRON, |
| 258 | + model=model, |
| 259 | + optimizer=None, |
| 260 | + lr_scheduler=None, |
| 261 | + parallel_context=parallel_context, |
| 262 | + root_folder=save_path, |
| 263 | + should_save_optimizer=False, |
| 264 | + should_save_lr_scheduler=False, |
| 265 | + checkpoint_metadata=checkpoint_metadata, |
| 266 | + sanity_checks=False, |
| 267 | + ) |
| 268 | + |
| 269 | + if dist.get_rank(parallel_context.world_pg) == 0: |
| 270 | + with open(save_path / "model_config.json", mode="w") as fo: |
| 271 | + fo.write(json.dumps(asdict(CONFIG_NANOTRON.model.model_config), indent=4)) |
| 272 | + |
| 273 | + print(f"Model saved to {save_path}") |
| 274 | + print("You can test the model by running the following command:") |
| 275 | + print(f"torchrun --nproc_per_node=1 run_generate.py --ckpt-path {save_path}") |
| 276 | + |
| 277 | + |
| 278 | +def main(): |
| 279 | + args = get_args() |
| 280 | + convert_trfrs_to_nanotron(**vars(args)) |
| 281 | + |
| 282 | + |
| 283 | +if __name__ == "__main__": |
| 284 | + main() |
0 commit comments