From dbeba688806dab84b65fec82e3088eac98466d32 Mon Sep 17 00:00:00 2001 From: GaoHuaZhang <1484391106@qq.com> Date: Thu, 22 Jan 2026 15:21:07 +0800 Subject: [PATCH 1/2] adapter mooncake trace dataset --- ais_bench/benchmark/cli/utils.py | 3 +- .../mooncake_trace/mooncake_trace_gen.py | 40 ++ ais_bench/benchmark/datasets/__init__.py | 3 +- ais_bench/benchmark/datasets/base.py | 30 +- .../benchmark/datasets/mooncake_trace.py | 589 ++++++++++++++++++ ais_bench/benchmark/models/output.py | 2 +- .../benchmark/openicl/icl_dataset_reader.py | 6 + .../icl_inferencer/icl_base_api_inferencer.py | 49 +- .../icl_inferencer/icl_gen_inferencer.py | 5 +- .../benchmark/tasks/openicl_api_infer.py | 52 +- ais_bench/benchmark/tasks/utils.py | 27 +- .../benchmark/utils/file/load_tokenizer.py | 4 +- .../benchmark/utils/logging/error_codes.py | 1 + .../test_icl_base_api_inferencer.py | 15 +- .../icl_inferencer/test_icl_gen_inferencer.py | 135 +++- tests/UT/tasks/test_openicl_api_infer.py | 201 +++++- tests/UT/tasks/test_utils.py | 122 ++++ 17 files changed, 1168 insertions(+), 116 deletions(-) create mode 100644 ais_bench/benchmark/configs/datasets/mooncake_trace/mooncake_trace_gen.py create mode 100644 ais_bench/benchmark/datasets/mooncake_trace.py diff --git a/ais_bench/benchmark/cli/utils.py b/ais_bench/benchmark/cli/utils.py index 2aa699fe..01ff50e5 100644 --- a/ais_bench/benchmark/cli/utils.py +++ b/ais_bench/benchmark/cli/utils.py @@ -7,7 +7,8 @@ from ais_bench.benchmark.utils.logging.error_codes import UTILS_CODES DATASETS_NEED_MODELS = ["ais_bench.benchmark.datasets.synthetic.SyntheticDataset", - "ais_bench.benchmark.datasets.sharegpt.ShareGPTDataset"] + "ais_bench.benchmark.datasets.sharegpt.ShareGPTDataset", + "ais_bench.benchmark.datasets.mooncake_trace.MooncakeTraceDataset"] MAX_NUM_WORKERS = int(os.cpu_count() * 0.8) DEFAULT_PRESSURE_TIME = 15 MAX_PRESSURE_TIME = 60 * 60 * 24 # 24 hours diff --git a/ais_bench/benchmark/configs/datasets/mooncake_trace/mooncake_trace_gen.py b/ais_bench/benchmark/configs/datasets/mooncake_trace/mooncake_trace_gen.py new file mode 100644 index 00000000..57d5c073 --- /dev/null +++ b/ais_bench/benchmark/configs/datasets/mooncake_trace/mooncake_trace_gen.py @@ -0,0 +1,40 @@ +from ais_bench.benchmark.openicl.icl_prompt_template import PromptTemplate +from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever +from ais_bench.benchmark.openicl.icl_inferencer import GenInferencer +from ais_bench.benchmark.datasets import MooncakeTraceDataset, MooncakeTraceEvaluator + + +mooncake_trace_reader_cfg = dict[str, list[str] | str]( + input_columns=["prompt", "timestamp","max_out_len"], + output_column="answer" +) + + +mooncake_trace_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template="{prompt}" + ), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer) +) + +mooncake_trace_eval_cfg = dict( + evaluator=dict(type=MooncakeTraceEvaluator) +) + +mooncake_trace_datasets = [ + dict( + abbr='mooncake-trace', + type=MooncakeTraceDataset, + path='', # 数据集路径,使用相对路径时相对于源码根路径,支持绝对路径 + generated_prompts_path='', # 生成的prompt缓存路径,使用相对路径时相对于源码根路径,支持绝对路径 + random_seed=None, + fixed_schedule_auto_offset=False, + fixed_schedule_start_offset=0, + fixed_schedule_end_offset=-1, + reader_cfg=mooncake_trace_reader_cfg, + infer_cfg=mooncake_trace_infer_cfg, + eval_cfg=mooncake_trace_eval_cfg + ) +] \ No newline at end of file diff --git a/ais_bench/benchmark/datasets/__init__.py b/ais_bench/benchmark/datasets/__init__.py index 28b65005..1581a2af 100644 --- a/ais_bench/benchmark/datasets/__init__.py +++ b/ais_bench/benchmark/datasets/__init__.py @@ -29,7 +29,7 @@ from ais_bench.benchmark.datasets.vocalsound import * from ais_bench.benchmark.datasets.lambada import * # noqa: F401, F403 from ais_bench.benchmark.datasets.lcsts import * # noqa: F401, F403 -from ais_bench.benchmark.datasets.siqa import * # noqa: F401, F403 +from ais_bench.benchmark.datasets.siqa import * # noqa: F401, F403 from ais_bench.benchmark.datasets.xsum import * # noqa: F401, F403 from ais_bench.benchmark.datasets.sharegpt import * from ais_bench.benchmark.datasets.mtbench import * @@ -52,3 +52,4 @@ from ais_bench.benchmark.datasets.videomme import * from ais_bench.benchmark.datasets.mmstar import * # noqa: F401, F403 from ais_bench.benchmark.datasets.dapo_math import * # noqa: F401, F403 +from ais_bench.benchmark.datasets.mooncake_trace import * # noqa: F401, F403 diff --git a/ais_bench/benchmark/datasets/base.py b/ais_bench/benchmark/datasets/base.py index f6e157e9..de062a5d 100644 --- a/ais_bench/benchmark/datasets/base.py +++ b/ais_bench/benchmark/datasets/base.py @@ -11,9 +11,6 @@ disable_progress_bar() # disable mapping progress bar, preventing terminal interface contamination -logger = AISLogger() - - class BaseDataset: def __init__(self, @@ -22,25 +19,26 @@ def __init__(self, n: int = 1, **kwargs): # Validate k and n parameters + self.logger = AISLogger() max_k = max(k) if isinstance(k, List) else k if max_k > n: raise ParameterValueError( DSET_CODES.INVALID_REPEAT_FACTOR, f"Maximum value of `k` ({max_k}) must be less than or equal to `n` ({n})" ) - + self.abbr = kwargs.pop('abbr', 'dataset') - - logger.debug(f"Loading dataset: {self.abbr}") + + self.logger.debug(f"Loading dataset: {self.abbr}") self.dataset = self.load(**kwargs) - logger.debug(f"Dataset loaded successfully, initializing reader") + self.logger.debug(f"Dataset loaded successfully, initializing reader") self._init_reader(**reader_cfg) self.repeated_dataset(self.abbr, n) # this process will update self.dataset and self.reader.dataset def _init_reader(self, **kwargs): self.reader = DatasetReader(self.dataset, **kwargs) - + def repeated_dataset(self, abbr: str, n: int): # Create repeated indices in batches to avoid generating an oversized index list at once @@ -52,7 +50,7 @@ def create_repeated_indices(length: int, n: int, batch_size: int = 10000) -> Lis batch_indices = [i for i in range(start, end) for _ in range(n)] indices.extend(batch_indices) return indices - + if isinstance(self.reader.dataset, Dataset): # Add metadata fields (use batching for efficiency) base_size = len(self.reader.dataset) @@ -65,14 +63,14 @@ def create_repeated_indices(length: int, n: int, batch_size: int = 10000) -> Lis writer_batch_size=writer_batch_size, load_from_cache_file=False ) - + # Safely generate indices orig_len = len(dataset) indices = create_repeated_indices(orig_len, n, batch_size=index_gen_batch_size) - + # Achieve sample duplication through index selection self.reader.dataset = dataset.select(indices) - + else: # Handle DatasetDict cases new_dict = DatasetDict() @@ -88,14 +86,14 @@ def create_repeated_indices(length: int, n: int, batch_size: int = 10000) -> Lis writer_batch_size=writer_batch_size, load_from_cache_file=False ) - + orig_len = len(mapped_ds) indices = create_repeated_indices(orig_len, n, batch_size=index_gen_batch_size) - + new_dict[key] = mapped_ds.select(indices) - + self.reader.dataset = new_dict - + self.dataset = self.reader.dataset diff --git a/ais_bench/benchmark/datasets/mooncake_trace.py b/ais_bench/benchmark/datasets/mooncake_trace.py new file mode 100644 index 00000000..4f2f6189 --- /dev/null +++ b/ais_bench/benchmark/datasets/mooncake_trace.py @@ -0,0 +1,589 @@ +import json +import os +import random +import hashlib +from pathlib import Path +from typing import Any +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +from datasets import Dataset + +from ais_bench.benchmark.registry import LOAD_DATASET +from ais_bench.benchmark.datasets.utils.datasets import get_data_path +from ais_bench.benchmark.datasets.base import BaseDataset +from ais_bench.benchmark.utils.file.load_tokenizer import load_tokenizer +from ais_bench.benchmark.openicl.icl_evaluator.icl_base_evaluator import BaseEvaluator +from ais_bench.benchmark.utils.logging.error_codes import DSET_CODES +from ais_bench.benchmark.utils.logging.exceptions import ( + AISBenchDataContentError, + AISBenchRuntimeError, + ParameterValueError, +) + +# ============================================================================ +# 1. RNG Management System +# ============================================================================ + +class RNGManager: + """Random number generator manager""" + + def __init__(self, root_seed: int | None): + self._root_seed = root_seed + if root_seed is not None: + # Set global random seed (defensive measure) + random.seed(root_seed) + np_seed = (root_seed ^ (root_seed >> 32)) & 0xFFFFFFFF + np.random.seed(np_seed) + + def derive(self, identifier: str) -> random.Random: + """Derive a child RNG from an identifier""" + if self._root_seed is not None: + # Deterministic derivation: use SHA-256 hash + seed_string = f"{self._root_seed}:{identifier}" + hash_bytes = hashlib.sha256(seed_string.encode("utf-8")).digest() + child_seed = int.from_bytes(hash_bytes[:8], byteorder="big") + return random.Random(child_seed) + else: + # Non-deterministic: use system random + return random.Random() + + +_rng_manager: RNGManager | None = None + + +def init_rng(seed: int | None): + """Initialize global RNG manager""" + global _rng_manager + _rng_manager = RNGManager(seed) + + +def derive_rng(identifier: str) -> random.Random: + """Derive a child RNG""" + if _rng_manager is None: + raise AISBenchRuntimeError( + DSET_CODES.UNKNOWN_ERROR, + "RNG manager not initialized. Call init_rng() first." + ) + return _rng_manager.derive(identifier) + + +# ============================================================================ +# 2. Corpus Loading +# ============================================================================ + +DEFAULT_CORPUS_FILE = "assets/shakespeare.txt" +MAX_CHARS_PER_CHUNK = 10_000 + + +def initialize_corpus(tokenizer, corpus_path: Path) -> list[int]: + """ + Load and tokenize corpus + + Uses a character-based chunking strategy to ensure identical chunk boundaries + across different machines. + """ + with open(corpus_path, encoding="utf-8") as f: + lines = f.readlines() + + # Preprocessing: filter empty lines + non_empty_lines = [line.strip() for line in lines if line.strip()] + + def tokenize_chunk(chunk: list[str]) -> list[int]: + """Tokenize a text chunk""" + text = " ".join(chunk) + tokens = tokenizer.encode(text, add_special_tokens=False) # Returns token ID list + return tokens + + # Character-based chunking (deterministic chunking) + chunks = [] + buffer = [] + char_count = 0 + + for line in non_empty_lines: + buffer.append(line) + char_count += len(line) + + if char_count >= MAX_CHARS_PER_CHUNK: + chunks.append(buffer) + buffer = [] + char_count = 0 + + # Add remaining lines as the last chunk + if buffer: + chunks.append(buffer) + + # Multi-threaded tokenization (thread count doesn't affect reproducibility + # because chunking is deterministic) + num_threads = min(os.cpu_count() or 4, 8) + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + tokenized_chunks = list(executor.map(tokenize_chunk, chunks)) + + # Flatten all tokens + tokenized_corpus = [ + token for chunk in tokenized_chunks for token in chunk + ] + + return tokenized_corpus + + +# ============================================================================ +# 3. PromptGenerator +# ============================================================================ + +class PromptGenerator: + """Prompt generator""" + + def __init__(self, tokenizer, tokenized_corpus: list[int], root_seed: int | None = None, block_size: int = 512): + self.tokenizer = tokenizer + self._tokenized_corpus = tokenized_corpus + self._corpus_size = len(tokenized_corpus) + + # Initialize RNG (for corpus sampling) + self._corpus_rng = derive_rng("dataset.prompt.corpus") + + # Hash ID cache: hash_id -> token list + self._cache: dict[int, list[int]] = {} + + # Block size (default 512 tokens) + self.block_size = block_size + + def generate( + self, + mean: int | None = None, + stddev: int | None = None, + hash_ids: list[int] | None = None, + ) -> str: + """ + Main entry point for generating prompts + + Args: + mean: Target number of tokens (if using hash_ids, this is the total token count) + stddev: Standard deviation (usually 0 for hash_ids mode) + hash_ids: Hash ID list for cache reuse + + Returns: + Generated prompt text + """ + if hash_ids: + return self._generate_cached_prompt( + mean, hash_ids, self.block_size + ) + + # No hash_ids: sample token count using normal distribution + num_tokens = self._sample_num_tokens(mean, stddev) + return self.generate_prompt(num_tokens) + + def generate_prompt(self, num_tokens: int) -> str: + """Generate a prompt with specified number of tokens""" + tokens = self._sample_tokens(num_tokens) + return self.tokenizer.decode(tokens, skip_special_tokens=False) + + def _sample_tokens(self, num_tokens: int) -> list[int]: + """ + Sample specified number of tokens from corpus + + Uses circular sampling: if beyond corpus end, continue from the beginning. + """ + if num_tokens > self._corpus_size: + # If requested token count exceeds corpus size, return entire corpus + return self._tokenized_corpus.copy() + + # Randomly select starting position + start_idx = self._corpus_rng.randrange(self._corpus_size) + + end_idx = start_idx + num_tokens + prompt_tokens = self._tokenized_corpus[start_idx:end_idx] + + # If beyond corpus end, continue from the beginning + if end_idx > self._corpus_size: + prompt_tokens += self._tokenized_corpus[: end_idx - self._corpus_size] + + return prompt_tokens + + def _sample_num_tokens(self, mean: int | None, stddev: int | None) -> int: + """Sample token count from normal distribution""" + if mean is None: + raise ParameterValueError( + DSET_CODES.MISSING_REQUIRED_PARAM, + "mean must be provided" + ) + + if stddev is None or stddev == 0: + return mean + + # Sample using normal distribution (ensure positive integer) + length_rng = derive_rng("dataset.prompt.length") + while True: + value = int(length_rng.gauss(mean, stddev)) + if value > 0: + return value + + def _generate_cached_prompt( + self, + num_tokens: int, + hash_ids: list[int], + block_size: int, + ) -> str: + """ + Generate prompt based on hash_ids (using cache mechanism) + + Each hash_id corresponds to a token block. If hash_id is in cache, + reuse cached tokens; otherwise generate new tokens and cache them. + + Args: + num_tokens: Total number of tokens + hash_ids: Hash ID list + block_size: Number of tokens per hash block (default 512) + + Returns: + Generated prompt text + """ + final_prompt: list[int] = [] + current_block_size = block_size + + # Calculate size of the last block + final_block_size = num_tokens - ((len(hash_ids) - 1) * block_size) + + # Validate parameters + if final_block_size <= 0 or block_size < final_block_size: + raise ParameterValueError( + DSET_CODES.INVALID_PARAM_VALUE, + f"Input length: {num_tokens}, Hash IDs: {hash_ids}, Block size: {block_size} " + f"are not compatible. Final block size: {final_block_size} must be > 0 and <= {block_size}." + ) + + # Process each hash_id + for index, hash_id in enumerate(hash_ids): + # Last hash_id uses remaining tokens + if index == len(hash_ids) - 1: + current_block_size = final_block_size + + # If hash_id not in cache, generate and cache + if hash_id not in self._cache: + prompt_tokens: list[int] = [] + + # If tokenizer supports block separator, insert BOS/EOS token + # This ensures different blocks won't merge + block_separation_token_id = getattr( + self.tokenizer, 'block_separation_token_id', None + ) + + if block_separation_token_id is not None: + prompt_tokens.append(block_separation_token_id) + prompt_tokens += self._sample_tokens(current_block_size - 1) + else: + prompt_tokens += self._sample_tokens(current_block_size) + + # Cache token list + self._cache[hash_id] = prompt_tokens + + # Reuse cached tokens + final_prompt.extend(self._cache[hash_id]) + + # Decode to text (don't skip special tokens, preserve block separator) + return self.tokenizer.decode(final_prompt, skip_special_tokens=False) + + +# ============================================================================ +# 4. Mooncake Trace Data Model +# ============================================================================ + +class MooncakeTrace: + """Mooncake trace data model""" + + def __init__(self, data: dict[str, Any]): + # Support input_text field: if input_text exists, input_length and hash_ids become optional + self.input_text = data.get("input_text") + + if self.input_text is None: + # If no input_text, input_length must exist + if "input_length" not in data or data["input_length"] is None: + raise ParameterValueError( + DSET_CODES.MISSING_REQUIRED_PARAM, + "Either 'input_text' or 'input_length' must be provided" + ) + self.input_length = data["input_length"] + self.hash_ids = data.get("hash_ids") + else: + # If input_text exists, input_length and hash_ids become optional (will be ignored) + self.input_length = data.get("input_length") + self.hash_ids = data.get("hash_ids") + + self.output_length = data.get("output_length") + self.timestamp = data.get("timestamp") + + +def load_mooncake_trace(filename: str) -> list[MooncakeTrace]: + """ + Load Mooncake trace data from JSONL file + + Returns: + List of trace data + """ + traces = [] + + with open(filename, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + + traces.append(MooncakeTrace(json.loads(line))) + + return traces + + +def _process_timestamps( + traces: list[MooncakeTrace], + auto_offset: bool = False, + start_offset: int = 0, + end_offset: int = -1, +) -> list[MooncakeTrace]: + """ + Process timestamps: apply auto offset, start offset, and end offset + + Args: + traces: Original trace data list + auto_offset: Whether to auto-offset timestamps (make first timestamp 0) + start_offset: Start offset in milliseconds, filter out timestamps less than this offset + end_offset: End offset in milliseconds, -1 means no limit, >=0 filters out timestamps greater than this offset + + Returns: + Processed trace data list + """ + if not traces: + return traces + + # Filter traces without timestamps (keep them but they won't participate in timestamp processing) + traces_with_timestamp = [(i, trace) for i, trace in enumerate(traces) if trace.timestamp is not None] + traces_without_timestamp = [(i, trace) for i, trace in enumerate(traces) if trace.timestamp is None] + + if not traces_with_timestamp: + # If no timestamps, return original list directly + return traces + + # Extract timestamp list + timestamps = [trace.timestamp for _, trace in traces_with_timestamp] + + # 1. Apply auto_offset: make first timestamp 0 + if auto_offset: + min_timestamp = min(timestamps) + if min_timestamp > 0: + # Subtract minimum from all timestamps + for _, trace in traces_with_timestamp: + trace.timestamp = trace.timestamp - min_timestamp + # Update timestamps list + timestamps = [t - min_timestamp for t in timestamps] + + # 2. Apply start_offset and end_offset filtering + filtered_indices = [] + for idx, (original_idx, trace) in enumerate(traces_with_timestamp): + timestamp = trace.timestamp + # Check if within range + if timestamp < start_offset: + continue + if end_offset >= 0 and timestamp > end_offset: + continue + filtered_indices.append(original_idx) + + # 3. Build result list: keep filtered traces and traces without timestamps + result_traces = [] + filtered_set = set(filtered_indices) + without_timestamp_set = {i for i, _ in traces_without_timestamp} + for i, trace in enumerate(traces): + if i in filtered_set or i in without_timestamp_set: + result_traces.append(trace) + + return result_traces + + +# ============================================================================ +# 5. MooncakeTraceDataset +# ============================================================================ + +@LOAD_DATASET.register_module() +class MooncakeTraceDataset(BaseDataset): + def load( + self, + path, + model_path, + random_seed=None, + generated_prompts_path="", + fixed_schedule_auto_offset=False, + fixed_schedule_start_offset=0, + fixed_schedule_end_offset=-1, + ): + """ + Load Mooncake trace dataset + + Args: + path: Path to JSONL file containing hashid and trace data + model_path: Model path for loading tokenizer + random_seed: Random seed + generated_prompts_path: Path to generated prompt cache file, will be reused if exists + fixed_schedule_auto_offset: Whether to auto-offset timestamps (make first timestamp 0), default False + fixed_schedule_start_offset: Start offset in milliseconds, default 0 + fixed_schedule_end_offset: End offset in milliseconds, -1 means no limit, default -1 + + Returns: + Dataset: Dataset containing prompt, timestamp, and max_out_len fields + """ + # Parameter validation + if fixed_schedule_end_offset >= 0 and fixed_schedule_start_offset > fixed_schedule_end_offset: + raise ParameterValueError( + DSET_CODES.INVALID_PARAM_VALUE, + f"fixed_schedule_start_offset ({fixed_schedule_start_offset}) must be <= " + f"fixed_schedule_end_offset ({fixed_schedule_end_offset})" + ) + + path = get_data_path(path) + self.logger.info(f"Loading mooncake trace dataset from: {path}") + + # Process generated_prompts_path: consider fixed_schedule parameters to generate unique cache filename + if not generated_prompts_path: + # Generate default cache file path: append _generated_prompts to original filename + dir_name = os.path.dirname(path) + base_name = os.path.basename(path) + name_without_ext, ext = os.path.splitext(base_name) + + # If fixed_schedule parameters are used, include them in cache filename + cache_suffix = "_generated_prompts" + if fixed_schedule_auto_offset or fixed_schedule_start_offset != 0 or fixed_schedule_end_offset != -1: + schedule_params = [] + if fixed_schedule_auto_offset: + schedule_params.append("auto") + if fixed_schedule_start_offset != 0: + schedule_params.append(f"start{fixed_schedule_start_offset}") + if fixed_schedule_end_offset != -1: + schedule_params.append(f"end{fixed_schedule_end_offset}") + cache_suffix += "_" + "_".join(schedule_params) + + generated_prompts_path = os.path.join( + dir_name, f"{name_without_ext}{cache_suffix}{ext}" + ) + else: + generated_prompts_path = get_data_path(generated_prompts_path) + + self.logger.info(f"Generated prompts cache path: {generated_prompts_path}") + + # Check if cache file exists, if exists load directly + if os.path.exists(generated_prompts_path): + self.logger.warning(f"Found existing generated prompts file, loading from: {generated_prompts_path}, if you want to regenerate the prompts, please delete the file and run again.") + dataset_list = [] + with open(generated_prompts_path, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + dataset_list.append(json.loads(line.strip())) + self.logger.info(f"Successfully loaded {len(dataset_list)} items from cache file") + return Dataset.from_list(dataset_list) + + # If file doesn't exist, need to generate dataset + self.logger.info(f"Cache file not found, generating prompts from source file") + + # 1. Initialize RNG system + init_rng(random_seed) + + # 2. Load tokenizer + self.logger.info(f"Loading tokenizer from: {model_path}") + tokenizer = load_tokenizer(model_path) + self.logger.info(f"Tokenizer loaded successfully, vocab_size: {tokenizer.vocab_size}") + + # 3. Load and tokenize corpus + # Try to find corpus file from multiple possible locations + corpus_path = None + possible_paths = [ + Path(__file__).parent.parent.parent / "third_party/aiperf/assets/shakespeare.txt", + ] + + for p in possible_paths: + if p.exists(): + corpus_path = p + break + + if corpus_path is None: + # If not found, try to copy from aiperf or use absolute path + # Here we use a fallback: if file not found, create a simple message + raise AISBenchDataContentError( + DSET_CODES.FILE_NOT_FOUND, + f"Corpus file not found. Please ensure {DEFAULT_CORPUS_FILE} exists in " + f"{[str(p) for p in possible_paths]}" + ) + + self.logger.info(f"Loading corpus from: {corpus_path}") + tokenized_corpus = initialize_corpus(tokenizer, corpus_path) + self.logger.info(f"Corpus loaded successfully, {len(tokenized_corpus)} tokens") + + # 4. Create PromptGenerator + prompt_generator = PromptGenerator(tokenizer, tokenized_corpus, root_seed=random_seed) + + # 5. Load Mooncake trace data + trace_data = load_mooncake_trace(path) + self.logger.info(f"Loaded {len(trace_data)} traces from source file") + + # 6. Process timestamps (apply fixed_schedule parameters) + if fixed_schedule_auto_offset or fixed_schedule_start_offset != 0 or fixed_schedule_end_offset != -1: + original_count = len(trace_data) + trace_data = _process_timestamps( + trace_data, + auto_offset=fixed_schedule_auto_offset, + start_offset=fixed_schedule_start_offset, + end_offset=fixed_schedule_end_offset, + ) + self.logger.info( + f"Applied timestamp processing: {original_count} -> {len(trace_data)} traces " + f"(auto_offset={fixed_schedule_auto_offset}, " + f"start_offset={fixed_schedule_start_offset}, " + f"end_offset={fixed_schedule_end_offset})" + ) + + # 7. Convert to prompts + prompts = [] + input_text_count = 0 + generated_count = 0 + + for trace in trace_data: + # Check if input_text field exists + if trace.input_text is not None: + # Directly use input_text as prompt + prompt = trace.input_text + input_text_count += 1 + else: + # Use PromptGenerator to generate prompt + prompt = prompt_generator.generate( + mean=trace.input_length, + stddev=0, # Mooncake trace usually doesn't use standard deviation + hash_ids=trace.hash_ids or [], + ) + generated_count += 1 + + item = { + "prompt": prompt, + "max_out_len": trace.output_length or 0, + "answer": "mock_answer", + } + # If timestamp exists, add to result + if trace.timestamp is not None: + item["timestamp"] = trace.timestamp + else: + item["timestamp"] = 0 # Default value + + prompts.append(item) + + if input_text_count > 0: + self.logger.info(f"Used input_text for {input_text_count} traces, generated prompts for {generated_count} traces") + self.logger.info(f"Generated {len(prompts)} prompts, saving to cache file: {generated_prompts_path}") + # Save to cache file + prompts.sort(key=lambda x: x["timestamp"]) + with open(generated_prompts_path, "w", encoding="utf-8") as f: + for item in prompts: + f.write(json.dumps(item, ensure_ascii=False) + "\n") + + self.logger.info(f"Successfully saved generated prompts to: {generated_prompts_path}") + + return Dataset.from_list(prompts) + + +class MooncakeTraceEvaluator(BaseEvaluator): + pass diff --git a/ais_bench/benchmark/models/output.py b/ais_bench/benchmark/models/output.py index 53041fdc..f0676ae3 100644 --- a/ais_bench/benchmark/models/output.py +++ b/ais_bench/benchmark/models/output.py @@ -17,7 +17,7 @@ def __init__(self, perf_mode: bool = False) -> None: self.extra_perf_data: dict = {} self.extra_details_data: dict = {} self.input: list | str = None - self.uuid: str = "" + self.uuid: str = "" # A unique identifier for each case: # In multi-turn dialogue scenarios, all turns of the same sample share the same uuid. # In pass@k scenarios, the same sample is sampled k times and each run receives a distinct uuid diff --git a/ais_bench/benchmark/openicl/icl_dataset_reader.py b/ais_bench/benchmark/openicl/icl_dataset_reader.py index 8dd63173..2fccdd35 100644 --- a/ais_bench/benchmark/openicl/icl_dataset_reader.py +++ b/ais_bench/benchmark/openicl/icl_dataset_reader.py @@ -107,6 +107,12 @@ def get_max_out_len(self): else: return None + def get_timestamp(self): + if "timestamp" in self.input_columns and "timestamp" in self.dataset['test'].features: + return self.dataset['test']['timestamp'] + else: + return None + def load_partial_dataset( dataset: Dataset, diff --git a/ais_bench/benchmark/openicl/icl_inferencer/icl_base_api_inferencer.py b/ais_bench/benchmark/openicl/icl_inferencer/icl_base_api_inferencer.py index d2fb8c58..94e074e7 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/icl_base_api_inferencer.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/icl_base_api_inferencer.py @@ -32,6 +32,21 @@ DEFAULT_DATA_FETCH_SIZE_FACTOR = 0.1 # default data fetch size factor is 0.1 of batch size +class ApiInferencerConfig: + def __init__(self, global_index: mp.RawValue, global_lock: mp.Lock, use_timestamp: bool = False, total_data_count: int = 0): + self.global_index = global_index + self.global_lock = global_lock + self.use_timestamp = use_timestamp + self.total_data_count = total_data_count + + def to_dict(self): + return { + "global_index": self.global_index, + "global_lock": self.global_lock, + "use_timestamp": self.use_timestamp, + "total_data_count": self.total_data_count, + } + class BaseApiInferencer(BaseInferencer): """Base Inferencer class for all evaluation Inferencer. @@ -65,15 +80,14 @@ def __init__( # Cache for batch-prefetched data self._data_cache = [] # Thread-local cache for batch data self.total_data_count = 0 + self.use_timestamp = False - def set_global_index(self, global_index: mp.RawValue): - self.global_index = global_index - - def set_global_lock(self, global_lock: mp.Lock): - self.global_lock = global_lock - - def set_data_count(self, data_num: int): - self.total_data_count = data_num + def set_config(self, config: ApiInferencerConfig): + for key, value in config.to_dict().items(): + if hasattr(self, key): + setattr(self, key, value) + else: + self.logger.debug(f"Unknown config key: {key}, skip") def _monitor_status_thread( self, @@ -277,7 +291,6 @@ def _get_single_data( end_index = (cur_index + 1) % len(indexes) # Update global index self.global_index.value = end_index - # Prefetch all data in the batch batch_data = [] for data_index in data_indices: @@ -393,7 +406,7 @@ async def _worker_loop( num_workers = self.batch_size if self.batch_size and self.batch_size > 0 else 1 # Limit maximum concurrency - semaphore = asyncio.Semaphore(num_workers) if num_workers else None + semaphore = asyncio.Semaphore(num_workers) if num_workers and not self.use_timestamp else None # Reuse session to improve concurrency connector = aiohttp.TCPConnector(limit=num_workers + 1) timeout = aiohttp.ClientTimeout(total=get_request_time_out()) @@ -417,6 +430,7 @@ async def limited_request_func(data): ICLI_CODES.CONCURRENCY_NOT_SET_IN_PRESSEURE_MODE, f"Concurrency not set in pressure mode, please set `batch_size` in model config", ) + return async with semaphore: # Pressure mode: continuously send requests until pressure_time if self.pressure_mode: @@ -454,18 +468,20 @@ async def limited_request_func(data): acquired = await asyncio.to_thread(token_bucket.acquire, timeout=1) if not acquired: continue + data = await self.wait_get_data(async_queue, stop_event) else: - # Slightly limit RR when no token to avoid high CPU usage causing TTFT accumulation - await asyncio.sleep(BLOCK_INTERVAL) - - data = await self.wait_get_data(async_queue, stop_event) + data = await self.wait_get_data(async_queue, stop_event) + if data: + await asyncio.sleep(BLOCK_INTERVAL) - # data == None -> sentinel if data is None: await asyncio.wait_for(async_queue.put(None), timeout=1) break # Call user-provided async request tasks.append(asyncio.create_task(limited_request_func(data))) + if len(tasks) > num_workers: + self.logger.warning(f"Process[{os.getpid()}] concurrency ({len(tasks)}) exceeds limit ({num_workers}). " + "Consider increasing `WORKERS_NUM` or unset `--debug` for better performance.") # Pressure mode: exit when max concurrency is reached if self.pressure_mode: if len(tasks) >= num_workers: # max concurrency is reached @@ -475,8 +491,7 @@ async def limited_request_func(data): self.logger.warning( f"Pressure mode: process {os.getpid()} exited before entering a stable state " "because the pressure timeout was hit. Consider increase the `request_rate` " - "in the model config, or increasing " - "`WORKERS_NUM` in global_consts.py to enhance concurrency." + "in the model config, or increasing `WORKERS_NUM` in global_consts.py and unset `--debug` to enhance concurrency." ) stop_event.set() break diff --git a/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_inferencer.py b/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_inferencer.py index edfdb2c7..1ef0ed65 100644 --- a/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_inferencer.py +++ b/ais_bench/benchmark/openicl/icl_inferencer/icl_gen_inferencer.py @@ -6,7 +6,6 @@ from typing import List, Optional import aiohttp -from tqdm import tqdm from ais_bench.benchmark.models.output import RequestOutput from ais_bench.benchmark.registry import ICL_INFERENCERS @@ -165,4 +164,8 @@ def get_data_list( for index, max_out_len in enumerate(max_out_lens): data_list[index]["max_out_len"] = max_out_len if max_out_len else self.model.max_out_len + timestamps = retriever.dataset_reader.get_timestamp() + if timestamps is not None: + for index, timestamp in enumerate(timestamps): + data_list[index]["timestamp"] = timestamp / 1000 # ms to s return data_list \ No newline at end of file diff --git a/ais_bench/benchmark/tasks/openicl_api_infer.py b/ais_bench/benchmark/tasks/openicl_api_infer.py index 60e3e2bf..df07f24c 100644 --- a/ais_bench/benchmark/tasks/openicl_api_infer.py +++ b/ais_bench/benchmark/tasks/openicl_api_infer.py @@ -13,6 +13,8 @@ from mmengine.config import Config, ConfigDict from collections import defaultdict +from mmengine.utils.dl_utils import time_counter + from ais_bench.benchmark.global_consts import WORKERS_NUM from ais_bench.benchmark.registry import ICL_INFERENCERS, TASKS, ICL_RETRIEVERS from ais_bench.benchmark.tasks.base import BaseTask, TaskStateManager @@ -23,7 +25,7 @@ TokenProducer, format_dict_as_table, ) -from ais_bench.benchmark.openicl.icl_inferencer.icl_base_api_inferencer import BaseApiInferencer +from ais_bench.benchmark.openicl.icl_inferencer.icl_base_api_inferencer import BaseApiInferencer, ApiInferencerConfig from ais_bench.benchmark.utils.core.abbr import task_abbr_from_cfg, merge_dataset_abbr_from_cfg from ais_bench.benchmark.utils.config import build_dataset_from_cfg from ais_bench.benchmark.utils.logging.error_codes import TINFER_CODES @@ -44,9 +46,8 @@ def run_single_inferencer( max_concurrency: int, indexes: Dict, token_bucket: BoundedSemaphore, - total_data_num: int, - global_index: mp.RawValue = None, - global_lock: mp.Lock = None, + api_inferencer_config: ApiInferencerConfig, + ): """Run a single inferencer that reads samples from shared memory. @@ -62,12 +63,9 @@ def run_single_inferencer( """ inferencer_cfg["model_cfg"] = model_cfg inferencer_cfg["batch_size"] = max_concurrency - inferencer = ICL_INFERENCERS.build(inferencer_cfg) - inferencer.set_global_index(global_index) - inferencer.set_global_lock(global_lock) - inferencer.set_data_count(total_data_num) + inferencer: BaseApiInferencer = ICL_INFERENCERS.build(inferencer_cfg) # pressure mode each process has a copy of the data list - + inferencer.set_config(api_inferencer_config) inferencer.inference_with_shm( shm_name, message_shm_name, @@ -311,10 +309,7 @@ def _run_debug( dataset_shm: Shared memory containing dataset message_shm: Shared memory for message passing indexes: Indexes for data - data_index_value: Value for data index token_bucket: Token bucket for rate limiting - global_index: Global index for data - global_lock: Global lock for data """ if self.concurrency > CONCURRENCY_PER_PROCESS: self.logger.warning( @@ -323,7 +318,7 @@ def _run_debug( ) else: self.logger.info(f"Debug mode, run with concurrency: {self.concurrency}") - self.inferencer.set_data_count(len(indexes)) + self.inferencer.total_data_count = len(indexes) self.inferencer.inference_with_shm(dataset_shm.name, message_shm.name, indexes, token_bucket) def _run_multi_process( @@ -338,7 +333,6 @@ def _run_multi_process( Args: dataset_shm: Shared memory containing dataset indexes: Indexes for data - data_index_value: Value for data index token_bucket: Token bucket for rate limiting message_shms: List to store message shared memory objects (mutated) @@ -375,9 +369,12 @@ def _run_multi_process( concurrency, indexes, token_bucket, - per_worker_data_num[i], - global_index, - global_lock, + ApiInferencerConfig( + global_index=global_index, + global_lock=global_lock, + use_timestamp=self.inferencer.use_timestamp, + total_data_count=per_worker_data_num[i], + ), ), ) @@ -400,6 +397,18 @@ def _run_multi_process( self._cleanup_shms(message_shm) return processes + def _get_time_stamps(self, data_list: List): + """Get timestamps from data_list. + """ + timestamps = [] + for data in data_list: + if data.get("timestamp", None) is not None: + timestamps.append(data["timestamp"]) + if len(timestamps) > 0: + self.inferencer.use_timestamp = True + self.logger.warning("Found timestamps in datasets, use timestamps for request delay, `request_rate` config will be ignored!") + return timestamps + def warm_up(self, data_list: List, task_state_manager: TaskStateManager): """Warm up the inferencer. @@ -459,17 +468,22 @@ def run(self, task_state_manager: TaskStateManager): if len(data_list) == 0: self.logger.warning(f"Get no data to infer, task finished") return + + # get timestamps from data_list + timestamps = self._get_time_stamps(data_list) + self.warm_up(data_list, task_state_manager) dataset_size, dataset_shm, indexes = self._dump_dataset_to_share_memory(data_list, global_indexes) # In pressure mode, treat the first `concurrency` requests as the dataset size if self.pressure: request_num = self.concurrency else: - request_num = dataset_size + request_num = len(global_indexes) # Create token producer token_producer = TokenProducer( self.model_cfg.pop("request_rate", 0), + timestamps, self.model_cfg.pop("traffic_cfg", {}), request_num, self.task_mode, @@ -679,4 +693,4 @@ def parse_args(): end_time = time.perf_counter() logger.info(f"Api infer task time elapsed: {end_time - start_time:.2f}s") task_state_manager.update_task_state({"status": "finish"}) - manager_t.join() + manager_t.join() \ No newline at end of file diff --git a/ais_bench/benchmark/tasks/utils.py b/ais_bench/benchmark/tasks/utils.py index bba230a6..bcd90ca4 100644 --- a/ais_bench/benchmark/tasks/utils.py +++ b/ais_bench/benchmark/tasks/utils.py @@ -1,7 +1,6 @@ import os import time import struct -from collections import OrderedDict from typing import Dict, List, Any from multiprocessing import Event, shared_memory, BoundedSemaphore @@ -376,6 +375,7 @@ class TokenProducer: def __init__( self, request_rate: float, + time_stamps: List[float], traffic_cfg: ConfigDict, request_num: int = None, mode: str = "infer", @@ -384,6 +384,7 @@ def __init__( """ Args: request_rate: Desired request rate (RPS) used to pace requests. + time_stamps: List of timestamps in seconds. traffic_cfg: Traffic configuration controlling ramp-up and burstiness. request_num: Total number of requests to schedule when known. pressure_mode: If True, after generating the first `request_num` tokens @@ -407,7 +408,7 @@ def __init__( self.burstiness = 1.0 self.work_dir = work_dir # When request_rate < 0.1, treat as infinite (no pacing applied here) - if self.request_rate < FINAL_RPS_MINIMUM_THRESHOLD: + if self.request_rate < FINAL_RPS_MINIMUM_THRESHOLD and not time_stamps: self.token_bucket = None if self.pressure_mode: self.logger.warning("Pressure mode with no request rate applied, concurrency will increase rapidly") @@ -417,10 +418,14 @@ def __init__( for _ in range(request_num + 1): self.token_bucket.acquire() + self.interval_lists = [] + # If timestamps are provided, use them directly + if time_stamps: + self.interval_lists = time_stamps[:] + return + # If `traffic_cfg` is provided, pre-generate `interval_lists` for ramp-up; after # exhausting it, fall back to gamma-distributed intervals based on request_rate. - self.interval_lists = [] - # if traffic_cfg: self.burstiness = float(traffic_cfg.get("burstiness", self.burstiness)) ramp_up_strategy = traffic_cfg.get("ramp_up_strategy") ramp_up_start_rps = traffic_cfg.get("ramp_up_start_rps") @@ -632,11 +637,15 @@ def produce_token(self, stop_evt: Event, per_pid_shms: Dict[int, shared_memory.S interval_index = 0 theta = 1.0 / (self.request_rate * self.burstiness) - start_time = time.perf_counter() + start_time = time.perf_counter() + self.interval_lists[0] while not stop_evt.is_set(): if interval_index < len(self.interval_lists): interval = self.interval_lists[interval_index] + current_time = time.perf_counter() + sleep_interval = interval - (current_time - start_time) + if sleep_interval > 0: + time.sleep(sleep_interval) try: self.token_bucket.release() except ValueError as e: @@ -645,11 +654,11 @@ def produce_token(self, stop_evt: Event, per_pid_shms: Dict[int, shared_memory.S wait_interval = np.random.gamma(shape=self.burstiness, scale=theta) time.sleep(wait_interval) continue - current_time = time.perf_counter() - sleep_interval = interval - (current_time - start_time) - if sleep_interval > 0: - time.sleep(sleep_interval) interval_index += 1 + elif not self.pressure_mode: + self.token_bucket.release() # realse None token to avoid deadlock + interval = np.random.gamma(shape=self.burstiness, scale=theta) + time.sleep(interval) else: try: # After first batch requests are sent, subsequent requests diff --git a/ais_bench/benchmark/utils/file/load_tokenizer.py b/ais_bench/benchmark/utils/file/load_tokenizer.py index f15eb01a..aba0ad00 100644 --- a/ais_bench/benchmark/utils/file/load_tokenizer.py +++ b/ais_bench/benchmark/utils/file/load_tokenizer.py @@ -47,7 +47,7 @@ class AISTokenizer: def __init__(self, tokenizer_path: str): self.tokenizer = load_tokenizer(tokenizer_path) - def encode(self, prompt: list) -> Tuple[float, List[int]]: + def encode(self, prompt: list, add_special_tokens: bool = True) -> Tuple[float, List[int]]: """Encode a string into tokens, measuring processing time.""" if isinstance(prompt, list): try: @@ -64,7 +64,7 @@ def encode(self, prompt: list) -> Tuple[float, List[int]]: else: logger.debug(f"Prompt: {prompt} is not a list or string.") return [] - tokens = self.tokenizer.encode(messages) + tokens = self.tokenizer.encode(messages, add_special_tokens=add_special_tokens) return tokens def decode(self, tokens: List[int]) -> Tuple[List[float], str]: diff --git a/ais_bench/benchmark/utils/logging/error_codes.py b/ais_bench/benchmark/utils/logging/error_codes.py index 5becaa1c..3520613e 100644 --- a/ais_bench/benchmark/utils/logging/error_codes.py +++ b/ais_bench/benchmark/utils/logging/error_codes.py @@ -274,6 +274,7 @@ class DSET_CODES: # Parameter related errors INVALID_REPEAT_FACTOR = BaseErrorCode("DSET-PARAM-002", ErrorModule.DATASET, ErrorType.PARAM, 2, "invalid repeat factor") # docs coverd INVALID_PARAM_VALUE = BaseErrorCode("DSET-PARAM-004", ErrorModule.DATASET, ErrorType.PARAM, 4, "invalid parameter value") # docs coverd + MISSING_REQUIRED_PARAM = BaseErrorCode("DSET-PARAM-005", ErrorModule.DATASET, ErrorType.PARAM, 5, "missing required parameter") # docs coverd # Dependency related errors EVALUATION_LIBRARY_NOT_INSTALLED = BaseErrorCode("DSET-DEPENDENCY-002", ErrorModule.DATASET, ErrorType.DEPENDENCY, 2, "evaluation library not installed") # docs coverd diff --git a/tests/UT/openicl/icl_inferencer/test_icl_base_api_inferencer.py b/tests/UT/openicl/icl_inferencer/test_icl_base_api_inferencer.py index 9a65a3de..5e4c9cff 100644 --- a/tests/UT/openicl/icl_inferencer/test_icl_base_api_inferencer.py +++ b/tests/UT/openicl/icl_inferencer/test_icl_base_api_inferencer.py @@ -132,7 +132,13 @@ def test_get_single_data(self, m_abbr, m_build): """测试_get_single_data方法从共享内存读取数据""" m_build.return_value = DummyModel() inf = ConcreteApiInferencer(model_cfg={}) - inf.set_data_count(1) + # 使用set_config设置total_data_count + from ais_bench.benchmark.openicl.icl_inferencer.icl_base_api_inferencer import ApiInferencerConfig + import multiprocessing as mp + global_index = mp.RawValue('i', 0) + global_lock = mp.Lock() + config = ApiInferencerConfig(global_index=global_index, global_lock=global_lock, use_timestamp=False, total_data_count=1) + inf.set_config(config) test_data = {"test": "data"} pickled_data = pickle.dumps(test_data) @@ -159,7 +165,12 @@ def test_get_single_data_wait_flag(self, m_abbr, m_build): m_build.return_value = DummyModel() inf = ConcreteApiInferencer(model_cfg={}) # 设置 total_data_count,否则在非 pressure 模式下会提前返回 None - inf.set_data_count(2) + from ais_bench.benchmark.openicl.icl_inferencer.icl_base_api_inferencer import ApiInferencerConfig + import multiprocessing as mp + global_index = mp.RawValue('i', 0) + global_lock = mp.Lock() + config = ApiInferencerConfig(global_index=global_index, global_lock=global_lock, use_timestamp=False, total_data_count=2) + inf.set_config(config) test_data1 = {"test": "data1"} test_data2 = {"test": "data2"} diff --git a/tests/UT/openicl/icl_inferencer/test_icl_gen_inferencer.py b/tests/UT/openicl/icl_inferencer/test_icl_gen_inferencer.py index 9fd98d70..906cad22 100644 --- a/tests/UT/openicl/icl_inferencer/test_icl_gen_inferencer.py +++ b/tests/UT/openicl/icl_inferencer/test_icl_gen_inferencer.py @@ -10,7 +10,11 @@ class DummyDataset: def __init__(self): - self.reader = type("R", (), {"output_column": "label", "get_max_out_len": lambda self=None: [5, 6]})() + self.reader = type("R", (), { + "output_column": "label", + "get_max_out_len": lambda self=None: [5, 6], + "get_timestamp": lambda self=None: None + })() self.train = Dataset.from_dict({"text": ["t0", "t1"], "label": [0, 1]}) self.test = Dataset.from_dict({"text": ["a", "b"], "label": [0, 1]}) self.abbr = "abbrd" @@ -79,7 +83,7 @@ def test_batch_inference_async(self, m_abbr, m_build): """测试GenInferencer的batch_inference方法调用model.generate和report_cache_info_sync""" m_build.return_value = DummyModel() inf = GenInferencer(model_cfg={}, batch_size=1) - + datum = { "index": [[0], [1]], "prompt": [["p0"], ["p1"]], @@ -87,12 +91,12 @@ def test_batch_inference_async(self, m_abbr, m_build): "max_out_len": [8, 8], "gold": [["g0"], ["g1"]], } - + try: inf.batch_inference(datum) except Exception as e: self.fail(f"batch_inference raised {type(e).__name__}: {e}") - + self.assertTrue(hasattr(DummyModel, 'generate')) @mock.patch("ais_bench.benchmark.openicl.icl_inferencer.icl_base_inferencer.build_model_from_cfg") @@ -101,7 +105,7 @@ def test_batch_inference_with_is_api(self, m_abbr, m_build): """测试GenInferencer在is_api=True时使用列表形式的max_out_len调用generate""" m_build.return_value = DummyModel(is_api=True) inf = GenInferencer(model_cfg={}, batch_size=1) - + datum = { "index": [[0], [1]], "prompt": [["p0"], ["p1"]], @@ -109,10 +113,10 @@ def test_batch_inference_with_is_api(self, m_abbr, m_build): "max_out_len": [8, 8], "gold": [["g0"], ["g1"]], } - + inf.model.generate = mock.Mock(return_value=["out1", "out2"]) inf.output_handler.report_cache_info_sync = mock.Mock(return_value=True) - + inf.batch_inference(datum) inf.model.generate.assert_called_once() @@ -124,14 +128,14 @@ def test_do_request(self, m_abbr, m_build): inf = GenInferencer(model_cfg={}, batch_size=1) inf.status_counter = DummyStatusCounter() inf.output_handler.report_cache_info = mock.AsyncMock(return_value=True) - + async def async_generate(inputs, max_out_len, output=None, session=None, **kwargs): if output: output.success = True output.content = "test_output" - + inf.model.generate = async_generate - + data = { "index": 0, "prompt": "test_input", @@ -139,13 +143,13 @@ async def async_generate(inputs, max_out_len, output=None, session=None, **kwarg "max_out_len": 10, "gold": "test_gold", } - + async def run_test(): token_bucket = mock.Mock() session = mock.Mock() await inf.do_request(data, token_bucket, session) inf.output_handler.report_cache_info.assert_called_once() - + asyncio.run(run_test()) @mock.patch("ais_bench.benchmark.openicl.icl_inferencer.icl_base_inferencer.build_model_from_cfg") @@ -153,17 +157,17 @@ async def run_test(): def test_do_request_failure(self, m_abbr, m_build): """测试GenInferencer的do_request在输出失败时仍报告缓存信息""" m_build.return_value = DummyModel() - + async def failed_generate(inputs, max_out_len, output=None, session=None, **kwargs): if output: output.success = False output.content = None - + inf = GenInferencer(model_cfg={}, batch_size=1) inf.model.generate = failed_generate inf.status_counter = DummyStatusCounter() inf.output_handler.report_cache_info = mock.AsyncMock(return_value=True) - + data = { "index": 0, "prompt": "test_input", @@ -171,13 +175,13 @@ async def failed_generate(inputs, max_out_len, output=None, session=None, **kwar "max_out_len": 10, "gold": "test_gold", } - + async def run_test(): token_bucket = mock.Mock() session = mock.Mock() await inf.do_request(data, token_bucket, session) inf.output_handler.report_cache_info.assert_called_once() - + asyncio.run(run_test()) @mock.patch("ais_bench.benchmark.openicl.icl_inferencer.icl_base_inferencer.build_model_from_cfg") @@ -188,29 +192,116 @@ def test_do_request_without_gold(self, m_abbr, m_build): inf = GenInferencer(model_cfg={}, batch_size=1) inf.status_counter = DummyStatusCounter() inf.output_handler.report_cache_info = mock.AsyncMock(return_value=True) - + async def async_generate(inputs, max_out_len, output=None, session=None, **kwargs): if output: output.success = True output.content = "test_output" - + inf.model.generate = async_generate - + data = { "index": 0, "prompt": "test_input", "data_abbr": "test", "max_out_len": 10, } - + async def run_test(): token_bucket = mock.Mock() session = mock.Mock() await inf.do_request(data, token_bucket, session) inf.output_handler.report_cache_info.assert_called_once() - + asyncio.run(run_test()) + @mock.patch("ais_bench.benchmark.openicl.icl_inferencer.icl_base_inferencer.build_model_from_cfg") + @mock.patch("ais_bench.benchmark.openicl.icl_inferencer.icl_base_inferencer.model_abbr_from_cfg", return_value="mabbr") + def test_get_data_list_with_timestamps(self, m_abbr, m_build): + """测试timestamp字段读取和转换(毫秒转秒)""" + m_build.return_value = DummyModel() + inf = GenInferencer(model_cfg={}, batch_size=1) + + # 创建带timestamp的dataset reader + class DummyDatasetWithTimestamp: + def __init__(self): + self.reader = type("R", (), { + "output_column": "label", + "get_max_out_len": lambda self=None: [5, 6], + "get_timestamp": lambda self=None: [1000, 2000] # 毫秒单位 + })() + self.train = Dataset.from_dict({"text": ["t0", "t1"], "label": [0, 1]}) + self.test = Dataset.from_dict({"text": ["a", "b"], "label": [0, 1]}) + self.abbr = "abbrd" + + r = DummyRetriever(DummyDatasetWithTimestamp()) + data_list = inf.get_data_list(r) + + # 验证timestamp被添加并转换为秒 + self.assertEqual(len(data_list), 2) + self.assertIn("timestamp", data_list[0]) + self.assertEqual(data_list[0]["timestamp"], 1.0) # 1000ms = 1.0s + self.assertEqual(data_list[1]["timestamp"], 2.0) # 2000ms = 2.0s + + @mock.patch("ais_bench.benchmark.openicl.icl_inferencer.icl_base_inferencer.build_model_from_cfg") + @mock.patch("ais_bench.benchmark.openicl.icl_inferencer.icl_base_inferencer.model_abbr_from_cfg", return_value="mabbr") + def test_get_data_list_without_timestamps(self, m_abbr, m_build): + """测试没有timestamp的情况""" + m_build.return_value = DummyModel() + inf = GenInferencer(model_cfg={}, batch_size=1) + + # 使用默认的DummyDataset(没有get_timestamp方法或返回None) + r = DummyRetriever(DummyDataset()) + data_list = inf.get_data_list(r) + + # 验证数据列表正常生成,但没有timestamp字段 + self.assertEqual(len(data_list), 2) + # 如果没有timestamp,data_list中不应该有timestamp字段 + # 注意:如果get_timestamp返回None,则不会添加timestamp字段 + self.assertNotIn("timestamp", data_list[0]) + + @mock.patch("ais_bench.benchmark.openicl.icl_inferencer.icl_base_inferencer.build_model_from_cfg") + @mock.patch("ais_bench.benchmark.openicl.icl_inferencer.icl_base_inferencer.model_abbr_from_cfg", return_value="mabbr") + def test_get_data_list_timestamp_conversion(self, m_abbr, m_build): + """测试timestamp转换(毫秒转秒)的正确性""" + m_build.return_value = DummyModel() + inf = GenInferencer(model_cfg={}, batch_size=1) + + # 创建带不同timestamp值的dataset reader + class DummyDatasetWithVariousTimestamps: + def __init__(self): + self.reader = type("R", (), { + "output_column": "label", + "get_max_out_len": lambda self=None: [5, 6, 7], + "get_timestamp": lambda self=None: [500, 1500, 3000] # 毫秒单位 + })() + self.train = Dataset.from_dict({"text": ["t0", "t1", "t2"], "label": [0, 1, 2]}) + self.test = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2]}) + self.abbr = "abbrd" + + # 需要更新DummyRetriever以支持3个数据项 + class DummyRetriever3: + def __init__(self, dataset): + self.dataset = dataset + self.dataset_reader = dataset.reader + def retrieve(self): + return [[0], [1], [2]] # 返回3个索引 + def generate_ice(self, idx_list): + return "ICE" + def generate_prompt_for_generate_task(self, idx, ice, gen_field_replace_token=""): + return f"P{idx}|{ice}" + def get_gold_ans(self): + return ["g0", "g1", "g2"] + + r = DummyRetriever3(DummyDatasetWithVariousTimestamps()) + data_list = inf.get_data_list(r) + + # 验证timestamp正确转换 + self.assertEqual(len(data_list), 3) + self.assertEqual(data_list[0]["timestamp"], 0.5) # 500ms = 0.5s + self.assertEqual(data_list[1]["timestamp"], 1.5) # 1500ms = 1.5s + self.assertEqual(data_list[2]["timestamp"], 3.0) # 3000ms = 3.0s + if __name__ == '__main__': unittest.main() diff --git a/tests/UT/tasks/test_openicl_api_infer.py b/tests/UT/tasks/test_openicl_api_infer.py index 49dbab55..280487e7 100644 --- a/tests/UT/tasks/test_openicl_api_infer.py +++ b/tests/UT/tasks/test_openicl_api_infer.py @@ -542,34 +542,59 @@ def test_run_multi_process(self, mock_logger_class, mock_create_shm, mock_proces task = self._create_task() task.logger = mock_logger task.concurrency = 600 # 大于CONCURRENCY_PER_PROCESS + # 设置inferencer,因为_run_multi_process需要使用task.inferencer.use_timestamp + task.inferencer = MagicMock() + task.inferencer.use_timestamp = False - # Mock shared memory - from multiprocessing import shared_memory, BoundedSemaphore - dataset_shm = shared_memory.SharedMemory(create=True, size=100) - message_shm = shared_memory.SharedMemory(create=True, size=100) - mock_create_shm.return_value = message_shm - - # Mock process - mock_process = MagicMock() - mock_process.pid = 12345 - mock_process_class.return_value = mock_process - - indexes = {0: (0, 0, 100)} - token_bucket = BoundedSemaphore(10) - message_shms = {} + # Mock _get_workers_num 返回有效值,确保 _deliver_concurrency_for_workers 不返回空列表 + with patch.object(task, '_get_workers_num', return_value=3): + # Mock shared memory + from multiprocessing import shared_memory, BoundedSemaphore + dataset_shm = shared_memory.SharedMemory(create=True, size=100) + message_shm = shared_memory.SharedMemory(create=True, size=100) + mock_create_shm.return_value = message_shm + + # Mock process - 使用一个计数器来为每个进程分配不同的pid + process_counter = [0] + def create_mock_process(*args, **kwargs): + mock_p = MagicMock() + current_idx = process_counter[0] + process_counter[0] += 1 + # 确保pid在start()之后设置(模拟真实行为) + def mock_start(): + mock_p.pid = 12345 + current_idx + mock_p.start = mock_start + # 初始pid为None,start()后才会设置 + mock_p.pid = None + return mock_p + mock_process_class.side_effect = create_mock_process + + indexes = {0: (0, 0, 100)} + token_bucket = BoundedSemaphore(10) + message_shms = {} - try: - processes = task._run_multi_process(dataset_shm, indexes, token_bucket, message_shms) + try: + processes = task._run_multi_process(dataset_shm, indexes, token_bucket, message_shms) - # 验证创建了process - self.assertGreater(len(processes), 0) - # 验证message_shms被更新 - self.assertGreater(len(message_shms), 0) - finally: - dataset_shm.close() - dataset_shm.unlink() - message_shm.close() - message_shm.unlink() + # 验证创建了process + self.assertGreater(len(processes), 0) + # 验证message_shms被更新 + self.assertGreater(len(message_shms), 0) + finally: + dataset_shm.close() + dataset_shm.unlink() + # 清理message_shms中的共享内存 + for shm in message_shms.values(): + try: + shm.close() + shm.unlink() + except: + pass + try: + message_shm.close() + message_shm.unlink() + except: + pass @patch('ais_bench.benchmark.tasks.openicl_api_infer.AISLogger') def test_init_with_repeat_gt_1(self, mock_logger_class): @@ -1038,6 +1063,132 @@ def test_run_multi_process_empty_concurrency(self, mock_logger_class): dataset_shm.close() dataset_shm.unlink() + @patch('ais_bench.benchmark.tasks.openicl_api_infer.AISLogger') + def test_get_time_stamps_extract(self, mock_logger_class): + """测试从data_list中提取timestamps""" + mock_logger = MagicMock() + mock_logger_class.return_value = mock_logger + + task = self._create_task() + task.logger = mock_logger + + # Mock inferencer + mock_inferencer = MagicMock() + mock_inferencer.use_timestamp = False + task.inferencer = mock_inferencer + + data_list = [ + {"prompt": "test1", "timestamp": 1.0}, + {"prompt": "test2", "timestamp": 2.0}, + {"prompt": "test3", "timestamp": 3.0}, + ] + + timestamps = task._get_time_stamps(data_list) + + self.assertEqual(timestamps, [1.0, 2.0, 3.0]) + + @patch('ais_bench.benchmark.tasks.openicl_api_infer.AISLogger') + def test_get_time_stamps_set_use_timestamp(self, mock_logger_class): + """测试use_timestamp标志设置""" + mock_logger = MagicMock() + mock_logger_class.return_value = mock_logger + + task = self._create_task() + task.logger = mock_logger + + # Mock inferencer + mock_inferencer = MagicMock() + mock_inferencer.use_timestamp = False + task.inferencer = mock_inferencer + + data_list = [ + {"prompt": "test1", "timestamp": 1.0}, + ] + + task._get_time_stamps(data_list) + + # 验证use_timestamp被设置为True + self.assertTrue(mock_inferencer.use_timestamp) + + @patch('ais_bench.benchmark.tasks.openicl_api_infer.AISLogger') + def test_get_time_stamps_warning(self, mock_logger_class): + """测试警告信息输出""" + mock_logger = MagicMock() + mock_logger_class.return_value = mock_logger + + task = self._create_task() + task.logger = mock_logger + + # Mock inferencer + mock_inferencer = MagicMock() + mock_inferencer.use_timestamp = False + task.inferencer = mock_inferencer + + data_list = [ + {"prompt": "test1", "timestamp": 1.0}, + ] + + task._get_time_stamps(data_list) + + # 验证记录了警告日志 + mock_logger.warning.assert_called() + warning_call = str(mock_logger.warning.call_args) + self.assertIn("Found timestamps in datasets", warning_call) + self.assertIn("request_rate", warning_call) + + @patch('ais_bench.benchmark.tasks.openicl_api_infer.AISLogger') + def test_get_time_stamps_empty(self, mock_logger_class): + """测试没有timestamp的情况""" + mock_logger = MagicMock() + mock_logger_class.return_value = mock_logger + + task = self._create_task() + task.logger = mock_logger + + # Mock inferencer + mock_inferencer = MagicMock() + mock_inferencer.use_timestamp = False + task.inferencer = mock_inferencer + + data_list = [ + {"prompt": "test1"}, + {"prompt": "test2"}, + ] + + timestamps = task._get_time_stamps(data_list) + + # 应该返回空列表 + self.assertEqual(timestamps, []) + # use_timestamp应该保持为False + self.assertFalse(mock_inferencer.use_timestamp) + + @patch('ais_bench.benchmark.tasks.openicl_api_infer.AISLogger') + def test_get_time_stamps_partial(self, mock_logger_class): + """测试部分数据有timestamp的情况""" + mock_logger = MagicMock() + mock_logger_class.return_value = mock_logger + + task = self._create_task() + task.logger = mock_logger + + # Mock inferencer + mock_inferencer = MagicMock() + mock_inferencer.use_timestamp = False + task.inferencer = mock_inferencer + + data_list = [ + {"prompt": "test1", "timestamp": 1.0}, + {"prompt": "test2"}, # 没有timestamp + {"prompt": "test3", "timestamp": 3.0}, + ] + + timestamps = task._get_time_stamps(data_list) + + # 应该只提取有timestamp的数据 + self.assertEqual(timestamps, [1.0, 3.0]) + # 因为有timestamp,use_timestamp应该被设置为True + self.assertTrue(mock_inferencer.use_timestamp) + if __name__ == '__main__': unittest.main() diff --git a/tests/UT/tasks/test_utils.py b/tests/UT/tasks/test_utils.py index 7e4734df..983eda2a 100644 --- a/tests/UT/tasks/test_utils.py +++ b/tests/UT/tasks/test_utils.py @@ -297,6 +297,7 @@ def test_init_with_normal_rate(self, mock_logger_class): producer = TokenProducer( request_rate=self.request_rate, + time_stamps=[], traffic_cfg=self.traffic_cfg, request_num=self.request_num, mode="pressure" if self.pressure_mode else "infer" @@ -315,6 +316,7 @@ def test_init_with_low_rate(self, mock_logger_class): producer = TokenProducer( request_rate=0.05, + time_stamps=[], traffic_cfg=self.traffic_cfg, request_num=self.request_num, mode="pressure" if self.pressure_mode else "infer" @@ -338,6 +340,7 @@ def test_init_with_ramp_up_strategy(self, mock_logger_class): producer = TokenProducer( request_rate=self.request_rate, + time_stamps=[], traffic_cfg=traffic_cfg, request_num=self.request_num, mode="pressure" if self.pressure_mode else "infer" @@ -364,6 +367,7 @@ def test_init_with_invalid_ramp_up_strategy(self, mock_logger_class): with self.assertRaises(ParameterValueError) as context: TokenProducer( request_rate=self.request_rate, + time_stamps=[], traffic_cfg=traffic_cfg, request_num=self.request_num, mode="infer" @@ -386,6 +390,7 @@ def test_generate_interval_lists_linear(self, mock_logger_class): producer = TokenProducer( request_rate=self.request_rate, + time_stamps=[], traffic_cfg=traffic_cfg, request_num=10, mode="pressure" if self.pressure_mode else "infer" @@ -410,6 +415,7 @@ def test_generate_interval_lists_exponential(self, mock_logger_class): producer = TokenProducer( request_rate=self.request_rate, + time_stamps=[], traffic_cfg=traffic_cfg, request_num=10, mode="pressure" if self.pressure_mode else "infer" @@ -425,6 +431,7 @@ def test_produce_token_without_token_bucket(self, mock_logger_class): producer = TokenProducer( request_rate=0.05, # 低速率,token_bucket为None + time_stamps=[], traffic_cfg=self.traffic_cfg, request_num=self.request_num, mode="pressure" if self.pressure_mode else "infer" @@ -444,6 +451,7 @@ def test_produce_token_with_token_bucket(self, mock_logger_class): producer = TokenProducer( request_rate=self.request_rate, + time_stamps=[], traffic_cfg=self.traffic_cfg, request_num=10, # 小数量用于快速测试 mode="pressure" if self.pressure_mode else "infer" @@ -474,6 +482,7 @@ def test_produce_token_with_exception(self, mock_logger_class): producer = TokenProducer( request_rate=self.request_rate, + time_stamps=[], traffic_cfg=self.traffic_cfg, request_num=10, mode="pressure" if self.pressure_mode else "infer" @@ -936,6 +945,7 @@ def test_produce_token_else_branch_exception(self, mock_logger_class): producer = TokenProducer( request_rate=self.request_rate, + time_stamps=[], traffic_cfg=self.traffic_cfg, request_num=10, mode="pressure" if self.pressure_mode else "infer" @@ -980,6 +990,7 @@ def test_produce_token_sleep_interval_negative(self, mock_logger_class): producer = TokenProducer( request_rate=100, # 高频率 + time_stamps=[], traffic_cfg=self.traffic_cfg, request_num=10, mode="pressure" if self.pressure_mode else "infer" @@ -1013,6 +1024,117 @@ def run_produce(): except: pass + @patch('ais_bench.benchmark.tasks.utils.AISLogger') + def test_init_with_timestamps(self, mock_logger_class): + """测试timestamps参数传递""" + mock_logger = MagicMock() + mock_logger_class.return_value = mock_logger + + time_stamps = [1.0, 2.0, 3.0, 4.0, 5.0] + producer = TokenProducer( + request_rate=10.0, + time_stamps=time_stamps, + traffic_cfg=self.traffic_cfg, + request_num=5, + mode="infer" + ) + + # 验证interval_lists被设置为timestamps的副本 + self.assertEqual(producer.interval_lists, time_stamps) + # 验证是副本,不是引用 + self.assertIsNot(producer.interval_lists, time_stamps) + + @patch('ais_bench.benchmark.tasks.utils.AISLogger') + def test_init_timestamps_as_interval_lists(self, mock_logger_class): + """测试timestamps存在时直接使用作为interval_lists""" + mock_logger = MagicMock() + mock_logger_class.return_value = mock_logger + + time_stamps = [0.5, 1.0, 1.5, 2.0] + producer = TokenProducer( + request_rate=20.0, # 这个值应该被忽略 + time_stamps=time_stamps, + traffic_cfg=self.traffic_cfg, + request_num=4, + mode="infer" + ) + + # 验证interval_lists等于timestamps + self.assertEqual(producer.interval_lists, time_stamps) + # 验证token_bucket被创建(因为request_rate >= 0.1) + self.assertIsNotNone(producer.token_bucket) + + @patch('ais_bench.benchmark.tasks.utils.AISLogger') + def test_init_timestamps_ignore_request_rate(self, mock_logger_class): + """测试timestamps存在时忽略request_rate配置""" + mock_logger = MagicMock() + mock_logger_class.return_value = mock_logger + + time_stamps = [1.0, 2.0, 3.0] + # 使用一个不同的request_rate,应该被忽略 + producer = TokenProducer( + request_rate=100.0, # 这个值应该被忽略,因为提供了timestamps + time_stamps=time_stamps, + traffic_cfg=self.traffic_cfg, + request_num=3, + mode="infer" + ) + + # 验证interval_lists是timestamps,而不是根据request_rate生成的 + self.assertEqual(producer.interval_lists, time_stamps) + # 验证长度匹配 + self.assertEqual(len(producer.interval_lists), len(time_stamps)) + + @patch('ais_bench.benchmark.tasks.utils.AISLogger') + def test_produce_token_with_timestamps(self, mock_logger_class): + """测试token释放时间符合timestamps""" + mock_logger = MagicMock() + mock_logger_class.return_value = mock_logger + + time_stamps = [0.1, 0.2, 0.3] # 使用较短的时间间隔以便测试 + producer = TokenProducer( + request_rate=10.0, + time_stamps=time_stamps, + traffic_cfg=self.traffic_cfg, + request_num=3, + mode="infer" + ) + + # 验证interval_lists被正确设置 + self.assertEqual(producer.interval_lists, time_stamps) + + # 测试produce_token会使用这些timestamps + # 注意:实际的produce_token测试需要更复杂的设置(共享内存等) + # 这里主要验证interval_lists被正确设置 + self.assertEqual(len(producer.interval_lists), 3) + self.assertEqual(producer.interval_lists[0], 0.1) + self.assertEqual(producer.interval_lists[1], 0.2) + self.assertEqual(producer.interval_lists[2], 0.3) + + @patch('ais_bench.benchmark.tasks.utils.AISLogger') + def test_produce_token_timing_accuracy(self, mock_logger_class): + """测试token bucket机制控制请求发送时间的准确性""" + mock_logger = MagicMock() + mock_logger_class.return_value = mock_logger + + time_stamps = [0.05, 0.1, 0.15] # 使用较短的时间间隔 + producer = TokenProducer( + request_rate=10.0, + time_stamps=time_stamps, + traffic_cfg=self.traffic_cfg, + request_num=3, + mode="infer" + ) + + # 验证token_bucket被创建 + self.assertIsNotNone(producer.token_bucket) + # 验证interval_lists被正确设置 + self.assertEqual(producer.interval_lists, time_stamps) + + # 验证interval_lists的长度与request_num匹配(或与timestamps长度匹配) + # 当提供timestamps时,interval_lists应该等于timestamps + self.assertEqual(len(producer.interval_lists), len(time_stamps)) + if __name__ == '__main__': unittest.main() From e0242161c813454ef12ddc70fbb6c78aa0e7c3d4 Mon Sep 17 00:00:00 2001 From: GaoHuaZhang <1484391106@qq.com> Date: Thu, 22 Jan 2026 15:27:38 +0800 Subject: [PATCH 2/2] append extra UT --- tests/UT/tasks/test_openicl_api_infer.py | 127 ----------------------- tests/UT/tasks/test_utils.py | 112 -------------------- 2 files changed, 239 deletions(-) diff --git a/tests/UT/tasks/test_openicl_api_infer.py b/tests/UT/tasks/test_openicl_api_infer.py index 280487e7..9dd85f99 100644 --- a/tests/UT/tasks/test_openicl_api_infer.py +++ b/tests/UT/tasks/test_openicl_api_infer.py @@ -1063,133 +1063,6 @@ def test_run_multi_process_empty_concurrency(self, mock_logger_class): dataset_shm.close() dataset_shm.unlink() - @patch('ais_bench.benchmark.tasks.openicl_api_infer.AISLogger') - def test_get_time_stamps_extract(self, mock_logger_class): - """测试从data_list中提取timestamps""" - mock_logger = MagicMock() - mock_logger_class.return_value = mock_logger - - task = self._create_task() - task.logger = mock_logger - - # Mock inferencer - mock_inferencer = MagicMock() - mock_inferencer.use_timestamp = False - task.inferencer = mock_inferencer - - data_list = [ - {"prompt": "test1", "timestamp": 1.0}, - {"prompt": "test2", "timestamp": 2.0}, - {"prompt": "test3", "timestamp": 3.0}, - ] - - timestamps = task._get_time_stamps(data_list) - - self.assertEqual(timestamps, [1.0, 2.0, 3.0]) - - @patch('ais_bench.benchmark.tasks.openicl_api_infer.AISLogger') - def test_get_time_stamps_set_use_timestamp(self, mock_logger_class): - """测试use_timestamp标志设置""" - mock_logger = MagicMock() - mock_logger_class.return_value = mock_logger - - task = self._create_task() - task.logger = mock_logger - - # Mock inferencer - mock_inferencer = MagicMock() - mock_inferencer.use_timestamp = False - task.inferencer = mock_inferencer - - data_list = [ - {"prompt": "test1", "timestamp": 1.0}, - ] - - task._get_time_stamps(data_list) - - # 验证use_timestamp被设置为True - self.assertTrue(mock_inferencer.use_timestamp) - - @patch('ais_bench.benchmark.tasks.openicl_api_infer.AISLogger') - def test_get_time_stamps_warning(self, mock_logger_class): - """测试警告信息输出""" - mock_logger = MagicMock() - mock_logger_class.return_value = mock_logger - - task = self._create_task() - task.logger = mock_logger - - # Mock inferencer - mock_inferencer = MagicMock() - mock_inferencer.use_timestamp = False - task.inferencer = mock_inferencer - - data_list = [ - {"prompt": "test1", "timestamp": 1.0}, - ] - - task._get_time_stamps(data_list) - - # 验证记录了警告日志 - mock_logger.warning.assert_called() - warning_call = str(mock_logger.warning.call_args) - self.assertIn("Found timestamps in datasets", warning_call) - self.assertIn("request_rate", warning_call) - - @patch('ais_bench.benchmark.tasks.openicl_api_infer.AISLogger') - def test_get_time_stamps_empty(self, mock_logger_class): - """测试没有timestamp的情况""" - mock_logger = MagicMock() - mock_logger_class.return_value = mock_logger - - task = self._create_task() - task.logger = mock_logger - - # Mock inferencer - mock_inferencer = MagicMock() - mock_inferencer.use_timestamp = False - task.inferencer = mock_inferencer - - data_list = [ - {"prompt": "test1"}, - {"prompt": "test2"}, - ] - - timestamps = task._get_time_stamps(data_list) - - # 应该返回空列表 - self.assertEqual(timestamps, []) - # use_timestamp应该保持为False - self.assertFalse(mock_inferencer.use_timestamp) - - @patch('ais_bench.benchmark.tasks.openicl_api_infer.AISLogger') - def test_get_time_stamps_partial(self, mock_logger_class): - """测试部分数据有timestamp的情况""" - mock_logger = MagicMock() - mock_logger_class.return_value = mock_logger - - task = self._create_task() - task.logger = mock_logger - - # Mock inferencer - mock_inferencer = MagicMock() - mock_inferencer.use_timestamp = False - task.inferencer = mock_inferencer - - data_list = [ - {"prompt": "test1", "timestamp": 1.0}, - {"prompt": "test2"}, # 没有timestamp - {"prompt": "test3", "timestamp": 3.0}, - ] - - timestamps = task._get_time_stamps(data_list) - - # 应该只提取有timestamp的数据 - self.assertEqual(timestamps, [1.0, 3.0]) - # 因为有timestamp,use_timestamp应该被设置为True - self.assertTrue(mock_inferencer.use_timestamp) - - if __name__ == '__main__': unittest.main() diff --git a/tests/UT/tasks/test_utils.py b/tests/UT/tasks/test_utils.py index 983eda2a..24c84838 100644 --- a/tests/UT/tasks/test_utils.py +++ b/tests/UT/tasks/test_utils.py @@ -1024,118 +1024,6 @@ def run_produce(): except: pass - @patch('ais_bench.benchmark.tasks.utils.AISLogger') - def test_init_with_timestamps(self, mock_logger_class): - """测试timestamps参数传递""" - mock_logger = MagicMock() - mock_logger_class.return_value = mock_logger - - time_stamps = [1.0, 2.0, 3.0, 4.0, 5.0] - producer = TokenProducer( - request_rate=10.0, - time_stamps=time_stamps, - traffic_cfg=self.traffic_cfg, - request_num=5, - mode="infer" - ) - - # 验证interval_lists被设置为timestamps的副本 - self.assertEqual(producer.interval_lists, time_stamps) - # 验证是副本,不是引用 - self.assertIsNot(producer.interval_lists, time_stamps) - - @patch('ais_bench.benchmark.tasks.utils.AISLogger') - def test_init_timestamps_as_interval_lists(self, mock_logger_class): - """测试timestamps存在时直接使用作为interval_lists""" - mock_logger = MagicMock() - mock_logger_class.return_value = mock_logger - - time_stamps = [0.5, 1.0, 1.5, 2.0] - producer = TokenProducer( - request_rate=20.0, # 这个值应该被忽略 - time_stamps=time_stamps, - traffic_cfg=self.traffic_cfg, - request_num=4, - mode="infer" - ) - - # 验证interval_lists等于timestamps - self.assertEqual(producer.interval_lists, time_stamps) - # 验证token_bucket被创建(因为request_rate >= 0.1) - self.assertIsNotNone(producer.token_bucket) - - @patch('ais_bench.benchmark.tasks.utils.AISLogger') - def test_init_timestamps_ignore_request_rate(self, mock_logger_class): - """测试timestamps存在时忽略request_rate配置""" - mock_logger = MagicMock() - mock_logger_class.return_value = mock_logger - - time_stamps = [1.0, 2.0, 3.0] - # 使用一个不同的request_rate,应该被忽略 - producer = TokenProducer( - request_rate=100.0, # 这个值应该被忽略,因为提供了timestamps - time_stamps=time_stamps, - traffic_cfg=self.traffic_cfg, - request_num=3, - mode="infer" - ) - - # 验证interval_lists是timestamps,而不是根据request_rate生成的 - self.assertEqual(producer.interval_lists, time_stamps) - # 验证长度匹配 - self.assertEqual(len(producer.interval_lists), len(time_stamps)) - - @patch('ais_bench.benchmark.tasks.utils.AISLogger') - def test_produce_token_with_timestamps(self, mock_logger_class): - """测试token释放时间符合timestamps""" - mock_logger = MagicMock() - mock_logger_class.return_value = mock_logger - - time_stamps = [0.1, 0.2, 0.3] # 使用较短的时间间隔以便测试 - producer = TokenProducer( - request_rate=10.0, - time_stamps=time_stamps, - traffic_cfg=self.traffic_cfg, - request_num=3, - mode="infer" - ) - - # 验证interval_lists被正确设置 - self.assertEqual(producer.interval_lists, time_stamps) - - # 测试produce_token会使用这些timestamps - # 注意:实际的produce_token测试需要更复杂的设置(共享内存等) - # 这里主要验证interval_lists被正确设置 - self.assertEqual(len(producer.interval_lists), 3) - self.assertEqual(producer.interval_lists[0], 0.1) - self.assertEqual(producer.interval_lists[1], 0.2) - self.assertEqual(producer.interval_lists[2], 0.3) - - @patch('ais_bench.benchmark.tasks.utils.AISLogger') - def test_produce_token_timing_accuracy(self, mock_logger_class): - """测试token bucket机制控制请求发送时间的准确性""" - mock_logger = MagicMock() - mock_logger_class.return_value = mock_logger - - time_stamps = [0.05, 0.1, 0.15] # 使用较短的时间间隔 - producer = TokenProducer( - request_rate=10.0, - time_stamps=time_stamps, - traffic_cfg=self.traffic_cfg, - request_num=3, - mode="infer" - ) - - # 验证token_bucket被创建 - self.assertIsNotNone(producer.token_bucket) - # 验证interval_lists被正确设置 - self.assertEqual(producer.interval_lists, time_stamps) - - # 验证interval_lists的长度与request_num匹配(或与timestamps长度匹配) - # 当提供timestamps时,interval_lists应该等于timestamps - self.assertEqual(len(producer.interval_lists), len(time_stamps)) - - if __name__ == '__main__': unittest.main()