Skip to content

adding Context Length Specialization (CCL) #388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def _compile(
self,
onnx_path: Optional[str] = None,
compile_dir: Optional[str] = None,
comp_ctx_lengths: Optional[List[int]] = None,
*,
mxint8_kv_cache: bool = False,
specializations: Optional[List[Dict[str, int]]] = None,
Expand Down Expand Up @@ -247,7 +248,7 @@ def _compile(
- convert_to_fp16=True -> -convert-to-fp16
"""
if onnx_path is None and self.onnx_path is None:
self.export()
self.export(comp_ctx_lengths)

onnx_path = Path(onnx_path or self.onnx_path)
compile_dir = Path(compile_dir or onnx_path.parent)
Expand Down
10 changes: 10 additions & 0 deletions QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def main(
full_batch_size: Optional[int] = None,
prompt_len: int = 32,
ctx_len: int = 128,
comp_ctx_lengths: Optional[List[int]] = None,
generation_len: Optional[int] = None,
mxfp6: bool = False,
mxint8: bool = False,
Expand Down Expand Up @@ -183,6 +184,7 @@ def main(
_ = qeff_model.compile(
prefill_seq_len=prompt_len,
ctx_len=ctx_len,
comp_ctx_lengths=comp_ctx_lengths,
num_cores=num_cores,
mxfp6_matmul=mxfp6,
aic_enable_depth_first=aic_enable_depth_first,
Expand All @@ -205,6 +207,7 @@ def main(
qeff_model=qeff_model,
model_name=model_name,
prompt=prompt,
comp_ctx_lengths=comp_ctx_lengths,
image_url=image_url,
image_path=image_path,
device_group=device_group,
Expand All @@ -225,6 +228,7 @@ def main(
prompts=prompt,
device_id=device_group,
prompt=prompt,
comp_ctx_lengths=comp_ctx_lengths,
prompts_txt_file_path=prompts_txt_file_path,
generation_len=generation_len,
)
Expand Down Expand Up @@ -253,6 +257,12 @@ def main(
"--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation."
)
parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.")
parser.add_argument(
"--comp_ctx_lengths",
"--comp_ctx_lengths",
type=lambda comp_ctx_lengths: [int(x) for x in comp_ctx_lengths.strip("[]").split(",")],
help="Compute Context length for text generation (comma-separated) e.g. [512,1024,2048] ",
)
parser.add_argument(
"--mxfp6",
"--mxfp6_matmul",
Expand Down
16 changes: 11 additions & 5 deletions QEfficient/customop/ctx_scatter_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
def CtxGather(
data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
) -> onnxscript.FLOAT:
# Create a shape tensor based on comp_ctx_len
shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0)

# Directly use the shape tensor without validation
ctx_indices = ops.Expand(ctx_indices, shape_tensor)
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
return ops.GatherND(data, ctx_indices, batch_dims=2)

Expand All @@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function):
"""

@staticmethod
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]
Expand All @@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data)
13 changes: 8 additions & 5 deletions QEfficient/customop/ctx_scatter_gather_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,12 @@ def symbolic(

@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGatherCB(
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
) -> onnxscript.FLOAT:
batch_size = ops.Gather(ops.Shape(batch_index), [0])
num_heads = ops.Gather(ops.Shape(data), [1])
ctx_len = ops.Gather(ops.Shape(data), [2])
# using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well.
ctx_len = ops.Reshape(comp_ctx_len, [1])

# Expanded shape to create indices
zero = ops.Constant(value_ints=[0])
Expand All @@ -119,7 +120,7 @@ def CtxGatherCB(

class CtxGatherFuncCB(torch.autograd.Function):
@staticmethod
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor):
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = batch_index.view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]
Expand All @@ -129,8 +130,10 @@ def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data)
def symbolic(
g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int
) -> torch.Value:
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
Expand Down
83 changes: 82 additions & 1 deletion QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def cloud_ai_100_exec_kv(
prompts_txt_file_path: Optional[str] = None,
device_id: Optional[List[int]] = None,
generation_len: Optional[int] = None,
comp_ctx_lengths: Optional[List[int]] = None,
enable_debug_logs: bool = False,
stream: bool = True,
write_io_dir: Optional[str] = None,
Expand Down Expand Up @@ -368,6 +369,7 @@ def cloud_ai_100_exec_kv(
qpc_path=qpc_path,
device_id=device_id,
ctx_len=ctx_len,
comp_ctx_lengths=comp_ctx_lengths,
enable_debug_logs=enable_debug_logs,
write_io_dir=write_io_dir,
full_batch_size=full_batch_size,
Expand Down Expand Up @@ -407,12 +409,14 @@ def __init__(
qpc_path: str,
full_batch_size: Optional[int] = None,
ctx_len: Optional[int] = None,
comp_ctx_lengths: Optional[List[int]] = None,
device_id: Optional[List[int]] = None,
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
is_tlm: Optional[int] = None,
) -> None:
self._ctx_len = ctx_len
self.comp_ctx_lengths = comp_ctx_lengths
self._write_io_dir = write_io_dir
self.is_tlm = is_tlm

Expand Down Expand Up @@ -724,6 +728,12 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)

inputs["comp_ctx_lengths"] = np.random.rand(
self.comp_ctx_lengths[0] if self.comp_ctx_lengths is not None else self._ctx_len
)
buffers = {"comp_ctx_len_out": np.zeros(1)}
self._session.set_buffers(buffers)

for i in range(num_chunks):
chunk_inputs = inputs.copy()
chunk_inputs["input_ids"] = inputs["input_ids"][
Expand All @@ -741,6 +751,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
generation_len,
)

def initialize_ccl(self, decode_inputs):
max_ccl_id = len(self.comp_ctx_lengths) - 1
max_position_id = np.max(decode_inputs["position_ids"])
ccl_id = 1
for i in range(1, len(self.comp_ctx_lengths)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a reverse list and pop out the last value if max_position_id < self.comp_ctx_lengths[-1]? this way we can avoid the loop

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we check with the last element? Each request can be finished in different position_id and we need to check to find the most suitable CCL window to get the best performance. This for loop only happens at the end of a request and it's an order of length(CCL) that can't be more than a few values because of compiler limitation in the number of specializations.

if max_position_id < self.comp_ctx_lengths[i]:
ccl_id = i
break
buffers = {"comp_ctx_len_out": np.zeros(1)}
print(f"CCL: {self.comp_ctx_lengths[ccl_id]}")

return buffers, ccl_id, max_ccl_id

def run_continuous_batching_decode(self, prompt_queue, generation_len):
"""
Runs continuous batching decode for the given prompt queue and generation length.
Expand Down Expand Up @@ -771,6 +794,16 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
# Prepare decode inputs inputs.
decode_inputs = self.prepare_decode_inputs()

if self.comp_ctx_lengths is not None:
list_of_comp_ctx_lengths = [np.zeros(length) for length in self.comp_ctx_lengths]
buffers, ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]
self._session.set_buffers(buffers)
else:
decode_inputs["comp_ctx_lengths"] = np.zeros(self._ctx_len)
buffers = {"comp_ctx_len_out": np.zeros(1)}
self._session.set_buffers(buffers)

while prompt_queue or current_decode_ongoing.any():
outputs = self._session.run(decode_inputs)

Expand Down Expand Up @@ -808,6 +841,19 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
batch_id_map[decode_batch_id]
]

if self.comp_ctx_lengths is not None:
###Recalculate ccl_id based on position ids###
# Determine the maximum value of position_ids across all batch elements
max_position_id = np.max(decode_inputs["position_ids"])

# Update ccl_id and comp_ctx_lengths based on the maximum position id
ccl_id = 1
for i in range(1, len(self.comp_ctx_lengths)):
if max_position_id < self.comp_ctx_lengths[i]:
ccl_id = i
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

break
decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]

else:
current_decode_ongoing[decode_batch_id] = False
else:
Expand All @@ -818,6 +864,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
next_token_id[decode_batch_id, -1]
)

if self.comp_ctx_lengths is not None:
# Update ccl_id and comp_ctx_lengths based on the maximum position id
if decode_inputs["position_ids"][decode_batch_id, -1] >= self.comp_ctx_lengths[ccl_id] - 1:
ccl_id = min(ccl_id + 1, max_ccl_id)
decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]

generated_id_current_index[decode_batch_id] += 1

return decode_pause_time
Expand All @@ -842,7 +894,25 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
self._session.set_buffers({"logits": logits_out_placeholder})
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
num_token = 0

if self.comp_ctx_lengths is not None:
list_of_comp_ctx_lengths = [np.zeros(length) for length in self.comp_ctx_lengths]
buffers, ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]
self._session.set_buffers(buffers)
else:
decode_inputs["comp_ctx_lengths"] = np.zeros(self._ctx_len)
buffers = {"comp_ctx_len_out": np.zeros(1)}
self._session.set_buffers(buffers)

cache_index = np.max(decode_inputs["position_ids"])
for num_token in range(1, generation_len):
if self.comp_ctx_lengths is not None:
if cache_index >= self.comp_ctx_lengths[ccl_id] - 1:
# if cache_index >= self.comp_ctx_lengths[ccl_id] - 1:
ccl_id = min(ccl_id + 1, max_ccl_id)
decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]

if streamer:
streamer.put(decode_inputs["input_ids"][0])
outputs = self._session.run(decode_inputs)
Expand All @@ -854,6 +924,7 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
# Prepare inputs for next iteration
decode_inputs["input_ids"] = outputs["logits"].argmax(2)
decode_inputs["position_ids"][:, -1] += 1
cache_index += 1
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id

Expand Down Expand Up @@ -901,17 +972,27 @@ def __init__(
qpc_path: str,
full_batch_size: Optional[int] = None,
ctx_len: Optional[int] = None,
comp_ctx_lengths: Optional[List[int]] = None,
device_id: Optional[List[int]] = None,
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
is_tlm: bool = False,
) -> None:
self._qaic_model = QEffTextGenerationBase(
tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm
tokenizer,
qpc_path,
full_batch_size,
ctx_len,
comp_ctx_lengths,
device_id,
enable_debug_logs,
write_io_dir,
is_tlm,
)
self._full_batch_size = self._qaic_model.full_batch_size
self._tokenizer = self._qaic_model.tokenizer
self._ctx_len = ctx_len
self.comp_ctx_lengths = comp_ctx_lengths
self._perf_metrics = None
self._prompt_queue = None
self._text_streamer = None
Expand Down
36 changes: 23 additions & 13 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def read_only(self, layer_idx, cache_kwargs):
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None)
comp_ctx_len = cache_kwargs.get("CCL")

ctx_len = k_out.shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
Expand All @@ -101,15 +103,19 @@ def read_only(self, layer_idx, cache_kwargs):
else:
invalid_idx_value = 0

ctx_indices = ctx_indices[:, :, :comp_ctx_len]
invalid_mask = ctx_indices > gather_limit

invalid_mask = invalid_mask[:, :, :comp_ctx_len]

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

if batch_index is not None:
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
k_out = CtxGatherFuncCB.apply(self.key_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
v_out = CtxGatherFuncCB.apply(self.value_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
else:
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
v_out = CtxGatherFunc.apply(v_out, ctx_indices)

k_out = CtxGatherFunc.apply(self.key_cache[layer_idx], ctx_indices, comp_ctx_len)
v_out = CtxGatherFunc.apply(self.value_cache[layer_idx], ctx_indices, comp_ctx_len)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out

Expand Down Expand Up @@ -144,6 +150,7 @@ def update(
else:
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs
comp_ctx_len = cache_kwargs.get("CCL")

# Scatter
if batch_index is not None:
Expand All @@ -163,26 +170,29 @@ def update(
self.value_cache[layer_idx], position_ids, value_states
)

k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Gather
ctx_len = k_out.shape[2]
ctx_len = self.key_cache[layer_idx].shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0

ctx_indices = ctx_indices[:, :, :comp_ctx_len]
invalid_mask = ctx_indices > gather_limit

invalid_mask = invalid_mask[:, :, :comp_ctx_len]

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

if batch_index is not None:
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
k_out = CtxGatherFuncCB.apply(self.key_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
v_out = CtxGatherFuncCB.apply(self.value_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
else:
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
k_out = CtxGatherFunc.apply(self.key_cache[layer_idx], ctx_indices, comp_ctx_len)
v_out = CtxGatherFunc.apply(self.value_cache[layer_idx], ctx_indices, comp_ctx_len)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)

return k_out, v_out
Expand Down
Loading
Loading