Skip to content

Commit 815e38c

Browse files
committed
perf: async prefetch of next segment's params during compute
1 parent 19bdfe2 commit 815e38c

1 file changed

Lines changed: 164 additions & 18 deletions

File tree

src/core/ggml_extend.hpp

Lines changed: 164 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,6 +1699,12 @@ struct GGMLRunner {
16991699
ggml_backend_buffer_t partial_runtime_params_buffer = nullptr;
17001700
std::vector<std::pair<ggml_tensor*, ggml_tensor*>> partial_offload_pairs;
17011701

1702+
// Next segment's params prefetched during current segment's compute.
1703+
ggml_context* pending_offload_ctx = nullptr;
1704+
ggml_backend_buffer_t pending_runtime_params_buffer = nullptr;
1705+
std::vector<std::pair<ggml_tensor*, ggml_tensor*>> pending_offload_pairs;
1706+
uint64_t pending_param_signature = 0;
1707+
17021708
// Params kept on the runtime backend across streaming segments.
17031709
ggml_context* resident_offload_ctx = nullptr;
17041710
std::vector<std::pair<ggml_tensor*, ggml_tensor*>> resident_offload_pairs;
@@ -2159,36 +2165,66 @@ struct GGMLRunner {
21592165
return true;
21602166
}
21612167

2162-
bool offload_partial_params(const std::vector<ggml_tensor*>& tensors) {
2163-
restore_partial_params();
2164-
if (params_backend == runtime_backend) {
2165-
return true;
2166-
}
2167-
if (tensors.empty()) {
2168-
return true;
2168+
static uint64_t param_signature(const std::vector<ggml_tensor*>& tensors) {
2169+
uint64_t h = 0;
2170+
for (ggml_tensor* t : tensors) {
2171+
h ^= reinterpret_cast<uintptr_t>(t) * 0x9E3779B97F4A7C15ull;
21692172
}
2170-
GGML_ASSERT(!params_on_runtime_backend);
2171-
GGML_ASSERT(partial_runtime_params_buffer == nullptr);
2173+
return h;
2174+
}
21722175

2173-
std::vector<ggml_tensor*> unique_tensors;
2174-
std::unordered_set<ggml_tensor*> seen_tensors;
2176+
void dedup_runtime_params(const std::vector<ggml_tensor*>& tensors,
2177+
std::vector<ggml_tensor*>& unique_tensors) {
2178+
std::unordered_set<ggml_tensor*> seen;
21752179
unique_tensors.reserve(tensors.size());
2176-
seen_tensors.reserve(tensors.size());
2180+
seen.reserve(tensors.size());
21772181
for (ggml_tensor* tensor : tensors) {
21782182
if (tensor == nullptr) {
21792183
continue;
21802184
}
21812185
if (resident_param_set.find(tensor) != resident_param_set.end()) {
21822186
continue;
21832187
}
2184-
if (seen_tensors.insert(tensor).second) {
2188+
if (seen.insert(tensor).second) {
21852189
unique_tensors.push_back(tensor);
21862190
}
21872191
}
2192+
}
2193+
2194+
bool offload_partial_params(const std::vector<ggml_tensor*>& tensors) {
2195+
if (params_backend == runtime_backend) {
2196+
restore_pending_params();
2197+
restore_partial_params();
2198+
return true;
2199+
}
2200+
if (tensors.empty()) {
2201+
restore_pending_params();
2202+
restore_partial_params();
2203+
return true;
2204+
}
2205+
2206+
std::vector<ggml_tensor*> unique_tensors;
2207+
dedup_runtime_params(tensors, unique_tensors);
21882208
if (unique_tensors.empty()) {
2209+
restore_pending_params();
2210+
restore_partial_params();
2211+
return true;
2212+
}
2213+
2214+
// Fast path: if the prefetch already loaded these exact params, just
2215+
// swap the original tensors onto the pending buffer (no extra H2D).
2216+
if (pending_runtime_params_buffer != nullptr &&
2217+
pending_param_signature == param_signature(unique_tensors)) {
2218+
restore_partial_params();
2219+
promote_pending_to_partial();
21892220
return true;
21902221
}
21912222

2223+
restore_pending_params();
2224+
restore_partial_params();
2225+
GGML_ASSERT(!params_on_runtime_backend);
2226+
GGML_ASSERT(partial_runtime_params_buffer == nullptr);
2227+
21922228
ggml_init_params params;
21932229
params.mem_size = std::max<size_t>(1, unique_tensors.size()) * ggml_tensor_overhead();
21942230
params.mem_buffer = nullptr;
@@ -2303,6 +2339,95 @@ struct GGMLRunner {
23032339
}
23042340
}
23052341

2342+
bool offload_pending_params(const std::vector<ggml_tensor*>& tensors) {
2343+
restore_pending_params();
2344+
if (params_backend == runtime_backend) {
2345+
return true;
2346+
}
2347+
if (tensors.empty()) {
2348+
return true;
2349+
}
2350+
2351+
std::vector<ggml_tensor*> unique_tensors;
2352+
dedup_runtime_params(tensors, unique_tensors);
2353+
if (unique_tensors.empty()) {
2354+
return true;
2355+
}
2356+
2357+
ggml_init_params params;
2358+
params.mem_size = std::max<size_t>(1, unique_tensors.size()) * ggml_tensor_overhead();
2359+
params.mem_buffer = nullptr;
2360+
params.no_alloc = true;
2361+
2362+
pending_offload_ctx = ggml_init(params);
2363+
GGML_ASSERT(pending_offload_ctx != nullptr);
2364+
pending_offload_pairs.reserve(unique_tensors.size());
2365+
2366+
for (ggml_tensor* tensor : unique_tensors) {
2367+
GGML_ASSERT(tensor->view_src == nullptr);
2368+
ggml_tensor* offload_tensor = ggml_dup_tensor(pending_offload_ctx, tensor);
2369+
ggml_set_name(offload_tensor, tensor->name);
2370+
pending_offload_pairs.push_back({tensor, offload_tensor});
2371+
}
2372+
2373+
pending_runtime_params_buffer = ggml_backend_alloc_ctx_tensors(pending_offload_ctx, runtime_backend);
2374+
if (pending_runtime_params_buffer == nullptr) {
2375+
LOG_DEBUG("%s alloc pending runtime params backend buffer failed, num_tensors = %zu",
2376+
get_desc().c_str(),
2377+
pending_offload_pairs.size());
2378+
ggml_free(pending_offload_ctx);
2379+
pending_offload_ctx = nullptr;
2380+
pending_offload_pairs.clear();
2381+
return false;
2382+
}
2383+
ggml_backend_buffer_set_usage(pending_runtime_params_buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
2384+
2385+
// Original tensors stay pointed at the partial buffer until promote.
2386+
for (auto& pair : pending_offload_pairs) {
2387+
ggml_backend_tensor_copy(pair.first, pair.second);
2388+
}
2389+
2390+
pending_param_signature = param_signature(unique_tensors);
2391+
return true;
2392+
}
2393+
2394+
void restore_pending_params() {
2395+
pending_offload_pairs.clear();
2396+
if (pending_runtime_params_buffer != nullptr) {
2397+
ggml_backend_buffer_free(pending_runtime_params_buffer);
2398+
pending_runtime_params_buffer = nullptr;
2399+
}
2400+
if (pending_offload_ctx != nullptr) {
2401+
ggml_free(pending_offload_ctx);
2402+
pending_offload_ctx = nullptr;
2403+
}
2404+
pending_param_signature = 0;
2405+
}
2406+
2407+
// Caller must have already restore_partial_params()ed.
2408+
void promote_pending_to_partial() {
2409+
GGML_ASSERT(partial_runtime_params_buffer == nullptr);
2410+
GGML_ASSERT(partial_offload_ctx == nullptr);
2411+
GGML_ASSERT(partial_offload_pairs.empty());
2412+
2413+
for (auto& pair : pending_offload_pairs) {
2414+
ggml_tensor* tensor = pair.first;
2415+
ggml_tensor* offload_tensor = pair.second;
2416+
std::swap(tensor->buffer, offload_tensor->buffer);
2417+
std::swap(tensor->data, offload_tensor->data);
2418+
std::swap(tensor->extra, offload_tensor->extra);
2419+
}
2420+
2421+
partial_offload_ctx = pending_offload_ctx;
2422+
partial_runtime_params_buffer = pending_runtime_params_buffer;
2423+
partial_offload_pairs = std::move(pending_offload_pairs);
2424+
2425+
pending_offload_ctx = nullptr;
2426+
pending_runtime_params_buffer = nullptr;
2427+
pending_offload_pairs.clear();
2428+
pending_param_signature = 0;
2429+
}
2430+
23062431
bool offload_resident_params(const std::vector<ggml_tensor*>& tensors) {
23072432
if (params_backend == runtime_backend) {
23082433
return true;
@@ -2631,7 +2756,8 @@ struct GGMLRunner {
26312756
const std::vector<ggml_tensor*>& runtime_param_tensors,
26322757
bool preserve_backend_tensor_data_map,
26332758
bool no_return = false,
2634-
const std::unordered_set<std::string>* cache_keep_names = nullptr) {
2759+
const std::unordered_set<std::string>* cache_keep_names = nullptr,
2760+
const std::function<void()>& prefetch_cb = {}) {
26352761
int64_t t_execute_begin = ggml_time_ms();
26362762
const bool use_partial_param_offload = !runtime_param_tensors.empty();
26372763
int64_t t_offload_begin = ggml_time_ms();
@@ -2676,9 +2802,14 @@ struct GGMLRunner {
26762802
}
26772803

26782804
int64_t t_compute_begin = ggml_time_ms();
2679-
ggml_status status = ggml_backend_graph_compute(runtime_backend, gf);
2680-
int64_t t_compute_end = ggml_time_ms();
2805+
ggml_status status = ggml_backend_graph_compute_async(runtime_backend, gf);
2806+
if (prefetch_cb) {
2807+
prefetch_cb();
2808+
}
2809+
ggml_backend_synchronize(runtime_backend);
2810+
int64_t t_compute_end = ggml_time_ms();
26812811
if (status != GGML_STATUS_SUCCESS) {
2812+
restore_pending_params();
26822813
LOG_ERROR("%s compute failed: %s", get_desc().c_str(), ggml_status_to_string(status));
26832814
if (free_compute_buffer_immediately) {
26842815
free_compute_buffer();
@@ -2955,15 +3086,29 @@ struct GGMLRunner {
29553086

29563087
ggml_context* segment_graph_ctx = nullptr;
29573088
ggml_cgraph* segment_graph = sd::ggml_graph_cut::build_segment_graph(gf, segment, &segment_graph_ctx);
2958-
auto segment_output = execute_graph<T>(segment_graph,
3089+
3090+
std::function<void()> prefetch_cb;
3091+
if (!is_last) {
3092+
const auto& next_segment = plan.segments[seg_idx + 1];
3093+
auto next_params = sd::ggml_graph_cut::runtime_param_tensors(gf, next_segment, get_desc().c_str());
3094+
if (!next_params.empty()) {
3095+
prefetch_cb = [this, next_params = std::move(next_params)]() {
3096+
offload_pending_params(next_params);
3097+
};
3098+
}
3099+
}
3100+
3101+
auto segment_output = execute_graph<T>(segment_graph,
29593102
n_threads,
29603103
/*free_compute_buffer_immediately=*/true,
29613104
sd::ggml_graph_cut::runtime_param_tensors(gf, segment, get_desc().c_str()),
29623105
/*preserve_backend_tensor_data_map=*/true,
29633106
/*no_return=*/!is_last || no_return,
2964-
&future_cut_names);
3107+
&future_cut_names,
3108+
prefetch_cb);
29653109
ggml_free(segment_graph_ctx);
29663110
if (!segment_output.has_value()) {
3111+
restore_pending_params();
29673112
free_cache_ctx_and_buffer();
29683113
free_compute_buffer();
29693114
free_compute_ctx();
@@ -3081,6 +3226,7 @@ struct GGMLRunner {
30813226

30823227
void free_params_buffer() {
30833228
// Restore swapped resident params before freeing their backing buffer.
3229+
restore_pending_params();
30843230
restore_resident_params();
30853231
if (params_buffer != nullptr) {
30863232
ggml_backend_buffer_free(params_buffer);

0 commit comments

Comments
 (0)