Skip to content

Commit 1bc9dcb

Browse files
committed
Qualcomm AI Engine Direct - Enable Lookahead Decoding
summary: - Add new eval_mode: lookahead - Add three arguments: ngram, window, gcap - Add lhd_token_generator
1 parent aed9c7e commit 1bc9dcb

File tree

12 files changed

+767
-66
lines changed

12 files changed

+767
-66
lines changed

examples/qualcomm/oss_scripts/llama/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ list(
3636
${CMAKE_CURRENT_LIST_DIR}/runner/token_generator.h
3737
${CMAKE_CURRENT_LIST_DIR}/runner/imem_alloc.h
3838
${CMAKE_CURRENT_LIST_DIR}/runner/client_mem.h
39+
${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.cpp
40+
${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.h
3941
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.cpp
4042
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.h
4143
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.cpp

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import getpass
1212
import json
1313
import logging
14+
import math
1415
import os
1516
import subprocess
1617
import sys
@@ -90,6 +91,12 @@
9091
logging.getLogger().setLevel(logging.INFO)
9192

9293

94+
def next_power_of_two(n):
95+
if n == 0:
96+
return 1
97+
return 2 ** math.ceil(math.log2(n))
98+
99+
93100
def smart_mask_updater(
94101
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
95102
):
@@ -531,6 +538,28 @@ def compile(args, pte_filename, tokenizer):
531538
use_i64_token=use_i64_token,
532539
)
533540
)
541+
elif args.model_mode == "lookahead":
542+
llama_instance_list.append(
543+
LlamaModel(
544+
kv_config,
545+
# To get better performance, we round up to the nearest power of 2.
546+
ar_len=next_power_of_two(
547+
(args.window + args.gcap) * (args.ngram - 1)
548+
),
549+
output_new_cache_only=True,
550+
output_cache=True,
551+
use_i64_token=use_i64_token,
552+
)
553+
)
554+
llama_instance_list.append(
555+
LlamaModel(
556+
prefill_config,
557+
ar_len=args.prefill_ar_len,
558+
output_new_cache_only=True,
559+
output_cache=True,
560+
use_i64_token=use_i64_token,
561+
)
562+
)
534563
else:
535564
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
536565

@@ -630,8 +659,8 @@ def permute(w, heads):
630659
tokenizer=tokenizer,
631660
custom_annotations=custom_annotations,
632661
)
633-
# If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
634-
if i == 0 and args.model_mode == "hybrid":
662+
# If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
663+
if i == 0 and args.model_mode in ["hybrid", "lookahead"]:
635664
output_indices = 0
636665
for node in llama_instance.llama_graph_module.graph.nodes:
637666
if node.op == "output":
@@ -673,7 +702,7 @@ def permute(w, heads):
673702
shared_buffer=args.shared_buffer,
674703
)
675704
quant_attrs = llama_instance_list[0].get_quant_attrs()
676-
elif args.model_mode == "hybrid":
705+
elif args.model_mode in ["hybrid", "lookahead"]:
677706
sample_inputs_list = [
678707
llama_instace.inputs for llama_instace in llama_instance_list
679708
]
@@ -763,6 +792,8 @@ def inference(args, pte_filename, runtime_tokenizer_path, pre_gen_pte=""):
763792
eval_mode = 0
764793
elif args.model_mode == "hybrid":
765794
eval_mode = 1
795+
elif args.model_mode == "lookahead":
796+
eval_mode = 2
766797
else:
767798
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
768799

@@ -836,6 +867,9 @@ def post_process():
836867
"--output_path outputs/outputs.txt",
837868
f"--performance_output_path {performance_output_path}",
838869
f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}",
870+
f"--window {args.window}",
871+
f"--gcap {args.gcap}",
872+
f"--ngram {args.ngram}",
839873
runner_args,
840874
]
841875
)
@@ -975,9 +1009,9 @@ def _build_parser():
9751009

9761010
parser.add_argument(
9771011
"--model_mode",
978-
help="Export and inference kv mode or hybrid mode",
1012+
help="Export and inference kv mode, hybrid mode, or lookahead decoding mode",
9791013
default="kv",
980-
choices=["kv", "hybrid"],
1014+
choices=["kv", "hybrid", "lookahead"],
9811015
type=str,
9821016
)
9831017

@@ -990,7 +1024,7 @@ def _build_parser():
9901024

9911025
parser.add_argument(
9921026
"--prefill_ar_len",
993-
help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid mode.",
1027+
help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid and lookahead mode.",
9941028
default=32,
9951029
type=int,
9961030
)
@@ -1011,6 +1045,27 @@ def _build_parser():
10111045
help="Fallback to cpu embedding operator and type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '4,32'.",
10121046
)
10131047

1048+
parser.add_argument(
1049+
"--ngram",
1050+
help="Represents the size of the n-grams used in the lookahead process.",
1051+
default=5,
1052+
type=int,
1053+
)
1054+
1055+
parser.add_argument(
1056+
"--window",
1057+
help="Determines how many future tokens the algorithm attempts to predict in each step.",
1058+
default=8,
1059+
type=int,
1060+
)
1061+
1062+
parser.add_argument(
1063+
"--gcap",
1064+
help="Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.",
1065+
default=8,
1066+
type=int,
1067+
)
1068+
10141069
parser.add_argument("-v", "--verbose", action="store_true")
10151070

