Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 75 additions & 65 deletions benchmarks/microbenchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

import torch

from benchmarks.microbenchmarks.profiler import (
generate_model_profile,
)
from benchmarks.microbenchmarks.utils import (
BenchmarkConfig,
BenchmarkResult,
Expand All @@ -29,70 +32,77 @@

def run(config: BenchmarkConfig) -> BenchmarkResult:
"""Run inference benchmarks"""
clean_caches() # Clean caches

# Create output directory if it doesn't exist
Path(config.output_dir).mkdir(parents=True, exist_ok=True)

base_model, input_data = create_model_and_input(
config.model_type,
config.m,
config.k,
config.n,
high_precision_dtype=config.high_precision_dtype,
device=config.device,
)

# Use quantize_ to apply each quantization function to the model
m_copy = deepcopy(base_model).eval().to(config.device)
ao_base_config = string_to_config(
config.quantization,
config.sparsity,
high_precision_dtype=config.high_precision_dtype,
)

# Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA)
is_cuda = config.device == "cuda" and torch.cuda.is_available()

if config.sparsity is not None and (
config.quantization is None or "baseline" in config.quantization
):
if is_cuda:
print(f"Applying {config.sparsity} sparsity to model")
sparsify_(m_copy, ao_base_config)
try:
clean_caches() # Clean caches

# Create output directory if it doesn't exist
Path(config.output_dir).mkdir(parents=True, exist_ok=True)

base_model, input_data = create_model_and_input(
config.model_type,
config.m,
config.k,
config.n,
high_precision_dtype=config.high_precision_dtype,
device=config.device,
)

# Use quantize_ to apply each quantization function to the model
m_copy = deepcopy(base_model).eval().to(config.device)
ao_base_config = string_to_config(
config.quantization,
config.sparsity,
high_precision_dtype=config.high_precision_dtype,
)

# Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA)
is_cuda = config.device == "cuda" and torch.cuda.is_available()

if config.sparsity is not None and (
config.quantization is None or "baseline" in config.quantization
):
if is_cuda:
print(f"Applying {config.sparsity} sparsity to model")
sparsify_(m_copy, ao_base_config)
else:
print(
f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}"
)
elif config.sparsity is None and (
config.quantization is None or "baseline" in config.quantization
):
pass # No quantization or sparsity specified, do nothing
else:
print(
f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}"
print("Quantizing model....")
quantize_(m_copy, ao_base_config)

if config.use_torch_compile:
print("Compiling model....")
m_copy = torch.compile(
m_copy, mode=config.torch_compile_mode, fullgraph=True
)
elif config.sparsity is None and (
config.quantization is None or "baseline" in config.quantization
):
pass # No quantization or sparsity specified, do nothing
else:
print("Quantizing model....")
quantize_(m_copy, ao_base_config)

if config.use_torch_compile:
print("Compiling model....")
m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True)

# Run benchmarks
result = BenchmarkResult(config=config)

# Benchmark time to run an inference call for quantized model
result.model_inference_time_in_ms = model_inference_time_in_ms(
model=m_copy, input_data=input_data
)

# TODO: Benchmark time using profiler
# Profile dtype model evaluation
# prof_dtype = benchmark_model_op_with_profiler_in_microseconds(m_copy, input_data, quantized_dtype)
# prof_dtype.export_chrome_trace(f"{quantization}_model_{input_data[0].size()[0]}.json") # Save profiling details

# TODO: Benchmark gemm time using cuda graph
# gemm_time = benchmark_torch_function_in_microseconds(gemm_op, *args, **kwargs)

# TODO: Benchmark op with cuda graph
# time = benchmark_op_with_cuda_graph(op, args)

return result

# Run benchmarks
result = BenchmarkResult(config=config)
# Store result in model for memory profiling
m_copy._benchmark_result = result

# Benchmark time to run an inference call for quantized model
result.model_inference_time_in_ms = model_inference_time_in_ms(
model=m_copy, input_data=input_data
)

