diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index c3ffc660d..a5da9ff28 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -529,6 +529,12 @@ def build( # Force rebuild winml build -c config.json -m microsoft/resnet-50 -o output/ --rebuild + + # Build with INT8 quantization + winml build -m microsoft/resnet-50 -o output/ --precision int8 + + # Build with mixed precision (INT8 weights, INT8 activations) + winml build -m microsoft/resnet-50 -o output/ --precision w8a8 """ # Merge top-level -v/-q with subcommand-level flags so either position works. verbose, quiet = cli_utils.resolve_verbosity(ctx, verbose, quiet) @@ -540,6 +546,16 @@ def build( if not output_dir and not use_cache: raise click.UsageError("One of --output-dir or --use-cache is required.") + # Validate precision value early for better error messages. + if precision is not None: + from ..config.precision import _is_valid_precision + + if not _is_valid_precision(precision.lower()): + raise click.UsageError( + f"Invalid precision '{precision}'. " + "Expected: auto, fp32, fp16, int8, int16, or w{{x}}a{{y}} (e.g., w8a8, w8a16)." + ) + # If ep unspecified, resolve the target device and pick the highest-priority # EP compatible with it. Avoids selecting an EP that does not match the host # hardware -- analyzing for the wrong EP leaves black nodes that block a @@ -615,6 +631,15 @@ def _patch_device(cfg: WinMLBuildConfig) -> None: # and other calibration settings from the existing config. cfg.quant.weight_type = resolved_quant.weight_type cfg.quant.activation_type = resolved_quant.activation_type + cfg.quant.mode = resolved_quant.mode + if resolved_quant.mode == "rtn": + cfg.quant.rtn_bits = resolved_quant.rtn_bits + cfg.quant.rtn_block_size = resolved_quant.rtn_block_size + cfg.quant.rtn_symmetric = resolved_quant.rtn_symmetric + cfg.quant.rtn_accuracy_level = resolved_quant.rtn_accuracy_level + # Store the original precision string for multi-pass expansion + if precision: + cfg.precision = precision.lower() if cfg.compile is not None and cfg.compile.ep_config is not None: provider = cfg.compile.ep_config.provider patched = WinMLCompileConfig.for_provider(provider, device=device) @@ -660,6 +685,8 @@ def _patch_device(cfg: WinMLBuildConfig) -> None: # on the key being present, matching the module-mode path which passes # allow_unsupported_nodes explicitly regardless of its value. extra_kwargs["allow_unsupported_nodes"] = allow_unsupported_nodes + if precision: + extra_kwargs["precision"] = precision if isinstance(config_or_configs, list): # ---- MODULE MODE: array config, one build per submodule ---- @@ -1112,10 +1139,11 @@ def _run_quantize_stage( quantized_path: Path, stage_timings: list[tuple[str, float | None]], ) -> Path: - """Run the quantize stage inside a StageLive context (if quant is configured). + """Run the quantize stage (if quant is configured). - Handles QDQ skip detection, shows dataset/calibration/precision details, - and appends timing to stage_timings. + Delegates multi-pass expansion (e.g., w4a16 → [int4, fp16]) entirely to + ``quantize_onnx(precision=...)``. The cmd layer only handles UI display + and the QDQ skip check. Args: config: Build configuration. @@ -1133,35 +1161,49 @@ def _run_quantize_stage( if config.quant is None: return current_path - if is_quantized_onnx(current_path): + # QDQ skip check: if model already has QDQ nodes and we're doing static/dynamic + if config.quant.mode in ("static", "dynamic") and is_quantized_onnx(current_path): print_stage_skip(console, "quantize", "(QDQ nodes already present)") stage_timings.append(("Quantize", None)) return current_path - with StageLive("quantize", console) as sl: - wt = config.quant.weight_type - sl.set_status(f"Quantizing ({wt})...") - # Calibration info before blocking call - ds = config.quant.dataset_name or "default" - sl.kv( - "Dataset:", - f"[cyan]{ds}[/cyan] [dim]({config.quant.task or 'unknown'})[/dim]", - ) - sl.kv( - "Calibration:", - f"[cyan]{config.quant.samples}[/cyan] samples" - f" [dim]({config.quant.calibration_method})[/dim]", - ) - # Suppress tqdm/datasets progress bars during quantize - # to keep Live display clean + # Determine stage label from precision/algorithm + precision = config.precision + is_fp16_only = precision and precision.lower() == "fp16" + stage_label = "fp16" if is_fp16_only else "quantize" + stage_name = "FP16" if is_fp16_only else "Quantize" + + with StageLive(stage_label, console) as sl: + # Show status based on what we're about to do + if is_fp16_only: + sl.set_status("Converting to FP16...") + elif config.quant.mode == "rtn": + sl.set_status(f"Quantizing (RTN {config.quant.rtn_bits}-bit)...") + if precision and "a16" in precision.lower(): + sl.detail("[dim]Multi-pass: int4 → fp16[/dim]") + else: + sl.set_status(f"Quantizing ({config.quant.weight_type})...") + ds = config.quant.dataset_name or "default" + sl.kv( + "Dataset:", + f"[cyan]{ds}[/cyan] [dim]({config.quant.task or 'unknown'})[/dim]", + ) + sl.kv( + "Calibration:", + f"[cyan]{config.quant.samples}[/cyan] samples" + f" [dim]({config.quant.calibration_method})[/dim]", + ) + + # Suppress tqdm/datasets progress bars for QDQ calibration _datasets_available = False - try: - import datasets + if config.quant.mode in ("static", "dynamic"): + try: + import datasets - datasets.disable_progress_bars() - _datasets_available = True - except ImportError: - pass # datasets package not installed; progress bar suppression not needed + datasets.disable_progress_bars() + _datasets_available = True + except ImportError: + pass t0 = time.monotonic() try: @@ -1169,31 +1211,47 @@ def _run_quantize_stage( model_path=current_path, output_path=quantized_path, config=config.quant, + precision=precision, use_external_data=True, ) finally: if _datasets_available: + import datasets + datasets.enable_progress_bars() + if not quant_result.success: errors = ", ".join(quant_result.errors) if quant_result.errors else "Unknown" sl.set_error(errors) - raise RuntimeError(f"Quantization failed: {errors}") - current_path = quantized_path - _quant_elapsed = time.monotonic() - t0 - sl.set_done(_quant_elapsed) - sl.kv( - "Precision:", - f"[cyan]{config.quant.weight_type}/" - f"{config.quant.activation_type}[/cyan]" - f" [dim](weight/activation)[/dim]", - ) - sl.artifact( - str(quantized_path), - _safe_size(quantized_path), - ) + raise RuntimeError(f"{stage_name} failed: {errors}") + + elapsed = time.monotonic() - t0 + sl.set_done(elapsed) + + # Show algorithm-specific result details + if is_fp16_only: + sl.detail("[dim]I/O types preserved as FP32[/dim]") + elif config.quant.mode == "rtn": + sl.kv( + "Algorithm:", + f"[cyan]RTN[/cyan] [dim](weight-only {config.quant.rtn_bits}-bit)[/dim]", + ) + sl.kv( + "Config:", + f"block_size={config.quant.rtn_block_size}, symmetric={config.quant.rtn_symmetric}", + ) + else: + sl.kv( + "Precision:", + f"[cyan]{config.quant.weight_type}/{config.quant.activation_type}[/cyan]" + f" [dim](weight/activation)[/dim]", + ) + + sl.artifact(str(quantized_path), _safe_size(quantized_path)) sl.blank() - stage_timings.append(("Quantize", _quant_elapsed)) - return current_path + + stage_timings.append((stage_name, elapsed)) + return quantized_path def _run_compile_stage( @@ -1366,6 +1424,8 @@ def _name(base: str) -> str: stage_timings.append(("Export", _export_elapsed)) + extra_kwargs.pop("precision", None) # consumed by _patch_device earlier + # ── Optimize stage ─────────────────────────────────────────── current_path, _ = _run_optimize_stage( config=config, @@ -1383,7 +1443,7 @@ def _name(base: str) -> str: # Persist config after autoconf config_path.write_text(json.dumps(config.to_dict(), indent=2)) - # ── Quantize stage ─────────────────────────────────────────── + # ── Quantize stage (handles QDQ + FP16 post-processing) ────── current_path = _run_quantize_stage( config=config, current_path=current_path, @@ -1425,6 +1485,7 @@ def _build_onnx_pipeline( max_iters: int = extra_kwargs.pop("hack_max_optim_iterations", 3) allow_unsupported_nodes: bool = extra_kwargs.pop("allow_unsupported_nodes", False) + extra_kwargs.pop("precision", None) # consumed by _patch_device earlier # ── Validate + setup ───────────────────────────────────────── if not onnx_path.exists(): @@ -1478,7 +1539,7 @@ def _build_onnx_pipeline( config_path.write_text(json.dumps(config.to_dict(), indent=2)) - # ── Quantize stage ─────────────────────────────────────────── + # ── Quantize stage (handles QDQ + FP16 post-processing) ────── current_path = _run_quantize_stage( config=config, current_path=current_path, diff --git a/src/winml/modelkit/commands/quantize.py b/src/winml/modelkit/commands/quantize.py index 234dff7ef..58524b4e4 100644 --- a/src/winml/modelkit/commands/quantize.py +++ b/src/winml/modelkit/commands/quantize.py @@ -38,6 +38,11 @@ console = Console() +def _warn_ignored_calibration_options(ctx: click.Context, reason: str) -> None: + """Warn if the user passed calibration-related CLI options that are ignored.""" + cli_utils.warn_ignored_calibration_options(ctx, reason, console=console) + + @click.command() @click.option( "--model", @@ -49,8 +54,10 @@ @cli_utils.output_option("Output path (default: {input}_qdq.onnx)") @cli_utils.precision_option( default=None, - help_text="Quantization precision: auto, int8, int16, or w{x}a{y} where " - "x,y in {8,16} (e.g., w8a8, w8a16, w16a16)", + help_text="Quantization precision: auto, fp16, int4, int8, int16, or w{x}a{y} where " + "x in {4,8,16}, y in {8,16} (e.g., w4a16, w8a8, w8a16). " + "int4/w4a16 uses RTN weight-only quantization; " + "fp16 converts all FP32 tensors to FP16 (no QDQ)", optional_message="Overridden by explicit --weight-type/--activation-type", ) @click.option( @@ -122,11 +129,11 @@ def quantize( quiet: bool, config_file: Path | None, ) -> None: - r"""Quantize ONNX model by inserting QDQ nodes. + r"""Quantize ONNX model by inserting QDQ nodes, RTN weight-only, or convert to FP16. - This command applies static quantization to an ONNX model using calibration - data to determine quantization parameters. The output model contains - QuantizeLinear and DequantizeLinear nodes for quantization-aware inference. + This command applies quantization to an ONNX model. The algorithm is + auto-selected from the precision: int4/w4a16 → RTN weight-only, + int8/int16/w8a8 → static QDQ, fp16 → FP16 conversion. \b Examples: @@ -136,9 +143,15 @@ def quantize( # Use precision shorthand (same as --weight-type uint8 --activation-type uint8) winml quantize -m model.onnx --precision int8 + # RTN 4-bit weight-only quantization (no calibration data needed) + winml quantize -m model.onnx --precision int4 + # Int16 quantization winml quantize -m model.onnx --precision int16 + # Convert model to FP16 (no QDQ, full-model conversion) + winml quantize -m model.onnx --precision fp16 + # Custom output path and more samples winml quantize -m model.onnx -o quantized.onnx --samples 100 @@ -174,69 +187,99 @@ def quantize( # Import quantizer (late import to speed up CLI) from ..quant import WinMLQuantizationConfig, quantize_onnx - # Resolve weight/activation types from --precision or explicit flags - resolved_weight, resolved_activation = _resolve_quant_types( - precision, weight_type, activation_type - ) + # ── Build config based on precision ────────────────────────── + precision_lower = precision.lower() if precision else None - # Determine output path - if output is None: - output = model.parent / f"{model.stem}_qdq.onnx" - output.parent.mkdir(parents=True, exist_ok=True) + if precision_lower == "fp16": + # FP16 conversion + _warn_ignored_calibration_options(ctx, "FP16 conversion does not use calibration data.") + if output is None: + output = model.parent / f"{model.stem}_fp16.onnx" + config = WinMLQuantizationConfig(mode="fp16") + label = "FP16 conversion" + + elif precision_lower and _is_weight_only(precision_lower): + # RTN weight-only + from ..config.precision import extract_weight_bits - # Show info + _warn_ignored_calibration_options( + ctx, "RTN weight-only quantization does not use calibration data." + ) + rtn_bits = extract_weight_bits(precision_lower) + if output is None: + output = model.parent / f"{model.stem}_int{rtn_bits}.onnx" + config = WinMLQuantizationConfig(mode="rtn", rtn_bits=rtn_bits) + label = f"RTN {rtn_bits}-bit quantization" + + else: + # QDQ calibrated quantization + resolved_weight, resolved_activation = _resolve_quant_types( + precision, weight_type, activation_type + ) + if output is None: + output = model.parent / f"{model.stem}_qdq.onnx" + config = WinMLQuantizationConfig( + samples=samples, + calibration_method=cast('Literal["minmax", "entropy", "percentile"]', method), + weight_type=cast('Literal["uint8", "int8", "uint16", "int16"]', resolved_weight), + activation_type=cast( + 'Literal["uint8", "int8", "uint16", "int16"]', resolved_activation + ), + per_channel=per_channel, + symmetric=symmetric, + task=task, + model_name=model_name, + ) + label = "Quantization" + + # Display QDQ-specific info + console.print(f"[bold blue]Weight type:[/bold blue] {resolved_weight}") + console.print(f"[bold blue]Activation type:[/bold blue] {resolved_activation}") + console.print(f"[bold blue]Samples:[/bold blue] {samples}") + console.print(f"[bold blue]Method:[/bold blue] {method}") + if config.dataset_name: + _dataset_display = config.dataset_name + elif config.task and config.task != "random": + _dataset_display = f"Default for task '{config.task}'" + else: + _dataset_display = "Random data (synthetic from ONNX I/O specs)" + console.print(f"[bold blue]Dataset:[/bold blue] {_dataset_display}") + + # ── Shared execution: print header, run, report ────────────── + output.parent.mkdir(parents=True, exist_ok=True) console.print(f"[bold blue]Input:[/bold blue] {model}") console.print(f"[bold blue]Output:[/bold blue] {output}") console.print(f"[bold blue]Precision:[/bold blue] {precision or 'auto'}") - console.print(f"[bold blue]Weight type:[/bold blue] {resolved_weight}") - console.print(f"[bold blue]Activation type:[/bold blue] {resolved_activation}") - console.print(f"[bold blue]Samples:[/bold blue] {samples}") - console.print(f"[bold blue]Method:[/bold blue] {method}") - - # Create config (output_path is passed separately to API). - # Click's Choice validates these strings at parse time, so cast acknowledges - # the Literal[] contract that mypy can't see through the str return type. - config = WinMLQuantizationConfig( - samples=samples, - calibration_method=cast('Literal["minmax", "entropy", "percentile"]', method), - weight_type=cast('Literal["uint8", "int8", "uint16", "int16"]', resolved_weight), - activation_type=cast('Literal["uint8", "int8", "uint16", "int16"]', resolved_activation), - per_channel=per_channel, - symmetric=symmetric, - task=task, - model_name=model_name, - ) - - # Display dataset info from config - if config.dataset_name: - _dataset_display = config.dataset_name - elif config.task and config.task != "random": - _dataset_display = f"Default for task '{config.task}'" - else: - _dataset_display = "Random data (synthetic from ONNX I/O specs)" - console.print(f"[bold blue]Dataset:[/bold blue] {_dataset_display}") try: - console.print("\n[bold]Running quantization...[/bold]") - result = quantize_onnx(model, output_path=output, config=config) + console.print(f"\n[bold]Running {label.lower()}...[/bold]") + result = quantize_onnx(model, output_path=output, config=config, precision=precision_lower) if result.success: - console.print("\n[bold green]Success![/bold green] Model quantized") + console.print(f"\n[bold green]Success![/bold green] {label} complete") console.print(f"[dim]Output: {result.output_path}[/dim]") - console.print(f"[dim]QDQ nodes inserted: {result.nodes_quantized}[/dim]") + if result.nodes_quantized: + console.print(f"[dim]QDQ nodes inserted: {result.nodes_quantized}[/dim]") console.print(f"[dim]Total time: {result.total_time_seconds:.2f}s[/dim]") else: - console.print("\n[bold red]Quantization failed:[/bold red]") + console.print(f"\n[bold red]{label} failed:[/bold red]") for error in result.errors: console.print(f" {error}") - raise click.ClickException("Quantization failed") + raise click.ClickException(f"{label} failed") except click.ClickException: raise except Exception as e: - console.print(f"\n[bold red]Quantization failed:[/bold red] {e}") - logger.exception("Quantization failed") - raise click.ClickException(f"Quantization failed: {e}") from e + console.print(f"\n[bold red]{label} failed:[/bold red] {e}") + logger.exception("%s failed", label) + raise click.ClickException(f"{label} failed: {e}") from e + + +def _is_weight_only(precision: str) -> bool: + """Check if precision requires weight-only (RTN) quantization.""" + from ..config.precision import is_weight_only_precision + + return is_weight_only_precision(precision) def _resolve_quant_types( @@ -254,6 +297,12 @@ def _resolve_quant_types( """ from ..config import is_quantized_precision, resolve_quant_types + if precision and _is_weight_only(precision.lower()): + # Should not reach here — RTN path returns early above. + raise click.BadParameter( + f"'{precision}' is a weight-only precision (use RTN path).", + param_hint="'-p' / '--precision'", + ) if precision and is_quantized_precision(precision): default_w, default_a = resolve_quant_types(precision) elif precision is None or precision.lower() == "auto": diff --git a/src/winml/modelkit/config/__init__.py b/src/winml/modelkit/config/__init__.py index c4142c875..6f4115035 100644 --- a/src/winml/modelkit/config/__init__.py +++ b/src/winml/modelkit/config/__init__.py @@ -33,7 +33,11 @@ ) from .precision import ( PrecisionPolicy, + expand_precision, + extract_activation_bits, + extract_weight_bits, is_quantized_precision, + is_weight_only_precision, resolve_precision, resolve_quant_types, ) @@ -43,10 +47,14 @@ "PrecisionPolicy", "SubmoduleClassNotFoundError", "WinMLBuildConfig", + "expand_precision", + "extract_activation_bits", + "extract_weight_bits", "generate_build_config", "generate_hf_build_config", "generate_onnx_build_config", "is_quantized_precision", + "is_weight_only_precision", "merge_config", "resolve_precision", "resolve_quant_compile_config", diff --git a/src/winml/modelkit/config/build.py b/src/winml/modelkit/config/build.py index e00a4921c..6ca550d15 100644 --- a/src/winml/modelkit/config/build.py +++ b/src/winml/modelkit/config/build.py @@ -136,6 +136,10 @@ class WinMLBuildConfig: compile: WinMLCompileConfig | None = field(default_factory=WinMLCompileConfig) eval: WinMLEvaluationConfig | None = None auto: bool = True + # Original precision string (e.g., "w4a16", "int4", "fp16") used to derive + # the quantization pass sequence via expand_precision(). Set during config + # resolution; None means legacy config without precision info. + precision: str | None = None def __post_init__(self) -> None: # Lazy import: inject into module globals so typing.get_type_hints() @@ -169,6 +173,7 @@ def from_dict(cls, config_dict: dict) -> WinMLBuildConfig: ), eval=eval_cfg, auto=config_dict.get("auto", True), + precision=config_dict.get("precision"), ) def to_dict(self) -> dict: @@ -176,6 +181,8 @@ def to_dict(self) -> dict: result: dict = {} if not self.auto: result["auto"] = False + if self.precision is not None: + result["precision"] = self.precision result.update( { "export": self.export.to_dict() if self.export is not None else None, @@ -225,9 +232,11 @@ def validate(self) -> None: # Exceptions: ONNX builds (export=None) don't need quant.task/model_name # because the ONNX model is pre-exported. Submodule builds (module_path # set) use RandomDataset which only needs the ONNX model_path. + # Algorithms that skip calibration (fp16, rtn, dynamic) also don't + # need task/model_name since they don't generate calibration datasets. if self.quant is not None: - is_submodule = bool(self.loader and self.loader.module_path) - needs_quant_ids = not is_onnx_build and not is_submodule + needs_calibration = self.quant.mode == "static" + needs_quant_ids = not is_onnx_build and not is_submodule and needs_calibration if needs_quant_ids and not self.quant.task: errors.append("quant.task is required when quant is enabled for HF builds") if needs_quant_ids and not self.quant.model_name: @@ -327,7 +336,11 @@ def resolve_quant_compile_config( policy does not require that stage (e.g., CPU with fp32). """ from ..sysinfo import resolve_check_device_ep - from .precision import resolve_precision + from .precision import ( + extract_weight_bits, + is_weight_only_precision, + resolve_precision, + ) resolved_device, available_devices, resolved_eps = resolve_check_device_ep(device=device, ep=ep) logger.info( @@ -353,6 +366,15 @@ def resolve_quant_compile_config( quant_config = WinMLQuantizationConfig() quant_config.weight_type = policy.weight_type quant_config.activation_type = policy.activation_type + elif policy.precision == "fp16": + # Pure FP16: no QDQ quantization, only FP16 conversion + quant_config = WinMLQuantizationConfig(mode="fp16") + elif is_weight_only_precision(policy.precision): + # Weight-only (RTN): derive rtn_bits from precision + quant_config = WinMLQuantizationConfig( + mode="rtn", + rtn_bits=extract_weight_bits(policy.precision), + ) # Compile config compile_config = WinMLCompileConfig.for_provider(policy.compile_provider, device=policy.device) @@ -660,7 +682,11 @@ class name. Uses torchinfo to discover submodules and infer # STEP 4.5: Apply device/precision policy (affects quant + compile only) # ========================================================================= from ..sysinfo import resolve_check_device_ep - from .precision import resolve_precision + from .precision import ( + extract_weight_bits, + is_weight_only_precision, + resolve_precision, + ) # ALWAYS detect hardware — even when device="auto" — so we don't # blindly default to QNN on machines without an NPU (#412). @@ -687,10 +713,22 @@ class name. Uses torchinfo to discover submodules and infer parent_config.quant = WinMLQuantizationConfig() parent_config.quant.weight_type = policy.weight_type parent_config.quant.activation_type = policy.activation_type + elif policy.precision == "fp16": + # Pure FP16: no QDQ, only FP16 conversion via quantize stage + parent_config.quant = WinMLQuantizationConfig(mode="fp16") + elif policy.precision and is_weight_only_precision(policy.precision): + # RTN weight-only quantization (e.g. int4, w4a16, w4a32) + parent_config.quant = WinMLQuantizationConfig( + mode="rtn", + rtn_bits=extract_weight_bits(policy.precision), + ) else: - # CPU/GPU: precision is float (fp16/fp32) — no quantization + # CPU/GPU: precision is float (fp32) — no quantization parent_config.quant = None + # Store resolved precision for multi-pass expansion + parent_config.precision = policy.precision + # Compile config parent_config.compile = WinMLCompileConfig.for_provider( policy.compile_provider, diff --git a/src/winml/modelkit/config/precision.py b/src/winml/modelkit/config/precision.py index 2558f646b..d686ff463 100644 --- a/src/winml/modelkit/config/precision.py +++ b/src/winml/modelkit/config/precision.py @@ -71,11 +71,18 @@ _VALID_DEVICES = frozenset({"npu", "gpu", "cpu"}) # Named precision presets (non-mixed) -_NAMED_PRECISIONS = frozenset({"auto", "fp32", "fp16", "int8", "int16"}) +_NAMED_PRECISIONS = frozenset({"auto", "fp32", "fp16", "int4", "int8", "int16"}) # Regex for mixed precision: w{weight_bits}a{activation_bits} _MIXED_RE = re.compile(r"^w(\d+)a(\d+)$") +# Valid bit widths for w{x}a{y} validation. +# Weight supports 4-bit (RTN weight-only) plus 8/16-bit (QDQ). +# Activation supports 8/16-bit for QDQ, plus 32-bit (meaning "keep FP32, no +# activation quantization") which is only valid with weight-only (4-bit RTN). +_VALID_WEIGHT_BITS = frozenset({4, 8, 16}) +_VALID_ACTIVATION_BITS = frozenset({8, 16, 32}) + def resolve_quant_types(precision: str) -> tuple[QuantType, QuantType]: """Resolve a precision string to (weight_type, activation_type). @@ -94,6 +101,14 @@ def resolve_quant_types(precision: str) -> tuple[QuantType, QuantType]: """ p = precision.lower() + # Weight-only precisions use RTN, not QDQ — caller should use + # is_weight_only_precision() to detect these before calling here. + if is_weight_only_precision(p): + raise ValueError( + f"Precision '{precision}' is weight-only (RTN) — no QDQ quant types. " + "Use is_weight_only_precision() to detect and create RTN config instead." + ) + # Named preset if p in _WEIGHT_TYPE: w, a = _WEIGHT_TYPE[p], _ACTIVATION_TYPE[p] @@ -126,19 +141,24 @@ def resolve_quant_types(precision: str) -> tuple[QuantType, QuantType]: def is_quantized_precision(precision: str) -> bool: """Return True if precision implies quantization (not float). - Only returns True for *supported* precisions — unknown w{x}a{y} bit - widths (e.g., w4a16) return False rather than claiming to be quantized. + Includes both QDQ precisions (int8, int16, w8a8) and weight-only + precisions (int4, w4a16, w4a32) that use RTN. """ p = precision.lower() if p in ("fp16", "fp32", "auto"): return False + if p == "int4": + return True if p in _WEIGHT_TYPE: return _WEIGHT_TYPE[p] is not None m = _MIXED_RE.match(p) if not m: return False w_bits, a_bits = int(m.group(1)), int(m.group(2)) - return w_bits in _BITS_TO_WEIGHT_TYPE and a_bits in _BITS_TO_ACTIVATION_TYPE + if w_bits not in _VALID_WEIGHT_BITS or a_bits not in _VALID_ACTIVATION_BITS: + return False + # a_bits=32 (keep FP32) only valid with weight-only (4-bit) RTN + return not (a_bits == 32 and w_bits in _BITS_TO_WEIGHT_TYPE) def _is_valid_precision(precision: str) -> bool: @@ -149,7 +169,134 @@ def _is_valid_precision(precision: str) -> bool: if not m: return False w_bits, a_bits = int(m.group(1)), int(m.group(2)) - return w_bits in _BITS_TO_WEIGHT_TYPE and a_bits in _BITS_TO_ACTIVATION_TYPE + if w_bits not in _VALID_WEIGHT_BITS or a_bits not in _VALID_ACTIVATION_BITS: + return False + # a_bits=32 (keep FP32) only valid with weight-only (4-bit) RTN + return not (a_bits == 32 and w_bits in _BITS_TO_WEIGHT_TYPE) + + +def is_weight_only_precision(precision: str) -> bool: + """Return True if precision implies weight-only quantization (RTN). + + Weight-only precisions use the RTN (Round-To-Nearest) algorithm with + MatMulNBits ops instead of QDQ (QuantizeLinear/DequantizeLinear). + + Rules: + - ``int4`` → weight-only 4-bit RTN (equivalent to ``w4a32``) + - ``w4a32`` → weight 4-bit RTN, activation stays FP32 + - ``w4a16`` → weight 4-bit RTN + FP16 post-processing on activations + - ``w4a8`` → weight 4-bit RTN + 8-bit activation (reserved) + - All other precisions → False (use QDQ or FP16) + + Only returns True for valid precisions — ``w4a4`` returns False because + 4-bit activation is not supported. + """ + p = precision.lower() + if p == "int4": + return True + m = _MIXED_RE.match(p) + if not m: + return False + w_bits, a_bits = int(m.group(1)), int(m.group(2)) + # Must be a valid precision AND have weight bits that are not QDQ-supported + return ( + w_bits not in _BITS_TO_WEIGHT_TYPE + and w_bits in _VALID_WEIGHT_BITS + and a_bits in _VALID_ACTIVATION_BITS + ) + + +def extract_weight_bits(precision: str) -> int: + """Extract weight bit-width from a precision string. + + Used to derive ``rtn_bits`` from the precision (e.g., ``int4`` → 4). + Validates the precision format before extracting. + + Args: + precision: A valid precision string (e.g., ``int4``, ``w4a16``, ``int8``). + + Returns: + Weight bit-width as integer. + + Raises: + ValueError: If precision is invalid or bit-width cannot be extracted. + """ + p = precision.lower() + preset_bits = {"int4": 4, "int8": 8, "int16": 16} + if p in preset_bits: + return preset_bits[p] + m = _MIXED_RE.match(p) + if m: + w_bits, a_bits = int(m.group(1)), int(m.group(2)) + if w_bits not in _VALID_WEIGHT_BITS or a_bits not in _VALID_ACTIVATION_BITS: + raise ValueError( + f"'{precision}' has unsupported bit-widths (weight={w_bits}, activation={a_bits})" + ) + # a_bits=32 only valid with weight-only (4-bit) — reject w8a32, w16a32 + if a_bits == 32 and w_bits in _BITS_TO_WEIGHT_TYPE: + raise ValueError( + f"'{precision}' is invalid: a32 (keep FP32) is only valid with " + "weight-only precisions (4-bit RTN)" + ) + return w_bits + raise ValueError(f"Cannot extract weight bits from '{precision}'") + + +def extract_activation_bits(precision: str) -> int: + """Extract activation bit-width from a precision string. + + For named presets: ``int4`` → 32 (activation stays FP32). + For mixed format: ``w4a16`` → 16, ``w4a32`` → 32. + + Args: + precision: A valid precision string. + + Returns: + Activation bit-width as integer (8, 16, or 32). + + Raises: + ValueError: If activation bits cannot be extracted. + """ + p = precision.lower() + # Named presets: int4 means activation stays FP32 + if p == "int4": + return 32 + m = _MIXED_RE.match(p) + if m: + a_bits = int(m.group(2)) + if a_bits not in _VALID_ACTIVATION_BITS: + raise ValueError(f"'{precision}' has unsupported activation bit-width: {a_bits}") + return a_bits + raise ValueError(f"Cannot extract activation bits from '{precision}'") + + +def expand_precision(precision: str) -> list[str]: + """Expand a composite precision into an ordered list of single-operation passes. + + Only weight-only precisions with FP16 activation (w4a16) expand into + multiple passes. QDQ precisions like w8a16 are a single QDQ operation + (activation=uint16), NOT "int8 then FP16". + + Args: + precision: A precision string (e.g., "w4a16", "int4", "fp16", "int8"). + + Returns: + List of single-pass precision strings in execution order. + + Examples: + >>> expand_precision("w4a16") + ['int4', 'fp16'] + >>> expand_precision("int4") + ['int4'] + >>> expand_precision("fp16") + ['fp16'] + >>> expand_precision("w8a16") + ['w8a16'] + """ + p = precision.lower() + if p == "w4a16": + return ["int4", "fp16"] + return [p] @dataclass @@ -274,8 +421,13 @@ def resolve_precision( if compile_provider == "CPUExecutionProvider": compile_provider = None - # Resolve weight/activation types — supports named presets and w{x}a{y} - if is_quantized_precision(resolved_precision): + # Resolve weight/activation types — supports named presets and w{x}a{y}. + # Weight-only precisions (int4, w4a16) use RTN, not QDQ — they have no + # traditional weight_type/activation_type. The caller (resolve_quant_compile_config) + # inspects PrecisionPolicy.precision to create RTN config. + if is_weight_only_precision(resolved_precision): + weight_type, activation_type = None, None + elif is_quantized_precision(resolved_precision): weight_type, activation_type = resolve_quant_types(resolved_precision) else: weight_type, activation_type = None, None diff --git a/src/winml/modelkit/quant/config.py b/src/winml/modelkit/quant/config.py index b9709cc0e..c42c06934 100644 --- a/src/winml/modelkit/quant/config.py +++ b/src/winml/modelkit/quant/config.py @@ -47,11 +47,20 @@ class WinMLQuantizationConfig: # Custom config config = WinMLQuantizationConfig(samples=100, weight_type="int8") result = quantize_onnx("model.onnx", config) + + # FP16 conversion (pure FP16, no quantization) + config = WinMLQuantizationConfig(mode="fp16") + result = quantize_onnx("model.onnx", config) """ - mode: Literal["qdq", "static", "dynamic"] = "qdq" + # Quantization mode + mode: Literal["static", "dynamic", "rtn", "fp16"] = "static" + # "static" — Calibrated QDQ quantization (requires calibration data) + # "dynamic" — Dynamic quantization (no calibration) [planned, not yet wired] + # "rtn" — Round-To-Nearest weight-only (no calibration, block-wise) + # "fp16" — Pure FP16 conversion only (no quantization) - # Calibration settings + # Calibration settings (static/dynamic) samples: int = 10 calibration_method: Literal["minmax", "entropy", "percentile"] = "minmax" calibration_data: CalibrationDataReader | None = None # None = random data @@ -61,11 +70,11 @@ class WinMLQuantizationConfig: model_name: str | None = None # e.g., "microsoft/resnet-50" dataset_name: str | None = None # Optional: override default dataset - # Quantization types + # Quantization types (static/dynamic) weight_type: Literal["uint8", "int8", "uint16", "int16"] = "uint8" activation_type: Literal["uint8", "int8", "uint16", "int16"] = "uint8" - # Quantization options + # Quantization options (static/dynamic) per_channel: bool = False symmetric: bool = False @@ -78,10 +87,20 @@ class WinMLQuantizationConfig: calibration_load_path: Path | None = None calibration_save_path: Path | None = None - # Advanced + # Advanced (static/dynamic) op_types_to_quantize: list[str] | None = None nodes_to_exclude: list[str] | None = None + # RTN-specific settings (only used when mode="rtn") + rtn_bits: int = 4 + rtn_block_size: int = 128 + rtn_symmetric: bool = True + rtn_accuracy_level: int = 0 + + # FP16 conversion settings (only used when mode="fp16") + fp16_keep_io_types: bool = True + fp16_op_block_list: list[str] | None = None + def to_dict(self) -> dict: """Convert to dictionary for serialization. @@ -116,6 +135,14 @@ def to_dict(self) -> dict: result["model_name"] = self.model_name if self.dataset_name is not None: result["dataset_name"] = self.dataset_name + if self.mode == "rtn": + result["rtn_bits"] = self.rtn_bits + result["rtn_block_size"] = self.rtn_block_size + result["rtn_symmetric"] = self.rtn_symmetric + result["rtn_accuracy_level"] = self.rtn_accuracy_level + if self.mode == "fp16": + result["fp16_keep_io_types"] = self.fp16_keep_io_types + result["fp16_op_block_list"] = self.fp16_op_block_list return result @classmethod @@ -128,8 +155,15 @@ def from_dict(cls, data: dict) -> WinMLQuantizationConfig: Returns: WinMLQuantizationConfig instance. """ + # Backward compat: prefer "algorithm" (authoritative in old configs) + # over deprecated "mode" (which defaulted to "qdq"). + # Map legacy "qdq" value to "static". + raw_mode = data["algorithm"] if "algorithm" in data else data.get("mode", "static") + if raw_mode == "qdq": + raw_mode = "static" + return cls( - mode=data.get("mode", "qdq"), + mode=raw_mode, samples=data.get("samples", data.get("calibration_samples", 10)), calibration_method=data.get("calibration_method", "minmax"), task=data.get("task"), @@ -150,6 +184,12 @@ def from_dict(cls, data: dict) -> WinMLQuantizationConfig: ), op_types_to_quantize=data.get("op_types_to_quantize"), nodes_to_exclude=data.get("nodes_to_exclude"), + rtn_bits=data.get("rtn_bits", 4), + rtn_block_size=data.get("rtn_block_size", 128), + rtn_symmetric=data.get("rtn_symmetric", True), + rtn_accuracy_level=data.get("rtn_accuracy_level", 0), + fp16_keep_io_types=data.get("fp16_keep_io_types", True), + fp16_op_block_list=data.get("fp16_op_block_list"), ) diff --git a/src/winml/modelkit/quant/fp16.py b/src/winml/modelkit/quant/fp16.py new file mode 100644 index 000000000..b3a6fe3a3 --- /dev/null +++ b/src/winml/modelkit/quant/fp16.py @@ -0,0 +1,87 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""FP16 conversion utility for ONNX models. + +Provides a single entry point for FP32→FP16 model conversion, used by +the quantizer's ``mode="fp16"`` path. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from onnx import ModelProto + +logger = logging.getLogger(__name__) + + +def convert_to_fp16( + model: ModelProto, + *, + keep_io_types: bool = True, + op_block_list: list[str] | None = None, +) -> ModelProto: + """Convert an ONNX model from FP32 to FP16 precision. + + Uses onnxruntime.transformers.float16.convert_float_to_float16 internally. + No new dependencies — ORT is already a project dependency. + + Note: ORT's converter mutates the model in-place and returns the same object. + + Args: + model: Input ONNX ModelProto (will be mutated in-place by ORT). + keep_io_types: If True, preserve FP32 model inputs/outputs by inserting + Cast nodes at boundaries. Recommended for CPU-safe inference. + op_block_list: Op types to keep in FP32 (e.g., ["LayerNorm", "Softmax"]). + When None, ORT uses its DEFAULT_OP_BLOCK_LIST which includes ops + known to be numerically unsafe in FP16 (e.g., TopK, CumSum, etc.). + + Returns: + The converted model (same object as input due to ORT in-place mutation). + """ + from onnx import TensorProto + from onnxruntime.transformers.float16 import convert_float_to_float16 + + # Skip if model is already FP16 (check floating-point initializer dtypes) + fp32_types = {TensorProto.FLOAT, TensorProto.DOUBLE, TensorProto.BFLOAT16} + initializers = model.graph.initializer + if initializers: + float_inits = [t for t in initializers if t.data_type in fp32_types | {TensorProto.FLOAT16}] + if float_inits and all(t.data_type == TensorProto.FLOAT16 for t in float_inits): + logger.info("Model is already FP16 — skipping conversion.") + return model + + original_nodes = len(model.graph.node) + + logger.info("Converting model to FP16...") + if keep_io_types: + logger.info(" Keeping I/O types as FP32") + if op_block_list: + logger.info(" Keeping ops in FP32: %s", op_block_list) + + converted: ModelProto = convert_float_to_float16( + model, + keep_io_types=keep_io_types, + op_block_list=op_block_list, + ) + + # ORT's converter appends Cast nodes at the end of the node list (for + # keep_io_types), which breaks topological ordering. Re-sort the graph + # using ORT's own topological sort utility. + if keep_io_types: + from onnxruntime.transformers.onnx_model import OnnxModel + + OnnxModel.graph_topological_sort(converted.graph) + + converted_nodes = len(converted.graph.node) + if converted_nodes != original_nodes: + logger.info("FP16 conversion complete: %d -> %d nodes", original_nodes, converted_nodes) + else: + logger.info("FP16 conversion complete: %d nodes", converted_nodes) + + return converted diff --git a/src/winml/modelkit/quant/quantizer.py b/src/winml/modelkit/quant/quantizer.py index c562599de..b21296920 100644 --- a/src/winml/modelkit/quant/quantizer.py +++ b/src/winml/modelkit/quant/quantizer.py @@ -10,11 +10,15 @@ import os import time from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from .config import QuantizeResult, WinMLQuantizationConfig +if TYPE_CHECKING: + from collections.abc import Callable + + logger = logging.getLogger(__name__) @@ -22,28 +26,331 @@ def quantize_onnx( model_path: str | Path, output_path: str | Path | None = None, config: WinMLQuantizationConfig | None = None, + *, + precision: str | None = None, **kwargs: Any, ) -> QuantizeResult: - """Quantize ONNX model by inserting QDQ nodes. + """Quantize ONNX model with optional multi-pass precision support. + + When ``precision`` is provided (e.g., "w4a16"), the function internally + expands it into sequential passes (e.g., ["int4", "fp16"]) and runs each + in order, managing intermediate files automatically. The caller only + needs to invoke this function once. + + When ``precision`` is None, falls back to single-pass execution based on + ``config.mode``. Args: - model_path: Path to input float32 ONNX model - output_path: Path for output quantized model (defaults to {model_stem}_qdq.onnx) - config: Quantization configuration (uses defaults if None) + model_path: Path to input ONNX model. + output_path: Path for output model (defaults to {model_stem}_qdq.onnx). + config: Quantization configuration (uses defaults if None). + precision: Optional precision string (e.g., "fp16", "int4", "w4a16"). + When set, overrides config.mode routing with multi-pass + expansion logic. Returns: - QuantizeResult with path to quantized model and metrics + QuantizeResult with path to final output model and aggregated metrics. Examples: - # Quick quantize with defaults (10 samples, uint8) - result = quantize_onnx("model.onnx") + # Single-pass RTN int4 + result = quantize_onnx("model.onnx", precision="int4", config=rtn_config) + + # Multi-pass w4a16 (int4 + fp16) + result = quantize_onnx("model.onnx", precision="w4a16", config=rtn_config) - # Quantize with explicit output path - result = quantize_onnx("model.onnx", "model_quantized.onnx") + # Single-pass FP16 only + result = quantize_onnx("model.onnx", precision="fp16") - # Quantize with custom config + # Legacy: use config.mode directly result = quantize_onnx("model.onnx", config=WinMLQuantizationConfig(samples=100)) """ + from ..config.precision import expand_precision + + model_path = Path(model_path) + config = config or WinMLQuantizationConfig() + + if output_path is not None: + output_path = Path(output_path) + else: + output_path = model_path.parent / f"{model_path.stem}_qdq.onnx" + + # If precision is provided, expand and run multi-pass + if precision is not None: + passes = expand_precision(precision) + return _run_multi_pass( + model_path=model_path, + output_path=output_path, + config=config, + passes=passes, + **kwargs, + ) + + # Single-pass: delegate to internal implementation + return _quantize_single_pass( + model_path=model_path, + output_path=output_path, + config=config, + **kwargs, + ) + + +def _run_multi_pass( + *, + model_path: Path, + output_path: Path, + config: WinMLQuantizationConfig, + passes: list[str], + **kwargs: Any, +) -> QuantizeResult: + """Run a sequence of quantization passes, managing intermediate files. + + Each pass produces a QuantizeResult; the final result aggregates timing + from all passes. + """ + current_path = model_path + total_time = 0.0 + all_warnings: list[str] = [] + intermediates: list[Path] = [] + + for step_idx, step_prec in enumerate(passes): + # Determine output for this step + if step_idx == len(passes) - 1: + step_output = output_path + else: + step_output = output_path.parent / ( + f"{output_path.stem}_pass{step_idx}{output_path.suffix}" + ) + intermediates.append(step_output) + + # Build config for this step + step_config = _make_pass_config(step_prec, config) + + result = _quantize_single_pass( + model_path=current_path, + output_path=step_output, + config=step_config, + **kwargs, + ) + + if not result.success: + # Clean up intermediates on failure + _cleanup_intermediates(intermediates) + return result + + total_time += result.total_time_seconds + all_warnings.extend(result.warnings) + current_path = step_output + + # Clean up intermediate files + _cleanup_intermediates(intermediates) + + return QuantizeResult( + success=True, + output_path=output_path, + total_time_seconds=total_time, + nodes_quantized=result.nodes_quantized, + errors=[], + warnings=all_warnings, + ) + + +def _make_pass_config( + step_prec: str, base_config: WinMLQuantizationConfig +) -> WinMLQuantizationConfig: + """Build a config for a single pass based on precision string.""" + if step_prec == "fp16": + return WinMLQuantizationConfig( + mode="fp16", + fp16_keep_io_types=base_config.fp16_keep_io_types, + fp16_op_block_list=base_config.fp16_op_block_list, + ) + # int4, w4a32, or any RTN/QDQ pass — use base config as-is + return base_config + + +def _cleanup_intermediates(intermediates: list[Path]) -> None: + """Remove intermediate pass files and their external data sidecars.""" + for path in intermediates: + if path.exists(): + path.unlink() + ext_data = path.parent / f"{path.name}.data" + if ext_data.exists(): + ext_data.unlink() + + +def _quantize_single_pass( + *, + model_path: Path, + output_path: Path, + config: WinMLQuantizationConfig, + **kwargs: Any, +) -> QuantizeResult: + """Run a single quantization pass (FP16, RTN, or QDQ). + + This is the internal workhorse — callers should use ``quantize_onnx()`` + which handles multi-pass expansion and path resolution. + """ + use_external_data: bool = kwargs.pop("use_external_data", True) + + start_time = time.perf_counter() + + # Validate input + if not model_path.exists(): + return QuantizeResult( + success=False, + output_path=None, + errors=[f"Model not found: {model_path}"], + ) + + errors: list[str] = [] + warnings: list[str] = [] + + try: + # Dispatch to the appropriate single-mode handler + _mode_handlers: dict[str, Callable[..., QuantizeResult]] = { + "fp16": _quantize_fp16, + "rtn": _quantize_rtn, + } + handler = _mode_handlers.get(config.mode, _quantize_qdq) + return handler( + model_path=model_path, + output_path=output_path, + config=config, + start_time=start_time, + use_external_data=use_external_data, + errors=errors, + warnings=warnings, + ) + + except Exception: + total_time = time.perf_counter() - start_time + logger.exception("Quantization failed") + + import traceback + + return QuantizeResult( + success=False, + output_path=None, + total_time_seconds=total_time, + errors=[traceback.format_exc()], + warnings=warnings, + ) + + +def _quantize_fp16( + *, + model_path: Path, + output_path: Path, + config: WinMLQuantizationConfig, + start_time: float, + use_external_data: bool, + errors: list[str], + warnings: list[str], +) -> QuantizeResult: + """Run FP16 conversion (no quantization).""" + from ..onnx import load_onnx, save_onnx + from .fp16 import convert_to_fp16 + + if config.calibration_data is not None: + logger.warning( + "calibration_data is set but mode='fp16' — calibration data will be ignored." + ) + + logger.info("Running FP16-only conversion (no quantization)...") + model = load_onnx(model_path, validate=False) + model = convert_to_fp16( + model, + keep_io_types=config.fp16_keep_io_types, + op_block_list=config.fp16_op_block_list, + ) + output_path.parent.mkdir(parents=True, exist_ok=True) + save_onnx(model, output_path, use_external_data=use_external_data) + + total_time = time.perf_counter() - start_time + logger.info( + "FP16 conversion complete: %s -> %s (%.2fs)", + model_path.name, + output_path.name, + total_time, + ) + return QuantizeResult( + success=True, + output_path=output_path, + total_time_seconds=total_time, + errors=errors, + warnings=warnings, + ) + + +def _quantize_rtn( + *, + model_path: Path, + output_path: Path, + config: WinMLQuantizationConfig, + start_time: float, + use_external_data: bool, + errors: list[str], + warnings: list[str], +) -> QuantizeResult: + """Run RTN weight-only quantization.""" + from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer + + from ..onnx import save_onnx + + if config.calibration_data is not None: + logger.warning("calibration_data is set but mode='rtn' — calibration data will be ignored.") + + logger.info( + "Running RTN %d-bit weight-only quantization (block_size=%d, symmetric=%s)...", + config.rtn_bits, + config.rtn_block_size, + config.rtn_symmetric, + ) + + accuracy_level = config.rtn_accuracy_level if config.rtn_accuracy_level != 0 else None + + quantizer = MatMulNBitsQuantizer( + model=str(model_path), + bits=config.rtn_bits, + block_size=config.rtn_block_size, + is_symmetric=config.rtn_symmetric, + accuracy_level=accuracy_level, + nodes_to_exclude=config.nodes_to_exclude, + ) + quantizer.process() + + output_path.parent.mkdir(parents=True, exist_ok=True) + quantized_model = quantizer.model.model + + save_onnx(quantized_model, output_path, use_external_data=use_external_data) + + total_time = time.perf_counter() - start_time + logger.info( + "RTN quantization complete: %s -> %s (%.2fs)", + model_path.name, + output_path.name, + total_time, + ) + return QuantizeResult( + success=True, + output_path=output_path, + total_time_seconds=total_time, + errors=errors, + warnings=warnings, + ) + + +def _quantize_qdq( + *, + model_path: Path, + output_path: Path, + config: WinMLQuantizationConfig, + start_time: float, + use_external_data: bool, + errors: list[str], + warnings: list[str], +) -> QuantizeResult: + """Run QDQ (static/dynamic) calibrated quantization.""" from onnxruntime.quantization import ( CalibrationMethod, QuantType, @@ -69,201 +376,130 @@ def quantize_onnx( "percentile": CalibrationMethod.Percentile, } - # TODO: Move to global env config - use_external_data: bool = kwargs.pop("use_external_data", True) - - start_time = time.perf_counter() - model_path = Path(model_path) - config = config or WinMLQuantizationConfig() - - # Validate input - if not model_path.exists(): - return QuantizeResult( - success=False, - output_path=None, - errors=[f"Model not found: {model_path}"], - ) + cal_start = time.perf_counter() - # Determine output path - if output_path is not None: - output_path = Path(output_path) + if config.calibration_data is not None: + data_reader = config.calibration_data + logger.info("Using custom calibration data") else: - output_path = model_path.parent / f"{model_path.stem}_qdq.onnx" - - errors: list[str] = [] - warnings: list[str] = [] - - try: - # Create calibration data reader - cal_start = time.perf_counter() + from ..datasets import DatasetCalibrationReader + + task = config.task or "random" + data_reader = DatasetCalibrationReader( + model_name=config.model_name or "random", + task=task, + max_samples=config.samples, + dataset_name=config.dataset_name, + model_path=model_path, + ) + logger.info( + "Using calibration: task=%s, samples=%d", + task, + config.samples, + ) - if config.calibration_data is not None: - # User provided explicit calibration data - data_reader = config.calibration_data - logger.info("Using custom calibration data") - else: - # Use DatasetCalibrationReader for all cases: - # - task-aware: auto-selects TextDataset, ImageDataset, etc. - # - fallback: unsupported tasks → RandomDataset (reads ONNX metadata) - # - no task: task="random" → RandomDataset directly - from ..datasets import DatasetCalibrationReader - - task = config.task or "random" - data_reader = DatasetCalibrationReader( - model_name=config.model_name or "random", - task=task, - max_samples=config.samples, - dataset_name=config.dataset_name, - model_path=model_path, - ) - logger.info( - "Using calibration: task=%s, samples=%d", - task, - config.samples, - ) + cal_time = time.perf_counter() - cal_start - cal_time = time.perf_counter() - cal_start + qdq_start = time.perf_counter() - # Apply QDQ quantization - qdq_start = time.perf_counter() + weight_type = weight_type_map[config.weight_type] + activation_type = activation_type_map[config.activation_type] + calibrate_method = calibration_method_map[config.calibration_method] - # Map config to ORT types - weight_type = weight_type_map[config.weight_type] - activation_type = activation_type_map[config.activation_type] - calibrate_method = calibration_method_map[config.calibration_method] + extra_options = { + "ActivationSymmetric": config.symmetric, + "WeightSymmetric": config.symmetric, + } - # Build extra options - extra_options = { - "ActivationSymmetric": config.symmetric, - "WeightSymmetric": config.symmetric, - } + logger.info("Generating QDQ config...") + qdq_config = get_qdq_config( + model_input=str(model_path), + calibration_data_reader=data_reader, + weight_type=weight_type, + activation_type=activation_type, + per_channel=config.per_channel, + calibrate_method=calibrate_method, + op_types_to_quantize=config.op_types_to_quantize, + nodes_to_exclude=config.nodes_to_exclude or [], + extra_options=extra_options, + ) - # Step 1: Generate QDQ config - logger.info("Generating QDQ config...") - qdq_config = get_qdq_config( - model_input=str(model_path), - calibration_data_reader=data_reader, - weight_type=weight_type, - activation_type=activation_type, - per_channel=config.per_channel, - calibrate_method=calibrate_method, - op_types_to_quantize=config.op_types_to_quantize, - nodes_to_exclude=config.nodes_to_exclude or [], - extra_options=extra_options, + from ..onnx import capture_metadata, load_onnx, restore_metadata, save_onnx + from .qdq_fix import fix_qdq_dtype_info + + pre_quant_model = load_onnx(model_path, load_weights=False, validate=False) + metadata_snapshot = capture_metadata(pre_quant_model) + del pre_quant_model + + if use_external_data: + qdq_config.use_external_data_format = True + logger.info("Applying quantization...") + # Temporarily change CWD to output directory so ORT's save_model_to_file() + # resolves its CWD-relative os.path.exists() check correctly. + abs_model_input = str(Path(model_path).resolve()) + abs_model_output = str(Path(output_path).resolve()) + # Remove stale output artifacts from a previous build + if output_path.exists(): + output_path.unlink() + stale_sidecar = output_path.parent / f"{output_path.name}.data" + if stale_sidecar.exists(): + stale_sidecar.unlink() + original_cwd = Path.cwd() + try: + os.chdir(output_path.parent) + quantize( + model_input=abs_model_input, + model_output=abs_model_output, + quant_config=qdq_config, ) + finally: + os.chdir(original_cwd) - # Step 2: Capture metadata before ORT quantization (it rebuilds the graph) - from ..onnx import capture_metadata, load_onnx, restore_metadata, save_onnx - from .qdq_fix import fix_qdq_dtype_info - - pre_quant_model = load_onnx(model_path, load_weights=False, validate=False) - metadata_snapshot = capture_metadata(pre_quant_model) - del pre_quant_model - - # Step 3: Apply quantization - if use_external_data: - qdq_config.use_external_data_format = True - logger.info("Applying quantization...") - # Temporarily change CWD to the output directory so that ORT's - # save_model_to_file() — which passes a bare filename - # (e.g. "quantized.onnx.data") to onnx.convert_model_to_external_data — - # resolves its CWD-relative os.path.exists() check against the actual - # output directory rather than the process CWD. Without this, a stale - # .onnx.data sidecar in the process CWD from a previous build triggers - # a false-positive FileExistsError even when the output dir is clean. - # Use absolute paths so the chdir does not break relative input/output - # resolution. output_path.parent is guaranteed to exist (caller mkdir). - abs_model_input = str(Path(model_path).resolve()) - abs_model_output = str(Path(output_path).resolve()) - # Remove stale output artifacts from a previous build. ORT/onnx refuse - # to overwrite an existing external-data sidecar (e.g. quantized.onnx.data), - # raising FileExistsError, so we proactively clear them here. - if output_path.exists(): - output_path.unlink() - stale_sidecar = output_path.parent / f"{output_path.name}.data" - if stale_sidecar.exists(): - stale_sidecar.unlink() - original_cwd = Path.cwd() - try: - os.chdir(output_path.parent) - quantize( - model_input=abs_model_input, - model_output=abs_model_output, - quant_config=qdq_config, - ) - finally: - os.chdir(original_cwd) - - qdq_time = time.perf_counter() - qdq_start - - # Post-processing: fix QDQ dtype + shape inference + restore metadata - postproc_start = time.perf_counter() + qdq_time = time.perf_counter() - qdq_start - # Step 4: Load quantized model for post-processing - quantized_model = load_onnx(output_path, validate=False) + # Post-processing: fix QDQ dtype + shape inference + restore metadata + postproc_start = time.perf_counter() - # Step 5: Fix QDQ node dtype info (scale/zero_point may have UNDEFINED types) - logger.info("Fixing QDQ node dtype info...") - fix_result = fix_qdq_dtype_info(quantized_model) - warnings.extend(fix_result.warnings) + quantized_model = load_onnx(output_path, validate=False) - # Step 6: Run shape inference (defensive — propagates shapes through QDQ nodes) - # Uses the shared infer_shapes which tries symbolic first (handles - # com.microsoft ops like QLinearConv) then falls back to ONNX standard. - # Does NOT run graph optimization pipes that could break quantized models. - from ..onnx import infer_shapes + logger.info("Fixing QDQ node dtype info...") + fix_result = fix_qdq_dtype_info(quantized_model) + warnings.extend(fix_result.warnings) - logger.info("Running shape inference on quantized model...") - quantized_model = infer_shapes(quantized_model) + from ..onnx import infer_shapes - # Step 7: Restore metadata lost during ORT quantization - if metadata_snapshot.node_count > 0: - logger.info("Restoring metadata from pre-quantization model...") - restore_metadata(quantized_model, metadata_snapshot) + logger.info("Running shape inference on quantized model...") + quantized_model = infer_shapes(quantized_model) - # Step 8: Save the fixed model back - save_onnx(quantized_model, output_path) + if metadata_snapshot.node_count > 0: + logger.info("Restoring metadata from pre-quantization model...") + restore_metadata(quantized_model, metadata_snapshot) - postproc_time = time.perf_counter() - postproc_start + postproc_time = time.perf_counter() - postproc_start - # Count quantized nodes from in-memory model - from ..compiler import QDQ_OP_TYPES + from ..compiler import QDQ_OP_TYPES - nodes_quantized = sum( - 1 for node in quantized_model.graph.node if node.op_type in QDQ_OP_TYPES - ) - - total_time = time.perf_counter() - start_time - - logger.info( - "Quantization complete: %s -> %s (%.2fs)", - model_path.name, - output_path.name, - total_time, - ) + nodes_quantized = sum(1 for node in quantized_model.graph.node if node.op_type in QDQ_OP_TYPES) - return QuantizeResult( - success=True, - output_path=output_path, - calibration_time_seconds=cal_time, - qdq_insertion_time_seconds=qdq_time, - postproc_time_seconds=postproc_time, - total_time_seconds=total_time, - nodes_quantized=nodes_quantized, - errors=errors, - warnings=warnings, - ) + save_onnx(quantized_model, output_path, use_external_data=use_external_data) - except Exception: - total_time = time.perf_counter() - start_time - logger.exception("Quantization failed") + total_time = time.perf_counter() - start_time - import traceback + logger.info( + "Quantization complete: %s -> %s (%.2fs)", + model_path.name, + output_path.name, + total_time, + ) - return QuantizeResult( - success=False, - output_path=None, - total_time_seconds=total_time, - errors=[traceback.format_exc()], - warnings=warnings, - ) + return QuantizeResult( + success=True, + output_path=output_path, + calibration_time_seconds=cal_time, + qdq_insertion_time_seconds=qdq_time, + postproc_time_seconds=postproc_time, + total_time_seconds=total_time, + nodes_quantized=nodes_quantized, + errors=errors, + warnings=warnings, + ) diff --git a/src/winml/modelkit/utils/cli.py b/src/winml/modelkit/utils/cli.py index 185751fb3..15eb96570 100644 --- a/src/winml/modelkit/utils/cli.py +++ b/src/winml/modelkit/utils/cli.py @@ -63,6 +63,35 @@ def warn_trust_remote_code() -> None: ) +def warn_ignored_calibration_options( + ctx: click.Context, reason: str, *, console: Console | None = None +) -> None: + """Warn if the user passed calibration-related CLI options that are ignored. + + Checks whether ``--samples``, ``--method``, ``--weight-type``, or + ``--activation-type`` were explicitly provided on the command line and + emits a yellow warning listing the ignored options. + + Args: + ctx: Click context (used to detect explicitly-provided params). + reason: Human-readable explanation (e.g., "FP16 does not use + calibration data."). + console: Optional Rich console for output. Defaults to stderr. + """ + ignored = [] + if is_cli_provided(ctx, "samples"): + ignored.append("--samples") + if is_cli_provided(ctx, "method"): + ignored.append("--method") + if is_cli_provided(ctx, "weight_type"): + ignored.append("--weight-type") + if is_cli_provided(ctx, "activation_type"): + ignored.append("--activation-type") + if ignored: + out = console or _stderr_console + out.print(f"[yellow]Warning:[/yellow] {', '.join(ignored)} ignored — {reason}") + + def model_path_option(required: bool = True) -> Callable[[F], F]: """Add --model option that accepts a local ONNX file path. diff --git a/tests/unit/commands/test_build.py b/tests/unit/commands/test_build.py index 00f54fc23..870a9d766 100644 --- a/tests/unit/commands/test_build.py +++ b/tests/unit/commands/test_build.py @@ -589,20 +589,25 @@ def test_precision_flag_sets_quant(self, tmp_path: Path, mock_run_single_build: assert passed.quant is not None assert passed.quant.weight_type == "uint8" - def test_precision_fp16_clears_quant(self, tmp_path: Path, mock_run_single_build: MagicMock): - """``--precision fp16`` skips quantization even on a quantizing device.""" + def test_precision_fp16_sets_fp16_algorithm( + self, tmp_path: Path, mock_run_single_build: MagicMock + ): + """``--precision fp16`` sets an fp16 algorithm quant config.""" cfg = _make_minimal_config_file(tmp_path) result = _invoke( [*self._base_args(cfg, tmp_path), "--device", "npu", "--precision", "fp16"] ) assert result.exit_code == 0, result.output - assert mock_run_single_build.call_args.kwargs["config"].quant is None + quant = mock_run_single_build.call_args.kwargs["config"].quant + assert quant is not None + assert quant.mode == "fp16" def test_precision_alone_triggers_quant_patch( self, tmp_path: Path, mock_run_single_build: MagicMock ): - """``--precision`` without ``--device`` still patches quant (here, clears it).""" - # Config ships an explicit quant section; fp16 must clear it even though + """``--precision`` without ``--device`` still patches quant to the fp16 algorithm.""" + # Config ships an explicit quant section; fp16 must switch it to the + # fp16 algorithm even though # --device was not passed (precision alone triggers the patch path). config = { "loader": {"task": "image-classification"}, @@ -631,7 +636,9 @@ def test_precision_alone_triggers_quant_patch( ] ) assert result.exit_code == 0, result.output - assert mock_run_single_build.call_args.kwargs["config"].quant is None + quant = mock_run_single_build.call_args.kwargs["config"].quant + assert quant is not None + assert quant.mode == "fp16" def test_trust_remote_code_forwarded(self, tmp_path: Path, mock_run_single_build: MagicMock): """``--trust-remote-code`` is forwarded via ``extra_kwargs``.""" diff --git a/tests/unit/commands/test_compile_quantize_flags.py b/tests/unit/commands/test_compile_quantize_flags.py index 3977e1014..b1b22a920 100644 --- a/tests/unit/commands/test_compile_quantize_flags.py +++ b/tests/unit/commands/test_compile_quantize_flags.py @@ -557,7 +557,7 @@ def _invoke(args): @pytest.mark.parametrize( "bad_precision", - ["banana", "w4a16", "int4", "fp64"], + ["banana", "fp64", "w4a4", "w2a8"], ) def test_unknown_precision_rejected(self, tmp_path, bad_precision): model, _ = TestQuantizeCliConfigPrecedence._setup(tmp_path) diff --git a/tests/unit/commands/test_config_cli.py b/tests/unit/commands/test_config_cli.py index 90517015f..1c0051b81 100644 --- a/tests/unit/commands/test_config_cli.py +++ b/tests/unit/commands/test_config_cli.py @@ -466,7 +466,7 @@ def test_invalid_device_rejected(self, bad_device: str) -> None: assert result.exit_code != 0 assert "Traceback (most recent call last)" not in result.output - @pytest.mark.parametrize("bad_precision", ["bf16", "fp64", "int4", "w3a5"]) + @pytest.mark.parametrize("bad_precision", ["bf16", "fp64", "w3a5", "w4a4"]) def test_invalid_precision_rejected(self, bad_precision: str) -> None: """Unknown precision strings must produce a UsageError, not a traceback.""" result = _invoke_config( diff --git a/tests/unit/config/test_build.py b/tests/unit/config/test_build.py index 26a25f698..ce7426ccd 100644 --- a/tests/unit/config/test_build.py +++ b/tests/unit/config/test_build.py @@ -1954,17 +1954,17 @@ def _mock_deps( "device,precision,expect_quant,expect_weight,expect_act,expect_compile_provider", [ ("npu", "auto", True, "uint8", "uint16", "qnn"), - ("npu", "fp16", False, None, None, "qnn"), + ("npu", "fp16", True, "uint8", "uint8", "qnn"), # fp16 algorithm quant config ("npu", "int8", True, "uint8", "uint8", "qnn"), - ("gpu", "auto", False, None, None, None), + ("gpu", "auto", True, None, None, None), # auto on gpu -> fp16 algorithm ("gpu", "int8", True, "uint8", "uint8", None), - ("gpu", "fp16", False, None, None, None), - ("cpu", "auto", False, None, None, None), + ("gpu", "fp16", True, None, None, None), # fp16 algorithm quant config + ("cpu", "auto", True, None, None, None), # auto on cpu -> fp16 algorithm ("cpu", "int8", True, "uint8", "uint8", None), ("cpu", "int16", True, "int16", "uint16", None), - ("cpu", "fp16", False, None, None, None), + ("cpu", "fp16", True, None, None, None), # fp16 algorithm quant config # auto device + explicit precision → picks NPU (mock returns npu first) - ("auto", "fp16", False, None, None, "qnn"), + ("auto", "fp16", True, None, None, "qnn"), # fp16 algorithm quant config ("auto", "int8", True, "uint8", "uint8", "qnn"), ("auto", "int16", True, "int16", "uint16", "qnn"), ], @@ -2021,8 +2021,12 @@ def test_config_gen_device_precision( assert result.quant is not None, ( f"Expected quant config for device={device}, precision={precision}" ) - assert result.quant.weight_type == expect_weight - assert result.quant.activation_type == expect_act + if result.quant.mode != "fp16": + assert result.quant.weight_type == expect_weight + assert result.quant.activation_type == expect_act + else: + # FP16 algorithm: quant stage does FP16 conversion, not QDQ + assert result.quant.mode == "fp16" else: assert result.quant is None, ( f"Expected no quant for device={device}, precision={precision}" @@ -2033,7 +2037,7 @@ def test_config_gen_device_precision( assert result.compile is not None assert result.compile.ep_config.provider == expect_compile_provider # TODO(#241): assert qdq_config alignment with quant policy - # Currently for_qnn() creates qdq_config even for fp16. + # Currently for_qnn() creates qdq_config even for the fp16 algorithm. # Issue #241 will pass quantize= to for_provider(). else: assert result.compile is None @@ -2205,7 +2209,7 @@ def test_device_npu_produces_qnn(self, tmp_path) -> None: assert data["quant"]["activation_type"] == "uint16" def test_device_gpu_precision_fp16(self, tmp_path) -> None: - """--device gpu --precision fp16 → no quant, compile.provider=dml.""" + """--device gpu --precision fp16 → fp16 algorithm quant config, no compile.""" self._patches["device"] = patch( "winml.modelkit.sysinfo.resolve_check_device_ep", return_value=("gpu", ["gpu", "cpu"], ["DmlExecutionProvider"]), @@ -2217,7 +2221,8 @@ def test_device_gpu_precision_fp16(self, tmp_path) -> None: assert result.exit_code == 0, f"CLI failed: {result.output}" data = json.loads(output_file.read_text()) - assert data["quant"] is None + assert data["quant"] is not None + assert data["quant"]["mode"] == "fp16" assert data["compile"] is None def test_device_cpu_precision_fp32(self, tmp_path) -> None: @@ -2397,7 +2402,7 @@ def test_raw_onnx_full_pipeline(self, tmp_path) -> None: assert config.compile.ep_config.provider == "qnn" def test_raw_onnx_cpu(self, tmp_path) -> None: - """Raw ONNX + device=cpu resolves quant=None and compile=None.""" + """Raw ONNX + device=cpu resolves to an fp16 algorithm quant config, compile=None.""" onnx_file = tmp_path / "model.onnx" onnx_file.write_bytes(b"fake") @@ -2412,7 +2417,8 @@ def test_raw_onnx_cpu(self, tmp_path) -> None: config = generate_onnx_build_config(str(onnx_file), device="cpu") assert config.export is None - assert config.quant is None + assert config.quant is not None + assert config.quant.mode == "fp16" assert config.compile is None def test_quantized_onnx_skips_quant(self, tmp_path) -> None: @@ -2739,10 +2745,10 @@ def test_onnx_path_as_pathlib(self, tmp_path) -> None: assert config.export is None def test_auto_device_auto_precision_defaults(self, tmp_path) -> None: - """device=auto + precision=auto (defaults) keeps config defaults. + """device=auto + precision=auto (defaults) resolves to fp16 on CPU. - resolve_quant_compile_config returns (None, None) when both are auto, - so raw ONNX gets quant=None, compile=None. + resolve_check_device_ep returns device="auto" but resolve_precision + resolves the EP to pick a concrete device, yielding an fp16 algorithm quant config. """ onnx_file = tmp_path / "model.onnx" onnx_file.write_bytes(b"fake") @@ -2757,8 +2763,9 @@ def test_auto_device_auto_precision_defaults(self, tmp_path) -> None: ): config = generate_onnx_build_config(str(onnx_file)) - # Both auto -> resolve_precision returns device="auto" -> (None, None) - assert config.quant is None + # EP resolves to CPU, auto-precision=fp16 → fp16 algorithm quant config + assert config.quant is not None + assert config.quant.mode == "fp16" assert config.compile is None def test_compiled_does_not_call_resolve_quant_compile(self, tmp_path) -> None: @@ -2778,7 +2785,7 @@ def test_compiled_does_not_call_resolve_quant_compile(self, tmp_path) -> None: mock_resolve.assert_not_called() def test_raw_onnx_with_gpu(self, tmp_path) -> None: - """Raw ONNX + device=gpu resolves quant=None, compile=dml.""" + """Raw ONNX + device=gpu resolves to an fp16 algorithm quant config, compile=None.""" onnx_file = tmp_path / "model.onnx" onnx_file.write_bytes(b"fake") @@ -2792,8 +2799,10 @@ def test_raw_onnx_with_gpu(self, tmp_path) -> None: ): config = generate_onnx_build_config(str(onnx_file), device="gpu") - # GPU auto-precision is fp16 -> no quantization, no compile (DML has no offline step) - assert config.quant is None + # GPU auto-precision is fp16 -> fp16 algorithm quant config, no + # compile (DML has no offline step) + assert config.quant is not None + assert config.quant.mode == "fp16" assert config.compile is None def test_ep_override_forwarded(self, tmp_path) -> None: @@ -2831,15 +2840,21 @@ class TestResolveQuantCompileConfig: the HF and ONNX build config paths. """ - def test_auto_auto_returns_none_none(self) -> None: - """device=auto + precision=auto returns (None, None).""" + def test_auto_auto_returns_fp16_algorithm(self) -> None: + """device=auto + precision=auto resolves to an fp16 algorithm quant config. + + When resolve_check_device_ep returns device="auto" but the EP + resolves to a concrete device, resolve_precision picks auto-precision + (fp16 for CPU), yielding an fp16 algorithm quant config. + """ with patch( "winml.modelkit.sysinfo.resolve_check_device_ep", return_value=("auto", ["npu", "gpu", "cpu"], ["CPUExecutionProvider"]), ): quant, compile_cfg = resolve_quant_compile_config() - assert quant is None + assert isinstance(quant, WinMLQuantizationConfig) + assert quant.mode == "fp16" assert compile_cfg is None def test_npu_returns_quant_and_compile(self) -> None: @@ -2856,26 +2871,28 @@ def test_npu_returns_quant_and_compile(self) -> None: assert isinstance(compile_cfg, WinMLCompileConfig) assert compile_cfg.ep_config.provider == "qnn" - def test_gpu_returns_none_quant_and_none_compile(self) -> None: - """device=gpu returns (None, None) — DML has no offline compile step.""" + def test_gpu_returns_fp16_quant_and_none_compile(self) -> None: + """device=gpu returns (fp16 algorithm quant config, None) — auto-precision is fp16.""" with patch( "winml.modelkit.sysinfo.resolve_check_device_ep", return_value=("gpu", ["gpu", "cpu"], ["DmlExecutionProvider"]), ): quant, compile_cfg = resolve_quant_compile_config(device="gpu") - assert quant is None + assert isinstance(quant, WinMLQuantizationConfig) + assert quant.mode == "fp16" assert compile_cfg is None - def test_cpu_returns_none_none(self) -> None: - """device=cpu returns (None, None) since CPU has no compile provider.""" + def test_cpu_returns_fp16_quant_and_none_compile(self) -> None: + """device=cpu returns (fp16 algorithm quant config, None) — auto-precision is fp16.""" with patch( "winml.modelkit.sysinfo.resolve_check_device_ep", return_value=("cpu", ["cpu"], ["CPUExecutionProvider"]), ): quant, compile_cfg = resolve_quant_compile_config(device="cpu") - assert quant is None + assert isinstance(quant, WinMLQuantizationConfig) + assert quant.mode == "fp16" assert compile_cfg is None def test_ep_override_changes_provider(self) -> None: diff --git a/tests/unit/config/test_build_onnx.py b/tests/unit/config/test_build_onnx.py index 434bb3661..57805c619 100644 --- a/tests/unit/config/test_build_onnx.py +++ b/tests/unit/config/test_build_onnx.py @@ -220,7 +220,7 @@ def test_raw_onnx_full_pipeline(self, tmp_path) -> None: assert config.compile.ep_config.provider == "qnn" def test_raw_onnx_cpu(self, tmp_path) -> None: - """Raw ONNX + device=cpu resolves quant=None and compile=None.""" + """Raw ONNX + device=cpu resolves to an fp16 algorithm quant config, compile=None.""" onnx_file = tmp_path / "model.onnx" onnx_file.write_bytes(b"fake") @@ -235,7 +235,8 @@ def test_raw_onnx_cpu(self, tmp_path) -> None: config = generate_onnx_build_config(str(onnx_file), device="cpu") assert config.export is None - assert config.quant is None + assert config.quant is not None + assert config.quant.mode == "fp16" assert config.compile is None def test_quantized_onnx_skips_quant(self, tmp_path) -> None: @@ -562,10 +563,9 @@ def test_onnx_path_as_pathlib(self, tmp_path) -> None: assert config.export is None def test_auto_device_auto_precision_defaults(self, tmp_path) -> None: - """device=auto + precision=auto (defaults) keeps config defaults. + """device=auto + precision=auto resolves to fp16 on CPU. - resolve_quant_compile_config returns (None, None) when both are auto, - so raw ONNX gets quant=None, compile=None. + resolve_precision resolves the EP to a concrete device, yielding the fp16 algorithm. """ onnx_file = tmp_path / "model.onnx" onnx_file.write_bytes(b"fake") @@ -580,8 +580,8 @@ def test_auto_device_auto_precision_defaults(self, tmp_path) -> None: ): config = generate_onnx_build_config(str(onnx_file)) - # Both auto -> resolve_precision returns device="auto" -> (None, None) - assert config.quant is None + assert config.quant is not None + assert config.quant.mode == "fp16" assert config.compile is None def test_compiled_does_not_call_resolve_quant_compile(self, tmp_path) -> None: @@ -601,7 +601,7 @@ def test_compiled_does_not_call_resolve_quant_compile(self, tmp_path) -> None: mock_resolve.assert_not_called() def test_raw_onnx_with_gpu(self, tmp_path) -> None: - """Raw ONNX + device=gpu resolves quant=None, compile=None. + """Raw ONNX + device=gpu resolves to an fp16 algorithm quant config, compile=None. DML has enable_ep_context=False so for_provider("dml") returns None — no offline compile step is needed. @@ -619,8 +619,9 @@ def test_raw_onnx_with_gpu(self, tmp_path) -> None: ): config = generate_onnx_build_config(str(onnx_file), device="gpu") - # GPU auto-precision is fp16 -> no quantization; DML has no EPContext step - assert config.quant is None + # GPU auto-precision is fp16 -> fp16 algorithm quant config; DML has no EPContext step + assert config.quant is not None + assert config.quant.mode == "fp16" assert config.compile is None def test_ep_override_forwarded(self, tmp_path) -> None: @@ -660,15 +661,16 @@ class TestResolveQuantCompileConfig: the HF and ONNX build config paths. """ - def test_auto_auto_returns_none_none(self) -> None: - """device=auto + precision=auto returns (None, None).""" + def test_auto_auto_returns_fp16_algorithm(self) -> None: + """device=auto + precision=auto resolves to an fp16 algorithm quant config.""" with patch( "winml.modelkit.sysinfo.resolve_check_device_ep", return_value=("auto", ["npu", "gpu", "cpu"], ["CPUExecutionProvider"]), ): quant, compile_cfg = resolve_quant_compile_config() - assert quant is None + assert isinstance(quant, WinMLQuantizationConfig) + assert quant.mode == "fp16" assert compile_cfg is None def test_npu_returns_quant_and_compile(self) -> None: @@ -685,26 +687,28 @@ def test_npu_returns_quant_and_compile(self) -> None: assert isinstance(compile_cfg, WinMLCompileConfig) assert compile_cfg.ep_config.provider == "qnn" - def test_gpu_returns_none_quant_and_none_compile(self) -> None: - """device=gpu returns (None, None) — DML has no EPContext step.""" + def test_gpu_returns_fp16_quant_and_none_compile(self) -> None: + """device=gpu returns (fp16 algorithm quant config, None) — auto-precision is fp16.""" with patch( "winml.modelkit.sysinfo.resolve_check_device_ep", return_value=("gpu", ["gpu", "cpu"], ["DmlExecutionProvider"]), ): quant, compile_cfg = resolve_quant_compile_config(device="gpu") - assert quant is None + assert isinstance(quant, WinMLQuantizationConfig) + assert quant.mode == "fp16" assert compile_cfg is None - def test_cpu_returns_none_none(self) -> None: - """device=cpu returns (None, None) since CPU has no compile provider.""" + def test_cpu_returns_fp16_quant_and_none_compile(self) -> None: + """device=cpu returns (fp16 algorithm quant config, None) — auto-precision is fp16.""" with patch( "winml.modelkit.sysinfo.resolve_check_device_ep", return_value=("cpu", ["cpu"], ["CPUExecutionProvider"]), ): quant, compile_cfg = resolve_quant_compile_config(device="cpu") - assert quant is None + assert isinstance(quant, WinMLQuantizationConfig) + assert quant.mode == "fp16" assert compile_cfg is None def test_ep_override_changes_provider(self) -> None: diff --git a/tests/unit/config/test_precision.py b/tests/unit/config/test_precision.py index c41ee4e31..122d15b94 100644 --- a/tests/unit/config/test_precision.py +++ b/tests/unit/config/test_precision.py @@ -17,7 +17,11 @@ import pytest from winml.modelkit.config.precision import ( + expand_precision, + extract_activation_bits, + extract_weight_bits, is_quantized_precision, + is_weight_only_precision, resolve_precision, resolve_quant_types, ) @@ -312,19 +316,24 @@ def test_auto_raises(self) -> None: with pytest.raises(ValueError, match="Unknown precision"): resolve_quant_types("auto") - # ---- Unsupported bit widths ---- - def test_unsupported_weight_bits_raises(self) -> None: - """w4a16 has unsupported weight bit-width 4 -- must raise ValueError.""" - with pytest.raises(ValueError, match="Unsupported weight bit-width 4"): + # ---- Weight-only precision raises (should use RTN, not QDQ) ---- + def test_weight_only_precision_raises(self) -> None: + """w4a16 is weight-only (RTN) — resolve_quant_types must raise.""" + with pytest.raises(ValueError, match=r"weight-only.*RTN"): resolve_quant_types("w4a16") + def test_int4_raises(self) -> None: + """int4 is weight-only (RTN) — resolve_quant_types must raise.""" + with pytest.raises(ValueError, match=r"weight-only.*RTN"): + resolve_quant_types("int4") + def test_unsupported_activation_bits_raises(self) -> None: """w8a4 has unsupported activation bit-width 4 -- must raise ValueError.""" with pytest.raises(ValueError, match="Unsupported activation bit-width 4"): resolve_quant_types("w8a4") - def test_both_bits_unsupported_raises_weight_first(self) -> None: - """w4a4 should raise on weight bits first (checked before activation).""" + def test_both_bits_unsupported_raises(self) -> None: + """w4a4 has unsupported bit-widths — must raise ValueError.""" with pytest.raises(ValueError, match="Unsupported weight bit-width 4"): resolve_quant_types("w4a4") @@ -400,11 +409,17 @@ def test_float_and_auto_return_false(self, precision: str) -> None: assert is_quantized_precision(precision) is False # ---- False cases: unsupported bit widths ---- - @pytest.mark.parametrize("precision", ["w4a16", "w8a4", "w4a4", "w2a8", "w8a2"]) + @pytest.mark.parametrize("precision", ["w8a4", "w4a4", "w2a8", "w8a2"]) def test_unsupported_bits_return_false(self, precision: str) -> None: """Unsupported w{x}a{y} bit widths must return False, not True.""" assert is_quantized_precision(precision) is False + # ---- True cases: weight-only ---- + @pytest.mark.parametrize("precision", ["int4", "w4a16", "w4a8"]) + def test_weight_only_return_true(self, precision: str) -> None: + """Weight-only precisions (int4, w4a16) are quantized.""" + assert is_quantized_precision(precision) is True + # ---- False cases: completely invalid ---- @pytest.mark.parametrize("precision", ["garbage", "wXaY", "", "bfloat16", "w0a0"]) def test_invalid_strings_return_false(self, precision: str) -> None: @@ -483,13 +498,21 @@ class TestMixedPrecisionInvalidInputs: @pytest.mark.parametrize( "precision", - ["w4a16", "w4a4", "w2a8"], + ["w4a4", "w2a8"], ) def test_unsupported_mixed_bits_rejected(self, precision: str) -> None: """Unsupported w{x}a{y} bit widths should raise ValueError.""" with pytest.raises(ValueError, match="Unknown precision"): resolve_precision(device="npu", precision=precision) + def test_w4a16_is_valid_weight_only(self) -> None: + """w4a16 is now a valid weight-only precision (RTN).""" + policy = resolve_precision(device="npu", precision="w4a16") + assert policy.precision == "w4a16" + # Weight-only: no traditional weight_type/activation_type + assert policy.weight_type is None + assert policy.activation_type is None + def test_w0a0_rejected(self) -> None: """w0a0 is not a valid precision.""" with pytest.raises(ValueError, match="Unknown precision"): @@ -580,10 +603,17 @@ def test_no_precision_defaults_uint8(self) -> None: # ---- Unsupported precision is rejected ---- def test_unsupported_precision_rejected(self) -> None: - """Unsupported precision (w4a16) must raise BadParameter, not silently fall back.""" + """Unsupported precision (w2a8) must raise BadParameter.""" import click with pytest.raises(click.BadParameter, match="not a supported quantization precision"): + self._resolve(precision="w2a8") + + def test_weight_only_precision_rejected(self) -> None: + """Weight-only precision (w4a16) must raise BadParameter (should use RTN path).""" + import click + + with pytest.raises(click.BadParameter, match="weight-only"): self._resolve(precision="w4a16") # ---- Explicit flags override precision ---- @@ -611,3 +641,145 @@ def test_w8a16_case_insensitive(self) -> None: w, a = self._resolve(precision="W8A16") assert w == "uint8" assert a == "uint16" + + +# ============================================================================= +# TestIsWeightOnlyPrecision - RTN detection +# ============================================================================= + + +class TestIsWeightOnlyPrecision: + """Test is_weight_only_precision() function.""" + + @pytest.mark.parametrize("precision", ["int4", "w4a16", "w4a8", "w4a32"]) + def test_weight_only_true(self, precision: str) -> None: + """Weight-only precisions should return True.""" + assert is_weight_only_precision(precision) is True + + @pytest.mark.parametrize("precision", ["int8", "int16", "w8a16", "w8a8", "w16a16"]) + def test_qdq_precisions_false(self, precision: str) -> None: + """QDQ precisions should return False.""" + assert is_weight_only_precision(precision) is False + + @pytest.mark.parametrize("precision", ["w8a32", "w16a32"]) + def test_a32_with_qdq_weight_false(self, precision: str) -> None: + """a32 (keep FP32) is only valid with weight-only (4-bit), not QDQ weights.""" + assert is_weight_only_precision(precision) is False + + @pytest.mark.parametrize("precision", ["fp16", "fp32", "auto"]) + def test_float_precisions_false(self, precision: str) -> None: + """Float precisions should return False.""" + assert is_weight_only_precision(precision) is False + + @pytest.mark.parametrize("precision", ["garbage", "", "bfloat16"]) + def test_invalid_returns_false(self, precision: str) -> None: + """Invalid precision strings should return False.""" + assert is_weight_only_precision(precision) is False + + +# ============================================================================= +# TestExtractWeightBits - bit extraction +# ============================================================================= + + +class TestExtractWeightBits: + """Test extract_weight_bits() function.""" + + @pytest.mark.parametrize( + ("precision", "expected"), + [ + ("int4", 4), + ("int8", 8), + ("int16", 16), + ("w4a16", 4), + ("w4a8", 4), + ("w4a32", 4), + ("w8a8", 8), + ("w8a16", 8), + ("w16a16", 16), + ], + ) + def test_extract_bits(self, precision: str, expected: int) -> None: + """Should extract correct weight bit-width.""" + assert extract_weight_bits(precision) == expected + + @pytest.mark.parametrize("precision", ["fp16", "fp32", "auto", "garbage"]) + def test_invalid_raises(self, precision: str) -> None: + """Non-quantized precisions should raise ValueError.""" + with pytest.raises(ValueError, match=r"Cannot extract weight bits"): + extract_weight_bits(precision) + + @pytest.mark.parametrize("precision", ["w4a4", "w3a8", "w32a8"]) + def test_unsupported_bits_raises(self, precision: str) -> None: + """Precisions with unsupported bit-widths should raise ValueError.""" + with pytest.raises(ValueError, match=r"unsupported bit-widths"): + extract_weight_bits(precision) + + +# ============================================================================= +# TestExtractActivationBits - activation bit extraction +# ============================================================================= + + +class TestExtractActivationBits: + """Test extract_activation_bits() function.""" + + @pytest.mark.parametrize( + ("precision", "expected"), + [ + ("int4", 32), # int4 preset = w4a32 (activation stays FP32) + ("w4a32", 32), + ("w4a16", 16), + ("w4a8", 8), + ("w8a8", 8), + ("w8a16", 16), + ], + ) + def test_extract_activation_bits(self, precision: str, expected: int) -> None: + """Should extract correct activation bit-width.""" + assert extract_activation_bits(precision) == expected + + @pytest.mark.parametrize("precision", ["fp16", "fp32", "auto", "garbage"]) + def test_invalid_raises(self, precision: str) -> None: + """Non-mixed precisions should raise ValueError.""" + with pytest.raises(ValueError, match=r"Cannot extract activation bits"): + extract_activation_bits(precision) + + def test_unsupported_activation_raises(self) -> None: + """Unsupported activation bit-width should raise ValueError.""" + with pytest.raises(ValueError, match=r"unsupported activation bit-width"): + extract_activation_bits("w4a4") + + +# ============================================================================= +# TestExpandPrecision - Multi-pass precision expansion +# ============================================================================= + + +class TestExpandPrecision: + """Test expand_precision() function.""" + + @pytest.mark.parametrize( + ("precision", "expected"), + [ + ("w4a16", ["int4", "fp16"]), + ("W4A16", ["int4", "fp16"]), + ("int4", ["int4"]), + ("w4a32", ["w4a32"]), + ("fp16", ["fp16"]), + ("int8", ["int8"]), + ("w8a16", ["w8a16"]), + ("w8a8", ["w8a8"]), + ("int16", ["int16"]), + ], + ) + def test_expand_precision(self, precision: str, expected: list[str]) -> None: + """Verify precision expansion produces correct pass sequences.""" + assert expand_precision(precision) == expected + + def test_w4a16_is_only_multi_pass(self) -> None: + """Only w4a16 should produce more than one pass.""" + single_pass_cases = ["int4", "int8", "int16", "fp16", "w4a32", "w8a16", "w8a8"] + for prec in single_pass_cases: + result = expand_precision(prec) + assert len(result) == 1, f"{prec} should be single-pass but got {result}" diff --git a/tests/unit/models/auto/test_config.py b/tests/unit/models/auto/test_config.py index da17f80d3..f2fc132cc 100644 --- a/tests/unit/models/auto/test_config.py +++ b/tests/unit/models/auto/test_config.py @@ -98,7 +98,7 @@ def test_default_values(self): config = WinMLQuantizationConfig() - assert config.mode == "qdq" + assert config.mode == "static" assert config.weight_type == "uint8" assert config.samples == 10 assert config.calibration_method == "minmax" @@ -108,13 +108,13 @@ def test_qdq_mode_config(self): from winml.modelkit.quant import WinMLQuantizationConfig config = WinMLQuantizationConfig( - mode="qdq", + mode="static", weight_type="int8", activation_type="int8", calibration_method="minmax", ) - assert config.mode == "qdq" + assert config.mode == "static" assert config.weight_type == "int8" diff --git a/tests/unit/optim/test_fp16.py b/tests/unit/optim/test_fp16.py new file mode 100644 index 000000000..2e63ab193 --- /dev/null +++ b/tests/unit/optim/test_fp16.py @@ -0,0 +1,147 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""FP16 conversion utility tests. + +Tests for winml.modelkit.optim.fp16.convert_to_fp16 which converts +FP32 ONNX models to FP16 precision. + +Following Cardinal Rules: +- CARDINAL RULE #1: No hardcoded model architectures +- CARDINAL RULE #2: All tests use pytest with code-generated results +- CARDINAL RULE #3: Tests must run and pass +""" + +from __future__ import annotations + +import numpy as np +from onnx import ModelProto, TensorProto, helper, numpy_helper + +from winml.modelkit.quant.fp16 import convert_to_fp16 + + +# ============================================================================= +# HELPERS +# ============================================================================= + + +def _build_simple_fp32_model() -> ModelProto: + """Build a simple FP32 model: out = x + weight.""" + x = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 4]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 4]) + weight = numpy_helper.from_array(np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32), "weight") + add = helper.make_node("Add", ["x", "weight"], ["out"], name="add") + graph = helper.make_graph([add], "simple", [x], [out], [weight]) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +def _build_multi_op_fp32_model() -> ModelProto: + """Build a model with multiple ops: out = Relu(x + weight).""" + x = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 4]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT, [1, 4]) + weight = numpy_helper.from_array(np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32), "weight") + add = helper.make_node("Add", ["x", "weight"], ["add_out"], name="add") + relu = helper.make_node("Relu", ["add_out"], ["out"], name="relu") + graph = helper.make_graph([add, relu], "multi_op", [x], [out], [weight]) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + +# ============================================================================= +# CONVERT_TO_FP16 TESTS +# ============================================================================= + + +class TestConvertToFP16: + """Test convert_to_fp16 utility function.""" + + def test_converts_weights_to_fp16(self) -> None: + """FP16 conversion converts float32 initializers to float16.""" + model = _build_simple_fp32_model() + result = convert_to_fp16(model) + + has_fp16 = any(init.data_type == TensorProto.FLOAT16 for init in result.graph.initializer) + assert has_fp16, "Expected at least one FP16 initializer after conversion" + + def test_default_keeps_io_types(self) -> None: + """Default keep_io_types=True preserves FP32 model I/O.""" + model = _build_simple_fp32_model() + result = convert_to_fp16(model, keep_io_types=True) + + for inp in result.graph.input: + assert inp.type.tensor_type.elem_type == TensorProto.FLOAT + for outp in result.graph.output: + assert outp.type.tensor_type.elem_type == TensorProto.FLOAT + + def test_keep_io_types_false_converts_io(self) -> None: + """With keep_io_types=False, model I/O becomes FP16.""" + model = _build_simple_fp32_model() + result = convert_to_fp16(model, keep_io_types=False) + + for inp in result.graph.input: + assert inp.type.tensor_type.elem_type == TensorProto.FLOAT16 + for outp in result.graph.output: + assert outp.type.tensor_type.elem_type == TensorProto.FLOAT16 + + def test_preserves_model_structure(self) -> None: + """FP16 conversion preserves graph structure (node count diff ≤ 2).""" + model = _build_multi_op_fp32_model() + original_count = len(model.graph.node) + result = convert_to_fp16(model, keep_io_types=True) + converted_count = len(result.graph.node) + + assert converted_count - original_count <= 2, ( + f"Node count changed from {original_count} to {converted_count}, " + f"difference {converted_count - original_count} exceeds threshold of 2" + ) + + def test_op_block_list_keeps_ops_in_fp32(self) -> None: + """Ops in block list should remain operating on FP32 data.""" + model = _build_multi_op_fp32_model() + result = convert_to_fp16(model, op_block_list=["Relu"]) + + op_types = [n.op_type for n in result.graph.node] + assert "Cast" in op_types, "Expected Cast nodes for blocked ops" + + def test_none_op_block_list_uses_ort_defaults(self) -> None: + """When op_block_list is None, ORT uses its DEFAULT_OP_BLOCK_LIST.""" + model = _build_simple_fp32_model() + # Should not raise — ORT applies its default safety list + result = convert_to_fp16(model, op_block_list=None) + assert result is not None + + def test_skips_already_fp16_model(self) -> None: + """If all floating-point initializers are already FP16, conversion is skipped.""" + # Build a model with FP16 initializers directly + x = helper.make_tensor_value_info("x", TensorProto.FLOAT16, [1, 4]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT16, [1, 4]) + weight_data = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float16) + weight = numpy_helper.from_array(weight_data, "weight") + add = helper.make_node("Add", ["x", "weight"], ["out"], name="add") + graph = helper.make_graph([add], "fp16_model", [x], [out], [weight]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + original_nodes = len(model.graph.node) + result = convert_to_fp16(model) + + # Should return the same model unchanged (no Cast nodes inserted) + assert len(result.graph.node) == original_nodes + assert result is model + + def test_skips_fp16_model_with_int_initializers(self) -> None: + """FP16 model with non-float initializers (e.g. INT64 shapes) should still skip.""" + x = helper.make_tensor_value_info("x", TensorProto.FLOAT16, [1, 4]) + out = helper.make_tensor_value_info("out", TensorProto.FLOAT16, [1, 4]) + weight_data = np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float16) + weight = numpy_helper.from_array(weight_data, "weight") + # INT64 initializer (e.g., shape tensor) — should be ignored by skip logic + shape_tensor = numpy_helper.from_array(np.array([1, 4], dtype=np.int64), "shape") + add = helper.make_node("Add", ["x", "weight"], ["out"], name="add") + graph = helper.make_graph([add], "fp16_mixed", [x], [out], [weight, shape_tensor]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + original_nodes = len(model.graph.node) + result = convert_to_fp16(model) + + assert len(result.graph.node) == original_nodes + assert result is model diff --git a/tests/unit/optim/test_optimizer.py b/tests/unit/optim/test_optimizer.py index fe20d8401..1e51af2b7 100644 --- a/tests/unit/optim/test_optimizer.py +++ b/tests/unit/optim/test_optimizer.py @@ -698,7 +698,7 @@ def test_resolve_dependencies_method(self) -> None: def test_registered_pipes_count(self) -> None: """Verify the expected number of pipes are registered.""" Optimizer._initialize_pipes() - # Currently: RewritePipe, ORTGraphPipe, ORTFusionPipe, SurgeryPipe + # Currently: ORTGraphPipe, RewritePipe, ORTFusionPipe, SurgeryPipe assert len(Optimizer.pipes) == 4 def test_registered_pipe_names(self) -> None: