Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MergeKit]add log #9948

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddlenlp/mergekit/merge_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import List, Optional

from paddlenlp.utils.env import MERGE_CONFIG_NAME
from paddlenlp.utils.log import logger


@dataclass
Expand Down Expand Up @@ -140,6 +141,7 @@ def save_pretrained(self, save_directory):
# save it
with open(output_path, "w") as writer:
writer.write(json.dumps(output_dict, indent=2, sort_keys=True))
logger.info(f"Merge config file saved in {output_path}.")

@classmethod
def from_pretrained(cls, pretrained_model_path, **kwargs):
Expand Down
61 changes: 44 additions & 17 deletions paddlenlp/mergekit/merge_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import paddle.distributed as dist
from safetensors import safe_open
from safetensors.numpy import save_file
from tqdm.auto import tqdm

from paddlenlp.peft import LoRAConfig
from paddlenlp.utils import device_guard
Expand Down Expand Up @@ -92,7 +93,8 @@
self.mergekit()
else:
self.mergekit()
self.copy_file()
if paddle.distributed.get_rank() == 0:
self.copy_file()

def copy_file(self):
if self.merge_config.copy_file_list is not None:
Expand All @@ -106,7 +108,7 @@
if os.path.isfile(src_file):
shutil.copy2(src_file, dst_file)
else:
logger.warning(f"Copy failed: {file} not found in {src_path}")
logger.debug(f"Copy failed: {file} not found in {src_path}")

Check warning on line 111 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L111

Added line #L111 was not covered by tests

def mergekit(self):
# Check model file type
Expand All @@ -129,6 +131,7 @@
state_dict_list.append(self.get_model_state_dict(model_path, file_type_list[i]))
if self.merge_config.base_model_path is not None:
state_dict_list.append(self.get_model_state_dict(self.merge_config.base_model_path, file_type_list[-1]))
logger.info("Load all model state dict.")

if not all(state_dict_list[0].keys() == state_dict.keys() for state_dict in state_dict_list):
raise ValueError("State dict keys mismatch. Please make sure you load the correct weight file")
Expand All @@ -149,7 +152,7 @@
index["metadata"]["total_size"] += int(
np.prod(state_dict_list[0][key].shape) * self.numpy_dtype_map[str(state_dict_list[0][key].dtype)]
)
for key in local_keys:
for key in tqdm(local_keys, desc="Merging tensor"):
# Tensor preprocess
is_bf16 = str(state_dict_list[0][key].dtype) == "uint16"
tensor_list = [state_dict_list[i].pop(key) for i in range(model_num)]
Expand Down Expand Up @@ -208,20 +211,24 @@
# dtype==bfloat16: numpy(float32) -> paddle(float32) -> paddle(bfloat16) -> numpy(uint16)
merge_state_dict[key] = paddle.Tensor(merge_tensor, zero_copy=True).astype("bfloat16").numpy()

logger.info("Merge tensors successfully.")
# Save safetensor file
save_file_name = os.path.join(
self.merge_config.output_path,
f"{self.merge_config.merge_prefix}-{rank+1:05d}-of-{dist.get_world_size():05d}.safetensors",
)
save_file(
merge_state_dict,
os.path.join(
self.merge_config.output_path,
f"{self.merge_config.merge_prefix}-{rank+1:05d}-of-{dist.get_world_size():05d}.safetensors",
),
save_file_name,
metadata={"format": "np"},
)
logger.info(f"Model weights saved in {save_file_name}.")
# Save index file & merge config file
if paddle.distributed.get_rank() == 0:
save_index_file = os.path.join(self.merge_config.output_path, self.safe_index_name())
with open(save_index_file, "w", encoding="utf-8") as f:
f.write(json.dumps(index, indent=2) + "\n")
logger.info(f"Model index file saved in {save_index_file}.")
self.merge_config.save_pretrained(self.merge_config.output_path)

