Skip to content

Commit

Permalink
rename ir printing logger + add to stablehlo_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Xida Ren committed Apr 18, 2024
1 parent 40f0102 commit 293a74b
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 12 deletions.
7 changes: 4 additions & 3 deletions projects/pt1/e2e_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,16 @@ def _get_argparse():
def main():
args = _get_argparse().parse_args()

logger = logging.getLogger("e2e_test")
ir_printer = logging.getLogger("ir_printer")
if args.print_ir:
print("WARNING: --print-ir is a work in progress feature.")
print("print-ir: Setting logging level to DEBUG and enabling IR printing.")
print("print-ir: This currently only affects the Linalg-on-Tensors and onnx configs.")
print("print-ir: Work in progress. See https://github.com/llvm/torch-mlir/issues/3172")
logger.setLevel(logging.DEBUG)
ir_printer.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.WARNING)
# disable logging
ir_printer.setLevel(logging.CRITICAL+1)


all_test_unique_names = set(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import logging
logger = logging.getLogger("e2e_test")
ir_printer = logging.getLogger("ir_printer")

from typing import Any

Expand Down Expand Up @@ -34,9 +34,9 @@ def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torchscript.compile(
program, example_args, output_type="linalg-on-tensors")
logger.debug("LinalgOnTensorsBackendTestConfig compiled module:")
logger.debug(module)
logger.debug("End LinalgOnTensorsBackendTestConfig compiled module")
ir_printer.debug("LinalgOnTensorsBackendTestConfig compiled module:")
ir_printer.debug(module)
ir_printer.debug("End LinalgOnTensorsBackendTestConfig compiled module")
return self.backend.compile(module)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch_mlir.dialects import torch as torch_d
from torch_mlir.ir import Context, Module
import logging
logger = logging.getLogger("e2e_test")
ir_printer = logging.getLogger("ir_printer")

def import_onnx(contents):
# Import the ONNX model proto from the file contents:
Expand Down Expand Up @@ -83,9 +83,9 @@ def __init__(self, backend: OnnxBackend, use_make_fx: bool = False):
def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
onnx_module = convert_onnx(program, example_args)
logger.debug("OnnxBackendTestConfig imported module:")
logger.debug(onnx_module)
logger.debug("End OnnxBackendTestConfig imported module")
ir_printer.debug("OnnxBackendTestConfig imported module:")
ir_printer.debug(onnx_module)
ir_printer.debug("End OnnxBackendTestConfig imported module")
compiled_module = self.backend.compile(onnx_module)
return compiled_module

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import logging
ir_printer = logging.getLogger("ir_printer")

from typing import Any

Expand Down Expand Up @@ -31,7 +33,9 @@ def __init__(self, backend: StablehloBackend):
def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torchscript.compile(program, example_args, output_type="stablehlo")

ir_printer.debug("StablehloBackendTestConfig compiled module:")
ir_printer.debug(module)
ir_printer.debug("End StablehloBackendTestConfig compiled module")
return self.backend.compile(module)

def run(self, artifact: Any, trace: Trace) -> Trace:
Expand Down

0 comments on commit 293a74b

Please sign in to comment.