Skip to content

Commit ed74577

Browse files
authored
feat: --stream-layers for streaming weights from CPU during generation (#1576)
1 parent 7948df8 commit ed74577

12 files changed

Lines changed: 692 additions & 12 deletions

examples/common/common.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,10 @@ ArgOptions SDContextParams::get_options() {
438438
};
439439

440440
options.bool_options = {
441+
{"",
442+
"--stream-layers",
443+
"enable residency+prefetch streaming on top of --max-vram (no effect without --max-vram; defaults to false)",
444+
true, &stream_layers},
441445
{"",
442446
"--force-sdxl-vae-conv-scale",
443447
"force use of conv scale on sdxl vae",
@@ -720,6 +724,7 @@ std::string SDContextParams::to_string() const {
720724
<< " sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n"
721725
<< " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n"
722726
<< " max_vram: " << max_vram << ",\n"
727+
<< " stream_layers: " << (stream_layers ? "true" : "false") << ",\n"
723728
<< " backend: \"" << backend << "\",\n"
724729
<< " params_backend: \"" << params_backend << "\",\n"
725730
<< " enable_mmap: " << (enable_mmap ? "true" : "false") << ",\n"
@@ -800,6 +805,7 @@ sd_ctx_params_t SDContextParams::to_sd_ctx_params_t(bool vae_decode_only, bool f
800805
qwen_image_zero_cond_t,
801806
str_to_vae_format(vae_format),
802807
max_vram,
808+
stream_layers,
803809
backend.c_str(),
804810
params_backend.c_str(),
805811
};

examples/common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ struct SDContextParams {
113113
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
114114
bool offload_params_to_cpu = false;
115115
float max_vram = 0.f;
116+
bool stream_layers = false;
116117
std::string backend;
117118
std::string params_backend;
118119
bool enable_mmap = false;

include/stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ typedef struct {
222222
bool qwen_image_zero_cond_t;
223223
enum sd_vae_format_t vae_format;
224224
float max_vram; // GiB budget for graph-cut segmented param offload (0 = disabled, -1 = auto free VRAM minus 1 GiB)
225+
bool stream_layers; // Enable residency+prefetch streaming on top of --max-vram (no effect without --max-vram)
225226
const char* backend;
226227
const char* params_backend;
227228
} sd_ctx_params_t;

src/conditioner.hpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ struct Conditioner {
118118
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) = 0;
119119
virtual size_t get_params_buffer_size() = 0;
120120
virtual void set_max_graph_vram_bytes(size_t max_vram_bytes) {}
121+
virtual void set_stream_layers_enabled(bool enabled) {}
121122
virtual void set_flash_attention_enabled(bool enabled) = 0;
122123
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
123124
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(int n_threads,
@@ -210,6 +211,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
210211
}
211212
}
212213

214+
void set_stream_layers_enabled(bool enabled) override {
215+
text_model->set_stream_layers_enabled(enabled);
216+
if (sd_version_is_sdxl(version)) {
217+
text_model2->set_stream_layers_enabled(enabled);
218+
}
219+
}
220+
213221
void set_flash_attention_enabled(bool enabled) override {
214222
text_model->set_flash_attention_enabled(enabled);
215223
if (sd_version_is_sdxl(version)) {
@@ -843,6 +851,18 @@ struct SD3CLIPEmbedder : public Conditioner {
843851
}
844852
}
845853

854+
void set_stream_layers_enabled(bool enabled) override {
855+
if (clip_l) {
856+
clip_l->set_stream_layers_enabled(enabled);
857+
}
858+
if (clip_g) {
859+
clip_g->set_stream_layers_enabled(enabled);
860+
}
861+
if (t5) {
862+
t5->set_stream_layers_enabled(enabled);
863+
}
864+
}
865+
846866
void set_flash_attention_enabled(bool enabled) override {
847867
if (clip_l) {
848868
clip_l->set_flash_attention_enabled(enabled);
@@ -1200,6 +1220,15 @@ struct FluxCLIPEmbedder : public Conditioner {
12001220
}
12011221
}
12021222

1223+
void set_stream_layers_enabled(bool enabled) override {
1224+
if (clip_l) {
1225+
clip_l->set_stream_layers_enabled(enabled);
1226+
}
1227+
if (t5) {
1228+
t5->set_stream_layers_enabled(enabled);
1229+
}
1230+
}
1231+
12031232
void set_flash_attention_enabled(bool enabled) override {
12041233
if (clip_l) {
12051234
clip_l->set_flash_attention_enabled(enabled);
@@ -1434,6 +1463,12 @@ struct T5CLIPEmbedder : public Conditioner {
14341463
}
14351464
}
14361465

1466+
void set_stream_layers_enabled(bool enabled) override {
1467+
if (t5) {
1468+
t5->set_stream_layers_enabled(enabled);
1469+
}
1470+
}
1471+
14371472
void set_flash_attention_enabled(bool enabled) override {
14381473
if (t5) {
14391474
t5->set_flash_attention_enabled(enabled);
@@ -1617,6 +1652,10 @@ struct AnimaConditioner : public Conditioner {
16171652
llm->set_max_graph_vram_bytes(max_vram_bytes);
16181653
}
16191654

1655+
void set_stream_layers_enabled(bool enabled) override {
1656+
llm->set_stream_layers_enabled(enabled);
1657+
}
1658+
16201659
void set_flash_attention_enabled(bool enabled) override {
16211660
llm->set_flash_attention_enabled(enabled);
16221661
}
@@ -1765,6 +1804,10 @@ struct LLMEmbedder : public Conditioner {
17651804
llm->set_max_graph_vram_bytes(max_vram_bytes);
17661805
}
17671806

1807+
void set_stream_layers_enabled(bool enabled) override {
1808+
llm->set_stream_layers_enabled(enabled);
1809+
}
1810+
17681811
void set_flash_attention_enabled(bool enabled) override {
17691812
llm->set_flash_attention_enabled(enabled);
17701813
}

0 commit comments

Comments
 (0)