2323#include " latent-preview.h"
2424#include " name_conversion.h"
2525
26+ #include < atomic>
27+
2628const char * model_version_to_str[] = {
2729 " SD 1.x" ,
2830 " SD 1.x Inpaint" ,
@@ -99,6 +101,9 @@ void suppress_pp(int step, int steps, float time, void* data) {
99101
100102/* =============================================== StableDiffusionGGML ================================================*/
101103
104+ static_assert (std::atomic<sd_cancel_mode_t >::is_always_lock_free,
105+ " sd_cancel_mode_t must be lock-free" );
106+
102107class StableDiffusionGGML {
103108public:
104109 ggml_backend_t backend = nullptr ; // general backend
@@ -149,6 +154,8 @@ class StableDiffusionGGML {
149154
150155 std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();
151156
157+ std::atomic<sd_cancel_mode_t > cancellation_flag;
158+
152159 StableDiffusionGGML () = default ;
153160
154161 ~StableDiffusionGGML () {
@@ -164,6 +171,18 @@ class StableDiffusionGGML {
164171 ggml_backend_free (backend);
165172 }
166173
174+ void set_cancel_flag (enum sd_cancel_mode_t flag) {
175+ cancellation_flag.store (flag, std::memory_order_release);
176+ }
177+
178+ void reset_cancel_flag () {
179+ set_cancel_flag (SD_CANCEL_RESET );
180+ }
181+
182+ enum sd_cancel_mode_t get_cancel_flag () {
183+ return cancellation_flag.load (std::memory_order_acquire);
184+ }
185+
167186 void init_backend () {
168187#ifdef SD_USE_CUDA
169188 LOG_DEBUG (" Using CUDA backend" );
@@ -1869,6 +1888,12 @@ class StableDiffusionGGML {
18691888 }
18701889
18711890 auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* {
1891+ enum sd_cancel_mode_t cancel_flag = get_cancel_flag ();
1892+ if (cancel_flag != SD_CANCEL_RESET ) {
1893+ LOG_DEBUG (" cancelling latent decodings" );
1894+ return nullptr ;
1895+ }
1896+
18721897 auto sd_preview_cb = sd_get_preview_callback ();
18731898 auto sd_preview_cb_data = sd_get_preview_callback_data ();
18741899 auto sd_preview_mode = sd_get_preview_mode ();
@@ -3423,6 +3448,12 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
34233448 img_cond = SDCondition (uncond.c_crossattn , uncond.c_vector , cond.c_concat );
34243449 }
34253450 for (int b = 0 ; b < batch_count; b++) {
3451+
3452+ if (sd_ctx->sd ->get_cancel_flag () != SD_CANCEL_RESET ) {
3453+ LOG_ERROR (" cancelling generation" );
3454+ break ;
3455+ }
3456+
34263457 int64_t sampling_start = ggml_time_ms ();
34273458 int64_t cur_seed = seed + b;
34283459 LOG_INFO (" generating image: %i/%i - seed %" PRId64, b + 1 , batch_count, cur_seed);
@@ -3484,6 +3515,12 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
34843515 LOG_INFO (" decoding %zu latents" , final_latents.size ());
34853516 std::vector<struct ggml_tensor *> decoded_images; // collect decoded images
34863517 for (size_t i = 0 ; i < final_latents.size (); i++) {
3518+
3519+ if (sd_ctx->sd ->get_cancel_flag () == SD_CANCEL_ALL ) {
3520+ LOG_ERROR (" cancelling latent decodings" );
3521+ break ;
3522+ }
3523+
34873524 t1 = ggml_time_ms ();
34883525 struct ggml_tensor * img = sd_ctx->sd ->decode_first_stage (work_ctx, final_latents[i] /* x_0 */ );
34893526 // print_ggml_tensor(img);
@@ -3520,6 +3557,16 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
35203557 return result_images;
35213558}
35223559
3560+ void sd_cancel_generation (sd_ctx_t * sd_ctx, enum sd_cancel_mode_t mode)
3561+ {
3562+ if (sd_ctx && sd_ctx->sd ) {
3563+ if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET ) {
3564+ mode = SD_CANCEL_ALL ;
3565+ }
3566+ sd_ctx->sd ->set_cancel_flag (mode);
3567+ }
3568+ }
3569+
35233570sd_image_t * generate_image (sd_ctx_t * sd_ctx, const sd_img_gen_params_t * sd_img_gen_params) {
35243571 sd_ctx->sd ->vae_tiling_params = sd_img_gen_params->vae_tiling_params ;
35253572 int width = sd_img_gen_params->width ;
@@ -3542,6 +3589,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
35423589 return nullptr ;
35433590 }
35443591
3592+ sd_ctx->sd ->reset_cancel_flag ();
3593+
35453594 struct ggml_init_params params;
35463595 params.mem_size = static_cast <size_t >(1024 * 1024 ) * 1024 ; // 1G
35473596 params.mem_buffer = nullptr ;
0 commit comments