-
Notifications
You must be signed in to change notification settings - Fork 190
Fix ONNX FP8 scaling #446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix ONNX FP8 scaling #446
Conversation
WalkthroughThe PR refactors type handling for scale and zero-point values in quantization utilities, shifting from numpy arrays to ONNX TensorProto objects at the retrieval stage, with type conversions deferred to points of use. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Rationale: Single file with consistent, homogeneous type refactoring pattern (numpy arrays → ONNX TensorProto). Changes follow a straightforward substitution logic: parameter types updated, conversions relocated to call sites, and tensor metadata accessed appropriately. Control flow unchanged. Review focuses on verifying type correctness and conversion safety. Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
@i-riyad could you please take a look at this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGMT
Signed-off-by: Purushothaman Saravanan <[email protected]>
858238f to
072e837
Compare
What does this PR do?
Type of change: ? Bug Fix
Overview: ?
When converting weight tensors to INT8/FP8, the zero-point array’s datatype was previously validated against ONNX datatypes (
onnx.TensorProto.FLOAT8E4M3FNoronnx.TensorProto.INT8). However, since the zero-point array is a NumPy array, weights were always incorrectly scaled to INT8 for FP8 quantization.This PR fixes that issue by checking the
data_typefield from theonnx.TensorProtoinstead of inferring it from the corresponding NumPy arrays.Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit