Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions ais_bench/benchmark/configs/models/mf_models/mf_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from ais_bench.benchmark.models import MindFormerModel

models = [
dict(
attr="local", # local or service
type=MindFormerModel, # transformers < 4.33.0 用这个,优先AutoModelForCausalLM.from_pretrained加载模型,失败则用AutoModel.from_pretrained加载
abbr='mindformer-model',
path='THUDM/chatglm-6b', # path to model dir, current value is just a example
checkpoint = 'THUDM/your_checkpoint', # path to checkpoint file, current value is just a example
yaml_cfg_file = 'THUDM/your.yaml',
tokenizer_path='THUDM/chatglm-6b', # path to tokenizer dir, current value is just a example
model_kwargs=dict( # 模型参数参考 huggingface.co/docs/transformers/v4.50.0/en/model_doc/auto#transformers.AutoModel.from_pretrained
device_map='npu',
),
tokenizer_kwargs=dict( # tokenizer参数参考 huggingface.co/docs/transformers/v4.50.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase
padding_side='right',
),
generation_kwargs = dict( # 后处理参数参考huggingface.co/docs/transformers/main_classes/test_generation
temperature = 0.5,
top_k = 10,
top_p = 0.95,
do_sample = True,
seed = None,
repetition_penalty = 1.03,
),
run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用torchrun拉起任务

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此处的注释提到了 torchrun,但对于 MindFormer 模型,实际使用的是 msrun。这可能会误导用户。建议将注释更新为 msrun 或更通用的描述,以反映实际使用的启动器。

Suggested change
run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用torchrun拉起任务
run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用msrun拉起任务

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此处的注释说明“使用torchrun拉起任务”,但根据 openicl_infer.py 中的实现,MindFormer 模型将会使用 msrun。这个注释有误导性,建议修改为 msrun 以保持一致。

Suggested change
run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用torchrun拉起任务
run_cfg = dict(num_gpus=1, num_procs=1), # 多卡/多机多卡 参数,使用msrun拉起任务

max_out_len=100, # 最大输出token长度
batch_size=2, # 每次推理的batch size
max_seq_len=2048,
batch_padding=True,
)
]
3 changes: 2 additions & 1 deletion ais_bench/benchmark/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
from ais_bench.benchmark.models.api_models.triton_api import TritonCustomAPIStream # noqa: F401
from ais_bench.benchmark.models.api_models.tgi_api import TGICustomAPIStream # noqa: F401
from ais_bench.benchmark.models.api_models.vllm_custom_api_chat import VllmMultiturnAPIChatStream # noqa: F401
from ais_bench.benchmark.models.local_models.vllm_offline_vl import VLLMOfflineVLModel
from ais_bench.benchmark.models.local_models.vllm_offline_vl import VLLMOfflineVLModel
from ais_bench.benchmark.models.local_models.mindformers_model import MindFormerModel
306 changes: 306 additions & 0 deletions ais_bench/benchmark/models/local_models/mindformers_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
import os, sys

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

为了提高代码的可读性,建议遵循 PEP 8 风格指南,每行只导入一个模块。

Suggested change
import os, sys
import os
import sys

from typing import Dict, List, Optional, Union

import numpy as np
import torch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

torch 库被导入但未在文件中使用。为了保持代码整洁和避免不必要的依赖,建议移除这个未使用的导入。

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此文件中导入了 torch 模块,但在整个文件中并未使用。为了保持代码整洁和避免不必要的依赖,建议移除未使用的导入。

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

torch 模块被导入但未在文件中使用。为了保持代码整洁,建议移除未使用的导入。

import transformers

from ais_bench.benchmark.models.base import BaseModel
from ais_bench.benchmark.models.base_api import APITemplateParser
from ais_bench.benchmark.registry import MODELS
from ais_bench.benchmark.utils.logging import get_logger
from ais_bench.benchmark.utils.prompt import PromptList

from mindspore import Tensor, Model
from mindformers import MindFormerConfig, build_context
from mindformers.models import build_network
from mindformers.core.parallel_config import build_parallel_config
from mindformers.utils.load_checkpoint_utils import get_load_path_after_hf_convert
from mindformers.trainer.utils import transform_and_load_checkpoint

PromptType = Union[PromptList, str, dict]


class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""

def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
batch_size: int,
):
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence,
add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer

def __call__(self, input_ids, scores, **kwargs) -> bool:
# compare the last len(stop) tokens
lookback_ids_batch = input_ids[:, -self.sequence_id_len:]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if done:
continue
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker


def drop_error_generation_kwargs(generation_kwargs: dict) -> dict:
for key in ['is_synthetic', 'batch_size', 'do_performance']:
if key in generation_kwargs:
generation_kwargs.pop(key)
return generation_kwargs


@MODELS.register_module()
class MindFormerModel(BaseModel):

def __init__(self,
path: str,
checkpoint: Optional[str] = None,
yaml_cfg_file: Optional[str] = None,
batch_size: int = 1,
max_seq_len: int = 2048,
tokenizer_path: Optional[str] = None,
tokenizer_kwargs: dict = dict(),
tokenizer_only: bool = False,
generation_kwargs: dict = dict(),
meta_template: Optional[Dict] = None,
extract_pred_after_decode: bool = False,
batch_padding: bool = False,
pad_token_id: Optional[int] = None,
mode: str = 'none',
use_fastchat_template: bool = False,
end_str: Optional[str] = None,
**kwargs):
super().__init__(path=path,
max_seq_len=max_seq_len,
tokenizer_only=tokenizer_only,
meta_template=meta_template)
self.logger = get_logger()
self.batch_size = batch_size
self.pad_token_id = pad_token_id
self.pretrained_model_path = path
if mode not in ['none', 'mid']:
raise ValueError(f"mode must be 'none' or 'mid', but got {mode}")
self.mode = mode
if not yaml_cfg_file:
raise ValueError('`yaml_cfg_file` is required for MindFormerModel')
self.config = MindFormerConfig(yaml_cfg_file)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

yaml_cfg_fileNone 时,MindFormerConfig(yaml_cfg_file) 可能会失败。__init__ 函数签名允许 yaml_cfg_fileNone,因此在调用 MindFormerConfig 之前应该添加一个检查,以确保 yaml_cfg_file 不是 None,或者在 yaml_cfg_fileNone 时进行适当的错误处理。

例如:

if not yaml_cfg_file:
    raise ValueError('`yaml_cfg_file` is required for MindFormerModel')
self.config = MindFormerConfig(yaml_cfg_file)

self.checkpoint = checkpoint
self._load_tokenizer(path=path,
tokenizer_path=tokenizer_path,
tokenizer_kwargs=tokenizer_kwargs)
self.batch_padding = batch_padding
self.extract_pred_after_decode = extract_pred_after_decode
if not tokenizer_only:
self._load_model(self.config, self.batch_size, self.max_seq_len)
self.generation_kwargs = generation_kwargs
self.use_fastchat_template = use_fastchat_template
self.end_str = end_str

def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
tokenizer_kwargs: dict):
from transformers import AutoTokenizer, GenerationConfig

