Skip to content

Commit d5d134e

Browse files
author
Mark Caldwell
committed
feat: automatic VAE-tiling fallback when an untiled decode exceeds the backend buffer limit
VAE decode can hard-fail on integrated / low-VRAM GPUs because the untiled compute buffer exceeds the backend's maximum single-buffer allocation (e.g. Vulkan's suballocation limit) even when total memory is plentiful. sd.cpp already supports tiling that keeps each compute buffer small, but it had to be requested up front with --vae-tiling, so users hit a hard failure one flag away from the working path. Make the fallback automatic and on by default: - sd_tiling_params_t gains a bool auto_tile (appended, so the C ABI stays compatible). In AUTO (the default: --vae-tiling off, auto_tile on) VAE::decode tries the untiled decode and, if its compute buffer can't be allocated, frees it and retries once with tiling. - --vae-tiling stays the original boolean flag (force tiling on); --no-vae-tiling-fallback turns the auto fallback off (hard-fail like before). - GGMLRunner gets an opt-in probe (set_probe_compute_buffer_fits) so AUTO can decline a too-large untiled decode before the backend emits its raw allocation error. On Vulkan it checks each op against the device's real per-buffer limit via ggml_backend_supports_op (the reported max buffer size, not the smaller suballocation block); other backends compare the planned compute buffer against ggml_backend_buft_get_max_size. The reactive output-empty -> tile path still backstops a genuine runtime OOM. - extra_tiling_args gains a max_buffer_size=<bytes> key: in AUTO the fallback also tiles when the planned untiled compute buffer would exceed it, letting a user cap VAE VRAM on any backend.
1 parent bb90bfa commit d5d134e

7 files changed

Lines changed: 148 additions & 9 deletions

File tree

examples/common/common.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,7 @@ ArgOptions SDGenerationParams::get_options() {
899899
&extra_sample_args},
900900
{"",
901901
"--extra-tiling-args",
902-
"extra VAE tiling args, key=value list. LTX video VAE supports temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)",
902+
"extra VAE tiling args, key=value list. max_buffer_size (bytes) forces the auto fallback to tile when an untiled VAE compute buffer would exceed it. LTX video VAE supports temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)",
903903
&extra_tiling_args},
904904
};
905905

