Skip to content

Commit ce1867c

Browse files
committed
add PiD support
1 parent a8dabf2 commit ce1867c

16 files changed

Lines changed: 1158 additions & 27 deletions

examples/common/common.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@ const char* const modes_str[] = {
3535
"metadata",
3636
};
3737

38+
static sd_vae_format_t str_to_vae_format(const std::string& value) {
39+
if (value == "auto") {
40+
return SD_VAE_FORMAT_AUTO;
41+
}
42+
if (value == "flux") {
43+
return SD_VAE_FORMAT_FLUX;
44+
}
45+
if (value == "sd3") {
46+
return SD_VAE_FORMAT_SD3;
47+
}
48+
if (value == "flux2") {
49+
return SD_VAE_FORMAT_FLUX2;
50+
}
51+
return SD_VAE_FORMAT_COUNT;
52+
}
53+
3854
#if defined(_WIN32)
3955
static std::string utf16_to_utf8(const std::wstring& wstr) {
4056
if (wstr.empty())
@@ -348,6 +364,10 @@ ArgOptions SDContextParams::get_options() {
348364
"--vae",
349365
"path to standalone vae model",
350366
&vae_path},
367+
{"",
368+
"--vae-format",
369+
"VAE latent format override: auto, flux, sd3, or flux2 (default: auto)",
370+
&vae_format},
351371
{"",
352372
"--audio-vae",
353373
"path to standalone LTX audio vae model",
@@ -639,6 +659,11 @@ bool SDContextParams::validate(SDMode mode) {
639659
}
640660
}
641661

662+
if (str_to_vae_format(vae_format) == SD_VAE_FORMAT_COUNT) {
663+
LOG_ERROR("error: vae_format must be 'auto', 'flux', 'sd3', or 'flux2'");
664+
return false;
665+
}
666+
642667
return true;
643668
}
644669

