Skip to content

Commit ad749d3

Browse files
committedMar 11, 2024
convert trfrs to nanotron
1 parent 2568943 commit ad749d3

File tree

2 files changed

+285
-1
lines changed

2 files changed

+285
-1
lines changed
 

‎.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ repos:
3333
- id: codespell
3434
args:
3535
- -w
36-
- --ignore-words-list=nd,reacher,thist,ths,magent,ba,fo
36+
- --ignore-words-list=nd,reacher,thist,ths,magent,ba,fo,doesnt
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
# ruff: noqa: E402
2+
"""
3+
This module converts
4+
5+
Command:
6+
torchrun --nproc_per_node=1 convert_trfrs_to_nanotron.py \
7+
--model_name hf-internal-testing/Mixtral-tiny \
8+
--save_path ./pretrained/mixtral
9+
"""
10+
import argparse
11+
import sys
12+
from dataclasses import asdict
13+
from pathlib import Path
14+
from typing import Dict, List
15+
16+
import torch
17+
from config_mixtral import MixtralConfig
18+
19+
sys.path.append(Path(__file__).parent.parent.as_posix())
20+
21+
import json
22+
23+
import nanotron.distributed as dist
24+
from config_mixtral_tiny import CONFIG as CONFIG_NANOTRON
25+
from config_mixtral_tiny import PARALLELISM as PARALLELISM_NANOTRON
26+
from mixtral import MixtralForTraining
27+
from nanotron.models import build_model
28+
from nanotron.parallel.context import ParallelContext
29+
from nanotron.parallel.parameters import NanotronParameter, sanity_check
30+
from nanotron.serialize import save
31+
from nanotron.trainer import mark_tied_parameters
32+
from transformers import MixtralForCausalLM
33+
34+
35+
def get_args():
36+
parser = argparse.ArgumentParser(description="Convert transformers weights to nanotron weights")
37+
parser.add_argument("--model_name", type=str, default="hf-internal-testing/Mixtral-tiny")
38+
parser.add_argument("--save_path", type=str, default="pretrained/Mixtral-7B-v0.1")
39+
parser.add_argument("--dp", type=int, default=1)
40+
parser.add_argument("--pp", type=int, default=1)
41+
parser.add_argument("--tp", type=int, default=1)
42+
return parser.parse_args()
43+
44+
45+
def permute_for_rotary(tensor, num_heads, per_head_hidden_size, hidden_size):
46+
return (
47+
tensor.view(num_heads, 2, per_head_hidden_size // 2, hidden_size)
48+
.transpose(1, 2)
49+
.contiguous()
50+
.view(num_heads * per_head_hidden_size, hidden_size)
51+
)
52+
53+
54+
def get_transformers_weight(
55+
name: str, ref_module_state_dict: Dict[str, torch.Tensor], ref_module: MixtralForCausalLM, get_grad: bool = False
56+
) -> torch.Tensor:
57+
"""From our nanotron implementation, we get the equivalent tensor in transformers implementation"""
58+
config = ref_module.config
59+
nanotron_prefix = "model."
60+
assert name.startswith(nanotron_prefix)
61+
name = name[len(nanotron_prefix) :]
62+
63+
path = name.split(".")
64+
path.remove("pp_block")
65+
name = ".".join(path)
66+
67+
if get_grad is False:
68+
69+
def get_tensor(path: str):
70+
return ref_module_state_dict[path]
71+
72+
def get_tensors(path: List[str]):
73+
return [get_tensor(p) for p in path]
74+
75+
else:
76+
77+
def get_tensor(path: str):
78+
weight = ref_module.get_parameter(path)
79+
return weight.grad
80+
81+
def get_tensors(path: List[str]):
82+
return [get_tensor(p) for p in path]
83+
84+
if name == "token_position_embeddings.token_embedding.weight":
85+
return get_tensor("model.embed_tokens.weight")
86+
87+
elif name == "lm_head.weight":
88+
# This only used when weights are not shared
89+
return get_tensor("lm_head.weight")
90+
91+
elif name == "final_layer_norm.weight":
92+
return get_tensor("model.norm.weight")
93+
94+
if path[0] == "decoder":
95+
transformer_path = ["model"] + ["layers"] + [path[1]]
96+
97+
if path[2] == "attn":
98+
path[2] = "self_attn"
99+
100+
if path[2] == "block_sparse_moe":
101+
if path[3] == "gate":
102+
return get_tensor(".".join(transformer_path + path[2:4] + path[5:]))
103+
104+
if path[3] == "experts":
105+
path.remove("mlp"), path.remove("module")
106+
tensor_list = []
107+
for exp in range(config.num_local_experts):
108+
weight = get_tensor(
109+
".".join(transformer_path + ["block_sparse_moe.experts"] + [str(exp)] + path[4:5] + ["weight"])
110+
)
111+
tensor_list.append(weight)
112+
return torch.cat(tensor_list, dim=0).T if "w2" not in name else torch.cat(tensor_list, dim=1).T
113+
114+
if path[3] == "qkv_proj":
115+
proj_names = ["q_proj", "k_proj", "v_proj"]
116+
tensor_list = get_tensors(
117+
[".".join(transformer_path + path[2:3] + [proj_name] + path[4:]) for proj_name in proj_names]
118+
)
119+
# Permute q/k
120+
per_head_hidden_size = config.hidden_size // config.num_attention_heads
121+
# Permute q
122+
print(f"Permuting q {tensor_list[0].shape}")
123+
tensor_list[0] = permute_for_rotary(
124+
tensor=tensor_list[0],
125+
num_heads=config.num_attention_heads,
126+
per_head_hidden_size=per_head_hidden_size,
127+
hidden_size=config.hidden_size,
128+
)
129+
# Permute k
130+
print(f"Permuting k {tensor_list[1].shape}")
131+
tensor_list[1] = permute_for_rotary(
132+
tensor=tensor_list[1],
133+
num_heads=config.num_key_value_heads,
134+
per_head_hidden_size=per_head_hidden_size,
135+
hidden_size=config.hidden_size,
136+
)
137+
return torch.cat(tensor_list, dim=0)
138+
139+
return get_tensor(".".join(transformer_path + path[2:]))
140+
141+
else:
142+
raise ValueError(f"Couldn't find transformer equivalent of {name}")
143+
144+
145+
def initialize_nanotron_model(dtype, parallel_context, parallel_config, model_config):
146+
model = build_model(
147+
model_builder=lambda: MixtralForTraining(
148+
config=model_config,
149+
parallel_context=parallel_context,
150+
parallel_config=parallel_config,
151+
random_states=None,
152+
),
153+
dtype=dtype,
154+
parallel_context=parallel_context,
155+
device=torch.device("cpu"),
156+
)
157+
return model
158+
159+
160+
def fix_device_map_for_pp(model_config_nanotron, model, parallel_context):
161+
device_map = {}
162+
current_pp_rank = dist.get_rank(group=parallel_context.pp_pg)
163+
device_map["model.embed_tokens"] = (
164+
model.model.token_position_embeddings.rank
165+
if current_pp_rank == model.model.token_position_embeddings.rank
166+
else "meta"
167+
)
168+
for i in range(model_config_nanotron.num_hidden_layers):
169+
device_map[f"model.layers.{i}"] = (
170+
model.model.decoder[i].rank if current_pp_rank == model.model.decoder[i].rank else "meta"
171+
)
172+
device_map["model.norm"] = (
173+
model.model.final_layer_norm.rank if current_pp_rank == model.model.final_layer_norm.rank else "meta"
174+
)
175+
device_map["lm_head"] = model.model.lm_head.rank if current_pp_rank == model.model.lm_head.rank else "meta"
176+
177+
178+
import lovely_tensors as lt
179+
180+
lt.monkey_patch() # noqa
181+
182+
183+
def convert_trfrs_to_nanotron(dp, pp, tp, model_name="huggyllama/llama-7b", save_path="pretrained/llama-7b"):
184+
# check save_path doesnt exist or is empty
185+
save_path = Path(save_path)
186+
if not save_path.exists():
187+
save_path.mkdir(parents=True, exist_ok=True)
188+
assert len(list(save_path.iterdir())) == 0, f"save_path {save_path} is not empty"
189+
190+
parallel_config = PARALLELISM_NANOTRON
191+
192+
parallel_config.dp = dp
193+
parallel_config.pp = pp
194+
parallel_config.tp = tp
195+
parallel_config.expert_parallel_size = 1
196+
197+
# Initialise all process groups
198+
parallel_context = ParallelContext(
199+
data_parallel_size=parallel_config.dp,
200+
pipeline_parallel_size=parallel_config.pp,
201+
tensor_parallel_size=parallel_config.tp,
202+
expert_parallel_size=parallel_config.expert_parallel_size,
203+
)
204+
# params
205+
dtype = torch.bfloat16 # Flash attention doesn't support fp32
206+
207+
# Initialise nanotron model
208+
nanotron_model_config = MixtralConfig.from_hf_config(model_name)
209+
model = initialize_nanotron_model(dtype, parallel_context, parallel_config, nanotron_model_config)
210+
211+
# Initialise transformers model
212+
device_map = fix_device_map_for_pp(CONFIG_NANOTRON.model.model_config, model, parallel_context)
213+
model_ref = MixtralForCausalLM.from_pretrained(model_name, torch_dtype=dtype, device_map=device_map)
214+
print(model)
215+
print(model_ref)
216+
# Copy weights from trfrs to nanotron
217+
ref_state_dict = model_ref.state_dict()
218+
for name, param in model.named_parameters():
219+
print(f"Syncing {name}")
220+
ref_param = get_transformers_weight(name=name, ref_module_state_dict=ref_state_dict, ref_module=model_ref)
221+
222+
param_is_tp_sharded = (
223+
isinstance(param, NanotronParameter)
224+
and param.is_sharded
225+
and parallel_context.world_ranks_to_pg[param.get_sharded_info().global_ranks] == parallel_context.tp_pg
226+
)
227+
228+
if param_is_tp_sharded:
229+
sharded_info = param.get_sharded_info()
230+
# copy param data (not just the reference)
231+
with torch.no_grad():
232+
for local_global_slices_pair in sharded_info.local_global_slices_pairs:
233+
local_slices = local_global_slices_pair.local_slices
234+
global_slices = local_global_slices_pair.global_slices
235+
param[local_slices].copy_(ref_param[global_slices])
236+
else:
237+
assert (
238+
ref_param.shape == param.shape
239+
), f"Parameter shape don't match for {name}\n{ref_param.shape} != {param.shape}"
240+
# copy param data (not just the reference)
241+
with torch.no_grad():
242+
param.copy_(ref_param)
243+
ref_param = None
244+
# torch.cuda.empty_cache()
245+
246+
# TODO @nouamanetazi: assert weights are the same
247+
# Marks parameters as NanotronParameters
248+
mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config)
249+
250+
sanity_check(root_module=model)
251+
252+
checkpoint_metadata = {
253+
"last_train_step": 0,
254+
"consumed_train_samples": 0,
255+
}
256+
save(
257+
config=CONFIG_NANOTRON,
258+
model=model,
259+
optimizer=None,
260+
lr_scheduler=None,
261+
parallel_context=parallel_context,
262+
root_folder=save_path,
263+
should_save_optimizer=False,
264+
should_save_lr_scheduler=False,
265+
checkpoint_metadata=checkpoint_metadata,
266+
sanity_checks=False,
267+
)
268+
269+
if dist.get_rank(parallel_context.world_pg) == 0:
270+
with open(save_path / "model_config.json", mode="w") as fo:
271+
fo.write(json.dumps(asdict(CONFIG_NANOTRON.model.model_config), indent=4))
272+
273+
print(f"Model saved to {save_path}")
274+
print("You can test the model by running the following command:")
275+
print(f"torchrun --nproc_per_node=1 run_generate.py --ckpt-path {save_path}")
276+
277+
278+
def main():
279+
args = get_args()
280+
convert_trfrs_to_nanotron(**vars(args))
281+
282+
283+
if __name__ == "__main__":
284+
main()

0 commit comments

Comments
 (0)
Please sign in to comment.