@@ -1084,6 +1084,12 @@ ArgOptions SDGenerationParams::get_options() {
10841084
"process vae in tiles to reduce memory usage",
10851085
true,
10861086
&vae_tiling_params.enabled},
1087+
{"",
1088+
"--no-vae-tiling-fallback",
1089+
"disable the automatic fallback to VAE tiling when an untiled decode would exceed the "
1090+
"backend's max buffer size (fail instead of tiling)",
1091+
false,
1092+
&vae_tiling_params.auto_tile},
10871093
{"",
10881094
"--temporal-tiling",
10891095
"enable temporal tiling for LTX video VAE decode",
@@ -1828,6 +1834,9 @@ bool SDGenerationParams::from_json_str(
18281834
if (tiling_json.contains("enabled") && tiling_json["enabled"].is_boolean()) {
18291835
vae_tiling_params.enabled = tiling_json["enabled"];
18301836
}
1837+
if (tiling_json.contains("auto_tile") && tiling_json["auto_tile"].is_boolean()) {
1838+
vae_tiling_params.auto_tile = tiling_json["auto_tile"];
1839+
}
18311840
if (tiling_json.contains("temporal_tiling") && tiling_json["temporal_tiling"].is_boolean()) {
18321841
vae_tiling_params.temporal_tiling = tiling_json["temporal_tiling"];
18331842
}
@@ -2641,10 +2650,12 @@ std::string build_sdcpp_image_metadata_json(const SDContextParams& ctx_params,
26412650
}
26422651

26432652
if (gen_params.vae_tiling_params.enabled ||
2653+
!gen_params.vae_tiling_params.auto_tile ||
26442654
gen_params.vae_tiling_params.temporal_tiling ||
26452655
!gen_params.extra_tiling_args.empty()) {
26462656
root["vae_tiling"] = {
26472657
{"enabled", gen_params.vae_tiling_params.enabled},
2658+
{"auto_tile", gen_params.vae_tiling_params.auto_tile},
26482659
{"temporal_tiling", gen_params.vae_tiling_params.temporal_tiling},
26492660
{"tile_size_x", gen_params.vae_tiling_params.tile_size_x},
26502661
{"tile_size_y", gen_params.vae_tiling_params.tile_size_y},

examples/common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ struct SDGenerationParams {
227227
int video_frames = 1;
228228
int fps = 16;
229229
float vace_strength = 1.f;
230-
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
230+
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr, true}; // auto_tile=true (AUTO)
231231
std::string extra_tiling_args;
232232

233233
std::string pm_id_images_dir;

examples/server/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ Shared default fields used by both `img_gen` and `vid_gen`:
518518
| `output_format` | `string` |
519519
| `output_compression` | `integer` |
520520

521-
`vae_tiling_params.extra_tiling_args` accepts a key=value list. For LTX video VAE temporal tiling, `temporal_tile_frames` defaults to `4` and `temporal_tile_overlap` defaults to `1`.
521+
`vae_tiling_params.extra_tiling_args` accepts a key=value list. `max_buffer_size` (bytes) forces the automatic tiling fallback when an untiled VAE compute buffer would exceed it. For LTX video VAE temporal tiling, `temporal_tile_frames` defaults to `4` and `temporal_tile_overlap` defaults to `1`.
522522

523523
`img_gen`-specific default fields:
524524

include/stable-diffusion.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,15 @@ enum lora_apply_mode_t {
153153
};
154154

155155
typedef struct {
156-
bool enabled;
156+
bool enabled; // true => always tile (ON)
157157
bool temporal_tiling;
158158
int tile_size_x;
159159
int tile_size_y;
160160
float target_overlap;
161161
float rel_size_x;
162162
float rel_size_y;
163163
const char* extra_tiling_args;
164+
bool auto_tile; // AUTO (default): tile only when an untiled VAE decode would exceed the backend's max buffer size
164165
} sd_tiling_params_t;
165166

166167
typedef struct {

src/core/ggml_extend.hpp

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1705,11 +1705,21 @@ struct GGMLRunner {
17051705

17061706
ggml_context* compute_ctx = nullptr;
17071707
ggml_gallocr* compute_allocr = nullptr;
1708+
// Set when alloc_compute_buffer() deliberately defers to tiling (probe found the
1709+
// untiled buffer exceeds the backend max); lets callers skip the failure error.
1710+
bool compute_buffer_deferred_to_tiling = false;
17081711

17091712
size_t max_graph_vram_bytes = 0;
17101713
bool stream_layers_enabled = false;
17111714
size_t observed_max_effective_budget_ = 0;
17121715

1716+
// When set, alloc_compute_buffer measures the planned compute buffer (no alloc)
1717+
// and bails if it exceeds the backend max, so VAE AUTO can fall back to tiling.
1718+
bool probe_compute_buffer_fits_ = false;
1719+
// Optional user cap (bytes): also fall back to tiling if the planned compute
1720+
// buffer would exceed this, regardless of the backend limit. 0 = no cap.
1721+
size_t probe_max_bytes_ = 0;
1722+
17131723
std::shared_ptr<WeightAdapter> weight_adapter = nullptr;
17141724
std::weak_ptr<RunnerWeightManager> weight_manager;
17151725
std::unordered_set<const ggml_tensor*> kept_compute_param_tensor_set;
@@ -1978,10 +1988,66 @@ struct GGMLRunner {
19781988
}
19791989

19801990
bool alloc_compute_buffer(ggml_cgraph* gf) {
1991+
compute_buffer_deferred_to_tiling = false;
19811992
if (compute_allocr != nullptr) {
19821993
return true;
19831994
}
1984-
compute_allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(runtime_backend));
1995+
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(runtime_backend);
1996+
1997+
if (probe_compute_buffer_fits_) {
1998+
// Probe whether an untiled decode fits before allocating; if not, defer to
1999+
// tiling here instead of letting the real reserve below fail with a raw
2000+
// backend error. A genuine runtime OOM still surfaces below as a backstop.
2001+
if (probe_max_bytes_ > 0) {
2002+
// User-requested cap (extra_tiling_args max_buffer_size): tile when the
2003+
// planned untiled buffer would exceed it, on any backend.
2004+
ggml_gallocr* probe = ggml_gallocr_new(buft);
2005+
size_t sizes[1] = {0};
2006+
ggml_gallocr_reserve_n_size(probe, gf, nullptr, nullptr, sizes);
2007+
ggml_gallocr_free(probe);
2008+
if (sizes[0] > probe_max_bytes_) {
2009+
LOG_DEBUG("%s: untiled compute buffer %.2f MB exceeds requested max_buffer_size %.2f MB; deferring to tiling",
2010+
get_desc().c_str(),
2011+
sizes[0] / 1024.0 / 1024.0,
2012+
probe_max_bytes_ / 1024.0 / 1024.0);
2013+
compute_buffer_deferred_to_tiling = true;
2014+
return false;
2015+
}
2016+
}
2017+
if (sd_backend_is(runtime_backend, "Vulkan")) {
2018+
// supports_op rejects any op larger than the device's real max buffer
2019+
// size, which is the true per-buffer limit -- unlike buft_get_max_size,
2020+
// which on Vulkan reports only the ~1 GB suballocation block.
2021+
for (int i = 0; i < ggml_graph_n_nodes(gf); ++i) {
2022+
ggml_tensor* op = ggml_graph_node(gf, i);
2023+
if (!ggml_backend_supports_op(runtime_backend, op)) {
2024+
LOG_DEBUG("%s: untiled compute op %.2f MB exceeds backend support; deferring to tiling",
2025+
get_desc().c_str(),
2026+
ggml_nbytes(op) / 1024.0 / 1024.0);
2027+
compute_buffer_deferred_to_tiling = true;
2028+
return false;
2029+
}
2030+
}
2031+
} else {
2032+
size_t max_size = ggml_backend_buft_get_max_size(buft);
2033+
if (max_size > 0) {
2034+
ggml_gallocr* probe = ggml_gallocr_new(buft);
2035+
size_t sizes[1] = {0};
2036+
ggml_gallocr_reserve_n_size(probe, gf, nullptr, nullptr, sizes);
2037+
ggml_gallocr_free(probe);
2038+
if (sizes[0] > max_size) {
2039+
LOG_DEBUG("%s: untiled compute buffer %.2f MB exceeds backend max single buffer %.2f MB; deferring to tiling",
2040+
get_desc().c_str(),
2041+
sizes[0] / 1024.0 / 1024.0,
2042+
max_size / 1024.0 / 1024.0);
2043+
compute_buffer_deferred_to_tiling = true;
2044+
return false;
2045+
}
2046+
}
2047+
}
2048+
}
2049+
2050+
compute_allocr = ggml_gallocr_new(buft);
19852051

19862052
if (!ggml_gallocr_reserve(compute_allocr, gf)) {
19872053
// failed to allocate the compute buffer
@@ -2432,7 +2498,11 @@ struct GGMLRunner {
24322498
GraphWeightDoneGuard graph_weight_done_guard(this, &params_to_prepare);
24332499

24342500
if (!alloc_compute_buffer(gf)) {
2435-
LOG_ERROR("%s alloc compute buffer failed", get_desc().c_str());
2501+
// compute_buffer_deferred_to_tiling: alloc_compute_buffer declined a too-large
2502+
// untiled buffer on purpose (VAE AUTO will retry with tiling) -- not a real error.
2503+
if (!compute_buffer_deferred_to_tiling) {
2504+
LOG_ERROR("%s alloc compute buffer failed", get_desc().c_str());
2505+
}
24362506
return std::nullopt;
24372507
}
24382508
struct ComputeBufferGuard {
@@ -2822,6 +2892,15 @@ struct GGMLRunner {
28222892
void set_stream_layers_enabled(bool enabled) {
28232893
stream_layers_enabled = enabled;
28242894
}
2895+
2896+
// When enabled, the next compute() measures its planned compute buffer and
2897+
// declines to allocate (returning failure) if it would exceed the backend's
2898+
// max single-buffer size, instead of attempting the allocation and emitting
2899+
// the backend's raw error. See probe_compute_buffer_fits_.
2900+
void set_probe_compute_buffer_fits(bool enabled, size_t max_bytes = 0) {
2901+
probe_compute_buffer_fits_ = enabled;
2902+
probe_max_bytes_ = enabled ? max_bytes : 0;
2903+
}
28252904
};
28262905

28272906
class GGMLBlock {

src/model/vae/vae.hpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,55 @@ struct VAE : public GGMLRunner {
199199
"vae decode compute failed while processing a tile",
200200
silent);
201201
} else {
202+
// AUTO: probe the untiled decode buffer first so a too-large one defers to tiling below
203+
// without the backend's raw alloc error; the output.empty() check still backstops a real OOM.
204+
const bool auto_probe = !tiling_params.enabled && tiling_params.auto_tile;
205+
if (auto_probe) {
206+
size_t max_bytes = 0;
207+
if (tiling_params.extra_tiling_args != nullptr) {
208+
for (const auto& [key, value] : parse_key_value_args(tiling_params.extra_tiling_args, "VAE extra tiling arg")) {
209+
if (key == "max_buffer_size") {
210+
max_bytes = strtoull(value.c_str(), nullptr, 10);
211+
}
212+
}
213+
}
214+
set_probe_compute_buffer_fits(true, max_bytes);
215+
}
202216
output = _compute(n_threads, input, true);
217+
if (auto_probe) {
218+
set_probe_compute_buffer_fits(false);
219+
}
220+
if (output.empty() && !tiling_params.enabled && tiling_params.auto_tile) {
221+
// Untiled decode exceeded the backend's per-buffer limit (common on iGPUs, where the
222+
// cap is per-buffer, not total memory) -- fall back to tiling instead of failing.
223+
free_compute_buffer();
224+
if (!silent) {
225+
LOG_WARN("vae: untiled decode buffer exceeded the backend limit; retrying with tiling");
226+
}
227+
sd_tiling_params_t auto_tiling = tiling_params;
228+
auto_tiling.enabled = true; // default tile size (32) via get_tile_sizes
229+
set_tiling_params(auto_tiling);
230+
const int scale_factor = get_scale_factor();
231+
int64_t W = input.shape()[0] * scale_factor;
232+
int64_t H = input.shape()[1] * scale_factor;
233+
float tile_overlap;
234+
int tile_size_x, tile_size_y;
235+
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, auto_tiling, input.shape()[0], input.shape()[1]);
236+
output = tiled_compute(
237+
input,
238+
n_threads,
239+
static_cast<int>(W),
240+
static_cast<int>(H),
241+
scale_factor,
242+
tile_size_x,
243+
tile_size_y,
244+
tile_overlap,
245+
circular_x,
246+
circular_y,
247+
true,
248+
"vae decode compute failed while processing a tile",
249+
silent);
250+
}
203251
}
204252

205253
free_compute_buffer();

src/stable-diffusion.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ class StableDiffusionGGML {
187187
bool apply_lora_immediately = false;
188188

189189
std::string taesd_path;
190-
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0, nullptr};
190+
sd_tiling_params_t vae_tiling_params = {false, false, 0, 0, 0.5f, 0, 0, nullptr, true}; // auto_tile=true (AUTO default)
191191
bool enable_mmap = false;
192192
sd::ggml_graph_cut::MaxVramAssignment max_vram_assignment;
193193
bool stream_layers = false;
@@ -2795,7 +2795,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
27952795
sd_img_gen_params->batch_count = 1;
27962796
sd_img_gen_params->control_strength = 0.9f;
27972797
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
2798-
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
2798+
sd_img_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr, true}; // auto_tile=true (AUTO)
27992799
sd_cache_params_init(&sd_img_gen_params->cache);
28002800
sd_hires_params_init(&sd_img_gen_params->hires);
28012801
}
@@ -2882,7 +2882,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
28822882
sd_vid_gen_params->fps = 16;
28832883
sd_vid_gen_params->moe_boundary = 0.875f;
28842884
sd_vid_gen_params->vace_strength = 1.f;
2885-
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr};
2885+
sd_vid_gen_params->vae_tiling_params = {false, false, 0, 0, 0.5f, 0.0f, 0.0f, nullptr, true}; // auto_tile=true (AUTO)
28862886
sd_vid_gen_params->hires.enabled = false;
28872887
sd_vid_gen_params->hires.upscaler = SD_HIRES_UPSCALER_LATENT;
28882888
sd_vid_gen_params->hires.scale = 2.f;

0 commit comments

Comments
 (0)