From 510ce59f99e84762f04c079ac7c6d69d5d978cf9 Mon Sep 17 00:00:00 2001 From: Silas Palmer Date: Tue, 28 Apr 2026 21:39:07 +1000 Subject: [PATCH] fix(ltx2): cap encoder channels to match pretrained VAE weights MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The LTX-2.3-dev pretrained VAE encoder caps feature channels at 1024, but `_make_encoder_block` blindly applied `out_channels = in_channels * multiplier`. With the default config the second `compress_all_res` block expanded 1024 → 2048 channels, while its actual conv weight `(128, 3, 3, 3, 1024)` only produces 128 × 8 = 1024 channels after the space-to-depth, causing a broadcast error in `x_conv + x_in`: ValueError: [broadcast_shapes] Shapes (1,1024,1,8,12) and (1,2048,1,8,12) cannot be broadcast. Add an optional `max_channels` (default 1024) to the three `compress_*_res` block builders so `out_channels = min(in_channels * multiplier, max_channels)`. This matches the pretrained weights (conv_out has in=1024) and unblocks I2V / two-stage HQ pipelines that exercise the encoder. Co-Authored-By: Claude Opus 4.7 (1M context) --- mlx_video/models/ltx_2/video_vae/video_vae.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/mlx_video/models/ltx_2/video_vae/video_vae.py b/mlx_video/models/ltx_2/video_vae/video_vae.py index bd85086..f5acc9a 100644 --- a/mlx_video/models/ltx_2/video_vae/video_vae.py +++ b/mlx_video/models/ltx_2/video_vae/video_vae.py @@ -122,7 +122,10 @@ def _make_encoder_block( spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all_res": - out_channels = in_channels * block_config.get("multiplier", 2) + max_channels = block_config.get("max_channels", 1024) + out_channels = min( + in_channels * block_config.get("multiplier", 2), max_channels + ) block = SpaceToDepthDownsample( dims=convolution_dimensions, in_channels=in_channels, @@ -131,7 +134,10 @@ def _make_encoder_block( spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_space_res": - out_channels = in_channels * block_config.get("multiplier", 2) + max_channels = block_config.get("max_channels", 1024) + out_channels = min( + in_channels * block_config.get("multiplier", 2), max_channels + ) block = SpaceToDepthDownsample( dims=convolution_dimensions, in_channels=in_channels, @@ -140,7 +146,10 @@ def _make_encoder_block( spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_time_res": - out_channels = in_channels * block_config.get("multiplier", 2) + max_channels = block_config.get("max_channels", 1024) + out_channels = min( + in_channels * block_config.get("multiplier", 2), max_channels + ) block = SpaceToDepthDownsample( dims=convolution_dimensions, in_channels=in_channels,