diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 5309f57379f9..192533729d94 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -29,6 +29,7 @@ def _module_lowering( output_type, torch_mod, extra_library_file_name=None, + backend_legal_ops=None, ): if output_type == OutputType.RAW: @@ -36,9 +37,24 @@ def _module_lowering( print(torch_mod) return torch_mod # TODO: pass extra_library_file_name by caller + + backend_legal_op_arg_str = "" + if backend_legal_ops is not None: + if not len(backend_legal_ops) == 0: + backend_legal_op_arg_str = "backend-legal-ops=" + ",".join( + backend_legal_ops + ) + if extra_library_file_name is None: extra_library_file_name = "" - option_string = "{extra-library=" + extra_library_file_name + "}" + option_string = ( + "{" + + backend_legal_op_arg_str + + " extra-library=" + + extra_library_file_name + + "}" + ) + run_pipeline_with_repro_report( torch_mod, f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})", @@ -61,6 +77,7 @@ def export_and_import( func_name: str = "main", enable_graph_printing: bool = False, enable_ir_printing: bool = False, + backend_legal_ops: Optional[list[str]] = None, **kwargs, ): context = ir.Context() @@ -98,7 +115,10 @@ def export_and_import( ) return _module_lowering( - enable_ir_printing, OutputType.get(output_type), fx_importer.module + enable_ir_printing, + OutputType.get(output_type), + fx_importer.module, + backend_legal_ops=backend_legal_ops, ) @@ -110,6 +130,7 @@ def stateless_fx_import( model_name: str = "main", enable_graph_printing: bool = False, enable_ir_printing: bool = False, + backend_legal_ops: Optional[list[str]] = None, ): if enable_graph_printing: gm.print_readable() @@ -119,5 +140,8 @@ def stateless_fx_import( fx_importer = FxImporter(context=context, hooks=hooks) fx_importer.import_stateless_graph(gm.graph, func_name=model_name) return _module_lowering( - enable_ir_printing, OutputType.get(output_type), fx_importer.module + enable_ir_printing, + OutputType.get(output_type), + fx_importer.module, + backend_legal_ops=backend_legal_ops, )