Skip to content

Qualcomm AI Engine Direct - Enable Lookahead Decoding #11437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/qualcomm/oss_scripts/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ list(
${CMAKE_CURRENT_LIST_DIR}/runner/token_generator.h
${CMAKE_CURRENT_LIST_DIR}/runner/imem_alloc.h
${CMAKE_CURRENT_LIST_DIR}/runner/client_mem.h
${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.h
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.h
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.cpp
Expand Down
18 changes: 15 additions & 3 deletions examples/qualcomm/oss_scripts/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
This file provides you the instructions to run LLAMA model with different parameters via Qualcomm HTP backend. We currently support the following models:
1. LLAMA2 Stories 110M
2. LLAMA3.2 1B
3. LLAMA3.2 3B (WIP)
3. LLAMA3.2 3B

We offer the following modes to execute the model:

KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt.
- KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt.

Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache modes to optimize token generation speed. Initially, it uses AR-N model to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens.
- Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache modes to optimize token generation speed. Initially, it uses AR-N model to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens.
- AR-N model: The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use it to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor in hybrid mode.
- Prompt processing with AR-N model:
<figure>
Expand All @@ -19,6 +19,7 @@ Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache
</figcaption>
</figure>

- Lookahead Mode: Lookahead Mode introduces [lookahead decoding](https://arxiv.org/abs/2402.02057) and uses AR-N model to process prompt to enhance token generation speed. While decoding multiple tokens in a single step is infeasible, an LLM can generate multiple guess tokens in parallel. These guess tokens may fit into future parts of the generated sequence. The lookahead decoder generates and verifies these guess tokens, integrating them into the sequence if suitable. In some cases, it can obtain more than one token in a single step. Result is lossless.

## Instructions
### Note
Expand Down Expand Up @@ -127,3 +128,14 @@ You can select the KV Cache update mechanism at runtime by setting the `KV_UPDAT
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
```

You can choose the lookahead mode to enhance decoding speed. To use this mode, you need to specify the following parameters:
- `--ngram` (N-gram size): Represents the size of the n-grams used in the lookahead process.
- `--window` (window size): Determines how many future tokens the algorithm attempts to predict in each step.
- `--gcap` (Verification candidates): Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.

For more details, please refer to the paper ["Break the Sequential Dependency of LLM Inference Using Lookahead Decoding"](https://arxiv.org/abs/2402.02057)

```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2
```
75 changes: 69 additions & 6 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import getpass
import json
import logging
import math
import os
import subprocess
import sys
Expand Down Expand Up @@ -90,6 +91,12 @@
logging.getLogger().setLevel(logging.INFO)


def next_power_of_two(n):
if n == 0:
return 1
return 2 ** math.ceil(math.log2(n))


def smart_mask_updater(
ar_len, atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
):
Expand Down Expand Up @@ -531,6 +538,28 @@ def compile(args, pte_filename, tokenizer):
use_i64_token=use_i64_token,
)
)
elif args.model_mode == "lookahead":
llama_instance_list.append(
LlamaModel(
kv_config,
# To get better performance, we round up to the nearest power of 2.
ar_len=next_power_of_two(
(args.window + args.gcap) * (args.ngram - 1)
),
output_new_cache_only=True,
output_cache=True,
use_i64_token=use_i64_token,
)
)
llama_instance_list.append(
LlamaModel(
prefill_config,
ar_len=args.prefill_ar_len,
output_new_cache_only=True,
output_cache=True,
use_i64_token=use_i64_token,
)
)
else:
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")

Expand Down Expand Up @@ -630,8 +659,8 @@ def permute(w, heads):
tokenizer=tokenizer,
custom_annotations=custom_annotations,
)
# If hybrid mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
if i == 0 and args.model_mode == "hybrid":
# If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
if i == 0 and args.model_mode in ["hybrid", "lookahead"]:
output_indices = 0
for node in llama_instance.llama_graph_module.graph.nodes:
if node.op == "output":
Expand Down Expand Up @@ -673,7 +702,7 @@ def permute(w, heads):
shared_buffer=args.shared_buffer,
)
quant_attrs = llama_instance_list[0].get_quant_attrs()
elif args.model_mode == "hybrid":
elif args.model_mode in ["hybrid", "lookahead"]:
sample_inputs_list = [
llama_instace.inputs for llama_instace in llama_instance_list
]
Expand Down Expand Up @@ -759,6 +788,8 @@ def inference(args, pte_filename, runtime_tokenizer_path, pre_gen_pte=""):
eval_mode = 0
elif args.model_mode == "hybrid":
eval_mode = 1
elif args.model_mode == "lookahead":
eval_mode = 2
else:
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")