# Run profiler if enabled
if config.enable_profiler:
print("Running profiler...")
try:
result.profiler_json_path = generate_model_profile(
m_copy, input_data, config.profiler_file_name
)
except Exception as e:
print(f"Error running profiler for {config.name} with error: {e}")

return result
except Exception as e:
print(f"Error in benchmark run: {config.name} with error: {e}")
return None
19 changes: 11 additions & 8 deletions benchmarks/microbenchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,16 +164,19 @@ def run_inference_benchmarks_from_config(configs: List[BenchmarkConfig]) -> None
f"Running: {config.name} for Quantization: {config.quantization} and Sparsity: {config.sparsity}"
)
result = run_inference(config) # Pass the config object directly
results.append(result)
except Exception:
print(f"Error running benchmark {config.name}")
if result is not None: # Only add successful results
results.append(result)
except Exception as e:
print(f"Error running benchmark {config.name} with error: {e}")
continue

# Add results to csv
generate_results_csv(results, configs[0].output_dir)

# Print results
print_results(results)
# Add results to csv if there are any
if results:
generate_results_csv(results, configs[0].output_dir)
# Print results
print_results(results)
else:
print("No benchmark results were collected. All benchmarks failed.")

# TODO: Process results: Speedups:
# 1. For different shapes for same model and quantization
Expand Down
60 changes: 60 additions & 0 deletions benchmarks/microbenchmarks/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import os

import torch
from torch.profiler import ProfilerActivity


def generate_model_profile(model, input_data, profile_file_path):
"""Function to benchmark model evaluation with profiling.

Args:
model: The model to profile
input_data: Input data for the model
profile_file_path: Path to save the profiler output

Returns:
profile_file_path
"""
# Create parent directory if it doesn't exist
os.makedirs(os.path.dirname(profile_file_path), exist_ok=True)

# Set up profiler activities based on device
activities = [ProfilerActivity.CPU]
device = next(model.parameters()).device
if device.type == "cuda" and torch.cuda.is_available():
activities.append(ProfilerActivity.CUDA)

# Warm up
with torch.no_grad():
for _ in range(3):
_ = model(input_data)
if device.type == "cuda":
torch.cuda.synchronize()

# Run profiler with minimal settings to ensure compatibility
with torch.profiler.profile(
activities=activities,
record_shapes=True,
with_stack=True,
profile_memory=True,
with_flops=True, # Experimental; might be unreliable for some layers
) as prof:
with torch.no_grad():
for _ in range(3):
_ = model(input_data)
if device.type == "cuda":
torch.cuda.synchronize()

# Save profiling details
prof.export_chrome_trace(profile_file_path)
print(f"Chrome trace saved at: {profile_file_path}")
print("You can now visualize it using:")
print("1. Chrome Trace Viewer: chrome://tracing")
print("2. Perfetto UI: https://ui.perfetto.dev")

return profile_file_path
35 changes: 8 additions & 27 deletions benchmarks/microbenchmarks/test/benchmark_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,27 @@
benchmark_mode: "inference"
quantization_config_recipe_names:
# Will run a baseline inference for model by default, without quantization for comparison
- "int4wo-32"
- "marlin"
sparsity_config_recipe_names:
- "int8wo"
- "int8dq"
- "float8dq"
- "float8wo"
# sparsity_config_recipe_names:
# Will run a baseline inference for model by default, without sparsity for comparison
- "semi-sparse"
- "block"
# - "semi-sparse"
# - "block"
output_dir: "benchmarks/microbenchmarks/results"
model_params:
- name: "small_bf16_linear"
matrix_shapes:
- name: "custom"
shapes: [
[1024, 1024, 1024], # [m, k, n]
]
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "linear"

- name: "large_bf16_ln_linear"
matrix_shapes:
- name: "custom"
shapes: [
[2048, 4096, 1024],
[4096, 4096, 1024]
]
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "ln_linear_sigmoid"

- name: "cpu_fp32_linear"
matrix_shapes:
- name: "custom"
shapes: [
[4096, 4096, 1024]
]
high_precision_dtype: "torch.float32"
use_torch_compile: false
device: "cpu"
model_type: "linear"
enable_profiler: true # Enable profiling for this model
Loading
Loading