77#include < vector>
88
99#include " common_block.hpp"
10+ #include " diffusion_model.hpp"
1011#include " flux.hpp"
1112#include " rope.hpp"
1213
@@ -518,7 +519,7 @@ namespace Anima {
518519 }
519520 };
520521
521- struct AnimaRunner : public GGMLRunner {
522+ struct AnimaRunner : public DiffusionModelRunner {
522523 public:
523524 std::vector<float > image_pe_vec;
524525 std::vector<float > adapter_q_pe_vec;
@@ -529,7 +530,7 @@ namespace Anima {
529530 ggml_backend_t params_backend,
530531 const String2TensorStorage& tensor_storage_map = {},
531532 const std::string prefix = " model.diffusion_model" )
532- : GGMLRunner (backend, params_backend) {
533+ : DiffusionModelRunner (backend, params_backend, prefix ) {
533534 int64_t num_layers = 0 ;
534535 std::string layer_tag = prefix + " .net.blocks." ;
535536 for (const auto & kv : tensor_storage_map) {
@@ -559,7 +560,7 @@ namespace Anima {
559560 return " anima" ;
560561 }
561562
562- void get_param_tensors (std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) {
563+ void get_param_tensors (std::map<std::string, ggml_tensor*>& tensors, const std::string& prefix) override {
563564 net.get_param_tensors (tensors, prefix + " .net" );
564565 }
565566
@@ -684,6 +685,19 @@ namespace Anima {
684685 };
685686 return restore_trailing_singleton_dims (GGMLRunner::compute<float >(get_graph, n_threads, false ), x.dim ());
686687 }
688+
689+ sd::Tensor<float > compute (int n_threads,
690+ const DiffusionParams& diffusion_params) override {
691+ GGML_ASSERT (diffusion_params.x != nullptr );
692+ GGML_ASSERT (diffusion_params.timesteps != nullptr );
693+ const auto * extra = diffusion_extra_as<AnimaDiffusionExtra>(diffusion_params);
694+ return compute (n_threads,
695+ *diffusion_params.x ,
696+ *diffusion_params.timesteps ,
697+ tensor_or_empty (diffusion_params.context ),
698+ tensor_or_empty (extra->t5_ids ),
699+ tensor_or_empty (extra->t5_weights ));
700+ }
687701 };
688702} // namespace Anima
689703
0 commit comments