Skip to content

Commit 3e937be

Browse files
committed
feat: support for canceling the ongoing generation
1 parent aaa8a51 commit 3e937be

2 files changed

Lines changed: 58 additions & 0 deletions

File tree

include/stable-diffusion.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,15 @@ SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
372372
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
373373
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);
374374

375+
enum sd_cancel_mode_t
376+
{
377+
SD_CANCEL_ALL,
378+
SD_CANCEL_NEW_LATENTS,
379+
SD_CANCEL_RESET
380+
};
381+
382+
SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode);
383+
375384
SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
376385
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out);
377386

src/stable-diffusion.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include "latent-preview.h"
2424
#include "name_conversion.h"
2525

26+
#include <atomic>
27+
2628
const char* model_version_to_str[] = {
2729
"SD 1.x",
2830
"SD 1.x Inpaint",
@@ -99,6 +101,9 @@ void suppress_pp(int step, int steps, float time, void* data) {
99101

100102
/*=============================================== StableDiffusionGGML ================================================*/
101103

104+
static_assert(std::atomic<sd_cancel_mode_t>::is_always_lock_free,
105+
"sd_cancel_mode_t must be lock-free");
106+
102107
class StableDiffusionGGML {
103108
public:
104109
ggml_backend_t backend = nullptr; // general backend
@@ -149,6 +154,8 @@ class StableDiffusionGGML {
149154

150155
std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();
151156

157+
std::atomic<sd_cancel_mode_t> cancellation_flag;
158+
152159
StableDiffusionGGML() = default;
153160

154161
~StableDiffusionGGML() {
@@ -164,6 +171,18 @@ class StableDiffusionGGML {
164171
ggml_backend_free(backend);
165172
}
166173

174+
void set_cancel_flag(enum sd_cancel_mode_t flag) {
175+
cancellation_flag.store(flag, std::memory_order_release);
176+
}
177+
178+
void reset_cancel_flag() {
179+
set_cancel_flag(SD_CANCEL_RESET);
180+
}
181+
182+
enum sd_cancel_mode_t get_cancel_flag() {
183+
return cancellation_flag.load(std::memory_order_acquire);
184+
}
185+
167186
void init_backend() {
168187
#ifdef SD_USE_CUDA
169188
LOG_DEBUG("Using CUDA backend");
@@ -1869,6 +1888,12 @@ class StableDiffusionGGML {
18691888
}
18701889

18711890
auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* {
1891+
enum sd_cancel_mode_t cancel_flag = get_cancel_flag();
1892+
if (cancel_flag != SD_CANCEL_RESET) {
1893+
LOG_DEBUG("cancelling latent decodings");
1894+
return nullptr;
1895+
}
1896+
18721897
auto sd_preview_cb = sd_get_preview_callback();
18731898
auto sd_preview_cb_data = sd_get_preview_callback_data();
18741899
auto sd_preview_mode = sd_get_preview_mode();
@@ -3423,6 +3448,12 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
34233448
img_cond = SDCondition(uncond.c_crossattn, uncond.c_vector, cond.c_concat);
34243449
}
34253450
for (int b = 0; b < batch_count; b++) {
3451+
3452+
if (sd_ctx->sd->get_cancel_flag() != SD_CANCEL_RESET) {
3453+
LOG_ERROR("cancelling generation");
3454+
break;
3455+
}
3456+
34263457
int64_t sampling_start = ggml_time_ms();
34273458
int64_t cur_seed = seed + b;
34283459
LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, batch_count, cur_seed);
@@ -3484,6 +3515,12 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
34843515
LOG_INFO("decoding %zu latents", final_latents.size());
34853516
std::vector<struct ggml_tensor*> decoded_images; // collect decoded images
34863517
for (size_t i = 0; i < final_latents.size(); i++) {
3518+
3519+
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
3520+
LOG_ERROR("cancelling latent decodings");
3521+
break;
3522+
}
3523+
34873524
t1 = ggml_time_ms();
34883525
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, final_latents[i] /* x_0 */);
34893526
// print_ggml_tensor(img);
@@ -3520,6 +3557,16 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
35203557
return result_images;
35213558
}
35223559

3560+
void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode)
3561+
{
3562+
if (sd_ctx && sd_ctx->sd) {
3563+
if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) {
3564+
mode = SD_CANCEL_ALL;
3565+
}
3566+
sd_ctx->sd->set_cancel_flag(mode);
3567+
}
3568+
}
3569+
35233570
sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
35243571
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
35253572
int width = sd_img_gen_params->width;
@@ -3542,6 +3589,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
35423589
return nullptr;
35433590
}
35443591

3592+
sd_ctx->sd->reset_cancel_flag();
3593+
35453594
struct ggml_init_params params;
35463595
params.mem_size = static_cast<size_t>(1024 * 1024) * 1024; // 1G
35473596
params.mem_buffer = nullptr;

0 commit comments

Comments
 (0)