Skip to content

Commit

Permalink
Run torch backend pipeline & Fix code
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy committed Apr 23, 2024
1 parent 580d4c0 commit 1947d12
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 8 deletions.
22 changes: 18 additions & 4 deletions python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,28 @@
)


def _simplify_lowering(torch_mod, output_type, verbose):
def _simplify_lowering(verbose, output_type, torch_mod, backend_legal_ops=None,
extra_library_file_name=None):
# TODO: pass backend_legal_ops/extra_library_file_name by caller
if backend_legal_ops is None:
backend_legal_ops = []
if extra_library_file_name is None:
extra_library_file_name = ""
option_string = ("{backend-legal-ops=" + ",".join(backend_legal_ops) +
" extra-library=" + extra_library_file_name + "}")
run_pipeline_with_repro_report(
torch_mod,
f"builtin.module(torch-simplification-pipeline)",
"Simplification pipeline for torch dialect",
enable_ir_printing=verbose,
)
return lower_mlir_module(verbose, torch_mod, output_type)
run_pipeline_with_repro_report(
torch_mod,
f"builtin.module(torch-function-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
enable_ir_printing=verbose,
)
return lower_mlir_module(verbose, output_type, torch_mod)


def export_and_import(
Expand Down Expand Up @@ -74,7 +88,7 @@ def export_and_import(
print(fx_importer.module)
return fx_importer.module
else:
return _simplify_lowering(fx_importer.module, output_type, verbose)
return _simplify_lowering(verbose, output_type, fx_importer.module)


def stateless_fx_import(
Expand All @@ -98,4 +112,4 @@ def stateless_fx_import(
print(fx_importer.module)
return fx_importer.module
else:
return _simplify_lowering(fx_importer.module, output_type, verbose)
return _simplify_lowering(verbose, output_type, fx_importer.module)
2 changes: 1 addition & 1 deletion test/python/fx_importer/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def forward(self):
return torch.full([], False, dtype=torch.bool, layout=torch.strided, device='cpu',
pin_memory=False)

m = fx.export_and_import(Basic(), func_name="test_full", enable_graph_printing=True)
m = fx.export_and_import(Basic(), func_name="test_full")
run_pipeline_with_repro_report(
m,
f"builtin.module(torch-simplification-pipeline)",
Expand Down
2 changes: 2 additions & 0 deletions test/python/fx_importer/v2.3/auto_functionalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def forward(self, x):
# assert (
# AssertionError: Current active mode <torch._subclasses.functional_tensor.FunctionalTensorMode object at 0x7a1106504fd0> not registered
decomposition_table=[],
output_type="raw",
)
# CHECK: %[[TIED:.*]] = torch.operator "torch.torch_mlir_test.inplace_modify"({{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK: torch.aten.mul.Tensor %[[TIED]], %[[TIED]]
Expand Down Expand Up @@ -85,6 +86,7 @@ def forward(self, x):
# assert (
# AssertionError: Current active mode <torch._subclasses.functional_tensor.FunctionalTensorMode object at 0x7a1106504fd0> not registered
decomposition_table=[],
output_type="raw",
)
# CHECK: %[[TIED:.*]]:2 = torch.operator "torch.torch_mlir_test.inplace_modify_calc"(%0) : (!torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>)
# CHECK: torch.aten.mul.Tensor %[[TIED]]#1, %[[TIED]]#0
Expand Down
8 changes: 7 additions & 1 deletion test/python/fx_importer/v2.3/mutation_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def forward(self, x, y):
torch.randn(3, 4),
torch.randn(3, 4),
experimental_support_mutation=True,
output_type="raw",
)
print(m)
m.operation.verify()
Expand All @@ -99,7 +100,10 @@ def forward(self, x):
return x * self.buffer

m = fx.export_and_import(
Basic(), torch.randn(3, 4), experimental_support_mutation=True
Basic(),
torch.randn(3, 4),
experimental_support_mutation=True,
output_type="raw",
)
print(m)
m.operation.verify()
Expand Down Expand Up @@ -145,6 +149,7 @@ def forward(self, x):
torch.randn(3, 4),
experimental_support_mutation=True,
hooks=ExternalBufferHooks(),
output_type="raw",
)
print(m)
m.operation.verify()
Expand All @@ -168,6 +173,7 @@ def forward(self, x):
Basic(),
torch.randn(3, 4),
experimental_support_mutation=True,
output_type="raw",
)
except NotImplementedError as e:
print("EXPECTED ERROR:", str(e))
4 changes: 2 additions & 2 deletions test/python/fx_importer/v2.3/special_forms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def forward(self, x):
# CHECK: torch.aten.clone %arg0, %none
return torch.ops.aten.lift_fresh_copy.default(x)

m = fx.export_and_import(Basic(), torch.randn(3, 4))
m = fx.export_and_import(Basic(), torch.randn(3, 4), output_type="raw")
print(m)


Expand All @@ -70,5 +70,5 @@ def forward(self, x):
# CHECK: return %[[NONE]], %[[RES]]#1, %[[RES]]#2
return torch.ops.torch_mlir_test.multi_return(x)

m = fx.export_and_import(Basic(), torch.randn(3, 4))
m = fx.export_and_import(Basic(), torch.randn(3, 4), output_type="raw")
print(m)

0 comments on commit 1947d12

Please sign in to comment.