Skip to content

Commit 9838264

Browse files
authored
refactor: simplify ControlNet output caching (#1655)
1 parent 17d70b9 commit 9838264

2 files changed

Lines changed: 32 additions & 51 deletions

File tree

src/core/ggml_extend.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2007,6 +2007,10 @@ struct GGMLRunner {
20072007
}
20082008

20092009
bool copy_cache_tensors_to_cache_buffer(const std::unordered_set<std::string>* cache_keep_names = nullptr) {
2010+
if (cache_tensor_map.empty() && cache_keep_names == nullptr) {
2011+
return true;
2012+
}
2013+
20102014
ggml_context* old_cache_ctx = cache_ctx;
20112015
ggml_backend_buffer_t old_cache_buffer = cache_buffer;
20122016
cache_ctx = nullptr;

src/model/diffusion/control.hpp

Lines changed: 28 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -312,16 +312,17 @@ struct ControlNet : public GGMLRunner {
312312
ControlNetBlock control_net;
313313
std::string weight_prefix;
314314

315-
ggml_backend_buffer_t control_buffer = nullptr;
316-
ggml_context* control_ctx = nullptr;
317315
std::vector<ggml_tensor*> control_outputs_ggml;
318316
ggml_tensor* guided_hint_output_ggml = nullptr;
319317
std::vector<sd::Tensor<float>> controls;
320-
sd::Tensor<float> guided_hint;
321318
bool guided_hint_cached = false;
322319
std::shared_ptr<ModelManager> owned_model_manager;
323320
ggml_backend_t params_backend = nullptr;
324321

322+
static const char* guided_hint_cache_name() {
323+
return "controlnet.guided_hint";
324+
}
325+
325326
ControlNet(ggml_backend_t backend,
326327
ggml_backend_t params_backend_,
327328
const String2TensorStorage& tensor_storage_map = {},
@@ -336,44 +337,12 @@ struct ControlNet : public GGMLRunner {
336337
free_control_ctx();
337338
}
338339

339-
void alloc_control_ctx(std::vector<ggml_tensor*> outs) {
340-
ggml_init_params params;
341-
params.mem_size = static_cast<size_t>(outs.size() * ggml_tensor_overhead()) + 1024 * 1024;
342-
params.mem_buffer = nullptr;
343-
params.no_alloc = true;
344-
control_ctx = ggml_init(params);
345-
346-
control_outputs_ggml.resize(outs.size() - 1);
347-
348-
size_t control_buffer_size = 0;
349-
350-
guided_hint_output_ggml = ggml_dup_tensor(control_ctx, outs[0]);
351-
control_buffer_size += ggml_nbytes(guided_hint_output_ggml);
352-
353-
for (int i = 0; i < outs.size() - 1; i++) {
354-
control_outputs_ggml[i] = ggml_dup_tensor(control_ctx, outs[i + 1]);
355-
control_buffer_size += ggml_nbytes(control_outputs_ggml[i]);
356-
}
357-
358-
control_buffer = ggml_backend_alloc_ctx_tensors(control_ctx, runtime_backend);
359-
360-
LOG_DEBUG("control buffer size %.2fMB", control_buffer_size * 1.f / 1024.f / 1024.f);
361-
}
362-
363340
void free_control_ctx() {
364-
if (control_buffer != nullptr) {
365-
ggml_backend_buffer_free(control_buffer);
366-
control_buffer = nullptr;
367-
}
368-
if (control_ctx != nullptr) {
369-
ggml_free(control_ctx);
370-
control_ctx = nullptr;
371-
}
372341
guided_hint_output_ggml = nullptr;
373342
guided_hint_cached = false;
374-
guided_hint = {};
375343
control_outputs_ggml.clear();
376344
controls.clear();
345+
free_cache_ctx_and_buffer();
377346
}
378347

379348
std::string get_desc() override {
@@ -397,11 +366,17 @@ struct ControlNet : public GGMLRunner {
397366
ggml_tensor* context = make_optional_input(context_tensor);
398367
ggml_tensor* y = make_optional_input(y_tensor);
399368

369+
guided_hint_output_ggml = nullptr;
370+
control_outputs_ggml.clear();
371+
400372
ggml_tensor* guided_hint_input = nullptr;
401-
if (guided_hint_cached && !guided_hint.empty()) {
402-
guided_hint_input = make_input(guided_hint);
403-
hint = nullptr;
404-
} else {
373+
if (guided_hint_cached) {
374+
guided_hint_input = get_cache_tensor_by_name(guided_hint_cache_name());
375+
if (guided_hint_input == nullptr) {
376+
guided_hint_cached = false;
377+
}
378+
}
379+
if (guided_hint_input == nullptr) {
405380
hint = make_input(hint_tensor);
406381
}
407382

@@ -415,13 +390,19 @@ struct ControlNet : public GGMLRunner {
415390
context,
416391
y);
417392

418-
if (control_ctx == nullptr) {
419-
alloc_control_ctx(outs);
393+
if (guided_hint_input == nullptr && !outs.empty()) {
394+
guided_hint_output_ggml = outs[0];
395+
ggml_set_output(guided_hint_output_ggml);
396+
cache(guided_hint_cache_name(), guided_hint_output_ggml);
397+
ggml_build_forward_expand(gf, guided_hint_output_ggml);
420398
}
421399

422-
ggml_build_forward_expand(gf, ggml_cpy(compute_ctx, outs[0], guided_hint_output_ggml));
423-
for (int i = 0; i < outs.size() - 1; i++) {
424-
ggml_build_forward_expand(gf, ggml_cpy(compute_ctx, outs[i + 1], control_outputs_ggml[i]));
400+
control_outputs_ggml.reserve(outs.size() > 0 ? outs.size() - 1 : 0);
401+
for (size_t i = 1; i < outs.size(); i++) {
402+
ggml_tensor* control_output = outs[i];
403+
ggml_set_output(control_output);
404+
ggml_build_forward_expand(gf, control_output);
405+
control_outputs_ggml.push_back(control_output);
425406
}
426407

427408
return gf;
@@ -441,23 +422,19 @@ struct ControlNet : public GGMLRunner {
441422
return build_graph(x, hint, timesteps, context, y);
442423
};
443424

444-
auto compute_result = GGMLRunner::compute<float>(get_graph, n_threads, false, false, false);
425+
auto compute_result = GGMLRunner::compute<float>(get_graph, n_threads, false, false, false, true);
445426
if (!compute_result.has_value()) {
446427
return std::nullopt;
447428
}
448429

449-
if (guided_hint_output_ggml != nullptr) {
450-
guided_hint = restore_trailing_singleton_dims(sd::make_sd_tensor_from_ggml<float>(guided_hint_output_ggml),
451-
4);
452-
}
430+
guided_hint_cached = get_cache_tensor_by_name(guided_hint_cache_name()) != nullptr;
453431
controls.clear();
454432
controls.reserve(control_outputs_ggml.size());
455433
for (ggml_tensor* control : control_outputs_ggml) {
456434
auto control_host = restore_trailing_singleton_dims(sd::make_sd_tensor_from_ggml<float>(control), 4);
457435
GGML_ASSERT(!control_host.empty());
458436
controls.push_back(std::move(control_host));
459437
}
460-
guided_hint_cached = true;
461438
return controls;
462439
}
463440

0 commit comments

Comments
 (0)