diff --git a/mlx_video/models/ltx_2/video_vae/video_vae.py b/mlx_video/models/ltx_2/video_vae/video_vae.py index bd85086..f5acc9a 100644 --- a/mlx_video/models/ltx_2/video_vae/video_vae.py +++ b/mlx_video/models/ltx_2/video_vae/video_vae.py @@ -122,7 +122,10 @@ def _make_encoder_block( spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all_res": - out_channels = in_channels * block_config.get("multiplier", 2) + max_channels = block_config.get("max_channels", 1024) + out_channels = min( + in_channels * block_config.get("multiplier", 2), max_channels + ) block = SpaceToDepthDownsample( dims=convolution_dimensions, in_channels=in_channels, @@ -131,7 +134,10 @@ def _make_encoder_block( spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_space_res": - out_channels = in_channels * block_config.get("multiplier", 2) + max_channels = block_config.get("max_channels", 1024) + out_channels = min( + in_channels * block_config.get("multiplier", 2), max_channels + ) block = SpaceToDepthDownsample( dims=convolution_dimensions, in_channels=in_channels, @@ -140,7 +146,10 @@ def _make_encoder_block( spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_time_res": - out_channels = in_channels * block_config.get("multiplier", 2) + max_channels = block_config.get("max_channels", 1024) + out_channels = min( + in_channels * block_config.get("multiplier", 2), max_channels + ) block = SpaceToDepthDownsample( dims=convolution_dimensions, in_channels=in_channels,