Skip to content

Commit f61b0f7

Browse files
committed
llama : reuse compute graphs
ggml-ci
1 parent 4534123 commit f61b0f7

File tree

8 files changed

+296
-93
lines changed

8 files changed

+296
-93
lines changed

src/llama-context.cpp

Lines changed: 56 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,14 @@ llama_context::llama_context(
227227

228228
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
229229

230-
// buffer used to store the computation graph and the tensor meta data
231-
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
230+
// buffers used to store the computation graph and the tensor meta data
231+
for (auto & res : gf_res) {
232+
res.reset(new llm_graph_result());
233+
res->reserve(max_nodes);
234+
};
235+
236+
gf_res_reserve.reset(new llm_graph_result());
237+
gf_res_reserve->reserve(max_nodes);
232238

233239
// TODO: move these checks to ggml_backend_sched
234240
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
@@ -388,10 +394,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
388394
return sched.get();
389395
}
390396

391-
ggml_context * llama_context::get_ctx_compute() const {
392-
return ctx_compute.get();
393-
}
394-
395397
uint32_t llama_context::n_ctx() const {
396398
return cparams.n_ctx;
397399
}
@@ -678,36 +680,40 @@ bool llama_context::apply_adapter_cvec(
678680
return cvec.apply(model, data, len, n_embd, il_start, il_end);
679681
}
680682

681-
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
683+
llm_graph_result_i * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
682684
if (mctx && !mctx->apply()) {
683685
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
684686
ret = GGML_STATUS_FAILED;
685687
return nullptr;
686688
}
687689

688-
auto * gf = graph_init();
690+
gf_res_next()->init();
691+
692+
auto * gf = gf_res_cur()->get_gf();
689693
if (!gf) {
690694
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
691695
ret = GGML_STATUS_FAILED;
692696
return nullptr;
693697
}
694698

695-
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
696-
if (!res) {
697-
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
698-
ret = GGML_STATUS_FAILED;
699-
return nullptr;
700-
}
699+
const bool can_reuse = graph_build(gf_res_cur(), gf_res_prv(), ubatch, gtype, mctx);
700+
if (can_reuse) {
701+
LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
702+
gf_res_next()->update(mctx);
703+
} else {
704+
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
701705

702-
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
706+
ggml_backend_sched_reset(sched.get());
707+
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
703708

704-
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
705-
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
706-
ret = GGML_STATUS_ALLOC_FAILED;
707-
return nullptr;
709+
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
710+
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
711+
ret = GGML_STATUS_ALLOC_FAILED;
712+
return nullptr;
713+
}
708714
}
709715

710-
res->set_inputs(&ubatch);
716+
gf_res_cur()->set_inputs(&ubatch);
711717

712718
const auto status = graph_compute(gf, ubatch.n_tokens > 1);
713719
if (status != GGML_STATUS_SUCCESS) {
@@ -718,7 +724,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
718724

719725
ret = GGML_STATUS_SUCCESS;
720726

721-
return res;
727+
return gf_res_cur();
722728
}
723729

