Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,17 @@ SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);

enum sd_cancel_mode_t {
// Stop the current generation as soon as possible.
SD_CANCEL_ALL,
// Finish the current image sample, then skip additional batch latents and return completed images.
SD_CANCEL_NEW_LATENTS,
// Clear a pending cancellation request.
SD_CANCEL_RESET
};

SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode);

SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
SD_API bool generate_video(sd_ctx_t* sd_ctx,
const sd_vid_gen_params_t* sd_vid_gen_params,
Expand Down
134 changes: 130 additions & 4 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
const char* sd_vae_format_name(enum sd_vae_format_t format);
static SDVersion sd_vae_format_to_version(enum sd_vae_format_t format, SDVersion fallback);

#include <atomic>

const char* model_version_to_str[] = {
"SD 1.x",
"SD 1.x Inpaint",
Expand Down Expand Up @@ -159,6 +161,9 @@ static float get_cache_reuse_threshold(const sd_cache_params_t& params) {

/*=============================================== StableDiffusionGGML ================================================*/

static_assert(std::atomic<sd_cancel_mode_t>::is_always_lock_free,
"sd_cancel_mode_t must be lock-free");

class StableDiffusionGGML {
public:
SDBackendManager backend_manager;
Expand Down Expand Up @@ -222,6 +227,20 @@ class StableDiffusionGGML {
return module_backend;
}

std::atomic<sd_cancel_mode_t> cancellation_flag = SD_CANCEL_RESET;

void set_cancel_flag(enum sd_cancel_mode_t flag) {
cancellation_flag.store(flag, std::memory_order_release);
}

void reset_cancel_flag() {
set_cancel_flag(SD_CANCEL_RESET);
}

enum sd_cancel_mode_t get_cancel_flag() {
return cancellation_flag.load(std::memory_order_acquire);
}

size_t max_graph_vram_bytes_for_module(SDBackendModule module) {
return max_vram_assignment.bytes_for_backend(backend_for(module));
}
Expand Down Expand Up @@ -1941,6 +1960,11 @@ class StableDiffusionGGML {
SamplePreviewContext preview = prepare_sample_preview_context();

auto denoise = [&](const sd::Tensor<float>& x, float sigma, int step) -> sd::guidance::GuiderOutput {
if (get_cancel_flag() == SD_CANCEL_ALL) {
LOG_DEBUG("cancelling generation");
return {};
}

if (step == 1 || step == -1) {
pretty_progress(0, (int)steps, 0);
last_progress_us = ggml_time_us();
Expand Down Expand Up @@ -2963,6 +2987,15 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
free(sd_ctx);
}

SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode) {
if (sd_ctx && sd_ctx->sd) {
if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) {
mode = SD_CANCEL_ALL;
}
sd_ctx->sd->set_cancel_flag(mode);
}
}

static sd_audio_t* waveform_to_sd_audio(const StableDiffusionGGML* sd,
const sd::Tensor<float>& waveform) {
if (sd == nullptr || waveform.empty()) {
Expand Down Expand Up @@ -4150,15 +4183,29 @@ static std::optional<ImageGenerationEmbeds> prepare_image_generation_embeds(sd_c
static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx,
const GenerationRequest& request,
const std::vector<sd::Tensor<float>>& final_latents) {
if (final_latents.size() != static_cast<size_t>(request.batch_count)) {
LOG_ERROR("expected %d latents, got %zu", request.batch_count, final_latents.size());
if (final_latents.empty()) {
LOG_ERROR("no latent images to decode");
return nullptr;
}
if (final_latents.size() > static_cast<size_t>(request.batch_count)) {
LOG_ERROR("expected at most %d latents, got %zu", request.batch_count, final_latents.size());
return nullptr;
}
LOG_INFO("decoding %zu latents", final_latents.size());
if (final_latents.size() < static_cast<size_t>(request.batch_count)) {
LOG_INFO("decoding %zu/%d latents", final_latents.size(), request.batch_count);
} else {
LOG_INFO("decoding %zu latents", final_latents.size());
}
std::vector<sd::Tensor<float>> decoded_images;
int64_t t0 = ggml_time_ms();
int64_t t0 = ggml_time_ms();
bool cancelled = false;

for (size_t i = 0; i < final_latents.size(); i++) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling latent decodings");
cancelled = true;
break;
}
int64_t t1 = ggml_time_ms();
sd::Tensor<float> image = sd_ctx->sd->decode_first_stage(final_latents[i]);
if (image.empty()) {
Expand All @@ -4172,6 +4219,10 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx,

int64_t t4 = ggml_time_ms();
LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t0) * 1.0f / 1000);
if (decoded_images.empty()) {
LOG_ERROR(cancelled ? "cancelled before any latent images were decoded" : "no decoded images");
return nullptr;
}

sd_image_t* result_images = (sd_image_t*)calloc(request.batch_count, sizeof(sd_image_t));
if (result_images == nullptr) {
Expand All @@ -4190,6 +4241,11 @@ static sd::Tensor<float> upscale_hires_latent(sd_ctx_t* sd_ctx,
const sd::Tensor<float>& latent,
const GenerationRequest& request,
UpscalerGGML* upscaler) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling hires latent upscale");
return {};
}

auto get_hires_latent_target_shape = [&]() {
std::vector<int64_t> target_shape = latent.shape();
if (target_shape.size() < 2) {
Expand Down Expand Up @@ -4262,6 +4318,10 @@ static sd::Tensor<float> upscale_hires_latent(sd_ctx_t* sd_ctx,
sd_hires_upscaler_name(request.hires.upscaler));
return {};
}
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling hires image upscale");
return {};
}

sd::Tensor<float> upscaled_tensor;
if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL) {
Expand Down Expand Up @@ -4298,6 +4358,10 @@ static sd::Tensor<float> upscale_hires_latent(sd_ctx_t* sd_ctx,
upscaled_tensor = sd::ops::clamp(upscaled_tensor, 0.0f, 1.0f);
}

if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling hires latent encode");
return {};
}
sd::Tensor<float> upscaled_latent = sd_ctx->sd->encode_first_stage(upscaled_tensor);
if (upscaled_latent.empty()) {
LOG_ERROR("encode_first_stage failed after hires %s upscale",
Expand Down Expand Up @@ -4362,6 +4426,8 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
return nullptr;
}

sd_ctx->sd->reset_cancel_flag();

int64_t t0 = ggml_time_ms();
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
GenerationRequest request(sd_ctx, sd_img_gen_params);
Expand Down Expand Up @@ -4397,6 +4463,18 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
std::vector<sd::Tensor<float>> final_latents;
int64_t denoise_start = ggml_time_ms();
for (int b = 0; b < request.batch_count; b++) {
sd_cancel_mode_t cancel = sd_ctx->sd->get_cancel_flag();
if (cancel == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation");
return nullptr;
}
if (cancel == SD_CANCEL_NEW_LATENTS) {
LOG_INFO("cancelling new latent generation, returning %zu/%d completed latents",
final_latents.size(),
request.batch_count);
break;
}

int64_t sampling_start = ggml_time_ms();
int64_t cur_seed = request.seed + b;
LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, request.batch_count, cur_seed);
Expand Down Expand Up @@ -4446,12 +4524,24 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
LOG_INFO("generating %zu latent images completed, taking %.2fs",
final_latents.size(),
(denoise_end - denoise_start) * 1.0f / 1000);
if (final_latents.empty()) {
LOG_ERROR("no latent images generated");
return nullptr;
}

if (request.hires.enabled && request.hires.target_width > 0) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before hires fix");
return nullptr;
}
LOG_INFO("hires fix: upscaling to %dx%d", request.hires.target_width, request.hires.target_height);

std::unique_ptr<UpscalerGGML> hires_upscaler;
if (request.hires.upscaler == SD_HIRES_UPSCALER_MODEL) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before hires model load");
return nullptr;
}
LOG_INFO("hires fix: loading model upscaler from '%s'", request.hires.model_path);
hires_upscaler = std::make_unique<UpscalerGGML>(sd_ctx->sd->n_threads,
false,
Expand Down Expand Up @@ -4485,6 +4575,10 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
std::vector<sd::Tensor<float>> hires_final_latents;
int64_t hires_denoise_start = ggml_time_ms();
for (int b = 0; b < (int)final_latents.size(); b++) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation during hires fix");
return nullptr;
}
int64_t cur_seed = request.seed + b;
sd_ctx->sd->rng->manual_seed(cur_seed);
sd_ctx->sd->sampler_rng->manual_seed(cur_seed);
Expand Down Expand Up @@ -4915,6 +5009,10 @@ static sd_image_t* decode_video_outputs(sd_ctx_t* sd_ctx,
LOG_ERROR("no latent video to decode");
return nullptr;
}
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling video decode");
return nullptr;
}
sd::Tensor<float> video_latent = final_latent;
if (sd_version_is_ltxav(sd_ctx->sd->version) &&
video_latent.shape()[3] > sd_ctx->sd->get_latent_channel()) {
Expand Down Expand Up @@ -5160,6 +5258,9 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
if (audio_out != nullptr) {
*audio_out = nullptr;
}

sd_ctx->sd->reset_cancel_flag();

if (num_frames_out != nullptr) {
*num_frames_out = 0;
}
Expand Down Expand Up @@ -5221,6 +5322,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
sd::Tensor<float> noise = sd::Tensor<float>::randn_like(x_t, sd_ctx->sd->rng);

if (plan.high_noise_sample_steps > 0) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before high-noise sampling");
return false;
}
LOG_DEBUG("sample(high noise) %dx%dx%d", W, H, T);

int64_t sampling_start = ggml_time_ms();
Expand Down Expand Up @@ -5263,6 +5368,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
}

if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before sampling");
return false;
}
LOG_DEBUG("sample %dx%dx%d", W, H, T);
int64_t sampling_start = ggml_time_ms();
sd::Tensor<float> final_latent = sd_ctx->sd->sample(sd_ctx->sd->diffusion_model,
Expand Down Expand Up @@ -5299,6 +5408,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);

if (latent_upscale_enabled) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before latent upscale");
return false;
}
int64_t upscale_start = ggml_time_ms();
sd::Tensor<float> upscaled_latent = upscale_ltx_spatial_video_latent(sd_ctx,
request.hires.model_path,
Expand Down Expand Up @@ -5358,6 +5471,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
}
sd::Tensor<float> hires_denoise_mask;
sd::Tensor<float> hires_video_positions;
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before latent upscale refine");
return false;
}
if (!apply_ltxv_refine_image_conditioning(sd_ctx,
sd_vid_gen_params,
hires_request,
Expand Down Expand Up @@ -5437,6 +5554,10 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
if (sd_version_is_ltxav(sd_ctx->sd->version) &&
latents.audio_length > 0 &&
sd_ctx->sd->audio_vae_model != nullptr) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before audio decode");
return false;
}
int64_t audio_latent_decode_start = ggml_time_ms();

auto audio_latent = unpack_ltxav_audio_latent(final_latent,
Expand Down Expand Up @@ -5469,6 +5590,11 @@ SD_API bool generate_video(sd_ctx_t* sd_ctx,
final_latent = sd::ops::slice(final_latent, 2, latents.ref_image_num, final_latent.shape()[2]);
}

if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation before video decode");
free_sd_audio(generated_audio);
return false;
}
auto result = decode_video_outputs(sd_ctx, latent_upscale_enabled ? hires_request : request, final_latent, num_frames_out);
if (result == nullptr) {
free_sd_audio(generated_audio);
Expand Down
Loading