Skip to content

Commit 24340c0

Browse files
committed
refactor: simplify diffusion model runner params
1 parent 92dc726 commit 24340c0

14 files changed

Lines changed: 408 additions & 946 deletions

src/anima.hpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <vector>
88

99
#include "common_block.hpp"
10+
#include "diffusion_model.hpp"
1011
#include "flux.hpp"
1112
#include "rope.hpp"
1213

@@ -518,7 +519,7 @@ namespace Anima {
518519
}
519520
};
520521

521-
struct AnimaRunner : public GGMLRunner {
522+
struct AnimaRunner : public DiffusionModelRunner {
522523
public:
523524
std::vector<float> image_pe_vec;
524525
std::vector<float> adapter_q_pe_vec;
@@ -529,7 +530,7 @@ namespace Anima {
529530
ggml_backend_t params_backend,
530531
const String2TensorStorage& tensor_storage_map = {},
531532
const std::string prefix = "model.diffusion_model")
532-
: GGMLRunner(backend, params_backend) {
533+
: DiffusionModelRunner(backend, params_backend, prefix) {
533534
int64_t num_layers = 0;
534535
std::string layer_tag = prefix + ".net.blocks.";
535536
for (const auto& kv : tensor_storage_map) {
@@ -559,7 +560,7 @@ namespace Anima {
559560
return "anima";
560561
}
561562

562-
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) {
563+
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string& prefix) override {
563564
net.get_param_tensors(tensors, prefix + ".net");
564565
}
565566

@@ -684,6 +685,19 @@ namespace Anima {
684685
};
685686
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim());
686687
}
688+
689+
sd::Tensor<float> compute(int n_threads,
690+
const DiffusionParams& diffusion_params) override {
691+
GGML_ASSERT(diffusion_params.x != nullptr);
692+
GGML_ASSERT(diffusion_params.timesteps != nullptr);
693+
const auto* extra = diffusion_extra_as<AnimaDiffusionExtra>(diffusion_params);
694+
return compute(n_threads,
695+
*diffusion_params.x,
696+
*diffusion_params.timesteps,
697+
tensor_or_empty(diffusion_params.context),
698+
tensor_or_empty(extra->t5_ids),
699+
tensor_or_empty(extra->t5_weights));
700+
}
687701
};
688702
} // namespace Anima
689703

src/conditioner.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ struct ConditionerParams {
102102
int clip_skip = -1;
103103
int width = -1;
104104
int height = -1;
105-
int adm_in_channels = -1;
106105
bool zero_out_masked = false;
107106
int num_input_imgs = 0; // for photomaker
108107
const std::vector<sd::Tensor<float>>* ref_images = nullptr; // for qwen image edit
@@ -502,7 +501,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
502501
int clip_skip,
503502
int width,
504503
int height,
505-
int adm_in_channels = -1,
506504
bool zero_out_masked = false) {
507505
int64_t t0 = ggml_time_ms();
508506
sd::Tensor<float> hidden_states; // [n_token, hidden_size] or [n_token, hidden_size + hidden_size2]
@@ -588,7 +586,8 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
588586

589587
sd::Tensor<float> vec;
590588
if (sd_version_is_sdxl(version)) {
591-
int out_dim = 256;
589+
int out_dim = 256;
590+
int adm_in_channels = 2816;
592591
GGML_ASSERT(!pooled.empty());
593592
vec = sd::Tensor<float>({adm_in_channels});
594593
vec.fill_(0.0f);
@@ -647,7 +646,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
647646
conditioner_params.clip_skip,
648647
conditioner_params.width,
649648
conditioner_params.height,
650-
conditioner_params.adm_in_channels,
651649
conditioner_params.zero_out_masked);
652650
return std::make_tuple(cond, clsm);
653651
}
@@ -674,7 +672,6 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
674672
conditioner_params.clip_skip,
675673
conditioner_params.width,
676674
conditioner_params.height,
677-
conditioner_params.adm_in_channels,
678675
conditioner_params.zero_out_masked);
679676
}
680677
};

0 commit comments

Comments
 (0)