724730
int llama_context::encode(const llama_batch & batch_inp) {
@@ -767,6 +773,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
767773

768774
n_outputs = n_tokens;
769775

776+
// TODO: when resetting the scheduler, clear prev graph buffers
777+
gf_res_next()->init();
770778
ggml_backend_sched_reset(sched.get());
771779
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
772780

@@ -778,7 +786,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
778786
cparams.causal_attn = false;
779787

780788
ggml_status status;
781-
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
789+
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
782790

783791
cparams.causal_attn = causal_attn_org;
784792

@@ -846,7 +854,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
846854

847855
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
848856
// overlap with device computation.
849-
ggml_backend_sched_reset(sched.get());
857+
//ggml_backend_sched_reset(sched.get());
850858

851859
// TODO: hacky solution
852860
if (model.arch == LLM_ARCH_T5 && t_embd) {
@@ -1005,11 +1013,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
10051013
n_outputs = n_outputs_new;
10061014
}
10071015

1008-
ggml_backend_sched_reset(sched.get());
1009-
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
1010-
10111016
ggml_status status;
1012-
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
1017+
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
10131018

10141019
if (!res) {
10151020
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@@ -1192,7 +1197,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
11921197

11931198
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
11941199
// overlap with device computation.
1195-
ggml_backend_sched_reset(sched.get());
1200+
//ggml_backend_sched_reset(sched.get());
11961201

11971202
return 0;
11981203
}
@@ -1279,18 +1284,6 @@ int32_t llama_context::graph_max_nodes() const {
12791284
return std::max<int32_t>(65536, 5*model.n_tensors());
12801285
}
12811286

1282-
ggml_cgraph * llama_context::graph_init() {
1283-
ggml_init_params params = {
1284-
/*.mem_size =*/ buf_compute_meta.size(),
1285-
/*.mem_buffer =*/ buf_compute_meta.data(),
1286-
/*.no_alloc =*/ true,
1287-
};
1288-
1289-
ctx_compute.reset(ggml_init(params));
1290-
1291-
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
1292-
}
1293-
12941287
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
12951288
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
12961289

@@ -1301,6 +1294,10 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13011294
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
13021295
}
13031296

1297+
// TODO: when resetting the scheduler, clear prev graph buffers
1298+
gf_res_next()->init();
1299+
ggml_backend_sched_reset(sched.get());
1300+
13041301
// store the n_outputs as it is, and restore it afterwards
13051302
// TODO: not sure if needed, might simplify in the future by removing this
13061303
const auto save_n_outputs = this->n_outputs;
@@ -1310,17 +1307,13 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13101307
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
13111308
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
13121309

1313-
auto * gf = graph_init();
1314-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1310+
gf_res_reserve->init();
1311+
auto * gf = gf_res_reserve->get_gf();
13151312

1316-
this->n_outputs = save_n_outputs;
1317-
1318-
if (!res) {
1319-
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
1320-
return nullptr;
1321-
}
1313+
const bool can_reuse = graph_build(gf_res_reserve.get(), nullptr, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
1314+
GGML_ASSERT(!can_reuse); // cannot reuse reserve graphs
13221315

1323-
ggml_backend_sched_reset(sched.get());
1316+
this->n_outputs = save_n_outputs;
13241317

13251318
// initialize scheduler with the specified graph
13261319
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
@@ -1331,15 +1324,17 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
13311324
return gf;
13321325
}
13331326

1334-
llm_graph_result_ptr llama_context::graph_build(
1335-
ggml_context * ctx,
1336-
ggml_cgraph * gf,
1327+
bool llama_context::graph_build(
1328+
llm_graph_result_i * gf_res_cur,
1329+
llm_graph_result_i * gf_res_prv,
13371330
const llama_ubatch & ubatch,
13381331
llm_graph_type gtype,
13391332
const llama_memory_context_i * mctx) {
13401333
return model.build_graph(
13411334
{
1342-
/*.ctx =*/ ctx,
1335+
/*.ctx =*/ gf_res_cur->get_ctx(),
1336+
/*.gf_res_cur =*/ static_cast<llm_graph_result *>(gf_res_cur),
1337+
/*.gf_res_prv =*/ static_cast<llm_graph_result *>(gf_res_prv),
13431338
/*.arch =*/ model.arch,
13441339
/*.hparams =*/ model.hparams,
13451340
/*.cparams =*/ cparams,
@@ -1352,7 +1347,7 @@ llm_graph_result_ptr llama_context::graph_build(
13521347
/*.cross =*/ &cross,
13531348
/*.n_outputs =*/ n_outputs,
13541349
/*.cb =*/ graph_get_cb(),
1355-
}, gf, gtype);
1350+
}, gtype);
13561351
}
13571352

