Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rfc] Use logging.getLogger for projects/pt1/e2e_testing #3173

Closed
wants to merge 12 commits into from
43 changes: 29 additions & 14 deletions projects/pt1/e2e_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Also available under a BSD-style license. See LICENSE.

import argparse
import logging
import re
import sys

Expand Down Expand Up @@ -71,10 +72,11 @@ def _get_argparse():
parser.add_argument("-f", "--filter", default=".*", help="""
Regular expression specifying which tests to include in this run.
""")
parser.add_argument("-v", "--verbose",
default=False,
action="store_true",
help="report test results with additional detail")
parser.add_argument("--log_level", default="WARNING", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="set the log level")
parser.add_argument("-d", "--debug", action="store_const", dest="log_level", const="DEBUG", help="set log level to DEBUG for detailed debug output")
parser.add_argument("-v", "--verbose", action="store_const", dest="log_level", const="INFO", help="set log level to INFO to report a more detailed but still user-friendly level of verbosity")
parser.add_argument("-q", "--quiet", action="store_const", dest="log_level", const="ERROR", help="suppress all logs except errors")

parser.add_argument("-s", "--sequential",
default=False,
action="store_true",
Expand All @@ -93,6 +95,18 @@ def _get_argparse():
def main():
args = _get_argparse().parse_args()

logger = logging.getLogger() # use root logger by default. Easy to change later.
logger.setLevel(logging.NOTSET)
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(args.log_level)
if args.log_level != "DEBUG":
fmt = "%(levelname)s: %(message)s"
else:
fmt = "%(levelname)s: %(filename)s:%(lineno)d:\n%(message)s"
formatter = logging.Formatter(fmt)
handler.setFormatter(formatter)
logger.addHandler(handler)

all_test_unique_names = set(
test.unique_name for test in GLOBAL_TEST_REGISTRY)

Expand Down Expand Up @@ -143,31 +157,32 @@ def main():
if args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed is not None:
for arg in args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed:
if arg not in all_test_unique_names:
print(f"ERROR: --crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed argument '{arg}' is not a valid test name")
logger.error(f"--crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed argument '{arg}' is not a valid test name")
sys.exit(1)

# Find the selected tests, and emit a diagnostic if none are found.
tests = [
test for test in available_tests
if re.match(args.filter, test.unique_name)
]
available_tests = [test.unique_name for test in available_tests]
if len(tests) == 0:
print(
f"ERROR: the provided filter {args.filter!r} does not match any tests"
logger.error(
f"the provided filter {args.filter!r} does not match any tests. The available tests are:\n" + "\n\t".join(available_tests)"
)
print("The available tests are:")
for test in available_tests:
print(test.unique_name)
sys.exit(1)

# Run the tests.
results = run_tests(tests, config, args.sequential, args.verbose)
results = run_tests(tests, config, args.sequential,
verbose=logger.level >= logging.INFO)

# Report the test results.
failed = report_results(results, xfail_set, args.verbose, args.config)
failed = report_results(results, xfail_set,
verbose=logger.level >= logging.INFO,
config=args.config)
if args.config == "torchdynamo":
print("\033[91mWarning: the TorchScript based dynamo support is deprecated. "
"The config for torchdynamo is planned to be removed in the future.\033[0m")
logger.warning("the TorchScript based dynamo support is deprecated. "
"The config for torchdynamo is planned to be removed in the future.")
if args.ignore_failures:
sys.exit(0)
sys.exit(1 if failed else 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import logging
logger = logging.getLogger()

from typing import Any

Expand Down Expand Up @@ -32,7 +34,7 @@ def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torchscript.compile(
program, example_args, output_type="linalg-on-tensors")

logger.debug("MLIR produced by LinalgOnTensorsBackendTestConfig:\n" + str(module))
return self.backend.compile(module)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from torch_mlir.extras import onnx_importer
from torch_mlir.dialects import torch as torch_d
from torch_mlir.ir import Context, Module

import logging
logger = logging.getLogger()

def import_onnx(contents):
# Import the ONNX model proto from the file contents:
Expand All @@ -39,7 +40,7 @@ def import_onnx(contents):
return m


def convert_onnx(model, inputs):
def convert_onnx(model: torch.nn.Module, inputs):
buffer = io.BytesIO()

# Process the type information so we export with the dynamic shape information
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(self, backend: OnnxBackend, use_make_fx: bool = False):
def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
onnx_module = convert_onnx(program, example_args)
logger.debug("MLIR produced by OnnxBackendTestConfig:\n" + str(onnx_module))
compiled_module = self.backend.compile(onnx_module)
return compiled_module

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import logging
logger = logging.getLogger()

from typing import Any

Expand Down Expand Up @@ -31,7 +33,7 @@ def __init__(self, backend: StablehloBackend):
def compile(self, program: torch.nn.Module) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
module = torchscript.compile(program, example_args, output_type="stablehlo")

logger.debug("MLIR produced by StablehloBackendTestConfig:\n" + str(module))
return self.backend.compile(module)

def run(self, artifact: Any, trace: Trace) -> Trace:
Expand Down
15 changes: 9 additions & 6 deletions projects/pt1/python/torch_mlir_e2e_test/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
Utilities for reporting the results of the test framework.
"""

from logging import getLogger
logger = getLogger()

from typing import Any, List, Optional, Set

import collections
Expand Down Expand Up @@ -292,17 +295,17 @@ def report_results(results: List[TestResult],
expected_failure = result.unique_name in expected_failures
if expected_failure:
if report.failed:
print(f'XFAIL - "{result.unique_name}"')
logger.info(f'XFAIL - "{result.unique_name}"')
results_by_outcome['XFAIL'].append((result, report))
else:
print(f'XPASS - "{result.unique_name}"')
logger.info(f'XPASS - "{result.unique_name}"')
results_by_outcome['XPASS'].append((result, report))
else:
if not report.failed:
print(f'PASS - "{result.unique_name}"')
logger.info(f'PASS - "{result.unique_name}"')
results_by_outcome['PASS'].append((result, report))
else:
print(f'FAIL - "{result.unique_name}"')
logger.info(f'FAIL - "{result.unique_name}"')
results_by_outcome['FAIL'].append((result, report))

OUTCOME_MEANINGS = collections.OrderedDict()
Expand All @@ -329,8 +332,8 @@ def report_results(results: List[TestResult],
for result, report in results:
print(f' {outcome} - "{result.unique_name}"')
# If the test failed, print the error message.
if outcome == 'FAIL' and verbose:
print(textwrap.indent(report.error_str(), ' ' * 8))
if outcome == 'FAIL':
logger.info(textwrap.indent(report.error_str(), ' ' * 8))

# Print a summary for easy scanning.
print('\nSummary:')
Expand Down
Loading