Skip to content

QuantizationPolicy.fp8_scaled_mm() cannot load ltx-2.3-22b-distilled-fp8.safetensors #205

@BenjiElysium

Description

GitHub Issue: FP8 Scaled Matrix Multiplication Checkpoint Loading Failure

Title: QuantizationPolicy.fp8_scaled_mm() cannot load ltx-2.3-22b-distilled-fp8.safetensors

Summary

The LTX-2.3 distilled FP8 checkpoint (Lightricks/LTX-2.3-fp8/ltx-2.3-22b-distilled-fp8.safetensors) cannot be loaded with QuantizationPolicy.fp8_scaled_mm() on Hopper GPUs due to state dict shape mismatches on transformer block weights. The checkpoint weights appear to be in PyTorch standard layout [out, in], but the FP8Linear layers expect cublas transposed layout [in, out].

Environment

  • GPU: H200 (Hopper)
  • ltx-core / ltx-pipelines: main branch (as of 2026-04-28)
  • TensorRT-LLM: 1.2.1
  • Checkpoint: ltx-2.3-22b-distilled-fp8.safetensors

Reproduction Steps

from ltx_pipelines.distilled import DistilledPipeline
from ltx_core.quantization import QuantizationPolicy

pipeline = DistilledPipeline(
    distilled_checkpoint_path="ltx-2.3-22b-distilled-fp8.safetensors",
    gemma_root="google/gemma-3-12b-it-qat-q4_0-unquantized",
    spatial_upsampler_path="ltx-2.3-spatial-upscaler-x2-1.1.safetensors",
    loras=[],
    quantization=QuantizationPolicy.fp8_scaled_mm(),
)
# RuntimeError: Error(s) in loading state_dict for LTXModel...

Error Message

RuntimeError: Error(s) in loading state_dict for LTXModel:
  size mismatch for transformer_blocks.1.ff.net.0.proj.weight: 
    copying a param with shape torch.Size([16384, 4096]) from checkpoint
    the shape in current model is torch.Size([4096, 16384])
  
  size mismatch for transformer_blocks.1.audio_ff.net.0.proj.weight: 
    copying a param with shape torch.Size([8192, 2048]) from checkpoint
    the shape in current model is torch.Size([2048, 8192])
  
  [... hundreds more ...]

Analysis

Weight Layout Mismatch

  • Checkpoint stores weights as: [out_features, in_features] (PyTorch standard)
  • Model expects (for FP8Linear): [in_features, out_features] (cublas transposed)

Affected layers across all transformer blocks (1-42, and possibly 0, 43-47):

  • ff.net.0.proj (feed-forward projections)
  • ff.net.2 (feed-forward output projections)
  • audio_ff.net.* (audio feed-forward layers)
  • Audio-video cross-attention layers

Expected Behavior

The FP8_TRANSPOSE_SD_OPS state dict operation should transpose all linear weights during state dict loading to convert from PyTorch layout to cublas layout. However, this is not occurring (or not occurring correctly), leading to shape mismatches.

Questions for Lightricks

  1. Checkpoint Creation: How was the ltx-2.3-22b-distilled-fp8.safetensors checkpoint created? Were weights transposed during quantization, or are they in standard PyTorch format?

  2. FP8_TRANSPOSE_SD_OPS: Is this operation being applied to the FP8 checkpoint state dict during loading? If so, why are weights still in mismatched shapes?

  3. Version Compatibility: Is there a version mismatch between when the FP8 checkpoint was created and the current ltx-core loader code?

  4. Documentation: The README says "Fp8-scaled-mm should be used with fp8 checkpoints," but it appears the checkpoint and loader are incompatible. Is the checkpoint still experimental or is there a different loading procedure we should use?

Potential Workarounds Tested

  • ❌ Removing block-level exclusions from FP8_TRANSPOSE_SD_OPS (blocks 0, 43-47) — did not resolve mismatch
  • ❌ Using fp8_cast() with the FP8 checkpoint — loads but produces garbled video (FP8 weights with wrong layout + incorrect scales)
  • ✅ Using fp8_cast() with bf16 checkpoint — works correctly (~62s/clip on H200)

Impact

The FP8 scaled-mm path with the official FP8 distilled checkpoint is currently non-functional on H200. This prevents users from accessing the TRT-LLM cublas_scaled_mm optimization, which is critical for achieving Hopper-class inference speeds comparable to managed services (e.g., ~20s/clip on Replicate).

Attachment

Full reproducible example and debug logs available at:

  • Repo: github.com/user/g-engine (private, can share on request)
  • Logs: Sweep results showing all 5-clip failures with detailed state dict errors

Additional Context: We are benchmarking LTX-2.3 across cloud GPU providers to close a 2.7× speed gap vs. Replicate. The fp8_scaled_mm path with the pre-quantized checkpoint is essential for reaching target performance on H200. Any guidance on resolving this would be greatly appreciated.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions