@@ -312,16 +312,17 @@ struct ControlNet : public GGMLRunner {
312312 ControlNetBlock control_net;
313313 std::string weight_prefix;
314314
315- ggml_backend_buffer_t control_buffer = nullptr ;
316- ggml_context* control_ctx = nullptr ;
317315 std::vector<ggml_tensor*> control_outputs_ggml;
318316 ggml_tensor* guided_hint_output_ggml = nullptr ;
319317 std::vector<sd::Tensor<float >> controls;
320- sd::Tensor<float > guided_hint;
321318 bool guided_hint_cached = false ;
322319 std::shared_ptr<ModelManager> owned_model_manager;
323320 ggml_backend_t params_backend = nullptr ;
324321
322+ static const char * guided_hint_cache_name () {
323+ return " controlnet.guided_hint" ;
324+ }
325+
325326 ControlNet (ggml_backend_t backend,
326327 ggml_backend_t params_backend_,
327328 const String2TensorStorage& tensor_storage_map = {},
@@ -336,44 +337,12 @@ struct ControlNet : public GGMLRunner {
336337 free_control_ctx ();
337338 }
338339
339- void alloc_control_ctx (std::vector<ggml_tensor*> outs) {
340- ggml_init_params params;
341- params.mem_size = static_cast <size_t >(outs.size () * ggml_tensor_overhead ()) + 1024 * 1024 ;
342- params.mem_buffer = nullptr ;
343- params.no_alloc = true ;
344- control_ctx = ggml_init (params);
345-
346- control_outputs_ggml.resize (outs.size () - 1 );
347-
348- size_t control_buffer_size = 0 ;
349-
350- guided_hint_output_ggml = ggml_dup_tensor (control_ctx, outs[0 ]);
351- control_buffer_size += ggml_nbytes (guided_hint_output_ggml);
352-
353- for (int i = 0 ; i < outs.size () - 1 ; i++) {
354- control_outputs_ggml[i] = ggml_dup_tensor (control_ctx, outs[i + 1 ]);
355- control_buffer_size += ggml_nbytes (control_outputs_ggml[i]);
356- }
357-
358- control_buffer = ggml_backend_alloc_ctx_tensors (control_ctx, runtime_backend);
359-
360- LOG_DEBUG (" control buffer size %.2fMB" , control_buffer_size * 1 .f / 1024 .f / 1024 .f );
361- }
362-
363340 void free_control_ctx () {
364- if (control_buffer != nullptr ) {
365- ggml_backend_buffer_free (control_buffer);
366- control_buffer = nullptr ;
367- }
368- if (control_ctx != nullptr ) {
369- ggml_free (control_ctx);
370- control_ctx = nullptr ;
371- }
372341 guided_hint_output_ggml = nullptr ;
373342 guided_hint_cached = false ;
374- guided_hint = {};
375343 control_outputs_ggml.clear ();
376344 controls.clear ();
345+ free_cache_ctx_and_buffer ();
377346 }
378347
379348 std::string get_desc () override {
@@ -397,11 +366,17 @@ struct ControlNet : public GGMLRunner {
397366 ggml_tensor* context = make_optional_input (context_tensor);
398367 ggml_tensor* y = make_optional_input (y_tensor);
399368
369+ guided_hint_output_ggml = nullptr ;
370+ control_outputs_ggml.clear ();
371+
400372 ggml_tensor* guided_hint_input = nullptr ;
401- if (guided_hint_cached && !guided_hint.empty ()) {
402- guided_hint_input = make_input (guided_hint);
403- hint = nullptr ;
404- } else {
373+ if (guided_hint_cached) {
374+ guided_hint_input = get_cache_tensor_by_name (guided_hint_cache_name ());
375+ if (guided_hint_input == nullptr ) {
376+ guided_hint_cached = false ;
377+ }
378+ }
379+ if (guided_hint_input == nullptr ) {
405380 hint = make_input (hint_tensor);
406381 }
407382
@@ -415,13 +390,19 @@ struct ControlNet : public GGMLRunner {
415390 context,
416391 y);
417392
418- if (control_ctx == nullptr ) {
419- alloc_control_ctx (outs);
393+ if (guided_hint_input == nullptr && !outs.empty ()) {
394+ guided_hint_output_ggml = outs[0 ];
395+ ggml_set_output (guided_hint_output_ggml);
396+ cache (guided_hint_cache_name (), guided_hint_output_ggml);
397+ ggml_build_forward_expand (gf, guided_hint_output_ggml);
420398 }
421399
422- ggml_build_forward_expand (gf, ggml_cpy (compute_ctx, outs[0 ], guided_hint_output_ggml));
423- for (int i = 0 ; i < outs.size () - 1 ; i++) {
424- ggml_build_forward_expand (gf, ggml_cpy (compute_ctx, outs[i + 1 ], control_outputs_ggml[i]));
400+ control_outputs_ggml.reserve (outs.size () > 0 ? outs.size () - 1 : 0 );
401+ for (size_t i = 1 ; i < outs.size (); i++) {
402+ ggml_tensor* control_output = outs[i];
403+ ggml_set_output (control_output);
404+ ggml_build_forward_expand (gf, control_output);
405+ control_outputs_ggml.push_back (control_output);
425406 }
426407
427408 return gf;
@@ -441,23 +422,19 @@ struct ControlNet : public GGMLRunner {
441422 return build_graph (x, hint, timesteps, context, y);
442423 };
443424
444- auto compute_result = GGMLRunner::compute<float >(get_graph, n_threads, false , false , false );
425+ auto compute_result = GGMLRunner::compute<float >(get_graph, n_threads, false , false , false , true );
445426 if (!compute_result.has_value ()) {
446427 return std::nullopt ;
447428 }
448429
449- if (guided_hint_output_ggml != nullptr ) {
450- guided_hint = restore_trailing_singleton_dims (sd::make_sd_tensor_from_ggml<float >(guided_hint_output_ggml),
451- 4 );
452- }
430+ guided_hint_cached = get_cache_tensor_by_name (guided_hint_cache_name ()) != nullptr ;
453431 controls.clear ();
454432 controls.reserve (control_outputs_ggml.size ());
455433 for (ggml_tensor* control : control_outputs_ggml) {
456434 auto control_host = restore_trailing_singleton_dims (sd::make_sd_tensor_from_ggml<float >(control), 4 );
457435 GGML_ASSERT (!control_host.empty ());
458436 controls.push_back (std::move (control_host));
459437 }
460- guided_hint_cached = true ;
461438 return controls;
462439 }
463440
0 commit comments