def get_model_state_dict(self, model_path, file_type, key_list=None, file=None):
Expand Down Expand Up @@ -325,7 +332,9 @@
file_map[index["weight_map"][key]] = [key]
else:
file_map[index["weight_map"][key]].append(key)
logger.info(f"Merging file list: {file_list[positions[rank] : positions[rank + 1]]}")

Check warning on line 335 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L335

Added line #L335 was not covered by tests
for shard_file in file_list[positions[rank] : positions[rank + 1]]:
logger.info(f"Start merging tensor in {shard_file}")

Check warning on line 337 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L337

Added line #L337 was not covered by tests
if self.merge_config.tensor_type == "np":
self.shard_merge_np(file_map[shard_file], index_list, shard_file)
else:
Expand Down Expand Up @@ -373,6 +382,7 @@
save_index_file = os.path.join(self.merge_config.output_path, self.safe_index_name())
with open(save_index_file, "w", encoding="utf-8") as f:
f.write(json.dumps(index, indent=2) + "\n")
logger.info(f"Model index file saved in {save_index_file}.")

def shard_merge_np(
self,
Expand Down Expand Up @@ -419,7 +429,7 @@
shard_file,
):
merge_state_dict = {}
for k in key_list:
for k in tqdm(key_list, desc="Merging tensor"):

Check warning on line 432 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L432

Added line #L432 was not covered by tests
tensor_list = []
for i, model_path in enumerate(self.merge_config.model_path_list):
with fast_safe_open(os.path.join(model_path, index_list[i]["weight_map"][k]), framework="np") as w:
Expand Down Expand Up @@ -473,11 +483,14 @@
merge_state_dict[k] = merge_tensor.astype("bfloat16").numpy()
else:
merge_state_dict[k] = merge_tensor.numpy()
logger.info("Merge tensors successfully.")
save_file_name = os.path.join(self.merge_config.output_path, shard_file)

Check warning on line 487 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L486-L487

Added lines #L486 - L487 were not covered by tests
save_file(
merge_state_dict,
os.path.join(self.merge_config.output_path, shard_file),
save_file_name,
metadata={"format": "np"},
)
logger.info(f"Model weights saved in {save_file_name}.")

Check warning on line 493 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L493

Added line #L493 was not covered by tests

def check_model_path(self, model_path, lora_merge=False):
if os.path.exists(os.path.join(model_path, self.safe_index_name())):
Expand Down Expand Up @@ -534,17 +547,19 @@

def shard_lora_merge(self, base_index, shard_file, lora_config, file_type_list, key_list=None, file=None):
merge_state_dict = {}
lora_state_dict = self.get_model_state_dict(self.merge_config.lora_model_path, file_type_list[0])
logger.info("Load LoRA weight successfully.")

Check warning on line 551 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L550-L551

Added lines #L550 - L551 were not covered by tests
base_state_dict = self.get_model_state_dict(
self.merge_config.base_model_path, file_type_list[1], key_list=key_list, file=file
)
lora_state_dict = self.get_model_state_dict(self.merge_config.lora_model_path, file_type_list[0])
logger.info("Load model weight successfully.")

Check warning on line 555 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L555

Added line #L555 was not covered by tests
if not lora_config.rslora:
scaling = lora_config.lora_alpha / lora_config.r
else:
scaling = lora_config.lora_alpha / math.sqrt(lora_config.r)

model_key_list = list(base_state_dict.keys())
for k in model_key_list:
for k in tqdm(model_key_list, desc="Merging tensor"):

Check warning on line 562 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L562

Added line #L562 was not covered by tests
if lora_state_dict is not None and k in lora_state_dict.keys():
tensor = lora_state_dict.pop(k)
else:
Expand All @@ -568,11 +583,15 @@
tensor += lora_A_tensor @ lora_B_tensor * scaling
tensor = tensor.numpy()
merge_state_dict[k] = tensor

logger.info("Merge tensors successfully.")
save_file_name = os.path.join(self.merge_config.output_path, shard_file)

Check warning on line 588 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L587-L588