13581353
ggml_status llama_context::graph_compute(
@@ -2064,8 +2059,11 @@ void llama_context::opt_epoch_iter(
20642059
break;
20652060
}
20662061

2067-
auto * gf = graph_init();
2068-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
2062+
gf_res_cur()->init();
2063+
auto * gf = gf_res_cur()->get_gf();
2064+
2065+
const bool can_reuse = graph_build(gf_res_cur(), nullptr, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
2066+
GGML_ASSERT(!can_reuse); // cannot reuse optimization graphs
20692067

20702068
struct ggml_context * ctx_compute_opt;
20712069
{
@@ -2078,10 +2076,10 @@ void llama_context::opt_epoch_iter(
20782076
};
20792077
ctx_compute_opt = ggml_init(params);
20802078
}
2081-
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
2079+
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, gf_res_cur()->get_tokens(), gf_res_cur()->get_logits());
20822080
ggml_opt_alloc(opt_ctx, train);
20832081

2084-
res->set_inputs(&ubatch);
2082+
gf_res_cur()->set_inputs(&ubatch);
20852083
{
20862084
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
20872085
GGML_ASSERT(labels->ne[1] == n_ubatch);

src/llama-context.h

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ struct llama_context {
3535

3636
ggml_backend_sched_t get_sched() const;
3737

38-
ggml_context * get_ctx_compute() const;
39-
4038
uint32_t n_ctx() const;
4139
uint32_t n_ctx_per_seq() const;
4240
uint32_t n_batch() const;
@@ -96,7 +94,7 @@ struct llama_context {
9694
// if memory_context is provided, it will be applied first to the context's memory
9795
// ret contains the status of the graph computation
9896
// returns nullptr only if ret != GGML_STATUS_SUCCESS
99-
llm_graph_result_ptr process_ubatch(
97+
llm_graph_result_i * process_ubatch(
10098
const llama_ubatch & ubatch,
10199
llm_graph_type gtype,
102100
llama_memory_context_i * mctx,
@@ -190,19 +188,17 @@ struct llama_context {
190188
public:
191189
int32_t graph_max_nodes() const;
192190

193-
// zero-out inputs and create the ctx_compute for the compute graph
194-
ggml_cgraph * graph_init();
195-
196191
// returns the result of ggml_backend_sched_graph_compute_async execution
197192
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
198193

199194
// reserve a graph with a dummy ubatch of the specified size
200195
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
201196

202197
private:
203-
llm_graph_result_ptr graph_build(
204-
ggml_context * ctx,
205-
ggml_cgraph * gf,
198+
// true - can reuse prev graph
199+
bool graph_build(
200+
llm_graph_result_i * gf_res_cur,
201+
llm_graph_result_i * gf_res_prv,
206202
const llama_ubatch & ubatch,
207203
llm_graph_type gtype,
208204
const llama_memory_context_i * mctx);
@@ -258,8 +254,6 @@ struct llama_context {
258254
ggml_backend_t backend_cpu = nullptr;
259255
std::vector<ggml_backend_ptr> backends;
260256

261-
ggml_context_ptr ctx_compute;
262-
263257
// training
264258
ggml_opt_context_t opt_ctx = nullptr;
265259

@@ -275,8 +269,29 @@ struct llama_context {
275269
std::vector<ggml_backend_t> backend_ptrs;
276270
std::vector<ggml_backend_buffer_type_t> backend_buft;
277271

278-
// memory buffers used to evaluate the model
279-
std::vector<uint8_t> buf_compute_meta;
272+
// ==================================
273+
// double-buffer for compute graphs
274+
// TODO: polish this rough first iteration
275+
//
276+
std::array<llm_graph_result_ptr, 2> gf_res;
277+
278+
int gf_res_i = 0;
279+
280+
llm_graph_result_i * gf_res_next() {
281+
gf_res_i = gf_res_i == 0 ? 1 : 0;
282+
return gf_res[gf_res_i].get();
283+
}
284+
285+
llm_graph_result_i * gf_res_cur() const {
286+
return gf_res[gf_res_i].get();
287+
}
288+
289+
llm_graph_result_i * gf_res_prv() const {
290+
return gf_res[(gf_res_i + 1) % 2].get();
291+
}
292+
293+
llm_graph_result_ptr gf_res_reserve;
294+
// ==================================
280295

281296
// host buffer for the model output (logits and embeddings)
282297
ggml_backend_buffer_ptr buf_output;

0 commit comments

Comments
 (0)