diff --git a/python/cutracer/cli.py b/python/cutracer/cli.py index 08de56b..b32b829 100644 --- a/python/cutracer/cli.py +++ b/python/cutracer/cli.py @@ -6,11 +6,11 @@ Provides command-line interface for trace validation and analysis. """ -import argparse import sys from importlib.metadata import PackageNotFoundError, version -from cutracer.validation.cli import _add_validate_args, validate_command +import click +from cutracer.validation.cli import validate_command def _get_package_version() -> str: @@ -21,49 +21,23 @@ def _get_package_version() -> str: return "0+unknown" -def main() -> int: - """Main CLI entry point.""" - pkg_version = _get_package_version() - - parser = argparse.ArgumentParser( - prog="cutraceross", - description="CUTracer: CUDA trace validation and analysis tools", - epilog=( - "Examples:\n" - " cutraceross validate kernel_trace.ndjson\n" - " cutraceross validate kernel_trace.ndjson.zst --verbose\n" - " cutraceross validate trace.log --format text\n" - ), - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - parser.add_argument( - "--version", - action="version", - version=f"%(prog)s {pkg_version}", - help="Show program's version number and exit", - ) +EXAMPLES = """ +Examples: + cutraceross validate kernel_trace.ndjson + cutraceross validate kernel_trace.ndjson.zst --verbose + cutraceross validate trace.log --format text +""" - subparsers = parser.add_subparsers(dest="command", required=True) - # validate subcommand - validate_parser = subparsers.add_parser( - "validate", - help="Validate a CUTracer trace file", - description=( - "Validate a CUTracer trace file.\n\n" - "Checks syntax and schema compliance for NDJSON, Zstd-compressed,\n" - "and text format trace files." - ), - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - _add_validate_args(validate_parser) - validate_parser.set_defaults(func=validate_command) +@click.group(epilog=EXAMPLES) +@click.version_option(version=_get_package_version(), prog_name="cutraceross") +def main() -> None: + """CUTracer: CUDA trace validation and analysis tools.""" + pass - # Parse arguments - args = parser.parse_args() - # Execute command - return args.func(args) +# Register subcommands +main.add_command(validate_command) if __name__ == "__main__": diff --git a/python/cutracer/validation/cli.py b/python/cutracer/validation/cli.py index 5b836c7..61104f0 100644 --- a/python/cutracer/validation/cli.py +++ b/python/cutracer/validation/cli.py @@ -6,50 +6,17 @@ This module provides command-line interface for validating CUTracer trace files. """ -import argparse import json import sys from pathlib import Path from typing import Any +import click + from .json_validator import validate_json_trace from .text_validator import validate_text_trace -def _add_validate_args(parser: argparse.ArgumentParser) -> None: - """Add arguments for the validate subcommand.""" - parser.add_argument( - "file", - type=Path, - help="Path to the trace file to validate", - ) - parser.add_argument( - "--format", - "-f", - choices=["json", "text", "auto"], - default="auto", - help="File format. Default: auto-detect from extension.", - ) - parser.add_argument( - "--quiet", - "-q", - action="store_true", - help="Quiet mode. Only return exit code.", - ) - parser.add_argument( - "--json", - dest="json_output", - action="store_true", - help="Output results in JSON format.", - ) - parser.add_argument( - "--verbose", - "-v", - action="store_true", - help="Verbose output with additional details.", - ) - - def _detect_format(file_path: Path) -> str: """Auto-detect file format from extension.""" suffixes = "".join(file_path.suffixes).lower() @@ -87,43 +54,76 @@ def _format_trace_format(result: dict[str, Any]) -> str: def _print_validation_result(result: dict[str, Any], verbose: bool = False) -> None: """Print validation result in human-readable format.""" if result["valid"]: - print("\u2705 Valid trace file") - print(f" Format: {_format_trace_format(result)}") - print(f" Records: {result['record_count']}") + click.echo("\u2705 Valid trace file") + click.echo(f" Format: {_format_trace_format(result)}") + click.echo(f" Records: {result['record_count']}") if result.get("message_type"): - print(f" Message type: {result['message_type']}") + click.echo(f" Message type: {result['message_type']}") if result.get("file_size"): - print(f" File size: {_format_size(result['file_size'])}") + click.echo(f" File size: {_format_size(result['file_size'])}") if verbose and result.get("compression") == "zstd": - print(" Compression: zstd") + click.echo(" Compression: zstd") else: - print("\u274c Validation failed") + click.echo("\u274c Validation failed") for error in result.get("errors", []): - print(f" {error}") - - -def validate_command(args: argparse.Namespace) -> int: - """Execute the validate subcommand.""" - file_path: Path = args.file - - # Check file exists - if not file_path.exists(): - if not args.quiet: - print(f"Error: File not found: {file_path}", file=sys.stderr) - return 2 + click.echo(f" {error}") + + +@click.command(name="validate") +@click.argument("file", type=click.Path(exists=True, path_type=Path)) +@click.option( + "--format", + "-f", + "file_format", + type=click.Choice(["json", "text", "auto"]), + default="auto", + help="File format. Default: auto-detect from extension.", +) +@click.option( + "--quiet", + "-q", + is_flag=True, + help="Quiet mode. Only return exit code.", +) +@click.option( + "--json", + "json_output", + is_flag=True, + help="Output results in JSON format.", +) +@click.option( + "--verbose", + "-v", + is_flag=True, + help="Verbose output with additional details.", +) +def validate_command( + file: Path, + file_format: str, + quiet: bool, + json_output: bool, + verbose: bool, +) -> None: + """Validate a CUTracer trace file. + + Checks syntax and schema compliance for NDJSON, Zstd-compressed, + and text format trace files. + + FILE is the path to the trace file to validate. + """ + file_path = file # Detect format - file_format = args.format if file_format == "auto": file_format = _detect_format(file_path) if file_format == "unknown": - if not args.quiet: - print( + if not quiet: + click.echo( f"Error: Cannot auto-detect format for {file_path}. " "Use --format to specify.", - file=sys.stderr, + err=True, ) - return 2 + sys.exit(2) # Run validation if file_format == "json": @@ -132,17 +132,17 @@ def validate_command(args: argparse.Namespace) -> int: result = validate_text_trace(file_path) # Handle quiet mode - if args.quiet: - return 0 if result["valid"] else 1 + if quiet: + sys.exit(0 if result["valid"] else 1) # Handle JSON output - if args.json_output: + if json_output: # Convert Path objects to strings for JSON serialization output = {k: str(v) if isinstance(v, Path) else v for k, v in result.items()} - print(json.dumps(output, indent=2)) - return 0 if result["valid"] else 1 + click.echo(json.dumps(output, indent=2)) + sys.exit(0 if result["valid"] else 1) # Human-readable output - _print_validation_result(result, args.verbose) + _print_validation_result(result, verbose) - return 0 if result["valid"] else 1 + sys.exit(0 if result["valid"] else 1) diff --git a/python/pyproject.toml b/python/pyproject.toml index 82653d4..8156873 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -16,6 +16,7 @@ readme = "README.md" license = "MIT" dependencies = [ + "click>=8.0.0", "jsonschema>=4.0.0", "zstandard>=0.20.0", "importlib_resources>=5.0.0; python_version < '3.11'", diff --git a/python/tests/test_cli.py b/python/tests/test_cli.py index f8459d7..53d0b35 100644 --- a/python/tests/test_cli.py +++ b/python/tests/test_cli.py @@ -6,8 +6,8 @@ import sys import unittest from pathlib import Path -from unittest.mock import patch +from click.testing import CliRunner from cutracer.cli import main from cutracer.validation.cli import _detect_format, _format_size @@ -48,159 +48,137 @@ class ValidateCommandTest(unittest.TestCase): def setUp(self): self.test_dir = Path(__file__).parent / "example_inputs" + self.runner = CliRunner() def test_validate_valid_json(self): """Test validating a valid NDJSON file.""" - with patch( - "sys.argv", - ["cutraceross", "validate", str(self.test_dir / "reg_trace_sample.ndjson")], - ): - exit_code = main() - self.assertEqual(exit_code, 0) + result = self.runner.invoke( + main, ["validate", str(self.test_dir / "reg_trace_sample.ndjson")] + ) + self.assertEqual(result.exit_code, 0) def test_validate_valid_json_zst(self): """Test validating a valid Zstd-compressed NDJSON file.""" - with patch( - "sys.argv", - [ - "cutraceross", - "validate", - str(self.test_dir / "reg_trace_sample.ndjson.zst"), - ], - ): - exit_code = main() - self.assertEqual(exit_code, 0) + result = self.runner.invoke( + main, ["validate", str(self.test_dir / "reg_trace_sample.ndjson.zst")] + ) + self.assertEqual(result.exit_code, 0) def test_validate_valid_text(self): """Test validating a valid text log file.""" - with patch( - "sys.argv", - ["cutraceross", "validate", str(self.test_dir / "reg_trace_sample.log")], - ): - exit_code = main() - self.assertEqual(exit_code, 0) + result = self.runner.invoke( + main, ["validate", str(self.test_dir / "reg_trace_sample.log")] + ) + self.assertEqual(result.exit_code, 0) def test_validate_invalid_syntax(self): """Test validating a file with invalid JSON syntax.""" - with patch( - "sys.argv", - ["cutraceross", "validate", str(self.test_dir / "invalid_syntax.ndjson")], - ): - exit_code = main() - self.assertEqual(exit_code, 1) + result = self.runner.invoke( + main, ["validate", str(self.test_dir / "invalid_syntax.ndjson")] + ) + self.assertEqual(result.exit_code, 1) def test_validate_invalid_schema(self): """Test validating a file with schema errors.""" - with patch( - "sys.argv", - ["cutraceross", "validate", str(self.test_dir / "invalid_schema.ndjson")], - ): - exit_code = main() - self.assertEqual(exit_code, 1) + result = self.runner.invoke( + main, ["validate", str(self.test_dir / "invalid_schema.ndjson")] + ) + self.assertEqual(result.exit_code, 1) def test_validate_quiet_mode(self): """Test quiet mode returns only exit code.""" - with patch( - "sys.argv", - [ - "cutraceross", - "validate", - "--quiet", - str(self.test_dir / "reg_trace_sample.ndjson"), - ], - ): - exit_code = main() - self.assertEqual(exit_code, 0) + result = self.runner.invoke( + main, + ["validate", "--quiet", str(self.test_dir / "reg_trace_sample.ndjson")], + ) + self.assertEqual(result.exit_code, 0) + # Quiet mode should produce no output + self.assertEqual(result.output.strip(), "") def test_validate_json_output(self): """Test JSON output format.""" - with patch( - "sys.argv", - [ - "cutraceross", - "validate", - "--json", - str(self.test_dir / "reg_trace_sample.ndjson"), - ], - ): - exit_code = main() - self.assertEqual(exit_code, 0) + result = self.runner.invoke( + main, + ["validate", "--json", str(self.test_dir / "reg_trace_sample.ndjson")], + ) + self.assertEqual(result.exit_code, 0) + # Should contain JSON output + self.assertIn('"valid"', result.output) def test_validate_file_not_found(self): """Test error handling for non-existent file.""" - with patch("sys.argv", ["cutraceross", "validate", "/nonexistent/file.ndjson"]): - exit_code = main() - self.assertEqual(exit_code, 2) + result = self.runner.invoke(main, ["validate", "/nonexistent/file.ndjson"]) + self.assertEqual(result.exit_code, 2) def test_validate_unknown_format(self): """Test error handling for unknown format.""" - with patch( - "sys.argv", - [ - "cutraceross", - "validate", - str(self.test_dir / "reg_trace_sample.ndjson").replace( - ".ndjson", ".unknown" - ), - ], - ): - exit_code = main() - self.assertEqual(exit_code, 2) + # Create a temporary file with unknown extension + unknown_file = self.test_dir / "reg_trace_sample.unknown" + # Copy content from existing file for the test + if not unknown_file.exists(): + import shutil + + shutil.copy(self.test_dir / "reg_trace_sample.ndjson", unknown_file) + + try: + result = self.runner.invoke(main, ["validate", str(unknown_file)]) + self.assertEqual(result.exit_code, 2) + self.assertIn("Cannot auto-detect format", result.output) + finally: + # Cleanup + if unknown_file.exists(): + unknown_file.unlink() def test_validate_explicit_format_json(self): """Test explicit --format json option.""" - with patch( - "sys.argv", + result = self.runner.invoke( + main, [ - "cutraceross", "validate", "--format", "json", str(self.test_dir / "reg_trace_sample.ndjson"), ], - ): - exit_code = main() - self.assertEqual(exit_code, 0) + ) + self.assertEqual(result.exit_code, 0) def test_validate_explicit_format_text(self): """Test explicit --format text option.""" - with patch( - "sys.argv", + result = self.runner.invoke( + main, [ - "cutraceross", "validate", "--format", "text", str(self.test_dir / "reg_trace_sample.log"), ], - ): - exit_code = main() - self.assertEqual(exit_code, 0) + ) + self.assertEqual(result.exit_code, 0) class MainEntryPointTest(unittest.TestCase): """Tests for main entry point.""" + def setUp(self): + self.runner = CliRunner() + def test_version_flag(self): """Test --version flag.""" - with patch("sys.argv", ["cutraceross", "--version"]): - with self.assertRaises(SystemExit) as cm: - main() - self.assertEqual(cm.exception.code, 0) + result = self.runner.invoke(main, ["--version"]) + self.assertEqual(result.exit_code, 0) + self.assertIn("cutraceross", result.output) def test_help_flag(self): """Test --help flag.""" - with patch("sys.argv", ["cutraceross", "--help"]): - with self.assertRaises(SystemExit) as cm: - main() - self.assertEqual(cm.exception.code, 0) + result = self.runner.invoke(main, ["--help"]) + self.assertEqual(result.exit_code, 0) + self.assertIn("validate", result.output) def test_no_command(self): """Test error when no command is provided.""" - with patch("sys.argv", ["cutraceross"]): - with self.assertRaises(SystemExit) as cm: - main() - self.assertNotEqual(cm.exception.code, 0) + result = self.runner.invoke(main, []) + # Click group with required subcommand returns exit code 2 when no command provided + self.assertEqual(result.exit_code, 2) class ModuleEntryPointTest(unittest.TestCase): diff --git a/scripts/parse_instr_hist_trace.py b/scripts/parse_instr_hist_trace.py index 3547a3a..1e30e9c 100644 --- a/scripts/parse_instr_hist_trace.py +++ b/scripts/parse_instr_hist_trace.py @@ -1,11 +1,12 @@ #!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. -import argparse + import json import os import re import sys +import click import pandas as pd @@ -362,60 +363,90 @@ def merge_traces( print(f"Successfully merged data and saved to {output_path}") -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Parse and merge trace files from Chrome's tracer and CUTRICER.", - formatter_class=argparse.RawTextHelpFormatter, - ) +@click.command() +@click.option( + "--chrome-trace", + "chrome_trace_input", + type=click.Path(exists=True), + help="Path to the Chrome trace JSON file.", +) +@click.option( + "--cutracer-trace", + "cutracer_trace_input", + type=click.Path(exists=True), + help="Path to the CUTRICER histogram CSV file.", +) +@click.option( + "--cutracer-log", + "cutracer_log_input", + type=click.Path(exists=True), + help="Path to the CUTRICER log file to enable merge mode.", +) +@click.option( + "--kernel-hash", + "kernel_hash_hex", + help="Optional kernel hash (e.g., 0x7fa21c3) to select a specific launch from the log.", +) +@click.option( + "--output", + required=True, + type=click.Path(), + help="Path for the output CSV file.", +) +def main( + chrome_trace_input, + cutracer_trace_input, + cutracer_log_input, + kernel_hash_hex, + output, +): + """Parse and merge trace files from Chrome's tracer and CUTRICER. - parser.add_argument( - "--chrome-trace", - dest="chrome_trace_input", - help="Path to the Chrome trace JSON file.", - ) - parser.add_argument( - "--cutracer-trace", - dest="cutracer_trace_input", - help="Path to the CUTRICER histogram CSV file.", - ) - parser.add_argument( - "--cutracer-log", - dest="cutracer_log_input", - help="Path to the CUTRICER log file to enable merge mode.", - ) - parser.add_argument( - "--kernel-hash", - dest="kernel_hash_hex", - help="Optional kernel hash (e.g., 0x7fa21c3) to select a specific launch from the log.", - ) - parser.add_argument("--output", required=True, help="Path for the output CSV file.") + Supports three modes: - args = parser.parse_args() + \b + 1. Merge mode (requires --cutracer-log, --chrome-trace, and --cutracer-trace): + Merges Chrome trace, CUTRICER histogram, and log data. - # --- Main Logic --- - if args.cutracer_log_input: + \b + 2. Chrome trace only (--chrome-trace): + Parses a Chrome trace JSON file to CSV. + + \b + 3. CUTRICER histogram only (--cutracer-trace): + Parses a CUTRICER histogram CSV file. + """ + if cutracer_log_input: # Merge mode - if not all([args.chrome_trace_input, args.cutracer_trace_input]): - parser.error("--cutracer-log requires --chrome-trace and --cutracer-trace.") + if not all([chrome_trace_input, cutracer_trace_input]): + raise click.UsageError( + "--cutracer-log requires --chrome-trace and --cutracer-trace." + ) merge_traces( - args.chrome_trace_input, - args.cutracer_trace_input, - args.cutracer_log_input, - args.output, - args.kernel_hash_hex, + chrome_trace_input, + cutracer_trace_input, + cutracer_log_input, + output, + kernel_hash_hex, ) - elif args.chrome_trace_input: + elif chrome_trace_input: # Standalone Chrome trace parsing - df = get_chrome_trace_df(args.chrome_trace_input) + df = get_chrome_trace_df(chrome_trace_input) if df is not None: - df.to_csv(args.output, index=False) - print(f"Successfully parsed Chrome trace and saved to {args.output}") - elif args.cutracer_trace_input: + df.to_csv(output, index=False) + print(f"Successfully parsed Chrome trace and saved to {output}") + elif cutracer_trace_input: # Standalone CUTRICER hist parsing - df = get_cutracer_hist_df(args.cutracer_trace_input) + df = get_cutracer_hist_df(cutracer_trace_input) if df is not None: - df.to_csv(args.output, index=False) - print(f"Successfully parsed CUTRICER histogram and saved to {args.output}") + df.to_csv(output, index=False) + print(f"Successfully parsed CUTRICER histogram and saved to {output}") else: - parser.print_help() + raise click.UsageError( + "At least one of --chrome-trace, --cutracer-trace, or --cutracer-log is required." + ) + + +if __name__ == "__main__": + main()