Added lines #L587 - L588 were not covered by tests
save_file(
merge_state_dict,
os.path.join(self.merge_config.output_path, shard_file),
save_file_name,
metadata={"format": "np"},
)
logger.info(f"Model weights saved in {save_file_name}.")

Check warning on line 594 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L594

Added line #L594 was not covered by tests

def merge_safetensor_lora_model(self, file_type_list):
# Load index
Expand All @@ -592,7 +611,9 @@
file_list = sorted(list(set(base_index["weight_map"].values())))
if file_type_list[-1] == "safetensors" and len(file_list) >= dist.get_world_size():
positions = divide_positions(len(file_list), dist.get_world_size())
logger.info(f"Merging file list: {file_list[positions[rank] : positions[rank + 1]]}")

Check warning on line 614 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L614

Added line #L614 was not covered by tests
for shard_file in file_list[positions[rank] : positions[rank + 1]]:
logger.info(f"Start merging tensor in {shard_file}")

Check warning on line 616 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L616

Added line #L616 was not covered by tests
self.shard_lora_merge(base_index, shard_file, lora_config, file_type_list, file=shard_file)
index["weight_map"] = base_index["weight_map"]
else:
Expand Down Expand Up @@ -639,12 +660,15 @@
save_index_file = os.path.join(self.merge_config.output_path, self.safe_index_name())
with open(save_index_file, "w", encoding="utf-8") as f:
f.write(json.dumps(index, indent=2) + "\n")
logger.info(f"Model index file saved in {save_index_file}.")

Check warning on line 663 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L663

Added line #L663 was not covered by tests
self.merge_config.save_pretrained(self.merge_config.output_path)

def merge_pdparams_lora_model(self, file_type_list):
# Load & check state dict
lora_state_dict = self.get_model_state_dict(self.merge_config.lora_model_path, file_type_list[0])
logger.info("Load LoRA weight successfully.")

Check warning on line 669 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L669

Added line #L669 was not covered by tests
base_state_dict = self.get_model_state_dict(self.merge_config.base_model_path, file_type_list[1])
logger.info("Load model weight successfully.")

Check warning on line 671 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L671

Added line #L671 was not covered by tests
for key in lora_state_dict.keys():
if "lora_A" in key:
if key.replace("lora_A", "lora_B") not in lora_state_dict.keys():
Expand Down Expand Up @@ -675,7 +699,7 @@
# Merge state dict
rank = dist.get_rank()
local_keys = key_list[positions[rank] : positions[rank + 1]]
for k in local_keys:
for k in tqdm(local_keys, desc="Merging tensor"):

Check warning on line 702 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L702

Added line #L702 was not covered by tests
if k in lora_state_dict.keys():
tensor = lora_state_dict[k]
else:
Expand All @@ -701,17 +725,20 @@
merge_state_dict[k] = tensor

# Save safetensor file
save_file_name = os.path.join(

Check warning on line 728 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L728

Added line #L728 was not covered by tests
self.merge_config.output_path,
f"{self.merge_config.merge_prefix}-{rank+1:05d}-of-{dist.get_world_size():05d}.safetensors",
)
save_file(
merge_state_dict,
os.path.join(
self.merge_config.output_path,
f"{self.merge_config.merge_prefix}-{rank+1:05d}-of-{dist.get_world_size():05d}.safetensors",
),
save_file_name,
metadata={"format": "np"},
)
logger.info(f"Model weights saved in {save_file_name}.")

Check warning on line 737 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L737

Added line #L737 was not covered by tests
# Save index file & merge config file
if paddle.distributed.get_rank() == 0:
save_index_file = os.path.join(self.merge_config.output_path, self.safe_index_name())
with open(save_index_file, "w", encoding="utf-8") as f:
f.write(json.dumps(index, indent=2) + "\n")
logger.info(f"Model index file saved in {save_index_file}.")

Check warning on line 743 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L743

Added line #L743 was not covered by tests
self.merge_config.save_pretrained(self.merge_config.output_path)
Loading