Skip to content

Commit cbf9219

Browse files
authored
fix: strip trailing latent channels for preview decode (#1548)
1 parent 8cf55a3 commit cbf9219

1 file changed

Lines changed: 14 additions & 20 deletions

File tree

src/stable-diffusion.cpp

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,23 +1610,18 @@ class StableDiffusionGGML {
16101610
std::function<void(int, int, sd_image_t*, bool, void*)> step_callback,
16111611
void* step_callback_data,
16121612
bool is_noisy) {
1613+
bool is_video = preview_latent_tensor_is_video(latents);
1614+
uint32_t dim = is_video ? static_cast<uint32_t>(latents.shape()[3]) : static_cast<uint32_t>(latents.shape()[2]);
1615+
int channels = get_latent_channel();
1616+
auto _latents = channels != dim ? is_video ? sd::ops::slice(latents, 3, 0, channels)
1617+
: sd::ops::slice(latents, 2, 0, channels)
1618+
: latents;
16131619
if (preview_mode == PREVIEW_PROJ) {
1614-
sd::Tensor<float> _latents = latents;
16151620
int patch_sz = 1;
16161621
const float(*latent_rgb_proj)[3] = nullptr;
16171622
float* latent_rgb_bias = nullptr;
1618-
bool is_video = preview_latent_tensor_is_video(latents);
1619-
uint32_t dim = is_video ? static_cast<uint32_t>(latents.shape()[3]) : static_cast<uint32_t>(latents.shape()[2]);
1620-
if (version == VERSION_LTXAV) {
1621-
if (is_video) {
1622-
_latents = sd::ops::slice(_latents, 3, 0, 128);
1623-
} else {
1624-
_latents = sd::ops::slice(_latents, 2, 0, 128);
1625-
}
1626-
dim = 128;
1627-
}
16281623

1629-
if (dim == 128) {
1624+
if (channels == 128) {
16301625
if (sd_version_uses_flux2_vae(version)) {
16311626
latent_rgb_proj = flux2_latent_rgb_proj;
16321627
latent_rgb_bias = flux2_latent_rgb_bias;
@@ -1638,15 +1633,15 @@ class StableDiffusionGGML {
16381633
LOG_WARN("No latent to RGB projection known for this model");
16391634
return;
16401635
}
1641-
} else if (dim == 48) {
1636+
} else if (channels == 48) {
16421637
if (sd_version_is_wan(version)) {
16431638
latent_rgb_proj = wan_22_latent_rgb_proj;
16441639
latent_rgb_bias = wan_22_latent_rgb_bias;
16451640
} else {
16461641
LOG_WARN("No latent to RGB projection known for this model");
16471642
return;
16481643
}
1649-
} else if (dim == 16) {
1644+
} else if (channels == 16) {
16501645
if (sd_version_is_sd3(version)) {
16511646
latent_rgb_proj = sd3_latent_rgb_proj;
16521647
latent_rgb_bias = sd3_latent_rgb_bias;
@@ -1660,7 +1655,7 @@ class StableDiffusionGGML {
16601655
LOG_WARN("No latent to RGB projection known for this model");
16611656
return;
16621657
}
1663-
} else if (dim == 4) {
1658+
} else if (channels == 4) {
16641659
if (sd_version_is_sdxl(version)) {
16651660
latent_rgb_proj = sdxl_latent_rgb_proj;
16661661
latent_rgb_bias = sdxl_latent_rgb_bias;
@@ -1671,8 +1666,8 @@ class StableDiffusionGGML {
16711666
LOG_WARN("No latent to RGB projection known for this model");
16721667
return;
16731668
}
1674-
} else if (dim != 3) {
1675-
LOG_WARN("No latent to RGB projection known for this model");
1669+
} else if (channels != 3) {
1670+
LOG_WARN("No latent to RGB projection known for this model (dim = %d)", dim);
16761671
return;
16771672
}
16781673

@@ -1697,14 +1692,13 @@ class StableDiffusionGGML {
16971692
if (preview_mode == PREVIEW_VAE || preview_mode == PREVIEW_TAE) {
16981693
sd::Tensor<float> vae_latents;
16991694
sd::Tensor<float> decoded;
1700-
bool is_video = preview_latent_tensor_is_video(latents);
17011695
if (preview_vae) {
17021696
preview_vae->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
1703-
vae_latents = preview_vae->diffusion_to_vae_latents(latents);
1697+
vae_latents = preview_vae->diffusion_to_vae_latents(_latents);
17041698
decoded = preview_vae->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
17051699
} else {
17061700
first_stage_model->set_temporal_tiling_enabled(vae_tiling_params.temporal_tiling);
1707-
vae_latents = first_stage_model->diffusion_to_vae_latents(latents);
1701+
vae_latents = first_stage_model->diffusion_to_vae_latents(_latents);
17081702
decoded = first_stage_model->decode(n_threads, vae_latents, vae_tiling_params, is_video, circular_x, circular_y, true);
17091703
}
17101704
if (decoded.empty()) {

0 commit comments

Comments
 (0)