From 944a6df611232307349e797f8fff7c479a63b883 Mon Sep 17 00:00:00 2001
From: penguin_wwy <940375606@qq.com>
Date: Sat, 27 Apr 2024 18:27:37 +0800
Subject: [PATCH] Extract the Python APIs in the pt1 dir back to the root
 (#3237)

---
 projects/pt1/python/CMakeLists.txt            |   1 -
 .../pt1/python/torch_mlir/compiler_utils.py   |  75 --------
 projects/pt1/python/torch_mlir/torchscript.py | 105 +----------
 .../configs/fx_importer_backend.py            |  11 +-
 .../configs/torchdynamo.py                    |  10 +-
 .../onnx_backends/linalg_on_tensors.py        |  10 +-
 python/CMakeLists.txt                         |   1 +
 python/torch_mlir/compiler_utils.py           | 166 ++++++++++++++++++
 8 files changed, 191 insertions(+), 188 deletions(-)
 delete mode 100644 projects/pt1/python/torch_mlir/compiler_utils.py
 create mode 100644 python/torch_mlir/compiler_utils.py

diff --git a/projects/pt1/python/CMakeLists.txt b/projects/pt1/python/CMakeLists.txt
index 642b86b50490..443fcc809e2c 100644
--- a/projects/pt1/python/CMakeLists.txt
+++ b/projects/pt1/python/CMakeLists.txt
@@ -20,7 +20,6 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel
   SOURCES
     torchscript.py
     _dynamo_fx_importer.py
-    compiler_utils.py
     dynamo.py
     _version.py
 )
diff --git a/projects/pt1/python/torch_mlir/compiler_utils.py b/projects/pt1/python/torch_mlir/compiler_utils.py
deleted file mode 100644
index 7792006032af..000000000000
--- a/projects/pt1/python/torch_mlir/compiler_utils.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-# Also available under a BSD-style license. See LICENSE.
-
-from io import StringIO
-import os
-import sys
-import tempfile
-
-from torch_mlir.passmanager import PassManager
-from torch_mlir.ir import StringAttr
-
-
-def get_module_name_for_debug_dump(module):
-    """Gets a name suitable for a debug dump.
-
-    The name is not guaranteed to be unique.
-    """
-    if not "torch.debug_module_name" in module.operation.attributes:
-        return "UnnammedModule"
-    return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
-
-
-class TorchMlirCompilerError(Exception):
-    pass
-
-def run_pipeline_with_repro_report(module,
-                                   pipeline: str,
-                                   description: str,
-                                   enable_ir_printing: bool = False):
-    """Runs `pipeline` on `module`, with a nice repro report if it fails."""
-    module_name = get_module_name_for_debug_dump(module)
-    try:
-        original_stderr = sys.stderr
-        sys.stderr = StringIO()
-        asm_for_error_report = module.operation.get_asm(
-            large_elements_limit=10, enable_debug_info=True)
-        # Lower module in place to make it ready for compiler backends.
-        with module.context as ctx:
-            pm = PassManager.parse(pipeline)
-            if enable_ir_printing:
-                ctx.enable_multithreading(False)
-                pm.enable_ir_printing()
-            pm.run(module.operation)
-    except Exception as e:
-        # TODO: More robust.
-        # - don't arbitrarily clutter up /tmp. When a test suite has many
-        #   tests, this can be a big disk cost (also, /tmp/ is frequently a
-        #   RAM fs, which increases worries about capacity).
-        # - don't have colliding filenames (hard to do without cluttering
-        #   up /tmp)
-        # - if we do have have colliding filenames, writes should at least
-        #   avoid being racy.
-        filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
-        with open(filename, 'w') as f:
-            f.write(asm_for_error_report)
-        debug_options="-mlir-print-ir-after-all -mlir-disable-threading"
-        # Put something descriptive here even if description is empty.
-        description = description or f"{module_name} compile"
-
-        message = f"""\
-            {description} failed with the following diagnostics:
-            {sys.stderr.getvalue()}
-
-            python exception: {e}
-
-            For Torch-MLIR developers, the error can be reproduced with:
-            $ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
-            Add '{debug_options}' to get the IR dump for debugging purpose.
-            """
-        trimmed_message = '\n'.join([m.lstrip() for m in message.split('\n')])
-        raise TorchMlirCompilerError(trimmed_message) from None
-    finally:
-        sys.stderr = original_stderr
diff --git a/projects/pt1/python/torch_mlir/torchscript.py b/projects/pt1/python/torch_mlir/torchscript.py
index acb487319ae9..508297cfe8f0 100644
--- a/projects/pt1/python/torch_mlir/torchscript.py
+++ b/projects/pt1/python/torch_mlir/torchscript.py
@@ -17,65 +17,15 @@
 from torch_mlir.dynamo import _get_decomposition_table
 from torch.fx.experimental.proxy_tensor import make_fx
 
