Skip to content

Commit

Permalink
Backend-legal-ops argument for fx lowering (#3956)
Browse files Browse the repository at this point in the history
Added `backend-legal-ops` argument in `fx.import_and_export` to stop
decomposition of certain torch ops. This PR is based on this
[issue](#3953)
  • Loading branch information
Abhishek-TyRnT authored Jan 21, 2025
1 parent f42c7e4 commit 2cc31d6
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,32 @@ def _module_lowering(
output_type,
torch_mod,
extra_library_file_name=None,
backend_legal_ops=None,
):

if output_type == OutputType.RAW:
if verbose:
print(torch_mod)
return torch_mod
# TODO: pass extra_library_file_name by caller

backend_legal_op_arg_str = ""
if backend_legal_ops is not None:
if not len(backend_legal_ops) == 0:
backend_legal_op_arg_str = "backend-legal-ops=" + ",".join(
backend_legal_ops
)

if extra_library_file_name is None:
extra_library_file_name = ""
option_string = "{extra-library=" + extra_library_file_name + "}"
option_string = (
"{"
+ backend_legal_op_arg_str
+ " extra-library="
+ extra_library_file_name
+ "}"
)

run_pipeline_with_repro_report(
torch_mod,
f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})",
Expand All @@ -61,6 +77,7 @@ def export_and_import(
func_name: str = "main",
enable_graph_printing: bool = False,
enable_ir_printing: bool = False,
backend_legal_ops: Optional[list[str]] = None,
**kwargs,
):
context = ir.Context()
Expand Down Expand Up @@ -98,7 +115,10 @@ def export_and_import(
)

return _module_lowering(
enable_ir_printing, OutputType.get(output_type), fx_importer.module
enable_ir_printing,
OutputType.get(output_type),
fx_importer.module,
backend_legal_ops=backend_legal_ops,
)


Expand All @@ -110,6 +130,7 @@ def stateless_fx_import(
model_name: str = "main",
enable_graph_printing: bool = False,
enable_ir_printing: bool = False,
backend_legal_ops: Optional[list[str]] = None,
):
if enable_graph_printing:
gm.print_readable()
Expand All @@ -119,5 +140,8 @@ def stateless_fx_import(
fx_importer = FxImporter(context=context, hooks=hooks)
fx_importer.import_stateless_graph(gm.graph, func_name=model_name)
return _module_lowering(
enable_ir_printing, OutputType.get(output_type), fx_importer.module
enable_ir_printing,
OutputType.get(output_type),
fx_importer.module,
backend_legal_ops=backend_legal_ops,
)

0 comments on commit 2cc31d6

Please sign in to comment.