10161071
return parser
@@ -1027,6 +1082,14 @@ def export_llama(args) -> None:
10271082
args.max_seq_len >= args.prefill_ar_len
10281083
), "Please ensure max_seq_len is >= prefill_ar_len"
10291084
pte_filename = "hybrid_llama_qnn"
1085+
elif args.model_mode == "lookahead":
1086+
assert (
1087+
args.max_seq_len >= args.prefill_ar_len
1088+
), "Please ensure max_seq_len is >= prefill_ar_len"
1089+
assert args.max_seq_len > next_power_of_two(
1090+
(args.window + args.gcap) * (args.ngram - 1)
1091+
), "Please ensure max_seq_len is > next_power_of_two((args.window + args.gcap) * (args.ngram - 1))"
1092+
pte_filename = "lookahead_llama_qnn"
10301093
else:
10311094
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")
10321095

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,24 @@ DEFINE_int32(
5353
DEFINE_int32(
5454
eval_mode,
5555
0,
56-
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv)");
56+
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv) / 2: Lookahead Decoding");
5757
DEFINE_string(
5858
kv_updater,
59-
"How to update kv cache. Choose between SmartMask and ShiftPointer",
60-
"SmartMask");
59+
"SmartMask",
60+
"How to update kv cache. Choose between SmartMask and ShiftPointer");
6161
DEFINE_int32(num_iters, 1, "total num of iterations to run.");
62+
DEFINE_int32(
63+
ngram,
64+
0,
65+
"[Lookahead Decoding] Represents the size of the n-grams used in the lookahead process.");
66+
DEFINE_int32(
67+
window,
68+
0,
69+
"[Lookahead Decoding] Determines how many future tokens the algorithm attempts to predict in each step.");
70+
DEFINE_int32(
71+
gcap,
72+
0,
73+
"[Lookahead Decoding] Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.");
6274

6375
std::vector<std::string> CollectPrompts(int argc, char** argv) {
6476
// Collect all prompts from command line, example usage:
@@ -111,7 +123,10 @@ int main(int argc, char** argv) {
111123
FLAGS_performance_output_path.c_str(),
112124
FLAGS_temperature,
113125
FLAGS_eval_mode,
114-
FLAGS_kv_updater);
126+
FLAGS_kv_updater,
127+
FLAGS_ngram,
128+
FLAGS_window,
129+
FLAGS_gcap);
115130
auto llama_version = runner.get_llama_version();
116131
std::vector<char> buf;
117132
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char

examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ void KVManager::init_attention_mask(
5151
int32_t ar_len,
5252
int32_t n_past) {
5353
ET_CHECK_MSG(
54-
attention_map.size() == ar_len,
54+
attention_map.size() <= ar_len,
5555
"The size of attention_map (%zu) doesn't match with ar_len (%d)",
5656
attention_map.size(),
5757
ar_len);
@@ -197,9 +197,11 @@ void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) {
197197
? 0
198198
: metadata_.max_cache_len - (metadata_.context_len - cur_ar_len_);
199199
v_cache_[layer][head].buffer = single_layer_v_cache +
200-
head * single_head_size_in + cache_gap * metadata_.head_dim;
201-
v_cache_[layer][head].output_buffer =
202-
single_layer_v_cache + (head + 1) * single_head_size_in;
200+
head * metadata_.head_dim * metadata_.context_len +
201+
cache_gap * metadata_.head_dim;
202+
v_cache_[layer][head].output_buffer = single_layer_v_cache +
203+
head * metadata_.head_dim * metadata_.context_len +
204+
single_head_size_in;
203205
}
204206
}
205207
break;
@@ -311,21 +313,29 @@ bool KVManager::update_cache_tensor(
311313
return updated;
312314
}
313315

314-
void KVManager::update_cache(int32_t ar_len, int32_t n_past, int32_t n_update) {
316+
void KVManager::update_cache(
317+
int32_t ar_len,
318+
int32_t n_past,
319+
int32_t n_update,
320+
const std::vector<bool>& selected) {
315321
ET_CHECK_MSG(
316322
cur_ar_len_ == ar_len,
317323
"Current AR length (%d) is not matched with target AR length (%d). Please rearrange cache first.",
318324
cur_ar_len_,
319325
ar_len);
320326
for (int layer = 0; layer < metadata_.num_layers; ++layer) {
321327
for (int head = 0; head < metadata_.num_heads; ++head) {
322-
update_key(k_cache_[layer][head], n_past, n_update);
323-
update_value(v_cache_[layer][head], n_past, n_update);
328+
update_key(k_cache_[layer][head], n_past, n_update, selected);
329+
update_value(v_cache_[layer][head], n_past, n_update, selected);
324330
}
325331
}
326332
}
327333

328-
void KVManager::update_key(KVCache& k_cache, int32_t n_past, int32_t n_update) {
334+
void KVManager::update_key(
335+
KVCache& k_cache,
336+
int32_t n_past,
337+
int32_t n_update,
338+
const std::vector<bool>& selected) {
329339
uint8_t* write_ptr = k_cache.buffer;
330340
uint8_t* read_ptr = k_cache.output_buffer;
331341
const int32_t copy_size = n_update * sizeof(uint8_t);
@@ -340,22 +350,35 @@ void KVManager::update_key(KVCache& k_cache, int32_t n_past, int32_t n_update) {
340350
write_ptr += iter_size + past_size;
341351
if (kv_updater_ == KVManagerMode::SMART_MASK)
342352
write_ptr += past_size;
343-
344-
for (int i = 0; i < n_iter; ++i) {
345-
std::memcpy(write_ptr, read_ptr, copy_size);
346-
write_ptr += iter_size;
347-
read_ptr += out_size;
353+
if (selected.empty()) {
354+
for (int i = 0; i < n_iter; ++i) {
355+
std::memcpy(write_ptr, read_ptr, copy_size);
356+
write_ptr += iter_size;
357+
read_ptr += out_size;
358+
}
359+
} else {
360+
std::vector<int32_t> true_indices(n_update);
361+
for (int i = 0, j = 0; i < selected.size() && j < n_update; ++i) {
362+
if (selected[i]) {
363+
true_indices[j++] = i;
364+
}
365+
}
366+
for (int i = 0; i < n_iter; ++i) {
367+
auto wp = write_ptr, rp = read_ptr;
368+
for (auto ind : true_indices) {
369+
*wp++ = rp[ind];
370+
}
371+
write_ptr += iter_size;
372+
read_ptr += out_size;
373+
}
348374
}
349375
}
350376

351377
void KVManager::update_value(
352378
KVCache& v_cache,
353379
int32_t n_past,
354-
int32_t n_update) {
355-
// Value cache doesn't need to copy for SHIFT_POINTER mode
356-
if (kv_updater_ == KVManagerMode::SHIFT_POINTER)
357-
return;
358-
380+
int32_t n_update,
381+
const std::vector<bool>& selected) {
359382
uint8_t* write_ptr = v_cache.buffer;
360383
uint8_t* read_ptr = v_cache.output_buffer;
361384
const int32_t copy_size = n_update * metadata_.head_dim * sizeof(uint8_t);
@@ -364,7 +387,31 @@ void KVManager::update_value(
364387
if (kv_updater_ == KVManagerMode::SMART_MASK)
365388
write_ptr += past_size;
366389

367-
std::memcpy(write_ptr, read_ptr, copy_size);
390+
// Update the value cache for lookahead decoding in SHIFT_POINTER mode
391+
if (kv_updater_ == KVManagerMode::SHIFT_POINTER) {
392+
read_ptr += past_size;
393+
write_ptr = read_ptr;
394+
}
395+
396+
if (selected.empty()) {
397+
// In general, value cache doesn't need to copy for SHIFT_POINTER mode
398+
if (kv_updater_ == KVManagerMode::SHIFT_POINTER)
399+
return;
400+
std::memcpy(write_ptr, read_ptr, copy_size);
401+
} else {
402+
int32_t update_times = n_update;
403+
auto wp = write_ptr, rp = read_ptr;
404+
for (auto sel : selected) {
405+
if (sel) {
406+
std::memcpy(wp, rp, metadata_.head_dim * sizeof(uint8_t));
407+
wp += metadata_.head_dim;
408+
update_times--;
409+
if (update_times == 0)
410+
break;
411+
}
412+
rp += metadata_.head_dim;
413+
}
414+
}
368415
}
369416

370417
} // namespace example

examples/qualcomm/oss_scripts/llama/runner/kv_manager.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,13 @@ class KVManager {
120120
* @param ar_len Length of input tokens.
121121
* @param n_past Number of past elements in the cache.
122122
* @param n_update Number of elements to be updated.
123+
* @param selected Indicate which position to be updated
123124
*/
124-
void update_cache(int32_t ar_len, int32_t n_past, int32_t n_update);
125+
void update_cache(
126+
int32_t ar_len,
127+
int32_t n_past,
128+
int32_t n_update,
129+
const std::vector<bool>& selected);
125130

126131
const std::vector<std::vector<KVCache>>& get_k_cache_() const {
127132
return k_cache_;
@@ -138,8 +143,16 @@ class KVManager {
138143
// Helper functions to rearrange and update key and value caches
139144
void rearrange_key(KVCache& k_cache, int32_t ar_len_dst);
140145
void rearrange_value(KVCache& v_cache, int32_t ar_len_dst);
141-
void update_key(KVCache& k_cache, int32_t n_past, int32_t n_update);
142-
void update_value(KVCache& v_cache, int32_t n_past, int32_t n_update);
146+
void update_key(
147+
KVCache& k_cache,
148+
int32_t n_past,
149+
int32_t n_update,
150+
const std::vector<bool>& selected);
151+
void update_value(
152+
KVCache& v_cache,
153+
int32_t n_past,
154+
int32_t n_update,
155+
const std::vector<bool>& selected);
143156
KVManagerMode kv_updater_;
144157

145158
// metadata

0 commit comments

Comments
 (0)