Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
570 changes: 459 additions & 111 deletions flux.hpp

Large diffs are not rendered by default.

32 changes: 13 additions & 19 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1766,7 +1766,6 @@ bool ModelLoader::model_is_unet() {

SDVersion ModelLoader::get_sd_version() {
TensorStorage token_embedding_weight, input_block_weight;
bool input_block_checked = false;

bool has_multiple_encoders = false;
bool is_unet = false;
Expand All @@ -1778,12 +1777,12 @@ SDVersion ModelLoader::get_sd_version() {
bool has_img_emb = false;

for (auto& tensor_storage : tensor_storages) {
if (!(is_xl || is_flux)) {
if (!(is_xl)) {
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
is_flux = true;
if (input_block_checked) {
break;
}
}
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
return VERSION_CHROMA_RADIANCE;
}
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
return VERSION_SD3;
Expand All @@ -1800,22 +1799,19 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) {
has_img_emb = true;
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos ||
tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
is_unet = true;
if (has_multiple_encoders) {
is_xl = true;
if (input_block_checked) {
break;
}
}
}
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos || tensor_storage.name.find("te.1") != std::string::npos) {
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos ||
tensor_storage.name.find("cond_stage_model.1") != std::string::npos ||
tensor_storage.name.find("te.1") != std::string::npos) {
has_multiple_encoders = true;
if (is_unet) {
is_xl = true;
if (input_block_checked) {
break;
}
}
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
Expand All @@ -1831,12 +1827,10 @@ SDVersion ModelLoader::get_sd_version() {
token_embedding_weight = tensor_storage;
// break;
}
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
input_block_weight = tensor_storage;
input_block_checked = true;
if (is_xl || is_flux) {
break;
}
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" ||
tensor_storage.name == "model.diffusion_model.img_in.weight" ||
tensor_storage.name == "unet.conv_in.weight") {
input_block_weight = tensor_storage;
}
}
if (is_wan) {
Expand Down
7 changes: 6 additions & 1 deletion model.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ enum SDVersion {
VERSION_FLUX_FILL,
VERSION_FLUX_CONTROLS,
VERSION_FLEX_2,
VERSION_CHROMA_RADIANCE,
VERSION_WAN2,
VERSION_WAN2_2_I2V,
VERSION_WAN2_2_TI2V,
Expand Down Expand Up @@ -70,7 +71,11 @@ static inline bool sd_version_is_sd3(SDVersion version) {
}

static inline bool sd_version_is_flux(SDVersion version) {
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2) {
if (version == VERSION_FLUX ||
version == VERSION_FLUX_FILL ||
version == VERSION_FLUX_CONTROLS ||
version == VERSION_FLEX_2 ||
version == VERSION_CHROMA_RADIANCE) {
return true;
}
return false;
Expand Down
2 changes: 1 addition & 1 deletion qwen_image.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ namespace Qwen {

static void load_from_file_and_test(const std::string& file_path) {
// cuda q8: pass
// cuda q8 fa: nan
// cuda q8 fa: pass
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_Q8_0;
Expand Down
Loading
Loading