-
Notifications
You must be signed in to change notification settings - Fork 94
Add debug weight sync feature #525
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @zianglih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the debugging capabilities for weight synchronization processes within the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a valuable debug feature for weight synchronization and adds support for mxfp8 quantization. However, critical security vulnerabilities were identified, primarily related to insecure deserialization and command injection. Specifically, the use of torch.load() without weights_only=True in the new debugging utility could lead to arbitrary code execution from malicious checkpoints. Furthermore, several training scripts are vulnerable to command injection due to direct interpolation of user-supplied arguments into shell commands. Beyond these security concerns, suggestions were made to enhance code quality by refining exception handling, reducing code duplication, and cleaning up module exports. Addressing these security issues is paramount.
| ) | ||
| return safe_open(path, framework="pt", device="cpu") | ||
| if self.fmt == "bin": | ||
| obj = torch.load(path, map_location="cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of torch.load() without weights_only=True is insecure as it relies on the pickle module, which can execute arbitrary code during deserialization. An attacker could provide a malicious checkpoint file that, when loaded for debugging or comparison, executes arbitrary commands on the system. It is highly recommended to use weights_only=True to restrict deserialization to safe types.
| obj = torch.load(path, map_location="cpu") | |
| obj = torch.load(path, map_location="cpu", weights_only=True) |
| U.exec_command( | ||
| f"huggingface-cli download Qwen/{args.model_name}-FP8 --local-dir /root/models/{args.model_name}-FP8" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model_name argument is directly interpolated into a shell command string without sanitization. This allows for command injection if an attacker can control the model_name parameter. For example, a model_name like ; touch /tmp/pwned would result in the execution of the injected command. Use shlex.quote() to sanitize any variables used in shell commands.
| U.exec_command( | ||
| f"python tools/convert_hf_to_mxfp8.py --model-dir /root/models/{args.model_name} --save-dir {mxfp8_path}" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
|
||
| if args.rollout_fp8: | ||
| if args.rollout_fp8 and not use_blackwell_fp8: | ||
| U.exec_command(f"hf download Qwen/{args.model_name}-FP8 --local-dir /root/models/{args.model_name}-FP8") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| U.exec_command( | ||
| f"python tools/convert_hf_to_mxfp8.py --model-dir /root/models/{args.model_name} --save-dir {mxfp8_path}" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| __all__ = ["remove_padding", "quantize_param", "quantize_params_fp8", "quantize_params_compressed_tensors"] | ||
| __all__ = [ | ||
| "remove_padding", | ||
| "quantize_param", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __all__ list includes quantize_param, but this function is not defined or imported in this module. This appears to be a pre-existing issue, but since this block is being modified, it's a good opportunity to correct it. Removing this line will prevent potential NameError exceptions and improve code clarity.
| # experts | ||
| expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)" | ||
| match = re.match(expert_pattern, rest) | ||
| if match: | ||
| rest, expert_idx = match.groups() | ||
| if rest in [ | ||
| "linear_fc1", | ||
| "linear_fc2", | ||
| ]: | ||
| quantize_named_params = [] | ||
| for converted_name, param in converted_named_params: | ||
| # skip bf16 weight_scale and input_scale | ||
| # TODO: find a clearer way. | ||
| if converted_name.endswith("_scale"): | ||
| continue | ||
| quantize_named_params.extend(_quantize_param(converted_name, param)) | ||
|
|
||
| return quantize_named_params | ||
|
|
||
| # shared expert | ||
| shared_expert_pattern = r"mlp.shared_experts\.(.+)" | ||
| match = re.match(shared_expert_pattern, rest) | ||
| if match: | ||
| rest = match.groups()[0] | ||
| if rest in [ | ||
| "linear_fc1.weight", | ||
| "linear_fc2.weight", | ||
| ]: | ||
| quantize_named_params = [] | ||
| for converted_name, param in converted_named_params: | ||
| quantize_named_params.extend(_quantize_param(converted_name, param)) | ||
|
|
||
| return quantize_named_params | ||
|
|
||
| if rest in [ | ||
| "self_attention.linear_proj.weight", | ||
| "self_attention.linear_qkv.weight", | ||
| "mlp.linear_fc1.weight", | ||
| "mlp.linear_fc2.weight", | ||
| # mla | ||
| "self_attention.linear_q_proj.weight", | ||
| "self_attention.linear_q_down_proj.weight", | ||
| "self_attention.linear_q_up_proj.weight", | ||
| "self_attention.linear_kv_down_proj.weight", | ||
| "self_attention.linear_kv_up_proj.weight", | ||
| ]: | ||
| quantize_named_params = [] | ||
| for converted_name, param in converted_named_params: | ||
| quantize_named_params.extend(_quantize_param(converted_name, param)) | ||
|
|
||
| return quantize_named_params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is significant code duplication in how quantization is applied for different layer types (experts, shared experts, and other linear layers). The logic to iterate over converted_named_params and call _quantize_param is repeated.
This could be refactored into a helper function to improve maintainability and readability. For example:
def _apply_quantization(converted_named_params, skip_scales=False):
quantized_params = []
for name, param in converted_named_params:
if skip_scales and name.endswith("_scale"):
continue
quantized_params.extend(_quantize_param(name, param))
return quantized_params
# ... inside quantize_params_mxfp8, you can then determine if quantization is needed
# and call the helper, e.g.:
# if should_quantize:
# return _apply_quantization(converted_named_params, skip_scales=is_expert_layer)| except Exception: # pragma: no cover - optional dependency | ||
| safe_open = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching a broad Exception for an optional import can hide other unexpected errors. It's better to catch the specific ImportError that occurs when the optional dependency is not installed.
| except Exception: # pragma: no cover - optional dependency | |
| safe_open = None | |
| except ImportError: # pragma: no cover - optional dependency | |
| safe_open = None |
| except Exception as exc: # pragma: no cover - optional dependency | ||
| logger.warning( | ||
| "Cannot resolve HF repo id %s (huggingface_hub unavailable): %s", | ||
| path_or_repo, | ||
| exc, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the previous comment, catching a broad Exception for an optional import can mask other issues. It's more precise to catch ImportError here as well.
| except Exception as exc: # pragma: no cover - optional dependency | |
| logger.warning( | |
| "Cannot resolve HF repo id %s (huggingface_hub unavailable): %s", | |
| path_or_repo, | |
| exc, | |
| ) | |
| except ImportError as exc: # pragma: no cover - optional dependency | |
| logger.warning( | |
| "Cannot resolve HF repo id %s (huggingface_hub unavailable): %s", | |
| path_or_repo, | |
| exc, | |
| ) |
@HumansAnd
This PR currently depends on #512 .
Example: