-
Notifications
You must be signed in to change notification settings - Fork 318
Fix batch norm folding in prepare_pt2e
for multiple conv->BN chains sharing the same conv weights
#2795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2795
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 95e561e with merge base 72b35bf ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -22,6 +22,26 @@ def forward(self, x): | |||
return x | |||
|
|||
|
|||
class ToyCNNModel(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: better to be more specific I think, like ConvWithSharedWeightInExportedModel
torchao/quantization/pt2e/utils.py
Outdated
) | ||
|
||
fused_convs_weight_nodes.add(conv_weight_node) | ||
del fused_convs_weight_nodes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: probably no need to delete this explicitly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see some nit comments inline
Fixed, thanks. |
@subhankarpal has imported this pull request. If you are a Meta employee, you can view this in D80546723. |
@subhankarpal no need for manual import, our oncall will import this in diff train |
Summary
For a model with multiple conv->BN chains that have the same conv weights,
prepare_pt2e
applies batch norm folding repeatedly for each chain, causing incorrect conv weights in the output prepared model. This fix ensures batch norm folding is performed only once per unique conv weight set, preventing multiple folds and preserving correct model behavior.Testplan