@@ -1699,6 +1699,12 @@ struct GGMLRunner {
16991699 ggml_backend_buffer_t partial_runtime_params_buffer = nullptr ;
17001700 std::vector<std::pair<ggml_tensor*, ggml_tensor*>> partial_offload_pairs;
17011701
1702+ // Next segment's params prefetched during current segment's compute.
1703+ ggml_context* pending_offload_ctx = nullptr ;
1704+ ggml_backend_buffer_t pending_runtime_params_buffer = nullptr ;
1705+ std::vector<std::pair<ggml_tensor*, ggml_tensor*>> pending_offload_pairs;
1706+ uint64_t pending_param_signature = 0 ;
1707+
17021708 // Params kept on the runtime backend across streaming segments.
17031709 ggml_context* resident_offload_ctx = nullptr ;
17041710 std::vector<std::pair<ggml_tensor*, ggml_tensor*>> resident_offload_pairs;
@@ -2159,36 +2165,66 @@ struct GGMLRunner {
21592165 return true ;
21602166 }
21612167
2162- bool offload_partial_params (const std::vector<ggml_tensor*>& tensors) {
2163- restore_partial_params ();
2164- if (params_backend == runtime_backend) {
2165- return true ;
2166- }
2167- if (tensors.empty ()) {
2168- return true ;
2168+ static uint64_t param_signature (const std::vector<ggml_tensor*>& tensors) {
2169+ uint64_t h = 0 ;
2170+ for (ggml_tensor* t : tensors) {
2171+ h ^= reinterpret_cast <uintptr_t >(t) * 0x9E3779B97F4A7C15ull ;
21692172 }
2170- GGML_ASSERT (!params_on_runtime_backend) ;
2171- GGML_ASSERT (partial_runtime_params_buffer == nullptr );
2173+ return h ;
2174+ }
21722175
2173- std::vector<ggml_tensor*> unique_tensors;
2174- std::unordered_set<ggml_tensor*> seen_tensors;
2176+ void dedup_runtime_params (const std::vector<ggml_tensor*>& tensors,
2177+ std::vector<ggml_tensor*>& unique_tensors) {
2178+ std::unordered_set<ggml_tensor*> seen;
21752179 unique_tensors.reserve (tensors.size ());
2176- seen_tensors .reserve (tensors.size ());
2180+ seen .reserve (tensors.size ());
21772181 for (ggml_tensor* tensor : tensors) {
21782182 if (tensor == nullptr ) {
21792183 continue ;
21802184 }
21812185 if (resident_param_set.find (tensor) != resident_param_set.end ()) {
21822186 continue ;
21832187 }
2184- if (seen_tensors .insert (tensor).second ) {
2188+ if (seen .insert (tensor).second ) {
21852189 unique_tensors.push_back (tensor);
21862190 }
21872191 }
2192+ }
2193+
2194+ bool offload_partial_params (const std::vector<ggml_tensor*>& tensors) {
2195+ if (params_backend == runtime_backend) {
2196+ restore_pending_params ();
2197+ restore_partial_params ();
2198+ return true ;
2199+ }
2200+ if (tensors.empty ()) {
2201+ restore_pending_params ();
2202+ restore_partial_params ();
2203+ return true ;
2204+ }
2205+
2206+ std::vector<ggml_tensor*> unique_tensors;
2207+ dedup_runtime_params (tensors, unique_tensors);
21882208 if (unique_tensors.empty ()) {
2209+ restore_pending_params ();
2210+ restore_partial_params ();
2211+ return true ;
2212+ }
2213+
2214+ // Fast path: if the prefetch already loaded these exact params, just
2215+ // swap the original tensors onto the pending buffer (no extra H2D).
2216+ if (pending_runtime_params_buffer != nullptr &&
2217+ pending_param_signature == param_signature (unique_tensors)) {
2218+ restore_partial_params ();
2219+ promote_pending_to_partial ();
21892220 return true ;
21902221 }
21912222
2223+ restore_pending_params ();
2224+ restore_partial_params ();
2225+ GGML_ASSERT (!params_on_runtime_backend);
2226+ GGML_ASSERT (partial_runtime_params_buffer == nullptr );
2227+
21922228 ggml_init_params params;
21932229 params.mem_size = std::max<size_t >(1 , unique_tensors.size ()) * ggml_tensor_overhead ();
21942230 params.mem_buffer = nullptr ;
@@ -2303,6 +2339,95 @@ struct GGMLRunner {
23032339 }
23042340 }
23052341
2342+ bool offload_pending_params (const std::vector<ggml_tensor*>& tensors) {
2343+ restore_pending_params ();
2344+ if (params_backend == runtime_backend) {
2345+ return true ;
2346+ }
2347+ if (tensors.empty ()) {
2348+ return true ;
2349+ }
2350+
2351+ std::vector<ggml_tensor*> unique_tensors;
2352+ dedup_runtime_params (tensors, unique_tensors);
2353+ if (unique_tensors.empty ()) {
2354+ return true ;
2355+ }
2356+
2357+ ggml_init_params params;
2358+ params.mem_size = std::max<size_t >(1 , unique_tensors.size ()) * ggml_tensor_overhead ();
2359+ params.mem_buffer = nullptr ;
2360+ params.no_alloc = true ;
2361+
2362+ pending_offload_ctx = ggml_init (params);
2363+ GGML_ASSERT (pending_offload_ctx != nullptr );
2364+ pending_offload_pairs.reserve (unique_tensors.size ());
2365+
2366+ for (ggml_tensor* tensor : unique_tensors) {
2367+ GGML_ASSERT (tensor->view_src == nullptr );
2368+ ggml_tensor* offload_tensor = ggml_dup_tensor (pending_offload_ctx, tensor);
2369+ ggml_set_name (offload_tensor, tensor->name );
2370+ pending_offload_pairs.push_back ({tensor, offload_tensor});
2371+ }
2372+
2373+ pending_runtime_params_buffer = ggml_backend_alloc_ctx_tensors (pending_offload_ctx, runtime_backend);
2374+ if (pending_runtime_params_buffer == nullptr ) {
2375+ LOG_DEBUG (" %s alloc pending runtime params backend buffer failed, num_tensors = %zu" ,
2376+ get_desc ().c_str (),
2377+ pending_offload_pairs.size ());
2378+ ggml_free (pending_offload_ctx);
2379+ pending_offload_ctx = nullptr ;
2380+ pending_offload_pairs.clear ();
2381+ return false ;
2382+ }
2383+ ggml_backend_buffer_set_usage (pending_runtime_params_buffer, GGML_BACKEND_BUFFER_USAGE_WEIGHTS );
2384+
2385+ // Original tensors stay pointed at the partial buffer until promote.
2386+ for (auto & pair : pending_offload_pairs) {
2387+ ggml_backend_tensor_copy (pair.first , pair.second );
2388+ }
2389+
2390+ pending_param_signature = param_signature (unique_tensors);
2391+ return true ;
2392+ }
2393+
2394+ void restore_pending_params () {
2395+ pending_offload_pairs.clear ();
2396+ if (pending_runtime_params_buffer != nullptr ) {
2397+ ggml_backend_buffer_free (pending_runtime_params_buffer);
2398+ pending_runtime_params_buffer = nullptr ;
2399+ }
2400+ if (pending_offload_ctx != nullptr ) {
2401+ ggml_free (pending_offload_ctx);
2402+ pending_offload_ctx = nullptr ;
2403+ }
2404+ pending_param_signature = 0 ;
2405+ }
2406+
2407+ // Caller must have already restore_partial_params()ed.
2408+ void promote_pending_to_partial () {
2409+ GGML_ASSERT (partial_runtime_params_buffer == nullptr );
2410+ GGML_ASSERT (partial_offload_ctx == nullptr );
2411+ GGML_ASSERT (partial_offload_pairs.empty ());
2412+
2413+ for (auto & pair : pending_offload_pairs) {
2414+ ggml_tensor* tensor = pair.first ;
2415+ ggml_tensor* offload_tensor = pair.second ;
2416+ std::swap (tensor->buffer , offload_tensor->buffer );
2417+ std::swap (tensor->data , offload_tensor->data );
2418+ std::swap (tensor->extra , offload_tensor->extra );
2419+ }
2420+
2421+ partial_offload_ctx = pending_offload_ctx;
2422+ partial_runtime_params_buffer = pending_runtime_params_buffer;
2423+ partial_offload_pairs = std::move (pending_offload_pairs);
2424+
2425+ pending_offload_ctx = nullptr ;
2426+ pending_runtime_params_buffer = nullptr ;
2427+ pending_offload_pairs.clear ();
2428+ pending_param_signature = 0 ;
2429+ }
2430+
23062431 bool offload_resident_params (const std::vector<ggml_tensor*>& tensors) {
23072432 if (params_backend == runtime_backend) {
23082433 return true ;
@@ -2631,7 +2756,8 @@ struct GGMLRunner {
26312756 const std::vector<ggml_tensor*>& runtime_param_tensors,
26322757 bool preserve_backend_tensor_data_map,
26332758 bool no_return = false ,
2634- const std::unordered_set<std::string>* cache_keep_names = nullptr ) {
2759+ const std::unordered_set<std::string>* cache_keep_names = nullptr ,
2760+ const std::function<void ()>& prefetch_cb = {}) {
26352761 int64_t t_execute_begin = ggml_time_ms ();
26362762 const bool use_partial_param_offload = !runtime_param_tensors.empty ();
26372763 int64_t t_offload_begin = ggml_time_ms ();
@@ -2676,9 +2802,14 @@ struct GGMLRunner {
26762802 }
26772803
26782804 int64_t t_compute_begin = ggml_time_ms ();
2679- ggml_status status = ggml_backend_graph_compute (runtime_backend, gf);
2680- int64_t t_compute_end = ggml_time_ms ();
2805+ ggml_status status = ggml_backend_graph_compute_async (runtime_backend, gf);
2806+ if (prefetch_cb) {
2807+ prefetch_cb ();
2808+ }
2809+ ggml_backend_synchronize (runtime_backend);
2810+ int64_t t_compute_end = ggml_time_ms ();
26812811 if (status != GGML_STATUS_SUCCESS ) {
2812+ restore_pending_params ();
26822813 LOG_ERROR (" %s compute failed: %s" , get_desc ().c_str (), ggml_status_to_string (status));
26832814 if (free_compute_buffer_immediately) {
26842815 free_compute_buffer ();
@@ -2955,15 +3086,29 @@ struct GGMLRunner {
29553086
29563087 ggml_context* segment_graph_ctx = nullptr ;
29573088 ggml_cgraph* segment_graph = sd::ggml_graph_cut::build_segment_graph (gf, segment, &segment_graph_ctx);
2958- auto segment_output = execute_graph<T>(segment_graph,
3089+
3090+ std::function<void ()> prefetch_cb;
3091+ if (!is_last) {
3092+ const auto & next_segment = plan.segments [seg_idx + 1 ];
3093+ auto next_params = sd::ggml_graph_cut::runtime_param_tensors (gf, next_segment, get_desc ().c_str ());
3094+ if (!next_params.empty ()) {
3095+ prefetch_cb = [this , next_params = std::move (next_params)]() {
3096+ offload_pending_params (next_params);
3097+ };
3098+ }
3099+ }
3100+
3101+ auto segment_output = execute_graph<T>(segment_graph,
29593102 n_threads,
29603103 /* free_compute_buffer_immediately=*/ true ,
29613104 sd::ggml_graph_cut::runtime_param_tensors (gf, segment, get_desc ().c_str ()),
29623105 /* preserve_backend_tensor_data_map=*/ true ,
29633106 /* no_return=*/ !is_last || no_return,
2964- &future_cut_names);
3107+ &future_cut_names,
3108+ prefetch_cb);
29653109 ggml_free (segment_graph_ctx);
29663110 if (!segment_output.has_value ()) {
3111+ restore_pending_params ();
29673112 free_cache_ctx_and_buffer ();
29683113 free_compute_buffer ();
29693114 free_compute_ctx ();
@@ -3081,6 +3226,7 @@ struct GGMLRunner {
30813226
30823227 void free_params_buffer () {
30833228 // Restore swapped resident params before freeing their backing buffer.
3229+ restore_pending_params ();
30843230 restore_resident_params ();
30853231 if (params_buffer != nullptr ) {
30863232 ggml_backend_buffer_free (params_buffer);
0 commit comments