DEFAULT_TOKENIZER_KWARGS = dict(padding_side='left', truncation_side='left', trust_remote_code=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

DEFAULT_TOKENIZER_KWARGS 是一个常量,最好在类级别(MindFormerModel 内部,__init__ 方法外部)定义,而不是在 _load_tokenizer 方法内部。这样可以提高代码的可读性和可维护性,并明确其作用域。

例如:

class MindFormerModel(BaseModel):
    DEFAULT_TOKENIZER_KWARGS = dict(...)

    def __init__(self, ...):
        ...

    def _load_tokenizer(self, ...):
        kwargs = self.DEFAULT_TOKENIZER_KWARGS.copy()
        ...

kwargs = DEFAULT_TOKENIZER_KWARGS.copy()
kwargs.update(tokenizer_kwargs)

load_path = tokenizer_path if tokenizer_path else path
self.tokenizer = AutoTokenizer.from_pretrained(load_path, **kwargs)

pad_token_id = self.pad_token_id

# A patch for some models without pad_token_id
if pad_token_id is not None:
if self.tokenizer.pad_token_id is None:
self.logger.debug(f'Using {pad_token_id} as pad_token_id')
elif self.tokenizer.pad_token_id != pad_token_id:
self.logger.warning(f'pad_token_id is not consistent. Using {pad_token_id} as pad_token_id')
self.tokenizer.pad_token_id = pad_token_id
return
if self.tokenizer.pad_token_id is not None:
return
self.logger.warning('pad_token_id is not set for the tokenizer.')

try:
generation_config = GenerationConfig.from_pretrained(path)
except Exception:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

使用 except Exception: 过于宽泛,它会捕获所有类型的异常,可能会掩盖非预期的错误。建议捕获更具体的异常,例如 OSErrortransformers 库在模型加载失败时可能抛出的特定异常,以提高代码的健壮性和可维护性。

generation_config = None

if generation_config and generation_config.pad_token_id is not None:
self.logger.warning(f'Using {generation_config.pad_token_id} as pad_token_id.')
self.tokenizer.pad_token_id = generation_config.pad_token_id
return
if self.tokenizer.eos_token_id is not None:
self.logger.warning(f'Using eos_token_id {self.tokenizer.eos_token_id} as pad_token_id.')
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
return
raise ValueError('pad_token_id is not set for this tokenizer. Please set `pad_token_id={PAD_TOKEN_ID}` in model_cfg.')

def _set_config_from_yaml(self):
if self.checkpoint is not None:
self.config.load_checkpoint = self.checkpoint
elif self.checkpoint is None and self.config.load_checkpoint is None:
self.config.load_checkpoint = self.path
self.config.model.pretrained_model_dir = self.pretrained_model_path
self.config.model.model_config.seq_length = self.max_seq_len
build_context(self.config)
build_parallel_config(self.config)

def _load_model(self, config, batch_size, max_seq_len):

self._set_config_from_yaml()
try:
self.model = build_network(
config.model,
default_args={
"parallel_config": config.parallel_config,
"moe_config": config.moe_config
})
self.logger.info("..........Network Built Successfully..........")
self.model.set_train(False)
config.load_checkpoint = get_load_path_after_hf_convert(config, self.model)
self.logger.info(f"load checkpoint path : {config.load_checkpoint}")
run_mode = config.get("run_mode", None)
if run_mode == "predict":
self.model.load_weights(config.load_checkpoint)
else:
model = Model(self.model)
input_ids = Tensor(np.ones((batch_size, max_seq_len), dtype=np.int32))
infer_data = self.model.prepare_inputs_for_predict_layout(input_ids)
transform_and_load_checkpoint(config, model, self.model, infer_data, do_eval=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

函数transform_and_load_checkpoint在此处被调用,但它既没有在本文件中定义,也没有被导入。这将导致运行时出现NameError。请从mindformers.checkpoint.checkpoint中导入此函数。

Suggested change
transform_and_load_checkpoint(config, model, self.model, infer_data, do_eval=True)
from mindformers.checkpoint.checkpoint import transform_and_load_checkpoint
transform_and_load_checkpoint(config, model, self.model, infer_data, do_eval=True)


self.logger.info("..........Checkpoint Load Successfully..........")
except ValueError as e:
raise ValueError('Failed to load MindFormers model, please check configuration') from e


def generate(self,
inputs: List[str],
max_out_len: int,
min_out_len: Optional[int] = None,
stopping_criteria: List[str] = [],
**kwargs) -> List[str]:
"""Generate results given a list of inputs.

Args:
inputs (List[str]): A list of strings.
max_out_len (int): The maximum length of the output.
min_out_len (Optional[int]): The minimum length of the output.

Returns:
List[str]: A list of generated strings.
"""
generation_kwargs = kwargs.copy()
generation_kwargs.update(self.generation_kwargs)

messages = list(inputs)
batch_size = len(messages)
prompt_char_lens = None

if self.extract_pred_after_decode:
prompt_char_lens = [len(text) for text in messages]

if self.use_fastchat_template:
try:
from fastchat.model import get_conversation_template
except ModuleNotFoundError:
raise ModuleNotFoundError(
'Fastchat is not implemented. You can use '
"'pip install \"fschat[model_worker,webui]\"' "
'to implement fastchat.')
for idx, text in enumerate(messages):
conv = get_conversation_template('vicuna')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

在这里,对话模板被硬编码为 'vicuna'。这限制了模型的灵活性,因为不同的模型可能需要不同的对话模板。建议将模板名称作为可配置的参数,例如通过 generation_kwargs__init__ 的新参数传入,以支持更广泛的模型。

conv.append_message(conv.roles[0], text)
conv.append_message(conv.roles[1], None)
messages[idx] = conv.get_prompt()
if self.mode == 'mid':
assert len(messages) == 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

assert 语句通常用于检查开发过程中的内部不变量,而不应用于验证运行时输入。如果 mid 模式下 messages 列表的长度不为 1,assert 将会使程序崩溃,且不会提供清晰的错误信息。建议使用 if 条件判断并抛出 ValueError,向用户提供更明确的错误说明。

Suggested change
assert len(messages) == 1
if len(messages) != 1:
raise ValueError(f"The 'mid' mode only supports a batch size of 1, but got {len(messages)}")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

mode == 'mid' 的情况下,使用了 assert 来检查 messages 列表的长度。assert 语句在生产环境中可能会被禁用(例如,使用 -O 标志运行 Python),这会导致校验失效。建议使用 raise ValueError 来进行参数校验,并提供更明确的错误信息,这样更健壮。

Suggested change
assert len(messages) == 1
if len(messages) != 1:
raise ValueError(f"In 'mid' mode, expected 1 message, but got {len(messages)}")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

使用 assert 语句进行输入验证不是最佳实践,因为断言在生产环境中可能会被禁用(例如,使用 -O 标志运行 Python)。建议改用 ValueError 来进行运行时检查,并提供更明确的错误信息。

Suggested change
assert len(messages) == 1
if len(messages) != 1:
raise ValueError(f"Expected a single message in 'mid' mode, but got {len(messages)}.")

tokens = self.tokenizer(messages, padding=False, truncation=False, return_tensors='np')
input_ids = tokens['input_ids']
if input_ids.shape[-1] > self.max_seq_len:
input_ids = np.concatenate([input_ids[:, : self.max_seq_len // 2], input_ids[:, - self.max_seq_len // 2:]], axis=-1)
tokens = {'input_ids': input_ids}
else:
tokenize_kwargs = dict(
padding=True,
truncation=True,
max_length=self.max_seq_len,
return_tensors='np'
)
tokens = self.tokenizer(messages, **tokenize_kwargs)

input_ids = tokens['input_ids']
if len(messages) > 1:
attention_mask = tokens.get('attention_mask')
prompt_token_lens = (
attention_mask.sum(axis=1).astype(int).tolist()
if attention_mask is not None else
[input_ids.shape[1]] * batch_size
)
else:
prompt_token_lens = [len(ids) for ids in input_ids]

input_ids_tensor = Tensor(input_ids)

if min_out_len is not None:
generation_kwargs['min_new_tokens'] = min_out_len
generation_kwargs['max_new_tokens'] = max_out_len
generation_kwargs.setdefault('top_k', 1)
generation_kwargs.setdefault('return_dict_in_generate', False)

origin_stopping_criteria = list(stopping_criteria)
if stopping_criteria:
if self.tokenizer.eos_token is not None:
stopping_criteria = stopping_criteria + [
self.tokenizer.eos_token
]
stopping_list = transformers.StoppingCriteriaList([
*[
MultiTokenEOSCriteria(sequence, self.tokenizer,
input_ids_tensor.shape[0])
for sequence in stopping_criteria
],
])
generation_kwargs['stopping_criteria'] = stopping_list

generation_kwargs = drop_error_generation_kwargs(generation_kwargs)

outputs = self.model.generate(input_ids=input_ids_tensor,
**generation_kwargs)

if isinstance(outputs, dict):
outputs = outputs.get('sequences', outputs)
Comment on lines +277 to +278

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

这里的代码 outputs = outputs.get('sequences', outputs) 存在风险。如果 outputs 是一个字典但没有 'sequences' 键,get 方法会返回默认值 outputs(即字典本身)。这会导致下一行 [seq.tolist() for seq in outputs] 对字典进行迭代,从而引发 AttributeError。建议修改为更安全的处理方式,确保在访问 'sequences' 失败时能正确处理。

Suggested change
if isinstance(outputs, dict):
outputs = outputs.get('sequences', outputs)
if isinstance(outputs, dict):
outputs = outputs.get('sequences')
if outputs is None:
raise ValueError("Model output dictionary is missing 'sequences' key.")

if outputs is None:
raise ValueError("Model output dictionary is missing 'sequence' key.")

sequences = [seq.tolist() for seq in outputs]

if not self.extract_pred_after_decode:
sequences = [
seq[prompt_len:]
for seq, prompt_len in zip(sequences, prompt_token_lens)
]

decodeds = [
self.tokenizer.decode(seq, skip_special_tokens=True)
for seq in sequences
]

if self.extract_pred_after_decode and prompt_char_lens is not None:
decodeds = [
text[length:]
for text, length in zip(decodeds, prompt_char_lens)
]

if self.end_str:
decodeds = [text.split(self.end_str)[0] for text in decodeds]
if origin_stopping_criteria:
for token in origin_stopping_criteria:
decodeds = [text.split(token)[0] for text in decodeds]
Comment on lines +301 to +305

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

当前对 end_strorigin_stopping_criteria 的后处理方式在循环中重复创建列表 (decodeds = [...]),当 decodeds 列表或停止标记列表很大时,效率较低。建议将所有停止标记组合起来,并在一次遍历中处理每个解码后的字符串,以提高性能。

Suggested change
if self.end_str:
decodeds = [text.split(self.end_str)[0] for text in decodeds]
if origin_stopping_criteria:
for token in origin_stopping_criteria:
decodeds = [text.split(token)[0] for text in decodeds]
all_stop_tokens = []
if self.end_str:
all_stop_tokens.append(self.end_str)
all_stop_tokens.extend(origin_stopping_criteria)
if all_stop_tokens:
for i in range(len(decodeds)):
text = decodeds[i]
for token in all_stop_tokens:
text = text.split(token)[0]
decodeds[i] = text

Comment on lines +303 to +305

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这段代码通过多次循环来移除多个可能的停止词(origin_stopping_criteria)。每次循环都会创建一个新的列表,当停止词较多或解码文本很长时,效率较低。可以优化为单次遍历,对每个文本找到第一个出现的停止词并截断。

Suggested change
if origin_stopping_criteria:
for token in origin_stopping_criteria:
decodeds = [text.split(token)[0] for text in decodeds]
if origin_stopping_criteria:
new_decodeds = []
for text in decodeds:
first_stop_pos = len(text)
for token in origin_stopping_criteria:
pos = text.find(token)
if pos != -1:
first_stop_pos = min(first_stop_pos, pos)
new_decodeds.append(text[:first_stop_pos])
decodeds = new_decodeds

Comment on lines +301 to +305

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

当前对 end_strorigin_stopping_criteria 的处理方式(通过多次循环和 split)效率较低,并且在停止标记(stop token)相互包含或顺序不同时可能导致不正确的结果。建议重构此逻辑,通过单次遍历找到第一个出现的停止标记并进行截断,这样更高效且健壮。

Suggested change
if self.end_str:
decodeds = [text.split(self.end_str)[0] for text in decodeds]
if origin_stopping_criteria:
for token in origin_stopping_criteria:
decodeds = [text.split(token)[0] for text in decodeds]
all_stop_tokens = ([self.end_str] if self.end_str else []) + origin_stopping_criteria
if all_stop_tokens:
new_decodeds = []
for text in decodeds:
min_index = len(text)
for token in all_stop_tokens:
if not token:
continue
idx = text.find(token)
if idx != -1:
min_index = min(min_index, idx)
new_decodeds.append(text[:min_index])
decodeds = new_decodeds

return decodeds
Loading
Loading