diff --git a/test/stablehlo/test_export_fx_passes.py b/test/stablehlo/test_export_fx_passes.py index 82650997316..580b756e177 100644 --- a/test/stablehlo/test_export_fx_passes.py +++ b/test/stablehlo/test_export_fx_passes.py @@ -212,7 +212,8 @@ def forward(self, x, dim, weight, bias, eps): after_decomp_out = native_layer_norm_impl(*args) self.assertTrue( torch.allclose(before_decomp_out, after_decomp_out, atol=1e-6)) - ep = export(m, args, dynamic_shapes=dynamic_shapes) + ep_training = export(m, args, dynamic_shapes=dynamic_shapes) + ep = ep_training.run_decompositions({}) decompose_dynamic_native_layer_norm(ep.graph_module) ep.graph_module.recompile() self.assertFalse('aten.native_layer_norm' in ep.graph_module.code)