Expand Down Expand Up @@ -832,6 +863,9 @@ def post_process():
"--output_path outputs/outputs.txt",
f"--performance_output_path {performance_output_path}",
f"--kv_updater {'SmartMask' if args.kv_updater == smart_mask_updater else 'ShiftPointer'}",
f"--window {args.window}",
f"--gcap {args.gcap}",
f"--ngram {args.ngram}",
runner_args,
]
)
Expand Down Expand Up @@ -971,9 +1005,9 @@ def _build_parser():

parser.add_argument(
"--model_mode",
help="Export and inference kv mode or hybrid mode",
help="Export and inference kv mode, hybrid mode, or lookahead decoding mode",
default="kv",
choices=["kv", "hybrid"],
choices=["kv", "hybrid", "lookahead"],
type=str,
)

Expand All @@ -986,7 +1020,7 @@ def _build_parser():

parser.add_argument(
"--prefill_ar_len",
help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid mode.",
help="The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use this option to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor for hybrid and lookahead mode.",
default=32,
type=int,
)
Expand All @@ -1007,6 +1041,27 @@ def _build_parser():
help="Fallback to cpu embedding operator and type of embedding quantization, '<bitwidth>,<groupsize>', e.g., '4,32'.",
)

parser.add_argument(
"--ngram",
help="Represents the size of the n-grams used in the lookahead process.",
default=5,
type=int,
)

parser.add_argument(
"--window",
help="Determines how many future tokens the algorithm attempts to predict in each step.",
default=8,
type=int,
)

parser.add_argument(
"--gcap",
help="Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.",
default=8,
type=int,
)

parser.add_argument("-v", "--verbose", action="store_true")

return parser
Expand All @@ -1023,6 +1078,14 @@ def export_llama(args) -> None:
args.max_seq_len >= args.prefill_ar_len
), "Please ensure max_seq_len is >= prefill_ar_len"
pte_filename = "hybrid_llama_qnn"
elif args.model_mode == "lookahead":
assert (
args.max_seq_len >= args.prefill_ar_len
), "Please ensure max_seq_len is >= prefill_ar_len"
assert args.max_seq_len > next_power_of_two(
(args.window + args.gcap) * (args.ngram - 1)
), "Please ensure max_seq_len is > next_power_of_two((args.window + args.gcap) * (args.ngram - 1))"
pte_filename = "lookahead_llama_qnn"
else:
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")

Expand Down
23 changes: 19 additions & 4 deletions examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,24 @@ DEFINE_int32(
DEFINE_int32(
eval_mode,
0,
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv)");
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv) / 2: Lookahead Decoding");
DEFINE_string(
kv_updater,
"How to update kv cache. Choose between SmartMask and ShiftPointer",
"SmartMask");
"SmartMask",
"How to update kv cache. Choose between SmartMask and ShiftPointer");
DEFINE_int32(num_iters, 1, "total num of iterations to run.");
DEFINE_int32(
ngram,
0,
"[Lookahead Decoding] Represents the size of the n-grams used in the lookahead process.");
DEFINE_int32(
window,
0,
"[Lookahead Decoding] Determines how many future tokens the algorithm attempts to predict in each step.");
DEFINE_int32(
gcap,
0,
"[Lookahead Decoding] Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.");

std::vector<std::string> CollectPrompts(int argc, char** argv) {
// Collect all prompts from command line, example usage:
Expand Down Expand Up @@ -111,7 +123,10 @@ int main(int argc, char** argv) {
FLAGS_performance_output_path.c_str(),
FLAGS_temperature,
FLAGS_eval_mode,
FLAGS_kv_updater);
FLAGS_kv_updater,
FLAGS_ngram,
FLAGS_window,
FLAGS_gcap);
auto llama_version = runner.get_llama_version();
std::vector<char> buf;
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
Expand Down
85 changes: 66 additions & 19 deletions examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void KVManager::init_attention_mask(
int32_t ar_len,
int32_t n_past) {
ET_CHECK_MSG(
attention_map.size() == ar_len,
attention_map.size() <= ar_len,
"The size of attention_map (%zu) doesn't match with ar_len (%d)",
attention_map.size(),
ar_len);
Expand Down Expand Up @@ -197,9 +197,11 @@ void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) {
? 0
: metadata_.max_cache_len - (metadata_.context_len - cur_ar_len_);
v_cache_[layer][head].buffer = single_layer_v_cache +
head * single_head_size_in + cache_gap * metadata_.head_dim;
v_cache_[layer][head].output_buffer =
single_layer_v_cache + (head + 1) * single_head_size_in;
head * metadata_.head_dim * metadata_.context_len +
cache_gap * metadata_.head_dim;
v_cache_[layer][head].output_buffer = single_layer_v_cache +
head * metadata_.head_dim * metadata_.context_len +
single_head_size_in;
}
}
break;
Expand Down Expand Up @@ -311,21 +313,29 @@ bool KVManager::update_cache_tensor(
return updated;
}

void KVManager::update_cache(int32_t ar_len, int32_t n_past, int32_t n_update) {
void KVManager::update_cache(
int32_t ar_len,
int32_t n_past,
int32_t n_update,
const std::vector<bool>& selected) {
ET_CHECK_MSG(
cur_ar_len_ == ar_len,
"Current AR length (%d) is not matched with target AR length (%d). Please rearrange cache first.",
cur_ar_len_,
ar_len);
for (int layer = 0; layer < metadata_.num_layers; ++layer) {
for (int head = 0; head < metadata_.num_heads; ++head) {
update_key(k_cache_[layer][head], n_past, n_update);
update_value(v_cache_[layer][head], n_past, n_update);
update_key(k_cache_[layer][head], n_past, n_update, selected);
update_value(v_cache_[layer][head], n_past, n_update, selected);
}
}
}

void KVManager::update_key(KVCache& k_cache, int32_t n_past, int32_t n_update) {
void KVManager::update_key(
KVCache& k_cache,
int32_t n_past,
int32_t n_update,
const std::vector<bool>& selected) {
uint8_t* write_ptr = k_cache.buffer;
uint8_t* read_ptr = k_cache.output_buffer;
const int32_t copy_size = n_update * sizeof(uint8_t);
Expand All @@ -340,22 +350,35 @@ void KVManager::update_key(KVCache& k_cache, int32_t n_past, int32_t n_update) {
write_ptr += iter_size + past_size;
if (kv_updater_ == KVManagerMode::SMART_MASK)
write_ptr += past_size;

for (int i = 0; i < n_iter; ++i) {
std::memcpy(write_ptr, read_ptr, copy_size);
write_ptr += iter_size;
read_ptr += out_size;
if (selected.empty()) {
for (int i = 0; i < n_iter; ++i) {
std::memcpy(write_ptr, read_ptr, copy_size);
write_ptr += iter_size;
read_ptr += out_size;
}
} else {
std::vector<int32_t> true_indices(n_update);
for (int i = 0, j = 0; i < selected.size() && j < n_update; ++i) {
if (selected[i]) {
true_indices[j++] = i;
}
}
for (int i = 0; i < n_iter; ++i) {
auto wp = write_ptr, rp = read_ptr;
for (auto ind : true_indices) {
*wp++ = rp[ind];
}
write_ptr += iter_size;
read_ptr += out_size;
}
}
}

void KVManager::update_value(
KVCache& v_cache,
int32_t n_past,
int32_t n_update) {
// Value cache doesn't need to copy for SHIFT_POINTER mode
if (kv_updater_ == KVManagerMode::SHIFT_POINTER)
return;

int32_t n_update,
const std::vector<bool>& selected) {
uint8_t* write_ptr = v_cache.buffer;
uint8_t* read_ptr = v_cache.output_buffer;
const int32_t copy_size = n_update * metadata_.head_dim * sizeof(uint8_t);
Expand All @@ -364,7 +387,31 @@ void KVManager::update_value(
if (kv_updater_ == KVManagerMode::SMART_MASK)
write_ptr += past_size;

std::memcpy(write_ptr, read_ptr, copy_size);
// Update the value cache for lookahead decoding in SHIFT_POINTER mode
if (kv_updater_ == KVManagerMode::SHIFT_POINTER) {
read_ptr += past_size;
write_ptr = read_ptr;
}

if (selected.empty()) {
// In general, value cache doesn't need to copy for SHIFT_POINTER mode
if (kv_updater_ == KVManagerMode::SHIFT_POINTER)
return;
std::memcpy(write_ptr, read_ptr, copy_size);
} else {
int32_t update_times = n_update;
auto wp = write_ptr, rp = read_ptr;
for (auto sel : selected) {
if (sel) {
std::memcpy(wp, rp, metadata_.head_dim * sizeof(uint8_t));
wp += metadata_.head_dim;
update_times--;
if (update_times == 0)
break;
}
rp += metadata_.head_dim;
}
}
}

} // namespace example
Loading
Loading