Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
e8f0d91
feat: add --precision fp16 to optimize, build, and export commands
DingmaomaoBJTU Jun 11, 2026
1fa70d4
refactor: integrate FP16 into quantize stage as post-processing
github-actions[bot] Jun 23, 2026
37f12a4
chore: remove spurious .data files
github-actions[bot] Jun 23, 2026
49d2d43
refactor: remove --precision from export/optimize, add fp16 to quantize
github-actions[bot] Jun 23, 2026
b57617b
feat(build): extend --precision to accept all quantization values
github-actions[bot] Jun 23, 2026
3b4e69f
fix: resolve CodeQL import warnings in fp16 module
github-actions[bot] Jun 23, 2026
75be8d3
fix: resolve rebase conflicts with main
github-actions[bot] Jun 23, 2026
6ac9d7f
feat: warn when calibration options are ignored in FP16 mode
github-actions[bot] Jun 23, 2026
4597e07
fix: skip task/model_name validation for fp16_only quant configs
github-actions[bot] Jun 23, 2026
e882dd5
fix: skip calibration validation for rtn and dynamic algorithms
github-actions[bot] Jun 23, 2026
9e99dac
feat: merge --rtn-bits into --precision (int4/w4a16 auto-selects RTN)
github-actions[bot] Jun 23, 2026
762f2d0
fix: build pipeline RTN routing and MatMulNBitsQuantizer model extrac…
github-actions[bot] Jun 23, 2026
43d27b3
fix: resolve lint warnings (raw regex strings, unused variable)
github-actions[bot] Jun 23, 2026
1183861
fix: resolve mypy type errors and remove duplicate imports
github-actions[bot] Jun 23, 2026
b99fabf
fix: address code review findings
github-actions[bot] Jun 23, 2026
5ea661b
fix: address deep code review findings
github-actions[bot] Jun 23, 2026
3f5b410
feat: support w4a32 precision (equivalent to int4) and w4a16 FP16 pos…
github-actions[bot] Jun 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 84 additions & 2 deletions src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,12 @@ def build(

# Force rebuild
winml build -c config.json -m microsoft/resnet-50 -o output/ --rebuild

# Build with INT8 quantization
winml build -m microsoft/resnet-50 -o output/ --precision int8

# Build with mixed precision (INT8 weights, INT8 activations)
winml build -m microsoft/resnet-50 -o output/ --precision w8a8
"""
# Merge top-level -v/-q with subcommand-level flags so either position works.
verbose, quiet = cli_utils.resolve_verbosity(ctx, verbose, quiet)
Expand All @@ -540,6 +546,16 @@ def build(
if not output_dir and not use_cache:
raise click.UsageError("One of --output-dir or --use-cache is required.")

# Validate precision value early for better error messages.
if precision is not None:
from ..config.precision import _is_valid_precision

if not _is_valid_precision(precision.lower()):
raise click.UsageError(
f"Invalid precision '{precision}'. "
"Expected: auto, fp32, fp16, int8, int16, or w{{x}}a{{y}} (e.g., w8a8, w8a16)."
)

# If ep unspecified, resolve the target device and pick the highest-priority
# EP compatible with it. Avoids selecting an EP that does not match the host
# hardware -- analyzing for the wrong EP leaves black nodes that block a
Expand Down Expand Up @@ -615,6 +631,14 @@ def _patch_device(cfg: WinMLBuildConfig) -> None:
# and other calibration settings from the existing config.
cfg.quant.weight_type = resolved_quant.weight_type
cfg.quant.activation_type = resolved_quant.activation_type
cfg.quant.fp16 = resolved_quant.fp16
cfg.quant.fp16_only = resolved_quant.fp16_only
cfg.quant.algorithm = resolved_quant.algorithm
if resolved_quant.algorithm == "rtn":
cfg.quant.rtn_bits = resolved_quant.rtn_bits
cfg.quant.rtn_block_size = resolved_quant.rtn_block_size
cfg.quant.rtn_symmetric = resolved_quant.rtn_symmetric
cfg.quant.rtn_accuracy_level = resolved_quant.rtn_accuracy_level
if cfg.compile is not None and cfg.compile.ep_config is not None:
provider = cfg.compile.ep_config.provider
patched = WinMLCompileConfig.for_provider(provider, device=device)
Expand Down Expand Up @@ -660,6 +684,8 @@ def _patch_device(cfg: WinMLBuildConfig) -> None:
# on the key being present, matching the module-mode path which passes
# allow_unsupported_nodes explicitly regardless of its value.
extra_kwargs["allow_unsupported_nodes"] = allow_unsupported_nodes
if precision:
extra_kwargs["precision"] = precision

if isinstance(config_or_configs, list):
# ---- MODULE MODE: array config, one build per submodule ----
Expand Down Expand Up @@ -1133,6 +1159,59 @@ def _run_quantize_stage(
if config.quant is None:
return current_path

# ── FP16-only fast path (no calibration / QDQ) ───────────────
if config.quant.fp16_only:
with StageLive("fp16", console) as sl:
sl.set_status("Converting to FP16...")
t0 = time.monotonic()
quant_result = quantize_onnx(
model_path=current_path,
output_path=quantized_path,
config=config.quant,
use_external_data=True,
)
if not quant_result.success:
errors = ", ".join(quant_result.errors) if quant_result.errors else "Unknown"
sl.set_error(errors)
raise RuntimeError(f"FP16 conversion failed: {errors}")
current_path = quantized_path
_fp16_elapsed = time.monotonic() - t0
sl.set_done(_fp16_elapsed)
sl.detail("[dim]I/O types preserved as FP32[/dim]")
sl.artifact(str(quantized_path), _safe_size(quantized_path))
sl.blank()
stage_timings.append(("FP16", _fp16_elapsed))
return current_path

# ── RTN weight-only path (no calibration) ────────────────────
if config.quant.algorithm == "rtn":
with StageLive("quantize", console) as sl:
bits = config.quant.rtn_bits
sl.set_status(f"Quantizing (RTN {bits}-bit)...")
t0 = time.monotonic()
quant_result = quantize_onnx(
model_path=current_path,
output_path=quantized_path,
config=config.quant,
use_external_data=True,
)
if not quant_result.success:
errors = ", ".join(quant_result.errors) if quant_result.errors else "Unknown"
sl.set_error(errors)
raise RuntimeError(f"RTN quantization failed: {errors}")
current_path = quantized_path
_rtn_elapsed = time.monotonic() - t0
sl.set_done(_rtn_elapsed)
sl.kv("Algorithm:", f"[cyan]RTN[/cyan] [dim](weight-only {bits}-bit)[/dim]")
sl.kv(
"Config:",
f"block_size={config.quant.rtn_block_size}, symmetric={config.quant.rtn_symmetric}",
)
sl.artifact(str(quantized_path), _safe_size(quantized_path))
sl.blank()
stage_timings.append(("Quantize", _rtn_elapsed))
return current_path

if is_quantized_onnx(current_path):
print_stage_skip(console, "quantize", "(QDQ nodes already present)")
stage_timings.append(("Quantize", None))
Expand Down Expand Up @@ -1366,6 +1445,8 @@ def _name(base: str) -> str:

stage_timings.append(("Export", _export_elapsed))

extra_kwargs.pop("precision", None) # consumed by _patch_device earlier

# ── Optimize stage ───────────────────────────────────────────
current_path, _ = _run_optimize_stage(
config=config,
Expand All @@ -1383,7 +1464,7 @@ def _name(base: str) -> str:
# Persist config after autoconf
config_path.write_text(json.dumps(config.to_dict(), indent=2))

# ── Quantize stage ───────────────────────────────────────────
# ── Quantize stage (handles QDQ + FP16 post-processing) ──────
current_path = _run_quantize_stage(
config=config,
current_path=current_path,
Expand Down Expand Up @@ -1425,6 +1506,7 @@ def _build_onnx_pipeline(

max_iters: int = extra_kwargs.pop("hack_max_optim_iterations", 3)
allow_unsupported_nodes: bool = extra_kwargs.pop("allow_unsupported_nodes", False)
extra_kwargs.pop("precision", None) # consumed by _patch_device earlier

# ── Validate + setup ─────────────────────────────────────────
if not onnx_path.exists():
Expand Down Expand Up @@ -1478,7 +1560,7 @@ def _build_onnx_pipeline(

config_path.write_text(json.dumps(config.to_dict(), indent=2))

# ── Quantize stage ───────────────────────────────────────────
# ── Quantize stage (handles QDQ + FP16 post-processing) ──────
current_path = _run_quantize_stage(
config=config,
current_path=current_path,
Expand Down
142 changes: 136 additions & 6 deletions src/winml/modelkit/commands/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@
console = Console()


def _warn_ignored_calibration_options(ctx: click.Context, reason: str) -> None:
"""Warn if the user passed calibration-related CLI options that are ignored."""
ignored = []
if cli_utils.is_cli_provided(ctx, "samples"):
ignored.append("--samples")
if cli_utils.is_cli_provided(ctx, "method"):
ignored.append("--method")
if cli_utils.is_cli_provided(ctx, "weight_type"):
ignored.append("--weight-type")
if cli_utils.is_cli_provided(ctx, "activation_type"):
ignored.append("--activation-type")
if ignored:
console.print(f"[yellow]Warning:[/yellow] {', '.join(ignored)} ignored — {reason}")


@click.command()
@click.option(
"--model",
Expand All @@ -49,8 +64,10 @@
@cli_utils.output_option("Output path (default: {input}_qdq.onnx)")
@cli_utils.precision_option(
default=None,
help_text="Quantization precision: auto, int8, int16, or w{x}a{y} where "
"x,y in {8,16} (e.g., w8a8, w8a16, w16a16)",
help_text="Quantization precision: auto, fp16, int4, int8, int16, or w{x}a{y} where "
"x in {4,8,16}, y in {8,16} (e.g., w4a16, w8a8, w8a16). "
"int4/w4a16 uses RTN weight-only quantization; "
"fp16 converts all FP32 tensors to FP16 (no QDQ)",
optional_message="Overridden by explicit --weight-type/--activation-type",
)
@click.option(
Expand Down Expand Up @@ -122,11 +139,11 @@ def quantize(
quiet: bool,
config_file: Path | None,
) -> None:
r"""Quantize ONNX model by inserting QDQ nodes.
r"""Quantize ONNX model by inserting QDQ nodes, RTN weight-only, or convert to FP16.

This command applies static quantization to an ONNX model using calibration
data to determine quantization parameters. The output model contains
QuantizeLinear and DequantizeLinear nodes for quantization-aware inference.
This command applies quantization to an ONNX model. The algorithm is
auto-selected from the precision: int4/w4a16 → RTN weight-only,
int8/int16/w8a8 → static QDQ, fp16 → FP16 conversion.

\b
Examples:
Expand All @@ -136,9 +153,15 @@ def quantize(
# Use precision shorthand (same as --weight-type uint8 --activation-type uint8)
winml quantize -m model.onnx --precision int8

# RTN 4-bit weight-only quantization (no calibration data needed)
winml quantize -m model.onnx --precision int4

# Int16 quantization
winml quantize -m model.onnx --precision int16

# Convert model to FP16 (no QDQ, full-model conversion)
winml quantize -m model.onnx --precision fp16

# Custom output path and more samples
winml quantize -m model.onnx -o quantized.onnx --samples 100

Expand Down Expand Up @@ -174,6 +197,106 @@ def quantize(
# Import quantizer (late import to speed up CLI)
from ..quant import WinMLQuantizationConfig, quantize_onnx

# ── FP16 fast path ───────────────────────────────────────────
is_fp16 = precision and precision.lower() == "fp16"

if is_fp16:
_warn_ignored_calibration_options(ctx, "FP16 conversion does not use calibration data.")

# Determine output path
if output is None:
output = model.parent / f"{model.stem}_fp16.onnx"
output.parent.mkdir(parents=True, exist_ok=True)

console.print(f"[bold blue]Input:[/bold blue] {model}")
console.print(f"[bold blue]Output:[/bold blue] {output}")
console.print("[bold blue]Precision:[/bold blue] fp16")

config = WinMLQuantizationConfig(fp16=True, fp16_only=True)

try:
console.print("\n[bold]Converting to FP16...[/bold]")
result = quantize_onnx(model, output_path=output, config=config)

if result.success:
console.print("\n[bold green]Success![/bold green] Model converted to FP16")
console.print(f"[dim]Output: {result.output_path}[/dim]")
console.print(f"[dim]Total time: {result.total_time_seconds:.2f}s[/dim]")
else:
console.print("\n[bold red]FP16 conversion failed:[/bold red]")
for error in result.errors:
console.print(f" {error}")
raise click.ClickException("FP16 conversion failed")

except click.ClickException:
raise
except Exception as e:
console.print(f"\n[bold red]FP16 conversion failed:[/bold red] {e}")
logger.exception("FP16 conversion failed")
raise click.ClickException(f"FP16 conversion failed: {e}") from e

return

# ── Weight-only (RTN) path ───────────────────────────────────
from ..config.precision import (
extract_activation_bits,
extract_weight_bits,
is_weight_only_precision,
)

is_rtn = precision and is_weight_only_precision(precision.lower())

if is_rtn:
_warn_ignored_calibration_options(
ctx, "RTN weight-only quantization does not use calibration data."
)

assert precision is not None # guaranteed by is_rtn check
rtn_bits = extract_weight_bits(precision.lower())
a_bits = extract_activation_bits(precision.lower())

# Determine output path
if output is None:
output = model.parent / f"{model.stem}_int{rtn_bits}.onnx"
output.parent.mkdir(parents=True, exist_ok=True)

console.print(f"[bold blue]Input:[/bold blue] {model}")
console.print(f"[bold blue]Output:[/bold blue] {output}")
console.print(f"[bold blue]Precision:[/bold blue] {precision}")
console.print(f"[bold blue]Algorithm:[/bold blue] RTN (weight-only, {rtn_bits}-bit)")

config = WinMLQuantizationConfig(
algorithm="rtn",
rtn_bits=rtn_bits,
fp16=a_bits == 16,
)

try:
console.print(f"\n[bold]Running RTN {rtn_bits}-bit weight-only quantization...[/bold]")
result = quantize_onnx(model, output_path=output, config=config)

if result.success:
console.print(
f"\n[bold green]Success![/bold green] Model quantized (RTN {rtn_bits}-bit)"
)
console.print(f"[dim]Output: {result.output_path}[/dim]")
console.print(f"[dim]Total time: {result.total_time_seconds:.2f}s[/dim]")
else:
console.print("\n[bold red]RTN quantization failed:[/bold red]")
for error in result.errors:
console.print(f" {error}")
raise click.ClickException("RTN quantization failed")

except click.ClickException:
raise
except Exception as e:
console.print(f"\n[bold red]RTN quantization failed:[/bold red] {e}")
logger.exception("RTN quantization failed")
raise click.ClickException(f"RTN quantization failed: {e}") from e

return

# ── QDQ quantization path ────────────────────────────────────
# Resolve weight/activation types from --precision or explicit flags
resolved_weight, resolved_activation = _resolve_quant_types(
precision, weight_type, activation_type
Expand Down Expand Up @@ -253,7 +376,14 @@ def _resolve_quant_types(
Tuple of (weight_type, activation_type).
"""
from ..config import is_quantized_precision, resolve_quant_types
from ..config.precision import is_weight_only_precision

if precision and is_weight_only_precision(precision.lower()):
# Should not reach here — RTN path returns early above.
raise click.BadParameter(
f"'{precision}' is a weight-only precision (use RTN path).",
param_hint="'-p' / '--precision'",
)
if precision and is_quantized_precision(precision):
default_w, default_a = resolve_quant_types(precision)
elif precision is None or precision.lower() == "auto":
Expand Down
6 changes: 6 additions & 0 deletions src/winml/modelkit/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
)
from .precision import (
PrecisionPolicy,
extract_activation_bits,
extract_weight_bits,
is_quantized_precision,
is_weight_only_precision,
resolve_precision,
resolve_quant_types,
)
Expand All @@ -43,10 +46,13 @@
"PrecisionPolicy",
"SubmoduleClassNotFoundError",
"WinMLBuildConfig",
"extract_activation_bits",
"extract_weight_bits",
"generate_build_config",
"generate_hf_build_config",
"generate_onnx_build_config",
"is_quantized_precision",
"is_weight_only_precision",
"merge_config",
"resolve_precision",
"resolve_quant_compile_config",
Expand Down
Loading
Loading