-from .compiler_utils import run_pipeline_with_repro_report
+from torch_mlir.compiler_utils import (
+    run_pipeline_with_repro_report,
+    OutputType,
+    lower_mlir_module
+)
 from torch_mlir.jit_ir_importer import ClassAnnotator, ImportOptions, ModuleBuilder
 from torch_mlir.jit_ir_importer.build_tools.library_generator import generate_library
 
 
-class OutputType(Enum):
-    """The kind of output that `torchscript.compile` can produce.
-
-    In MLIR terminology, this describes the mix of dialects that will be
-    produced by the conversion process.
-
-    In user-facing API's, this type can always be passed interchangeably with an
-    appropriate string specifying the output type. The allowed strings are
-    the set of enum vales, allowed to be case insensitive and with `-` allowed
-    in place of `_`. The `OutputType.get` static method can be used to convert
-    from a string to an `OutputType` instance.
-    """
-
-    # This output type consists of `torch` dialect ops that have been converted
-    # maximally to value semantics, decomposed, and shapes have been inferred.
-    TORCH = "torch"
-
-    # The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
-    # `arith` ops (and also `math` and `tm_tensor`). It can be thought of
-    # as taking the `TORCH` output type and lowering it so that tensor
-    # computations are done with `linalg`-on-tensors ops.
-    LINALG_ON_TENSORS = "linalg-on-tensors"
-
-    # This output type consists of `tosa` dialect ops. It can be thought of
-    # as taking the `TORCH` output type and lowering it to TOSA.
-    TOSA = "tosa"
-
-    # This output type consists of `stablehlo` dialect ops. It can be thought of
-    # as taking the `TORCH` output type and lowering it to StableHLO.
-    STABLEHLO = "stablehlo"
-
-    # Raw output of the JIT IR importer. This is not expected to be useful
-    # for end-users, but can be convenient for development or reporting bugs.
-    RAW = "raw"
-
-    @staticmethod
-    def get(spec: Union[str, "OutputType"]) -> "OutputType":
-        """Gets an OutputType from allowed way to specify one.
-
-        Args:
-          spec: An OutputType instance or the case-insensitive name of one of the
-            enum values.
-        Returns:
-          An OutputType instance.
-        """
-        if isinstance(spec, OutputType):
-            return spec
-        spec = spec.upper().replace("-", "_")
-        if spec not in OutputType.__members__:
-            raise ValueError(f"For output_type= argument, expected one of: "
-                             f"{', '.join(OutputType.__members__.keys())}")
-        return OutputType[spec]
-
-
 class TensorPlaceholder:
     """A class that represents a formal parameter of a given shape and dtype.
 
@@ -270,49 +220,6 @@ def _canon_extra_library(extra_library, extra_library_file_name="custom_op_extra
         return ""
 
 
-def _lower_mlir_module(verbose, output_type, module):
-    if verbose:
-        print("\n====================")
-        print("Torch Backend IR")
-        print(module)
-
-    if output_type == OutputType.TORCH:
-        return module
-
-    if output_type == OutputType.TOSA:
-        run_pipeline_with_repro_report(
-            module, "builtin.module(torch-backend-to-tosa-backend-pipeline)",
-            "Lowering Torch Backend IR -> TOSA Backend IR")
-        if verbose:
-            print("\n====================")
-            print("TOSA Backend IR")
-            print(module)
-        return module
-
-    if output_type == OutputType.LINALG_ON_TENSORS:
-        run_pipeline_with_repro_report(
-            module,
-            "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
-            "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
-        if verbose:
-            print("\n====================")
-            print("LINALG Backend IR")
-            print(module)
-        return module
-
-    elif output_type == OutputType.STABLEHLO:
-        run_pipeline_with_repro_report(
-            module,
-            "builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
-            "Lowering Torch Backend IR -> StableHLO Backend IR")
-        if verbose:
-            print("\n====================")
-            print("StableHLO Backend IR")
-            print(module)
-        return module
-    raise Exception(f"Unknown OutputType: {output_type}")
-
-
 def compile(model: torch.nn.Module,
             example_args: _example_args,
             output_type: Union[str, "OutputType"] = OutputType.TORCH,
@@ -464,4 +371,4 @@ def compile(model: torch.nn.Module,
         enable_ir_printing=enable_ir_printing,
     )
 
-    return _lower_mlir_module(verbose, output_type, mb.module)
+    return lower_mlir_module(verbose, output_type, mb.module)
diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py
index 0d75fe2ad3f0..e45c7b18bb7a 100644
--- a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py
+++ b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py
@@ -12,12 +12,13 @@
 from torch.export import ExportedProgram
 
 from torch_mlir import fx
-from torch_mlir.torchscript import (
-    _example_args,
+from torch_mlir.compiler_utils import (
+    run_pipeline_with_repro_report,
+    lower_mlir_module,
     OutputType,
+)
+from torch_mlir.torchscript import (
     BACKEND_LEGAL_OPS,
-    run_pipeline_with_repro_report,
-    _lower_mlir_module,
     _canon_extra_library,
 )
 from torch_mlir_e2e_test.configs.utils import (
@@ -76,7 +77,7 @@ def jit(
         "Lowering TorchFX IR -> Torch Backend IR",
     )
 
-    return _lower_mlir_module(verbose, output_type, mlir_module)
+    return lower_mlir_module(verbose, output_type, mlir_module)
 
 
 class FxImporterTestConfig(TestConfig):
diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py
index bdc410741cae..13f4d3df863f 100644
--- a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py
+++ b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py
@@ -15,14 +15,16 @@
     set_model_name,
 )
 
+from torch_mlir.compiler_utils import (
+    run_pipeline_with_repro_report,
+    lower_mlir_module,
+    OutputType,
+)
 from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func
 from torch_mlir.dynamo import _get_decomposition_table
 from torch_mlir.torchscript import (
     _example_args,
-    OutputType,
     BACKEND_LEGAL_OPS,
-    run_pipeline_with_repro_report,
-    _lower_mlir_module,
     _canon_extra_library,
 )
 from torch_mlir_e2e_test.configs.utils import (
@@ -148,7 +150,7 @@ def my_aot_autograd_backend(gm: torch.fx.GraphModule,
             "Lowering TorchFX IR -> Torch Backend IR",
         )
 
-    return _lower_mlir_module(verbose, output_type, mlir_module)
+    return lower_mlir_module(verbose, output_type, mlir_module)
 
 
 class TorchDynamoTestConfig(TestConfig):
diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py
index 449e6bb40f01..fcd1efb3f4d6 100644
--- a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py
+++ b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py
@@ -4,11 +4,13 @@
 # Also available under a BSD-style license. See LICENSE.
 
 
-from torch_mlir.compiler_utils import run_pipeline_with_repro_report
+from torch_mlir.compiler_utils import (
+    run_pipeline_with_repro_report,
+    lower_mlir_module,
+    OutputType,
+)
 from torch_mlir.ir import *
 from torch_mlir.passmanager import *
-from torch_mlir.torchscript import OutputType
-from torch_mlir.torchscript import _lower_mlir_module
 
 from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
 
@@ -58,7 +60,7 @@ def compile(self, imported_module: Module):
             "Lowering TorchFX IR -> Torch Backend IR",
         )
 
-        imported_module = _lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module)
+        imported_module = lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module)
         compiled_module = self.refbackend.compile(imported_module)
         return compiled_module
 
diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt
index e52135599864..76cdbcca41eb 100644
--- a/python/CMakeLists.txt
+++ b/python/CMakeLists.txt
@@ -43,6 +43,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI
   ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}"
   ADD_TO_PARENT TorchMLIRPythonSources
   SOURCES
+    compiler_utils.py
     fx.py
     extras/fx_decomp_util.py
 )
diff --git a/python/torch_mlir/compiler_utils.py b/python/torch_mlir/compiler_utils.py
new file mode 100644
index 000000000000..6416b88aab5f
--- /dev/null
+++ b/python/torch_mlir/compiler_utils.py
@@ -0,0 +1,166 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+# Also available under a BSD-style license. See LICENSE.
+from enum import Enum
+from io import StringIO
+import os
+import sys
+import tempfile
+from typing import Union
+
+from torch_mlir.passmanager import PassManager
+from torch_mlir.ir import StringAttr
+
+
+def get_module_name_for_debug_dump(module):
+    """Gets a name suitable for a debug dump.
+
+    The name is not guaranteed to be unique.
+    """
+    if not "torch.debug_module_name" in module.operation.attributes:
+        return "UnnammedModule"
+    return StringAttr(module.operation.attributes["torch.debug_module_name"]).value
+
+
+class TorchMlirCompilerError(Exception):
+    pass
+
+def run_pipeline_with_repro_report(module,
+                                   pipeline: str,
+                                   description: str,
+                                   enable_ir_printing: bool = False):
+    """Runs `pipeline` on `module`, with a nice repro report if it fails."""
+    module_name = get_module_name_for_debug_dump(module)
+    original_stderr = sys.stderr
+    try:
+        sys.stderr = StringIO()
+        asm_for_error_report = module.operation.get_asm(
+            large_elements_limit=10, enable_debug_info=True)
+        # Lower module in place to make it ready for compiler backends.
+        with module.context as ctx:
+            pm = PassManager.parse(pipeline)
+            if enable_ir_printing:
+                ctx.enable_multithreading(False)
+                pm.enable_ir_printing()
+            pm.run(module.operation)
+    except Exception as e:
+        # TODO: More robust.
+        # - don't arbitrarily clutter up /tmp. When a test suite has many
+        #   tests, this can be a big disk cost (also, /tmp/ is frequently a
+        #   RAM fs, which increases worries about capacity).
+        # - don't have colliding filenames (hard to do without cluttering
+        #   up /tmp)
+        # - if we do have have colliding filenames, writes should at least
+        #   avoid being racy.
+        filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
+        with open(filename, 'w') as f:
+            f.write(asm_for_error_report)
+        debug_options="-mlir-print-ir-after-all -mlir-disable-threading"
+        # Put something descriptive here even if description is empty.
+        description = description or f"{module_name} compile"
+
+        message = f"""\
+            {description} failed with the following diagnostics:
+            {sys.stderr.getvalue()}
+
+            python exception: {e}
+
+            For Torch-MLIR developers, the error can be reproduced with:
+            $ torch-mlir-opt -pass-pipeline='{pipeline}' {filename}
+            Add '{debug_options}' to get the IR dump for debugging purpose.
+            """
+        trimmed_message = '\n'.join([m.lstrip() for m in message.split('\n')])
+        raise TorchMlirCompilerError(trimmed_message) from None
+    finally:
+        sys.stderr = original_stderr
+
+
+class OutputType(Enum):
+
+    # Output torch dialect. When converting from FX, this will be immediately
+    # after the import from FX to MLIR. When converting from torchscript,
+    # this will come after some cleanup passes which attempt to de-alias,
+    # decompose and infer shapes. These should be roughly the same level of
+    # abstraction since those steps are done within PyTorch itself
+    # when coming directly from Dynamo/FX.
+    TORCH = "torch"
+
+    # The output type contains a mix of `linalg`-on-tensors ops, `scf`, and
+    # `arith` ops (and also `math` and `tm_tensor`). It can be thought of
+    # as taking the `TORCH` output type and lowering it so that tensor
+    # computations are done with `linalg`-on-tensors ops.
+    LINALG_ON_TENSORS = "linalg-on-tensors"
+
+    # This output type consists of `tosa` dialect ops. It can be thought of
+    # as taking the `TORCH` output type and lowering it to TOSA.
+    TOSA = "tosa"
+
+    # This output type consists of `stablehlo` dialect ops. It can be thought of
+    # as taking the `TORCH` output type and lowering it to StableHLO.
+    STABLEHLO = "stablehlo"
+
+    # Raw output of the JIT IR importer. This is not expected to be useful
+    # for end-users, but can be convenient for development or reporting bugs.
+    RAW = "raw"
+
+    @staticmethod
+    def get(spec: Union[str, "OutputType"]) -> "OutputType":
+        """Gets an OutputType from allowed way to specify one.
+
+        Args:
+          spec: An OutputType instance or the case-insensitive name of one of the
+            enum values.
+        Returns:
+          An OutputType instance.
+        """
+        if isinstance(spec, OutputType):
+            return spec
+        spec = spec.upper().replace("-", "_")
+        if spec not in OutputType.__members__:
+            raise ValueError(f"For output_type= argument, expected one of: "
+                             f"{', '.join(OutputType.__members__.keys())}")
+        return OutputType[spec]
+
+
+def lower_mlir_module(verbose, output_type, module):
+    if verbose:
+        print("\n====================")
+        print("Torch Backend IR")
+        print(module)
+
+    if output_type == OutputType.TORCH:
+        return module
+
+    if output_type == OutputType.TOSA:
+        run_pipeline_with_repro_report(
+            module, "builtin.module(torch-backend-to-tosa-backend-pipeline)",
+            "Lowering Torch Backend IR -> TOSA Backend IR")
+        if verbose:
+            print("\n====================")
+            print("TOSA Backend IR")
+            print(module)
+        return module
+
+    if output_type == OutputType.LINALG_ON_TENSORS:
+        run_pipeline_with_repro_report(
+            module,
+            "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
+            "Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
+        if verbose:
+            print("\n====================")
+            print("LINALG Backend IR")
+            print(module)
+        return module
+
+    elif output_type == OutputType.STABLEHLO:
+        run_pipeline_with_repro_report(
+            module,
+            "builtin.module(torch-backend-to-stablehlo-backend-pipeline)",
+            "Lowering Torch Backend IR -> StableHLO Backend IR")
+        if verbose:
+            print("\n====================")
+            print("StableHLO Backend IR")
+            print(module)
+        return module
+    raise Exception(f"Unknown OutputType: {output_type}")