@@ -371,31 +371,11 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
371371// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
372372// these are used by the llama_context to extact the relevant data, based on the compute parameters
373373
374- // TODO: this interface seems redundant - remove it
375- class llm_graph_result_i {
376- public:
377- virtual ~llm_graph_result_i () = default ;
378-
379- virtual ggml_tensor * get_tokens () const = 0;
380- virtual ggml_tensor * get_logits () const = 0;
381- virtual ggml_tensor * get_embd () const = 0;
382- virtual ggml_tensor * get_embd_pooled () const = 0;
383-
384- virtual ggml_cgraph * get_gf () = 0;
385- virtual ggml_context * get_ctx () = 0;
386-
387- virtual void reset () = 0;
388-
389- virtual void set_inputs (const llama_ubatch * ubatch) = 0;
390-
391- virtual bool can_reuse (const llm_graph_params & params) = 0;
392- };
393-
394- using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
395-
396374// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
397375using llm_graph_cb = std::function<void (const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
398376
377+ class llm_graph_result ;
378+
399379struct llm_graph_params {
400380 llm_arch arch = LLM_ARCH_UNKNOWN;
401381
@@ -418,8 +398,7 @@ struct llm_graph_params {
418398
419399 llm_graph_cb cb;
420400
421- // TODO: temporary
422- llm_graph_result_i * res;
401+ llm_graph_result * res;
423402
424403 // return true if the "other" params would result in a graph with the same topology as with the current params
425404 // having the same topology allows us to reuse the graph in some cases
@@ -464,35 +443,37 @@ struct llm_graph_params {
464443 }
465444};
466445
467- class llm_graph_result : public llm_graph_result_i {
446+ class llm_graph_result {
468447public:
469448 llm_graph_result (int64_t max_nodes);
470449
471450 virtual ~llm_graph_result () = default ;
472451
473- ggml_tensor * get_tokens () const override { return t_tokens; }
474- ggml_tensor * get_logits () const override { return t_logits; }
475- ggml_tensor * get_embd () const override { return t_embd; }
476- ggml_tensor * get_embd_pooled () const override { return t_embd_pooled; }
452+ ggml_tensor * get_tokens () const { return t_tokens; }
453+ ggml_tensor * get_logits () const { return t_logits; }
454+ ggml_tensor * get_embd () const { return t_embd; }
455+ ggml_tensor * get_embd_pooled () const { return t_embd_pooled; }
477456
478- ggml_cgraph * get_gf () override { return gf; }
479- ggml_context * get_ctx () override { return ctx_compute.get (); }
457+ ggml_cgraph * get_gf () const { return gf; }
458+ ggml_context * get_ctx () const { return ctx_compute.get (); }
480459
481460 int64_t get_max_nodes () const ;
482461
483- void reset () override ;
462+ void reset ();
484463
485- void set_inputs (const llama_ubatch * ubatch) override ;
464+ void set_inputs (const llama_ubatch * ubatch);
486465
487466 // try to update the existing graph result using the new graph parameters in order to reuse it
488467 // this can only be done if we determine that the resulting graph using the new graph parameters
489468 // would be identical to the existing graph. in that case, we simply have to update the memory
490469 // contexts of the input tensors of the graph and we can reuse it for another computation
491470 // return true if the graph was updated and can be reused
492- bool can_reuse (const llm_graph_params & params) override ;
471+ bool can_reuse (const llm_graph_params & params);
493472
494473 llm_graph_input_i * add_input (llm_graph_input_ptr input);
495474
475+ void set_params (const llm_graph_params & params);
476+
496477 // important graph nodes
497478 ggml_tensor * t_tokens = nullptr ;
498479 ggml_tensor * t_logits = nullptr ;
@@ -510,6 +491,7 @@ class llm_graph_result : public llm_graph_result_i {
510491
511492 int64_t max_nodes;
512493
494+ private:
513495 // keep a copy of the previous graph parameters
514496 // we will use this to determine whether the graph can be reused by comparing them with the new parameters
515497 // note: these are updated after constructing the new graph
@@ -519,6 +501,8 @@ class llm_graph_result : public llm_graph_result_i {
519501 int debug = 0 ;
520502};
521503
504+ using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
505+
522506//
523507// llm_graph_context
524508//
@@ -576,6 +560,7 @@ struct llm_graph_context {
576560 llm_graph_result * res;
577561
578562 ggml_context * ctx0 = nullptr ;
563+ ggml_cgraph * gf = nullptr ;
579564
580565 llm_graph_context (const llm_graph_params & params);
581566 virtual ~llm_graph_context () = default ;
@@ -661,7 +646,6 @@ struct llm_graph_context {
661646 //
662647
663648 ggml_tensor * build_attn_mha (
664- ggml_cgraph * gf,
665649 ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens]
666650 ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens]
667651 ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
@@ -674,7 +658,6 @@ struct llm_graph_context {
674658
675659 ggml_tensor * build_attn (
676660 llm_graph_input_attn_no_cache * inp,
677- ggml_cgraph * gf,
678661 ggml_tensor * wo,
679662 ggml_tensor * wo_b,
680663 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -689,7 +672,6 @@ struct llm_graph_context {
689672
690673 ggml_tensor * build_attn (
691674 llm_graph_input_attn_kv_unified * inp,
692- ggml_cgraph * gf,
693675 ggml_tensor * wo,
694676 ggml_tensor * wo_b,
695677 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -705,7 +687,6 @@ struct llm_graph_context {
705687 // note: if k_cur or v_cur are not provided, they will not be stored in the memory
706688 ggml_tensor * build_attn (
707689 llm_graph_input_attn_kv_unified_iswa * inp,
708- ggml_cgraph * gf,
709690 ggml_tensor * wo,
710691 ggml_tensor * wo_b,
711692 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -720,7 +701,6 @@ struct llm_graph_context {
720701
721702 ggml_tensor * build_attn (
722703 llm_graph_input_attn_cross * inp,
723- ggml_cgraph * gf,
724704 ggml_tensor * wo,
725705 ggml_tensor * wo_b,
726706 ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -742,7 +722,6 @@ struct llm_graph_context {
742722 // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
743723 // `llama_memory_recurrent`
744724 ggml_tensor * build_rs (
745- ggml_cgraph * gf,
746725 ggml_tensor * s,
747726 ggml_tensor * state_copy,
748727 int32_t state_size,
@@ -757,17 +736,15 @@ struct llm_graph_context {
757736
758737 ggml_tensor * build_rs (
759738 llm_graph_input_rs * inp,
760- ggml_cgraph * gf,
761739 ggml_tensor * s,
762740 int32_t state_size,
763741 int32_t n_seqs,
764742 const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const ;
765743
766744 ggml_tensor * build_rwkv_token_shift_load (
767745 llm_graph_input_rs * inp,
768- ggml_cgraph * gf,
769746 const llama_ubatch & ubatch,
770- int il) const ;
747+ int il) const ;
771748
772749 ggml_tensor * build_rwkv_token_shift_store (
773750 ggml_tensor * token_shift,
@@ -784,7 +761,6 @@ struct llm_graph_context {
784761 //
785762
786763 void build_pooling (
787- ggml_cgraph * gf,
788764 ggml_tensor * cls,
789765 ggml_tensor * cls_b,
790766 ggml_tensor * cls_out,
0 commit comments