From 372974043f3a390469387d4d165dbdc62bc90dd2 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 2 Oct 2025 20:37:34 +0000 Subject: [PATCH] removed redudancies from QEFFHybridCache Signed-off-by: Onkar Chougule --- QEfficient/transformers/cache_utils.py | 30 +++++++++++--------------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 16767fbe2..84c8e5c70 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -346,15 +346,15 @@ def update( sliding_window_pattern = cache_kwargs.get("sliding_window_pattern") is_sliding_layer = torch.tensor(bool((layer_idx + 1) % sliding_window_pattern)) layer_ctx_len = self.key_cache[layer_idx].shape[2] - kv_position_ids = torch.where( - (~is_sliding_layer | (position_ids == -1)), position_ids, position_ids % (layer_ctx_len - 1) - ) - kv_position_ids = torch.where( - is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1) * 2), - (position_ids + 1) % layer_ctx_len, - kv_position_ids, - ) + if is_sliding_layer: + kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % (layer_ctx_len - 1)) + + kv_position_ids = torch.where( + position_ids.max() >= (layer_ctx_len - 1) * 2, (position_ids + 1) % layer_ctx_len, kv_position_ids + ) + else: + kv_position_ids = position_ids valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) @@ -368,7 +368,7 @@ def update( # Original Gather ctx_len = self.key_cache[layer_idx].shape[2] ctx_indices = torch.arange(ctx_len)[None, None, ...] - gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit if torch.onnx.is_in_onnx_export(): invalid_idx_value = torch.iinfo(torch.int32).max @@ -376,15 +376,9 @@ def update( invalid_idx_value = 0 ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 - rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) - final_indices = torch.where( - (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices - ) - k_out = CtxGatherFunc.apply(k_out, final_indices) - v_out = CtxGatherFunc.apply(v_out, final_indices) - ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) - v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) + k_out = CtxGatherFunc.apply(k_out, ctx_indices) + v_out = CtxGatherFunc.apply(v_out, ctx_indices) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out