@@ -679,6 +704,7 @@ std::string SDContextParams::to_string() const {
679704
<< " high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n"
680705
<< " embeddings_connectors_path: \"" << embeddings_connectors_path << "\",\n"
681706
<< " vae_path: \"" << vae_path << "\",\n"
707+
<< " vae_format: \"" << vae_format << "\",\n"
682708
<< " audio_vae_path: \"" << audio_vae_path << "\",\n"
683709
<< " taesd_path: \"" << taesd_path << "\",\n"
684710
<< " esrgan_path: \"" << esrgan_path << "\",\n"
@@ -772,6 +798,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f
772798
chroma_use_t5_mask,
773799
chroma_t5_mask_pad,
774800
qwen_image_zero_cond_t,
801+
str_to_vae_format(vae_format),
775802
max_vram,
776803
backend.c_str(),
777804
params_backend.c_str(),

examples/common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ struct SDContextParams {
9494
std::string high_noise_diffusion_model_path;
9595
std::string embeddings_connectors_path;
9696
std::string vae_path;
97+
std::string vae_format = "auto";
9798
std::string audio_vae_path;
9899
std::string taesd_path;
99100
std::string esrgan_path;

include/stable-diffusion.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,14 @@ typedef struct {
168168
const char* path;
169169
} sd_embedding_t;
170170

171+
enum sd_vae_format_t {
172+
SD_VAE_FORMAT_AUTO = -1,
173+
SD_VAE_FORMAT_FLUX,
174+
SD_VAE_FORMAT_SD3,
175+
SD_VAE_FORMAT_FLUX2,
176+
SD_VAE_FORMAT_COUNT,
177+
};
178+
171179
typedef struct {
172180
const char* model_path;
173181
const char* clip_l_path;
@@ -212,6 +220,7 @@ typedef struct {
212220
bool chroma_use_t5_mask;
213221
int chroma_t5_mask_pad;
214222
bool qwen_image_zero_cond_t;
223+
enum sd_vae_format_t vae_format;
215224
float max_vram; // GiB budget for graph-cut segmented param offload (0 = disabled, -1 = auto free VRAM minus 1 GiB)
216225
const char* backend;
217226
const char* params_backend;

src/conditioner.hpp

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,6 @@ struct FluxCLIPEmbedder : public Conditioner {
11711171
return true;
11721172
}
11731173

1174-
11751174
void free_params_buffer() override {
11761175
if (clip_l) {
11771176
clip_l->free_params_buffer();
@@ -1601,8 +1600,8 @@ struct AnimaConditioner : public Conditioner {
16011600

16021601
bool alloc_params_buffer() override {
16031602
if (!llm->alloc_params_buffer()) {
1604-
return false;
1605-
}
1603+
return false;
1604+
}
16061605
return true;
16071606
}
16081607

@@ -1719,13 +1718,17 @@ struct LLMEmbedder : public Conditioner {
17191718
arch = LLM::LLMArch::MINISTRAL_3_3B;
17201719
} else if (sd_version_is_lens(version)) {
17211720
arch = LLM::LLMArch::GPT_OSS_20B;
1721+
} else if (sd_version_is_pid(version)) {
1722+
arch = LLM::LLMArch::GEMMA2_2B;
17221723
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
17231724
arch = LLM::LLMArch::QWEN3;
17241725
}
17251726
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2 || arch == LLM::LLMArch::MINISTRAL_3_3B) {
17261727
tokenizer = std::make_shared<MistralTokenizer>();
17271728
} else if (arch == LLM::LLMArch::GPT_OSS_20B) {
17281729
tokenizer = std::make_shared<GPTOSSTokenizer>();
1730+
} else if (arch == LLM::LLMArch::GEMMA2_2B) {
1731+
tokenizer = std::make_shared<Gemma2Tokenizer>();
17291732
} else {
17301733
tokenizer = std::make_shared<Qwen2Tokenizer>();
17311734
}
@@ -1743,7 +1746,7 @@ struct LLMEmbedder : public Conditioner {
17431746

17441747
bool alloc_params_buffer() override {
17451748
if (!llm->alloc_params_buffer()) {
1746-
return false;
1749+
return false;
17471750
}
17481751
return true;
17491752
}
@@ -1847,12 +1850,16 @@ struct LLMEmbedder : public Conditioner {
18471850
sd::Tensor<int32_t> input_ids({static_cast<int64_t>(tokens.size())}, tokens);
18481851
sd::Tensor<float> attention_mask;
18491852
if (!mask.empty()) {
1850-
attention_mask = sd::Tensor<float>({static_cast<int64_t>(mask.size()), static_cast<int64_t>(mask.size())});
1853+
attention_mask = sd::Tensor<float>({static_cast<int64_t>(mask.size()), static_cast<int64_t>(mask.size())});
1854+
const float masked_attention_value = -std::numeric_limits<float>::max() / 4.0f;
18511855
for (size_t i1 = 0; i1 < mask.size(); ++i1) {
18521856
for (size_t i0 = 0; i0 < mask.size(); ++i0) {
18531857
float value = 0.0f;
1854-
if (mask[i0] == 0.0f || i0 > i1) {
1855-
value = -INFINITY;
1858+
if (mask[i0] == 0.0f) {
1859+
value += masked_attention_value;
1860+
}
1861+
if (i0 > i1) {
1862+
value += masked_attention_value;
18561863
}
18571864
attention_mask[static_cast<int64_t>(i0 + mask.size() * i1)] = value;
18581865
}
@@ -2126,6 +2133,53 @@ struct LLMEmbedder : public Conditioner {
21262133
prompt_attn_range.second = static_cast<int>(prompt.size());
21272134

21282135
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
2136+
} else if (sd_version_is_pid(version)) {
2137+
constexpr int pixeldit_max_length = 300;
2138+
const std::string chi_prompt =
2139+
"Given a user prompt, generate an \"Enhanced prompt\" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:\n"
2140+
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.\n"
2141+
"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n"
2142+
"Here are examples of how to transform or refine prompts:\n"
2143+
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
2144+
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n"
2145+
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:\n"
2146+
"User Prompt: ";
2147+
auto chi_tokens = std::get<0>(tokenize(chi_prompt, {0, 0}));
2148+
size_t num_chi_tokens = chi_tokens.size();
2149+
max_length = (int)num_chi_tokens + pixeldit_max_length - 2;
2150+
min_length = max_length;
2151+
2152+
prompt_attn_range.first = static_cast<int>(prompt.size());
2153+
prompt += " " + conditioner_params.text;
2154+
prompt_attn_range.second = static_cast<int>(prompt.size());
2155+
2156+
auto hidden_states = encode_prompt(n_threads,
2157+
prompt,
2158+
prompt_attn_range,
2159+
min_length,
2160+
0,
2161+
image_embeds,
2162+
out_layers,
2163+
0,
2164+
false,
2165+
max_length);
2166+
GGML_ASSERT(!hidden_states.empty());
2167+
2168+
if (hidden_states.shape()[1] > pixeldit_max_length) {
2169+
auto bos = sd::ops::slice(hidden_states, 1, 0, 1);
2170+
auto tail = sd::ops::slice(hidden_states,
2171+
1,
2172+
hidden_states.shape()[1] - (pixeldit_max_length - 1),
2173+
hidden_states.shape()[1]);
2174+
hidden_states = sd::ops::concat(bos, tail, 1);
2175+
}
2176+
2177+
int64_t t1 = ggml_time_ms();
2178+
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
2179+
2180+
SDCondition result;
2181+
result.c_crossattn = std::move(hidden_states);
2182+
return result;
21292183
} else {
21302184
GGML_ABORT("unknown version %d", version);
21312185
}
@@ -2268,10 +2322,10 @@ struct LTXAVEmbedder : public Conditioner {
22682322

22692323
bool alloc_params_buffer() override {
22702324
if (!llm->alloc_params_buffer()) {
2271-
return false;
2325+
return false;
22722326
}
22732327
if (!projector->alloc_params_buffer()) {
2274-
return false;
2328+
return false;
22752329
}
22762330
return true;
22772331
}

src/llm.hpp

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ namespace LLM {
3737
MISTRAL_SMALL_3_2,
3838
MINISTRAL_3_3B,
3939
GEMMA3_12B,
40+
GEMMA2_2B,
4041
GPT_OSS_20B,
4142
ARCH_COUNT,
4243
};
@@ -48,6 +49,7 @@ namespace LLM {
4849
"mistral_small3.2",
4950
"ministral3.3b",
5051
"gemma3_12b",
52+
"gemma2_2b",
5153
"gpt_oss_20b",
5254
};
5355

@@ -900,6 +902,33 @@ namespace LLM {
900902
1.f,
901903
32.f,
902904
1.f);
905+
} else if (arch == LLMArch::GEMMA2_2B) {
906+
q = ggml_rope_ext(ctx->ggml_ctx,
907+
q,
908+
input_pos,
909+
nullptr,
910+
head_dim,
911+
GGML_ROPE_TYPE_NEOX,
912+
8192,
913+
10000.f,
914+
1.f,
915+
0.f,
916+
1.f,
917+
32.f,
918+
1.f);
919+
k = ggml_rope_ext(ctx->ggml_ctx,
920+
k,
921+
input_pos,
922+
nullptr,
923+
head_dim,
924+
GGML_ROPE_TYPE_NEOX,
925+
8192,
926+
10000.f,
927+
1.f,
928+
0.f,
929+
1.f,
930+
32.f,
931+
1.f);
903932
} else if (arch == LLMArch::QWEN3_VL) {
904933
int sections[4] = {24, 20, 20, 0};
905934
q = ggml_rope_multi(ctx->ggml_ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_IMROPE, 262144, 5000000.f, 1.f, 0.f, 1.f, 32.f, 1.f);
@@ -957,10 +986,18 @@ namespace LLM {
957986
: arch(params.arch),
958987
sliding_attention(0) {
959988
if (params.arch == LLMArch::GEMMA3_12B) {
960-
post_attention_norm_name = "post_attention_norm";
961-
post_ffw_norm_name = "post_ffw_norm";
989+
post_attention_norm_name = "post_attention_norm"; // attn_post_norm
990+
pre_ffw_norm_name = "post_attention_layernorm"; // ffn_norm
991+
post_ffw_norm_name = "post_ffw_norm"; // ffn_post_norm
992+
} else if (params.arch == LLMArch::GEMMA2_2B) {
993+
post_attention_norm_name = "post_attention_layernorm"; // ffn_norm
994+
pre_ffw_norm_name = "pre_feedforward_layernorm";
995+
post_ffw_norm_name = "post_feedforward_layernorm";
996+
} else if (params.arch == LLMArch::GPT_OSS_20B) {
997+
pre_ffw_norm_name = "post_attention_norm"; // attn_post_norm
998+
} else {
999+
pre_ffw_norm_name = "post_attention_layernorm"; // ffn_norm
9621000
}
963-
pre_ffw_norm_name = params.arch == LLMArch::GPT_OSS_20B ? "post_attention_norm" : "post_attention_layernorm";
9641001

9651002
blocks["self_attn"] = std::make_shared<Attention>(params);
9661003
if (params.arch == LLMArch::GPT_OSS_20B) {
@@ -1447,6 +1484,21 @@ namespace LLM {
14471484
params.rope_thetas = {1000000.f, 10000.f};
14481485
params.rope_scales = {8.f, 1.f};
14491486
params.sliding_attention = {1024, 1024, 1024, 1024, 1024, 0};
1487+
} else if (arch == LLMArch::GEMMA2_2B) {
1488+
params.head_dim = 256;
1489+
params.num_heads = 8;
1490+
params.num_kv_heads = 4;
1491+
params.qkv_bias = false;
1492+
params.qk_norm = false;
1493+
params.rms_norm_eps = 1e-6f;
1494+
params.rms_norm_add = true;
1495+
params.normalize_input = true;
1496+
params.max_position_embeddings = 8192;
1497+
params.mlp_activation = MLPActivation::GELU_TANH;
1498+
params.hidden_size = 2304;
1499+
params.intermediate_size = 9216;
1500+
params.num_layers = 26;
1501+
params.vocab_size = 256000;
14501502
} else if (arch == LLMArch::GPT_OSS_20B) {
14511503
params.head_dim = 64;
14521504
params.num_heads = 64;
@@ -1585,6 +1637,7 @@ namespace LLM {
15851637
params.arch == LLMArch::MINISTRAL_3_3B ||
15861638
params.arch == LLMArch::QWEN3 ||
15871639
params.arch == LLMArch::GEMMA3_12B ||
1640+
params.arch == LLMArch::GEMMA2_2B ||
15881641
params.arch == LLMArch::GPT_OSS_20B) {
15891642
input_pos_vec.resize(n_tokens);
15901643
for (int i = 0; i < n_tokens; ++i) {

src/lora.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ struct LoraModel : public GGMLRunner {
9191
return false;
9292
}
9393

94-
9594
dry_run = false;
9695
model_loader.load_tensors(on_new_tensor_cb, n_threads);
9796

src/ltx_audio_vae.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,8 +1069,8 @@ namespace LTXV {
10691069
prefix);
10701070

10711071
if (!ltx_audio_vae->alloc_params_buffer()) {
1072-
LOG_ERROR("ltx audio vae buffer allocation failed");
1073-
return;
1072+
LOG_ERROR("ltx audio vae buffer allocation failed");
1073+
return;
10741074
}
10751075

10761076
std::map<std::string, ggml_tensor*> tensors;

src/model.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,9 @@ SDVersion ModelLoader::get_sd_version() {
432432
tensor_storage.name.find("model.diffusion_model.single_transformer_blocks.") != std::string::npos) {
433433
is_flux = true;
434434
}
435+
if (tensor_storage.name.find("model.diffusion_model.net.lq_proj.latent_proj.0.weight") != std::string::npos) {
436+
return VERSION_PID;
437+
}
435438
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
436439
return VERSION_CHROMA_RADIANCE;
437440
}

src/model.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ enum SDVersion {
4949
VERSION_ERNIE_IMAGE,
5050
VERSION_LENS,
5151
VERSION_LONGCAT,
52+
VERSION_PID,
5253
VERSION_COUNT,
5354
};
5455

@@ -164,6 +165,13 @@ static inline bool sd_version_is_lens(SDVersion version) {
164165
return false;
165166
}
166167

168+
static inline bool sd_version_is_pid(SDVersion version) {
169+
if (version == VERSION_PID) {
170+
return true;
171+
}
172+
return false;
173+
}
174+
167175
static inline bool sd_version_uses_flux2_vae(SDVersion version) {
168176
if (sd_version_is_flux2(version) || sd_version_is_ernie_image(version) || sd_version_is_lens(version)) {
169177
return true;
@@ -194,7 +202,8 @@ static inline bool sd_version_is_dit(SDVersion version) {
194202
sd_version_is_z_image(version) ||
195203
sd_version_is_ernie_image(version) ||
196204
sd_version_is_lens(version) ||
197-
sd_version_is_longcat(version)) {
205+
sd_version_is_longcat(version) ||
206+
sd_version_is_pid(version)) {
198207
return true;
199208
}
200209
return false;

0 commit comments

Comments
 (0)