Skip to content

Fix batch norm folding in prepare_pt2e for multiple conv->BN chains sharing the same conv weights #2795

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 19, 2025

Conversation

subhankarpal
Copy link
Contributor

Summary

For a model with multiple conv->BN chains that have the same conv weights, prepare_pt2e applies batch norm folding repeatedly for each chain, causing incorrect conv weights in the output prepared model. This fix ensures batch norm folding is performed only once per unique conv weight set, preventing multiple folds and preserving correct model behavior.

Testplan

python test/quantization/pt2e/test_quantize_pt2e.py -k test_chunked_bn_fusion

Copy link

pytorch-bot bot commented Aug 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2795

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 95e561e with merge base 72b35bf (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 19, 2025
@@ -22,6 +22,26 @@ def forward(self, x):
return x


class ToyCNNModel(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: better to be more specific I think, like ConvWithSharedWeightInExportedModel

)

fused_convs_weight_nodes.add(conv_weight_node)
del fused_convs_weight_nodes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: probably no need to delete this explicitly

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see some nit comments inline

@jerryzh168 jerryzh168 added topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) pt2e_quant pt2 export quantization (prepare_pt2e, convert_pt2e) labels Aug 19, 2025
@subhankarpal
Copy link
Contributor Author

see some nit comments inline

Fixed, thanks.

@subhankarpal subhankarpal merged commit 9473060 into main Aug 19, 2025
18 checks passed
@facebook-github-bot
Copy link
Contributor

@subhankarpal has imported this pull request. If you are a Meta employee, you can view this in D80546723.

@jerryzh168
Copy link
Contributor

@subhankarpal no need for manual import, our oncall will import this in diff train

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. pt2e_quant pt2 export quantization (prepare_pt2e, convert_pt2e) topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants