-
Notifications
You must be signed in to change notification settings - Fork 11
【Feature】mooncake trace datsets support #108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @GaoHuaZhang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the benchmarking framework's capabilities by integrating a new dataset type for structured prompt generation, enhancing performance metric analysis through advanced visualizations, and providing more granular control over synthetic dataset creation. Additionally, it introduces a mechanism for simulating realistic request traffic using timestamps and refines tokenizer behavior for improved flexibility. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for the "mooncake trace" dataset, which allows replaying production traces for performance benchmarking. This includes a new dataset loader that generates prompts from trace data and a new feature to visualize metric distributions as HTML plots. The changes also update synthetic dataset generation and include several refactorings and bug fixes across the codebase.
My review has identified several critical issues, primarily related to copy-paste errors that introduce bugs in concurrent data processing and argument passing. There are also high-severity maintainability concerns due to significant code duplication and a very large, complex method for plotting. Additionally, I've found several medium-severity issues, including hardcoded values that should be constants and duplicated content in documentation files. I recommend addressing the critical bugs before merging and considering the refactoring suggestions to improve code quality.
| # Check total_data_count before adding to data_indices | ||
| if not self.pressure_mode: | ||
| if not self.total_data_count: | ||
| end_index = cur_index | ||
| break | ||
| self.total_data_count -= 1 | ||
| # Only add to data_indices after all checks pass | ||
| data_indices.append(cur_index) | ||
| # Update end_index to next index after successfully adding | ||
| end_index = (cur_index + 1) % len(indexes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code is a duplicate of the preceding block (lines 270-278). This copy-paste error will cause self.total_data_count to be decremented twice and the same index to be appended to data_indices twice for each item in the loop. This is a critical bug that needs to be fixed by removing the duplicated block.
| indexes, | ||
| token_bucket, | ||
| per_worker_data_num[i], | ||
| per_worker_data_num[i], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The argument per_worker_data_num[i] is passed twice to the run_single_inferencer process. The function signature is run_single_inferencer(..., total_data_num: int, global_index: mp.RawValue = None, ...). The second per_worker_data_num[i] is incorrectly passed in place of global_index, which expects an mp.RawValue object. This is a type mismatch and a critical bug. The duplicated argument should be removed.
| # Get total data count from indexes | ||
| total_data_count = len(indexes) if self.pressure else len(indexes) - 1 | ||
| per_worker_data_num = self._deliver_data_num_for_workers(per_worker_concurrency, total_data_count) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| acquired = await asyncio.to_thread(token_bucket.acquire, timeout=1) | ||
| if not acquired: | ||
| continue | ||
| data = await self.wait_get_data(async_queue, stop_event) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def _adjust_prompt_length(self, tokenizer, prompt, target_token_length): | ||
| """Adjust prompt text to ensure it encodes to exactly target_token_length tokens. | ||
| Args: | ||
| tokenizer: The tokenizer to use | ||
| prompt: Initial prompt text | ||
| target_token_length: Desired number of tokens | ||
| Returns: | ||
| str: Adjusted prompt text that encodes to exactly target_token_length tokens | ||
| """ | ||
| # Encode to check current token count | ||
| encoded = tokenizer.encode(prompt, add_special_tokens=False) | ||
| current_length = len(encoded) | ||
|
|
||
| if current_length == target_token_length: | ||
| return prompt | ||
|
|
||
| # If we need more tokens, append characters | ||
| if current_length < target_token_length: | ||
| # Find a character that adds tokens when appended | ||
| padding_chars = [" ", "A", "B", "C", "D", "E"] | ||
| padding_char = " " | ||
|
|
||
| for char in padding_chars: | ||
| test_text = prompt + char | ||
| test_encoded = tokenizer.encode(test_text, add_special_tokens=False) | ||
| if len(test_encoded) > current_length: | ||
| padding_char = char | ||
| break | ||
|
|
||
| # Add padding until we reach target length | ||
| max_iterations = target_token_length * 2 # Safety limit | ||
| iteration = 0 | ||
| while current_length < target_token_length and iteration < max_iterations: | ||
| prompt += padding_char | ||
| encoded = tokenizer.encode(prompt, add_special_tokens=False) | ||
| new_length = len(encoded) | ||
|
|
||
| # If adding a character doesn't increase token count, try a different approach | ||
| if new_length == current_length: | ||
| # Try adding a word instead | ||
| prompt += " word" | ||
| encoded = tokenizer.encode(prompt, add_special_tokens=False) | ||
| new_length = len(encoded) | ||
|
|
||
| current_length = new_length | ||
| iteration += 1 | ||
|
|
||
| if current_length >= target_token_length: | ||
| break | ||
|
|
||
| # If we have too many tokens, use binary search to find the right length | ||
| if current_length > target_token_length: | ||
| # Binary search for the right text length | ||
| left, right = 0, len(prompt) | ||
| best_prompt = prompt | ||
| best_length_diff = abs(current_length - target_token_length) | ||
|
|
||
| # Limit binary search iterations | ||
| max_binary_iterations = 50 | ||
| binary_iteration = 0 | ||
|
|
||
| while left < right and binary_iteration < max_binary_iterations: | ||
| mid = (left + right) // 2 | ||
| if mid == 0: | ||
| break | ||
|
|
||
| test_prompt = prompt[:mid] | ||
| test_encoded = tokenizer.encode(test_prompt, add_special_tokens=False) | ||
| test_length = len(test_encoded) | ||
|
|
||
| if test_length == target_token_length: | ||
| return test_prompt | ||
| elif test_length < target_token_length: | ||
| length_diff = target_token_length - test_length | ||
| if length_diff < best_length_diff: | ||
| best_prompt = test_prompt | ||
| best_length_diff = length_diff | ||
| left = mid + 1 | ||
| else: | ||
| right = mid | ||
|
|
||
| binary_iteration += 1 | ||
|
|
||
| # Fine-tune from the best result | ||
| prompt = best_prompt | ||
| encoded = tokenizer.encode(prompt, add_special_tokens=False) | ||
| current_length = len(encoded) | ||
|
|
||
| # Fine-tune by adding/removing characters one by one | ||
| max_fine_tune = 100 # Safety limit | ||
| fine_tune_iter = 0 | ||
|
|
||
| while ( | ||
| current_length != target_token_length and fine_tune_iter < max_fine_tune | ||
| ): | ||
| if current_length < target_token_length: | ||
| # Try adding different characters | ||
| for char in [" ", "A", "B", "C"]: | ||
| test_prompt = prompt + char | ||
| test_encoded = tokenizer.encode( | ||
| test_prompt, add_special_tokens=False | ||
| ) | ||
| if len(test_encoded) == target_token_length: | ||
| return test_prompt | ||
| elif len(test_encoded) > target_token_length: | ||
| break | ||
| # If no single char works, add a space and continue | ||
| prompt += " " | ||
| else: | ||
| # Remove one character from end | ||
| if len(prompt) > 0: | ||
| prompt = prompt[:-1] | ||
| else: | ||
| break | ||
|
|
||
| encoded = tokenizer.encode(prompt, add_special_tokens=False) | ||
| current_length = len(encoded) | ||
| fine_tune_iter += 1 | ||
|
|
||
| return prompt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method _adjust_prompt_length is duplicated in ais_bench/benchmark/datasets/synthetic.py. Duplicated code is hard to maintain, as bug fixes or improvements need to be applied in multiple places. Consider extracting this method into a shared utility module, for example, in ais_bench/benchmark/datasets/utils/, and import it in both MooncakeTraceDataset and SyntheticDataset.
| # Cache for batch-prefetched data | ||
| self._data_cache = [] # Thread-local cache for batch data | ||
| self.total_data_count = 0 | ||
| self.total_data_count = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| from concurrent.futures import ProcessPoolExecutor, as_completed | ||
| from functools import partial |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| metrics_to_visualize = [ | ||
| m for m in sorted(all_metrics) | ||
| if m not in {"PrefillBatchsize", "DecoderBatchsize"} | ||
| ] | ||
|
|
||
| if not metrics_to_visualize: | ||
| self.logger.warning("No suitable metrics for curve visualization, skipping.") | ||
| return | ||
|
|
||
| # Calculate number of subplots needed | ||
| num_metrics = len(metrics_to_visualize) | ||
| num_stages = len(all_stages) | ||
|
|
||
| # Create subplots: one row per metric, one column per stage | ||
| fig = make_subplots( | ||
| rows=num_metrics, | ||
| cols=num_stages, | ||
| subplot_titles=[ | ||
| f"{metric} - {stage}" | ||
| for metric in metrics_to_visualize | ||
| for stage in sorted(all_stages) | ||
| ], | ||
| vertical_spacing=0.08, | ||
| horizontal_spacing=0.1 | ||
| ) | ||
|
|
||
| # Generate density curve for each metric-stage combination | ||
| for metric_idx, metric in enumerate(metrics_to_visualize): | ||
| for stage_idx, stage_name in enumerate(sorted(all_stages)): | ||
| row = metric_idx + 1 | ||
| col = stage_idx + 1 | ||
|
|
||
| # Get data for this metric and stage | ||
| if stage_name not in self.result or metric not in self.result[stage_name]: | ||
| self.logger.debug(f"No data for {metric} in stage {stage_name}, skipping.") | ||
| continue | ||
|
|
||
| data = self.result[stage_name][metric] | ||
| if not data: | ||
| self.logger.debug(f"Empty data for {metric} in stage {stage_name}, skipping.") | ||
| continue | ||
|
|
||
| # Flatten data if it contains numpy arrays | ||
| if isinstance(data, list) and len(data) > 0: | ||
| if isinstance(data[0], np.ndarray): | ||
| flat_data = np.concatenate(data) | ||
| else: | ||
| flat_data = np.array(data) | ||
| else: | ||
| flat_data = np.array(data) | ||
|
|
||
| # Apply unit conversion for specific metrics (ITL, TPOT, TTFT) | ||
| # These metrics need to be multiplied by 1000 | ||
| if metric in {"ITL", "TPOT", "TTFT"}: | ||
| flat_data = flat_data * 1000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sets of metric names {"PrefillBatchsize", "DecoderBatchsize"} and {"ITL", "TPOT", "TTFT"} are hardcoded inline. It's better to define them as constants at the class or module level. This improves readability and makes it easier to manage these lists.
For example, you could add these constants to the class:
_METRICS_TO_IGNORE_FOR_CURVES = {"PrefillBatchsize", "DecoderBatchsize"}
_METRICS_TO_CONVERT_UNITS = {"ITL", "TPOT", "TTFT"}And then use them in the method.
| "Params": { | ||
| "Mean": 256, # Central value: 256 | ||
| "Var": 10, # Variance: 10 | ||
| "Var": 10, # Variance: 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "Params": { | ||
| "Mean": 256, # 中心值256 | ||
| "Var": 10, # 方差10 | ||
| "Var": 10, # 方差10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ebc819a to
17b42df
Compare
0182f17 to
8fbd3a3
Compare
Thanks for your contribution; we appreciate it a lot. The following instructions will make your pull request healthier and help you get feedback more easily. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.
感谢您的贡献,我们非常重视。以下说明将使您的拉取请求更健康,更易于获得反馈。如果您不理解某些项目,请不要担心,只需提交拉取请求并从维护人员那里寻求帮助即可。
PR Type / PR类型
Related Issue | 关联 Issue
Fixes #(issue ID / issue 编号) / Relates to #(issue ID / issue 编号)
🔍 Motivation / 变更动机
Please describe the motivation of this PR and the goal you want to achieve through this PR.
请描述您的拉取请求的动机和您希望通过此拉取请求实现的目标。
📝 Modification / 修改内容
Please briefly describe what modification is made in this PR.
请简要描述此拉取请求中进行的修改。
📐 Associated Test Results / 关联测试结果
Please provide links to the related test results, such as CI pipelines, test reports, etc.
请提供相关测试结果的链接,例如 CI 管道、测试报告等。
Does the modification introduce changes that break the backward compatibility of the downstream repositories? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.
是否引入了会破坏下游存储库向后兼容性的更改?如果是,请描述它如何破坏兼容性,以及下游项目应该如何修改其代码以保持与此 PR 的兼容性。
If the modification introduces performance degradation, please describe the impact of the performance degradation and the expected performance improvement.
如果引入了性能下降,请描述性能下降的影响和预期的性能改进。
🌟 Use cases (Optional) / 使用案例(可选)
If this PR introduces a new feature, it is better to list some use cases here and update the documentation.
如果此拉取请求引入了新功能,最好在此处列出一些用例并更新文档。
✅ Checklist / 检查列表
Before PR:
After PR:
👥 Collaboration Info / 协作信息
🌟 Useful CI Command / 实用的CI命令
/gemini review/gemini summary/gemini help/readthedocs build