Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 1 addition & 17 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,6 @@ int main(int argc, const char* argv[]) {
}
}

bool vae_decode_only = true;

auto load_image_and_update_size = [&](const std::string& path,
SDImageOwner& image,
bool resize_image = true,
Expand All @@ -646,21 +644,18 @@ int main(int argc, const char* argv[]) {
};

if (gen_params.init_image_path.size() > 0) {
vae_decode_only = false;
if (!load_image_and_update_size(gen_params.init_image_path, gen_params.init_image)) {
return 1;
}
}

if (gen_params.end_image_path.size() > 0) {
vae_decode_only = false;
if (!load_image_and_update_size(gen_params.end_image_path, gen_params.end_image)) {
return 1;
}
}

if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false;
gen_params.ref_images.clear();
for (auto& path : gen_params.ref_image_paths) {
SDImageOwner ref_image({0, 0, 3, nullptr});
Expand Down Expand Up @@ -735,18 +730,7 @@ int main(int argc, const char* argv[]) {
}
}

if (cli_params.mode == VID_GEN) {
vae_decode_only = false;
}

if (gen_params.hires_enabled &&
(gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_MODEL ||
gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_LANCZOS ||
gen_params.resolved_hires_upscaler == SD_HIRES_UPSCALER_NEAREST)) {
vae_decode_only = false;
}

sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, cli_params.taesd_preview);
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(cli_params.taesd_preview);

SDImageVec results;
int num_results = 0;
Expand Down
3 changes: 1 addition & 2 deletions examples/common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ std::string SDContextParams::to_string() const {
return oss.str();
}

sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool taesd_preview) {
sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool taesd_preview) {
embedding_vec.clear();
embedding_vec.reserve(embedding_map.size());
for (const auto& kv : embedding_map) {
Expand Down Expand Up @@ -787,7 +787,6 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool t
static_cast<uint32_t>(embedding_vec.size()),
photo_maker_path.c_str(),
tensor_type_rules.c_str(),
vae_decode_only,
n_threads,
wtype,
rng_type,
Expand Down
2 changes: 1 addition & 1 deletion examples/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ struct SDContextParams {
bool validate(SDMode mode);
bool resolve_and_validate(SDMode mode);
std::string to_string() const;
sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool taesd_preview);
sd_ctx_params_t to_sd_ctx_params_t(bool taesd_preview);
};

struct SDGenerationParams {
Expand Down
2 changes: 1 addition & 1 deletion examples/server/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ int main(int argc, const char** argv) {
LOG_DEBUG("%s", ctx_params.to_string().c_str());
LOG_DEBUG("%s", default_gen_params.to_string().c_str());

sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false);
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false);
SDCtxPtr sd_ctx(new_sd_ctx(&sd_ctx_params));

if (sd_ctx == nullptr) {
Expand Down
1 change: 0 additions & 1 deletion include/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ typedef struct {
uint32_t embedding_count;
const char* photo_maker_path;
const char* tensor_type_rules;
bool vae_decode_only;
int n_threads;
enum sd_type_t wtype;
enum rng_type_t rng_type;
Expand Down
2 changes: 1 addition & 1 deletion src/model/vae/ltx_vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ struct LTXVideoVAE : public VAE {
const sd::Tensor<float>& z,
bool decode_graph) override {
if (!decode_graph && decode_only) {
LOG_ERROR("LTX video VAE encode requires encoder weights; create the context with vae_decode_only=false");
LOG_ERROR("LTX video VAE encode requires encoder weights");
return {};
}
sd::Tensor<float> input = z;
Expand Down
66 changes: 11 additions & 55 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ class StableDiffusionGGML {
SDBackendManager backend_manager;

SDVersion version;
bool vae_decode_only = false;
bool external_vae_is_invalid = false;

bool circular_x = false;
Expand Down Expand Up @@ -318,7 +317,6 @@ class StableDiffusionGGML {

bool init(const sd_ctx_params_t* sd_ctx_params) {
n_threads = sd_ctx_params->n_threads;
vae_decode_only = sd_ctx_params->vae_decode_only;
offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu;
enable_mmap = sd_ctx_params->enable_mmap;
max_vram = sd_ctx_params->max_vram;
Expand Down Expand Up @@ -560,10 +558,6 @@ class StableDiffusionGGML {
size_t control_net_params_mem_size = 0;
size_t extension_params_mem_size = 0;

if (sd_version_is_control(version)) {
// Might need vae encode for control cond
vae_decode_only = false;
}
bool tae_preview_only = sd_ctx_params->tae_preview_only;
if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
tae_preview_only = false;
Expand Down Expand Up @@ -591,7 +585,6 @@ class StableDiffusionGGML {
"model.diffusion_model",
model_manager);
} else if (sd_version_is_pid(version)) {
vae_decode_only = false;
cond_stage_model = std::make_shared<LLMEmbedder>(backend_for(SDBackendModule::TE),
tensor_storage_map,
version,
Expand Down Expand Up @@ -706,15 +699,11 @@ class StableDiffusionGGML {
}
}
} else if (sd_version_is_qwen_image(version)) {
bool enable_vision = false;
if (!vae_decode_only) {
enable_vision = true;
}
cond_stage_model = std::make_shared<LLMEmbedder>(backend_for(SDBackendModule::TE),
tensor_storage_map,
version,
"",
enable_vision,
true,
model_manager);
diffusion_model = std::make_shared<Qwen::QwenImageRunner>(backend_for(SDBackendModule::DIFFUSION),
tensor_storage_map,
Expand All @@ -723,15 +712,11 @@ class StableDiffusionGGML {
sd_ctx_params->qwen_image_zero_cond_t,
model_manager);
} else if (sd_version_is_longcat(version)) {
bool enable_vision = false;
if (!vae_decode_only) {
enable_vision = true;
}
cond_stage_model = std::make_shared<LLMEmbedder>(backend_for(SDBackendModule::TE),
tensor_storage_map,
version,
"",
enable_vision,
true,
model_manager);
diffusion_model = std::make_shared<Flux::FluxRunner>(backend_for(SDBackendModule::DIFFUSION),
tensor_storage_map,
Expand Down Expand Up @@ -827,10 +812,6 @@ class StableDiffusionGGML {
return false;
}

if (sd_version_is_unet_edit(version)) {
vae_decode_only = false;
}

if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_max_graph_vram_bytes(max_graph_vram_bytes);
high_noise_diffusion_model->set_stream_layers_enabled(stream_layers);
Expand All @@ -846,23 +827,23 @@ class StableDiffusionGGML {
return false;
}

auto create_tae = [&]() -> std::shared_ptr<VAE> {
auto create_tae = [&](bool decode_only) -> std::shared_ptr<VAE> {
if (sd_version_is_wan(version) ||
sd_version_is_qwen_image(version) ||
sd_version_is_anima(version) ||
sd_version_is_ltxav(version)) {
return std::make_shared<TinyVideoAutoEncoder>(backend_for(SDBackendModule::VAE),
tensor_storage_map,
"decoder",
vae_decode_only,
decode_only,
version,
model_manager);

} else {
auto model = std::make_shared<TinyImageAutoEncoder>(backend_for(SDBackendModule::VAE),
tensor_storage_map,
"decoder.layers",
vae_decode_only,
decode_only,
version,
model_manager);
return model;
Expand All @@ -884,7 +865,7 @@ class StableDiffusionGGML {
return std::make_shared<LTXVideoVAE>(backend_for(SDBackendModule::VAE),
tensor_storage_map,
"first_stage_model",
vae_decode_only,
false,
version,
model_manager);
} else if (sd_version_is_wan(version) ||
Expand All @@ -893,14 +874,14 @@ class StableDiffusionGGML {
return std::make_shared<WAN::WanVAERunner>(backend_for(SDBackendModule::VAE),
tensor_storage_map,
"first_stage_model",
vae_decode_only,
false,
version,
model_manager);
} else {
auto model = std::make_shared<AutoEncoderKL>(backend_for(SDBackendModule::VAE),
tensor_storage_map,
"first_stage_model",
vae_decode_only,
false,
false,
vae_version,
model_manager);
Expand Down Expand Up @@ -930,7 +911,7 @@ class StableDiffusionGGML {
}
} else if (use_tae && !tae_preview_only) {
LOG_INFO("using TAE for encoding / decoding");
first_stage_model = create_tae();
first_stage_model = create_tae(false);
first_stage_model->set_max_graph_vram_bytes(max_graph_vram_bytes);
if (!register_runner_params("VAE",
first_stage_model,
Expand All @@ -950,7 +931,7 @@ class StableDiffusionGGML {
}
if (use_tae && tae_preview_only) {
LOG_INFO("using TAE for preview");
preview_vae = create_tae();
preview_vae = create_tae(true);
preview_vae->set_max_graph_vram_bytes(max_graph_vram_bytes);
if (!register_runner_params("preview VAE",
preview_vae,
Expand Down Expand Up @@ -1080,13 +1061,6 @@ class StableDiffusionGGML {
ignore_tensors.insert("model.diffusion_model.__32x32__");
ignore_tensors.insert("model.diffusion_model.__index_timestep_zero__");

if (vae_decode_only) {
ignore_tensors.insert("first_stage_model.encoder");
ignore_tensors.insert("first_stage_model.conv1");
ignore_tensors.insert("first_stage_model.quant");
ignore_tensors.insert("tae.encoder");
ignore_tensors.insert("text_encoders.llm.visual.");
}
if (audio_vae_model) {
ignore_tensors.insert("audio_vae.encoder");
}
Expand Down Expand Up @@ -2642,7 +2616,6 @@ void sd_hires_params_init(sd_hires_params_t* hires_params) {

void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
*sd_ctx_params = {};
sd_ctx_params->vae_decode_only = true;
sd_ctx_params->n_threads = sd_get_num_physical_cores();
sd_ctx_params->wtype = SD_TYPE_COUNT;
sd_ctx_params->rng_type = CUDA_RNG;
Expand Down Expand Up @@ -2691,7 +2664,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"control_net_path: %s\n"
"photo_maker_path: %s\n"
"tensor_type_rules: %s\n"
"vae_decode_only: %s\n"
"n_threads: %d\n"
"wtype: %s\n"
"rng_type: %s\n"
Expand Down Expand Up @@ -2730,7 +2702,6 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->control_net_path),
SAFE_STR(sd_ctx_params->photo_maker_path),
SAFE_STR(sd_ctx_params->tensor_type_rules),
BOOL_STR(sd_ctx_params->vae_decode_only),
sd_ctx_params->n_threads,
sd_type_name(sd_ctx_params->wtype),
sd_rng_type_name(sd_ctx_params->rng_type),
Expand Down Expand Up @@ -3913,7 +3884,7 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
}
}

if (!control_image_tensor.empty() && !sd_ctx->sd->vae_decode_only) {
if (!control_image_tensor.empty()) {
control_latent = sd_ctx->sd->encode_first_stage(control_image_tensor);
if (control_latent.empty()) {
LOG_ERROR("failed to encode control image");
Expand Down Expand Up @@ -4255,11 +4226,6 @@ static sd::Tensor<float> upscale_hires_latent(sd_ctx_t* sd_ctx,
} else if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL ||
request.hires.upscaler == SD_HIRES_UPSCALER_LANCZOS ||
request.hires.upscaler == SD_HIRES_UPSCALER_NEAREST) {
if (sd_ctx->sd->vae_decode_only) {
LOG_ERROR("hires %s upscaler requires VAE encoder weights; create the context with vae_decode_only=false",
sd_hires_upscaler_name(request.hires.upscaler));
return {};
}
if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL && upscaler == nullptr) {
LOG_ERROR("hires model upscaler context is null");
return {};
Expand Down Expand Up @@ -4607,11 +4573,6 @@ static std::optional<ImageGenerationLatents> prepare_video_generation_latents(sd
}

if (!start_image.empty() || !end_image.empty()) {
if (sd_ctx->sd->vae_decode_only) {
LOG_ERROR("LTXAV image conditioning requires VAE encoder weights; create the context with vae_decode_only=false");
return std::nullopt;
}

if (!start_image.empty() && !end_image.empty()) {
LOG_INFO("FLF2V");
} else if (!start_image.empty()) {
Expand Down Expand Up @@ -5076,11 +5037,6 @@ static bool apply_ltxv_refine_image_conditioning(sd_ctx_t* sd_ctx,
sd_vid_gen_params->end_image.data == nullptr) {
return true;
}
if (sd_ctx->sd->vae_decode_only) {
LOG_ERROR("LTXV refine image conditioning requires VAE encoder weights; create the context with vae_decode_only=false");
return false;
}

constexpr float conditioning_strength = 1.f;
int latent_channels = sd_ctx->sd->get_latent_channel();
sd::Tensor<float> video_latent = *latent;
Expand Down
Loading