diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py
index 119ec599f..ef5da3408 100644
--- a/QEfficient/base/modeling_qeff.py
+++ b/QEfficient/base/modeling_qeff.py
@@ -219,6 +219,7 @@ def _compile(
         self,
         onnx_path: Optional[str] = None,
         compile_dir: Optional[str] = None,
+        comp_ctx_lengths: Optional[List[int]] = None,
         *,
         mxint8_kv_cache: bool = False,
         specializations: Optional[List[Dict[str, int]]] = None,
@@ -247,7 +248,7 @@ def _compile(
                 - convert_to_fp16=True -> -convert-to-fp16
         """
         if onnx_path is None and self.onnx_path is None:
-            self.export()
+            self.export(comp_ctx_lengths)
 
         onnx_path = Path(onnx_path or self.onnx_path)
         compile_dir = Path(compile_dir or onnx_path.parent)
diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py
index 30e67344a..ac1f00a57 100644
--- a/QEfficient/cloud/infer.py
+++ b/QEfficient/cloud/infer.py
@@ -102,6 +102,7 @@ def main(
     full_batch_size: Optional[int] = None,
     prompt_len: int = 32,
     ctx_len: int = 128,
+    comp_ctx_lengths: Optional[List[int]] = None,
     generation_len: Optional[int] = None,
     mxfp6: bool = False,
     mxint8: bool = False,
@@ -183,6 +184,7 @@ def main(
     _ = qeff_model.compile(
         prefill_seq_len=prompt_len,
         ctx_len=ctx_len,
+        comp_ctx_lengths=comp_ctx_lengths,
         num_cores=num_cores,
         mxfp6_matmul=mxfp6,
         aic_enable_depth_first=aic_enable_depth_first,
@@ -209,6 +211,7 @@ def main(
             qeff_model=qeff_model,
             model_name=model_name,
             prompt=prompt,
+            comp_ctx_lengths=comp_ctx_lengths,
             image_url=image_url,
             image_path=image_path,
             device_group=device_group,
@@ -229,6 +232,7 @@ def main(
             prompts=prompt,
             device_id=device_group,
             prompt=prompt,
+            comp_ctx_lengths=comp_ctx_lengths,
             prompts_txt_file_path=prompts_txt_file_path,
             generation_len=generation_len,
         )
@@ -257,6 +261,12 @@ def main(
         "--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation."
     )
     parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.")
+    parser.add_argument(
+        "--comp_ctx_lengths",
+        "--comp_ctx_lengths",
+        type=lambda comp_ctx_lengths: [int(x) for x in comp_ctx_lengths.strip("[]").split(",")],
+        help="Compute Context length for text generation (comma-separated) e.g. [512,1024,2048]  ",
+    )
     parser.add_argument(
         "--mxfp6",
         "--mxfp6_matmul",
diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py
index 570df0cf5..79761e02d 100644
--- a/QEfficient/customop/ctx_scatter_gather.py
+++ b/QEfficient/customop/ctx_scatter_gather.py
@@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor
 
 
 @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
-def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
-    ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
+def CtxGather(
+    data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
+) -> onnxscript.FLOAT:
+    # Create a shape tensor based on comp_ctx_len
+    shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0)
+
+    # Directly use the shape tensor without validation
+    ctx_indices = ops.Expand(ctx_indices, shape_tensor)
     ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
     return ops.GatherND(data, ctx_indices, batch_dims=2)
 
@@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function):
     """
 
     @staticmethod
-    def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
+    def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
         batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
         head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
         return data[batch_indices, head_indices, ctx_indices]
@@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs):
         pass
 
     @staticmethod
-    def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
-        return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
+    def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value:
+        return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data)
diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py
index e4408829d..15b0847aa 100644
--- a/QEfficient/customop/ctx_scatter_gather_cb.py
+++ b/QEfficient/customop/ctx_scatter_gather_cb.py
@@ -97,11 +97,12 @@ def symbolic(
 
 @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
 def CtxGatherCB(
-    data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
+    data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
 ) -> onnxscript.FLOAT:
     batch_size = ops.Gather(ops.Shape(batch_index), [0])
     num_heads = ops.Gather(ops.Shape(data), [1])
-    ctx_len = ops.Gather(ops.Shape(data), [2])
+    # using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well.
+    ctx_len = ops.Reshape(comp_ctx_len, [1])
 
     # Expanded shape to create indices
     zero = ops.Constant(value_ints=[0])
@@ -119,7 +120,7 @@ def CtxGatherCB(
 
 class CtxGatherFuncCB(torch.autograd.Function):
     @staticmethod
-    def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor):
+    def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
         batch_indices = batch_index.view(-1, 1, 1)
         head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
         return data[batch_indices, head_indices, ctx_indices]
@@ -129,8 +130,10 @@ def setup_context(ctx, inputs, outputs):
         pass
 
     @staticmethod
-    def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value:
-        return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data)
+    def symbolic(
+        g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int
+    ) -> torch.Value:
+        return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data)
 
 
 @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py
index 2dd485a5e..249d966b4 100755
--- a/QEfficient/generation/text_generation_inference.py
+++ b/QEfficient/generation/text_generation_inference.py
@@ -316,6 +316,7 @@ def cloud_ai_100_exec_kv(
     prompts_txt_file_path: Optional[str] = None,
     device_id: Optional[List[int]] = None,
     generation_len: Optional[int] = None,
+    comp_ctx_lengths: Optional[List[int]] = None,
     enable_debug_logs: bool = False,
     stream: bool = True,
     write_io_dir: Optional[str] = None,
@@ -368,6 +369,7 @@ def cloud_ai_100_exec_kv(
         qpc_path=qpc_path,
         device_id=device_id,
         ctx_len=ctx_len,
+        comp_ctx_lengths=comp_ctx_lengths,
         enable_debug_logs=enable_debug_logs,
         write_io_dir=write_io_dir,
         full_batch_size=full_batch_size,
@@ -407,12 +409,14 @@ def __init__(
         qpc_path: str,
         full_batch_size: Optional[int] = None,
         ctx_len: Optional[int] = None,
+        comp_ctx_lengths: Optional[List[int]] = None,
         device_id: Optional[List[int]] = None,
         enable_debug_logs: bool = False,
         write_io_dir: Optional[str] = None,
         is_tlm: Optional[int] = None,
     ) -> None:
         self._ctx_len = ctx_len
+        self.comp_ctx_lengths = comp_ctx_lengths
         self._write_io_dir = write_io_dir
         self.is_tlm = is_tlm
 
@@ -724,6 +728,12 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
                 batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
                 inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)
 
+        inputs["comp_ctx_lengths"] = np.random.rand(
+            self.comp_ctx_lengths[0] if self.comp_ctx_lengths is not None else self._ctx_len
+        )
+        buffers = {"comp_ctx_len_out": np.zeros(1)}
+        self._session.set_buffers(buffers)
+
         for i in range(num_chunks):
             chunk_inputs = inputs.copy()
             chunk_inputs["input_ids"] = inputs["input_ids"][
@@ -741,6 +751,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
             generation_len,
         )
 
+    def initialize_ccl(self, decode_inputs):
+        max_ccl_id = len(self.comp_ctx_lengths) - 1
+        max_position_id = np.max(decode_inputs["position_ids"])
+        ccl_id = 1
+        for i in range(1, len(self.comp_ctx_lengths)):
+            if max_position_id < self.comp_ctx_lengths[i]:
+                ccl_id = i
+                break
+        buffers = {"comp_ctx_len_out": np.zeros(1)}
+        print(f"CCL: {self.comp_ctx_lengths[ccl_id]}")
+
+        return buffers, ccl_id, max_ccl_id
+
     def run_continuous_batching_decode(self, prompt_queue, generation_len):
         """
         Runs continuous batching decode for the given prompt queue and generation length.
@@ -771,6 +794,16 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
         # Prepare decode inputs inputs.
         decode_inputs = self.prepare_decode_inputs()
 
+        if self.comp_ctx_lengths is not None:
+            list_of_comp_ctx_lengths = [np.zeros(length) for length in self.comp_ctx_lengths]
+            buffers, ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
+            decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]
+            self._session.set_buffers(buffers)
+        else:
+            decode_inputs["comp_ctx_lengths"] = np.zeros(self._ctx_len)
+            buffers = {"comp_ctx_len_out": np.zeros(1)}
+            self._session.set_buffers(buffers)
+
         while prompt_queue or current_decode_ongoing.any():
             outputs = self._session.run(decode_inputs)
 
@@ -808,6 +841,19 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
                                 batch_id_map[decode_batch_id]
                             ]
 
+                        if self.comp_ctx_lengths is not None:
+                            ###Recalculate ccl_id based on position ids###
+                            # Determine the maximum value of position_ids across all batch elements
+                            max_position_id = np.max(decode_inputs["position_ids"])
+
+                            # Update ccl_id and comp_ctx_lengths based on the maximum position id
+                            ccl_id = 1
+                            for i in range(1, len(self.comp_ctx_lengths)):
+                                if max_position_id < self.comp_ctx_lengths[i]:
+                                    ccl_id = i
+                                    break
+                            decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]
+
                     else:
                         current_decode_ongoing[decode_batch_id] = False
                 else:
@@ -818,6 +864,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
                         next_token_id[decode_batch_id, -1]
                     )
 
+                    if self.comp_ctx_lengths is not None:
+                        # Update ccl_id and comp_ctx_lengths based on the maximum position id
+                        if decode_inputs["position_ids"][decode_batch_id, -1] >= self.comp_ctx_lengths[ccl_id] - 1:
+                            ccl_id = min(ccl_id + 1, max_ccl_id)
+                            decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]
+
                     generated_id_current_index[decode_batch_id] += 1
 
         return decode_pause_time
@@ -842,7 +894,25 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
             self._session.set_buffers({"logits": logits_out_placeholder})
         finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
         num_token = 0
+
+        if self.comp_ctx_lengths is not None:
+            list_of_comp_ctx_lengths = [np.zeros(length) for length in self.comp_ctx_lengths]
+            buffers, ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
+            decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]
+            self._session.set_buffers(buffers)
+        else:
+            decode_inputs["comp_ctx_lengths"] = np.zeros(self._ctx_len)
+            buffers = {"comp_ctx_len_out": np.zeros(1)}
+            self._session.set_buffers(buffers)
+
+        cache_index = np.max(decode_inputs["position_ids"])
         for num_token in range(1, generation_len):
+            if self.comp_ctx_lengths is not None:
+                if cache_index >= self.comp_ctx_lengths[ccl_id] - 1:
+                    # if cache_index >= self.comp_ctx_lengths[ccl_id] - 1:
+                    ccl_id = min(ccl_id + 1, max_ccl_id)
+                    decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]
+
             if streamer:
                 streamer.put(decode_inputs["input_ids"][0])
             outputs = self._session.run(decode_inputs)
@@ -854,6 +924,7 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
             # Prepare inputs for next iteration
             decode_inputs["input_ids"] = outputs["logits"].argmax(2)
             decode_inputs["position_ids"][:, -1] += 1
+            cache_index += 1
             self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
             finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
 
@@ -901,17 +972,27 @@ def __init__(
         qpc_path: str,
         full_batch_size: Optional[int] = None,
         ctx_len: Optional[int] = None,
+        comp_ctx_lengths: Optional[List[int]] = None,
         device_id: Optional[List[int]] = None,
         enable_debug_logs: bool = False,
         write_io_dir: Optional[str] = None,
         is_tlm: bool = False,
     ) -> None:
         self._qaic_model = QEffTextGenerationBase(
-            tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm
+            tokenizer,
+            qpc_path,
+            full_batch_size,
+            ctx_len,
+            comp_ctx_lengths,
+            device_id,
+            enable_debug_logs,
+            write_io_dir,
+            is_tlm,
         )
         self._full_batch_size = self._qaic_model.full_batch_size
         self._tokenizer = self._qaic_model.tokenizer
         self._ctx_len = ctx_len
+        self.comp_ctx_lengths = comp_ctx_lengths
         self._perf_metrics = None
         self._prompt_queue = None
         self._text_streamer = None
diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py
index f9d529038..4b1e243e5 100644
--- a/QEfficient/transformers/cache_utils.py
+++ b/QEfficient/transformers/cache_utils.py
@@ -91,6 +91,8 @@ def read_only(self, layer_idx, cache_kwargs):
         k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
         position_ids = cache_kwargs.get("position_ids")
         batch_index = cache_kwargs.get("batch_index", None)
+        comp_ctx_len = cache_kwargs.get("CCL")
+
         ctx_len = k_out.shape[2]
         ctx_indices = torch.arange(ctx_len)[None, None, ...]
         gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
@@ -101,15 +103,19 @@ def read_only(self, layer_idx, cache_kwargs):
         else:
             invalid_idx_value = 0
 
+        ctx_indices = ctx_indices[:, :, :comp_ctx_len]
+        invalid_mask = ctx_indices > gather_limit
+
+        invalid_mask = invalid_mask[:, :, :comp_ctx_len]
+
         ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
 
         if batch_index is not None:
-            k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
-            v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
+            k_out = CtxGatherFuncCB.apply(self.key_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
+            v_out = CtxGatherFuncCB.apply(self.value_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
         else:
-            k_out = CtxGatherFunc.apply(k_out, ctx_indices)
-            v_out = CtxGatherFunc.apply(v_out, ctx_indices)
-
+            k_out = CtxGatherFunc.apply(self.key_cache[layer_idx], ctx_indices, comp_ctx_len)
+            v_out = CtxGatherFunc.apply(self.value_cache[layer_idx], ctx_indices, comp_ctx_len)
         v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
         return k_out, v_out
 
@@ -144,6 +150,7 @@ def update(
         else:
             position_ids = cache_kwargs.get("position_ids")
             batch_index = cache_kwargs.get("batch_index", None)  # Check and fetch batch index value form the kwargs
+            comp_ctx_len = cache_kwargs.get("CCL")
 
             # Scatter
             if batch_index is not None:
@@ -163,26 +170,29 @@ def update(
                     self.value_cache[layer_idx], position_ids, value_states
                 )
 
-            k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
-
             # Gather
-            ctx_len = k_out.shape[2]
+            ctx_len = self.key_cache[layer_idx].shape[2]
             ctx_indices = torch.arange(ctx_len)[None, None, ...]
             gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
-            invalid_mask = ctx_indices > gather_limit
 
             if torch.onnx.is_in_onnx_export():
                 invalid_idx_value = torch.iinfo(torch.int32).max
             else:
                 invalid_idx_value = 0
 
+            ctx_indices = ctx_indices[:, :, :comp_ctx_len]
+            invalid_mask = ctx_indices > gather_limit
+
+            invalid_mask = invalid_mask[:, :, :comp_ctx_len]
+
             ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
+
             if batch_index is not None:
-                k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
-                v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
+                k_out = CtxGatherFuncCB.apply(self.key_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
+                v_out = CtxGatherFuncCB.apply(self.value_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
             else:
-                k_out = CtxGatherFunc.apply(k_out, ctx_indices)
-                v_out = CtxGatherFunc.apply(v_out, ctx_indices)
+                k_out = CtxGatherFunc.apply(self.key_cache[layer_idx], ctx_indices, comp_ctx_len)
+                v_out = CtxGatherFunc.apply(self.value_cache[layer_idx], ctx_indices, comp_ctx_len)
             v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
 
         return k_out, v_out
diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py
index dae783361..946e15215 100644
--- a/QEfficient/transformers/models/llama/modeling_llama.py
+++ b/QEfficient/transformers/models/llama/modeling_llama.py
@@ -5,6 +5,7 @@
 #
 # -----------------------------------------------------------------------------
 
+from dataclasses import dataclass
 from typing import Callable, List, Optional, Tuple, Union
 
 import torch
@@ -29,6 +30,16 @@
 from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
 
 
+@dataclass
+class QEffBaseModelOutputWithPast(BaseModelOutputWithPast):
+    comp_ctx_len_out: Optional[torch.LongTensor] = None
+
+
+@dataclass
+class QEffCausalLMOutputWithPast(CausalLMOutputWithPast):
+    comp_ctx_len_out: Optional[torch.LongTensor] = None
+
+
 class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
     """
     Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
@@ -130,6 +141,7 @@ def forward(
         attention_mask: Optional[torch.Tensor],
         position_ids: Optional[torch.LongTensor] = None,
         past_key_value: Optional[Cache] = None,
+        comp_ctx_lengths: Optional[torch.LongTensor] = None,
         batch_index: Optional[torch.LongTensor] = None,
         output_attentions: bool = False,
         use_cache: bool = False,
@@ -154,8 +166,15 @@ def forward(
         query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
         if past_key_value is not None:
+            attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
             # sin and cos are specific to RoPE models; cache_position needed for the static cache
-            cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids}
+            cache_kwargs = {
+                "sin": sin,
+                "cos": cos,
+                "batch_index": batch_index,
+                "position_ids": position_ids,
+                "CCL": attention_mask.shape[-1],
+            }
             key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
         attention_interface: Callable = eager_attention_forward
@@ -188,6 +207,7 @@ def forward(
         attention_mask: Optional[torch.Tensor] = None,
         position_ids: Optional[torch.LongTensor] = None,
         past_key_value: Optional[Cache] = None,
+        comp_ctx_lengths: Optional[torch.LongTensor] = None,
         batch_index: Optional[torch.LongTensor] = None,
         output_attentions: Optional[bool] = False,
         use_cache: Optional[bool] = False,
@@ -204,6 +224,7 @@ def forward(
             attention_mask=attention_mask,
             position_ids=position_ids,
             past_key_value=past_key_value,
+            comp_ctx_lengths=comp_ctx_lengths,
             batch_index=batch_index,
             output_attentions=output_attentions,
             use_cache=use_cache,
@@ -241,6 +262,7 @@ def forward(
         attention_mask: Optional[torch.Tensor] = None,
         position_ids: Optional[torch.LongTensor] = None,
         past_key_values: Optional[Cache] = None,
+        comp_ctx_lengths: Optional[torch.LongTensor] = None,
         batch_index: Optional[torch.LongTensor] = None,
         inputs_embeds: Optional[torch.FloatTensor] = None,
         use_cache: Optional[bool] = None,
@@ -294,6 +316,7 @@ def forward(
                 attention_mask=causal_mask,
                 position_ids=position_ids,
                 past_key_value=past_key_values,
+                comp_ctx_lengths=comp_ctx_lengths,
                 batch_index=batch_index,
                 output_attentions=output_attentions,
                 use_cache=use_cache,
@@ -315,11 +338,13 @@ def forward(
         if return_legacy_cache:
             past_key_values = past_key_values.to_legacy_cache()
 
-        output = BaseModelOutputWithPast(
+        comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :]
+        output = QEffBaseModelOutputWithPast(
             last_hidden_state=hidden_states,
             past_key_values=past_key_values if use_cache else None,
             hidden_states=all_hidden_states,
             attentions=all_self_attns,
+            comp_ctx_len_out=comp_ctx_len_out,
         )
         return output if return_dict else output.to_tuple()
 
@@ -337,6 +362,7 @@ def forward(
         attention_mask: Optional[torch.Tensor] = None,
         position_ids: Optional[torch.LongTensor] = None,
         past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+        comp_ctx_lengths: Optional[torch.LongTensor] = None,
         batch_index: Optional[torch.LongTensor] = None,
         inputs_embeds: Optional[torch.FloatTensor] = None,
         labels: Optional[torch.LongTensor] = None,
@@ -360,6 +386,7 @@ def forward(
             attention_mask=attention_mask,
             position_ids=position_ids,
             past_key_values=past_key_values,
+            comp_ctx_lengths=comp_ctx_lengths,
             batch_index=batch_index,
             inputs_embeds=inputs_embeds,
             use_cache=use_cache,
@@ -377,10 +404,12 @@ def forward(
         logits = self.lm_head(hidden_states)
         logits = logits.float()
 
-        return CausalLMOutputWithPast(
+        comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :]
+        return QEffCausalLMOutputWithPast(
             loss=None,
             logits=logits,
             past_key_values=outputs.past_key_values,
             hidden_states=outputs.hidden_states,
             attentions=outputs.attentions,
+            comp_ctx_len_out=comp_ctx_len_out,
         )
diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py
index f181ee5eb..ed70143ae 100644
--- a/QEfficient/transformers/models/modeling_auto.py
+++ b/QEfficient/transformers/models/modeling_auto.py
@@ -1326,6 +1326,7 @@ def __init__(
         self.continuous_batching = continuous_batching
         self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs)
         self.is_tlm = transformed
+        self.comp_ctx_lengths = kwargs.pop("comp_ctx_lengths", None)
 
     @property
     def model_name(self) -> str:
@@ -1388,6 +1389,9 @@ def from_pretrained(
 
         kv_offload = kwargs.pop("kv_offload", None)
 
+        comp_ctx_lengths = kwargs.pop("comp_ctx_lengths", None)
+        cls.comp_ctx_lengths = comp_ctx_lengths
+
         kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
         model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
         if qaic_config is not None:
@@ -1422,7 +1426,7 @@ def model_hash(self) -> str:
     def get_model_config(self) -> dict:
         return self.model.config.__dict__
 
-    def export(self, export_dir: Optional[str] = None) -> str:
+    def export(self, comp_ctx_lengths: Optional[List[int]] = None, export_dir: Optional[str] = None) -> str:
         """
         Exports the model to ``ONNX`` format using ``torch.onnx.export``.
 
@@ -1442,10 +1446,12 @@ def export(self, export_dir: Optional[str] = None) -> str:
             "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
             "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),
             "past_key_values": [[] for _ in range(self.num_layers)],
+            "comp_ctx_lengths": torch.randint(0, 100, (40,), dtype=torch.long),
         }
         dynamic_axes = {
             "input_ids": {0: "batch_size", 1: "seq_len"},
             "position_ids": {0: "batch_size", 1: "seq_len"},
+            "comp_ctx_lengths": {0: "comp_ctx_lengths"},
         }
         if len(kv_cache_shape) == 3:  # For GPTBigCode arch the pkv is 3d
             pkv_dynamic_axes = {
@@ -1485,6 +1491,7 @@ def build_prefill_specialization(
         self,
         prefill_seq_len: int = 32,
         ctx_len: int = 128,
+        comp_ctx_lengths: Optional[int] = None,
         batch_size: int = 1,
         kv_cache_batch_size: Optional[int] = None,
         full_batch_size: Optional[int] = None,
@@ -1495,6 +1502,8 @@ def build_prefill_specialization(
             "ctx_len": ctx_len,
             "num_logits_to_keep": 1 if self.is_tlm else None,
         }
+        spec["comp_ctx_lengths"] = comp_ctx_lengths
+
         if self.continuous_batching:
             spec["full_batch_size"] = kv_cache_batch_size
         else:
@@ -1507,6 +1516,7 @@ def build_decode_specialization(
         self,
         prefill_seq_len: int = 32,
         ctx_len: int = 128,
+        comp_ctx_lengths: Optional[int] = None,
         batch_size: int = 1,
         kv_cache_batch_size: Optional[int] = None,
         full_batch_size: Optional[int] = None,
@@ -1520,6 +1530,7 @@ def build_decode_specialization(
             "ctx_len": ctx_len,
             "num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None,
         }
+        spec["comp_ctx_lengths"] = comp_ctx_lengths
         if self.continuous_batching:
             spec["full_batch_size"] = kv_cache_batch_size
         else:
@@ -1575,6 +1586,9 @@ def compile(
         Returns:
             :str: Path of the compiled ``qpc`` package.
         """
+        if self.comp_ctx_lengths is None:
+            self.comp_ctx_lengths = self.__class__.comp_ctx_lengths
+
         # --- Validation ---
         if prefill_only is not None and not isinstance(prefill_only, bool):
             raise TypeError("`prefill_only` must be a boolean.")
@@ -1598,17 +1612,43 @@ def compile(
         specializations = []
 
         if prefill_only is None or prefill_only or prefill_seq_len == 1:
+            ctx_for_specialization = self.comp_ctx_lengths[0] if self.comp_ctx_lengths is not None else ctx_len
             specializations.append(
                 self.build_prefill_specialization(
-                    prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size
+                    prefill_seq_len, ctx_len, ctx_for_specialization, batch_size, kv_cache_batch_size, full_batch_size
                 )
             )
+
         if prefill_only is None or not prefill_only:
-            decode_spec = self.build_decode_specialization(
-                prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size, num_speculative_tokens
-            )
-            if decode_spec:
-                specializations.append(decode_spec)
+            if self.comp_ctx_lengths is not None:
+                # Adding elements from self.comp_ctx_lengths to decode_specialization
+                for i in range(1, len(self.comp_ctx_lengths)):
+                    # if self.comp_ctx_lengths is None, we pass the default maximum value which is ctx_len for the value of self.comp_ctx_lengths
+                    decode_spec = self.build_decode_specialization(
+                        prefill_seq_len,
+                        ctx_len,
+                        self.comp_ctx_lengths[i],
+                        batch_size,
+                        kv_cache_batch_size,
+                        full_batch_size,
+                        num_speculative_tokens,
+                    )
+                    if decode_spec:
+                        specializations.append(decode_spec)
+
+            else:
+                # if self.comp_ctx_lengths is None, we pass the default maximum value which is ctx_len for the value of self.comp_ctx_lengths
+                decode_spec = self.build_decode_specialization(
+                    prefill_seq_len,
+                    ctx_len,
+                    ctx_len,
+                    batch_size,
+                    kv_cache_batch_size,
+                    full_batch_size,
+                    num_speculative_tokens,
+                )
+                if decode_spec:
+                    specializations.append(decode_spec)
 
         # --- Compilation ---
         kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
@@ -1622,6 +1662,7 @@ def compile(
         qpc_path = self._compile(
             onnx_path=onnx_path,
             compile_dir=compile_dir,
+            comp_ctx_lengths=self.comp_ctx_lengths,
             compile_only=True,
             retained_state=True,
             specializations=specializations,
@@ -1673,6 +1714,7 @@ def generate(
                 tokenizer,
                 self.qpc_path,
                 prompt=prompts,
+                comp_ctx_lengths=self.comp_ctx_lengths,
                 device_id=device_id,
                 generation_len=generation_len,
                 is_tlm=self.is_tlm,
diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py
new file mode 100644
index 000000000..2bc22a586
--- /dev/null
+++ b/examples/compute_context_length.py
@@ -0,0 +1,47 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# -----------------------------------------------------------------------------
+
+## In this example, you can run a model for static and continuous batching with different Compute-Context-Length (CCL) inputs. ##
+
+from transformers import AutoTokenizer
+
+from QEfficient import QEFFAutoModelForCausalLM
+
+## Using optional variable comp_ctx_lengths variable you can pass a list of context lengths. It will run the model with default context length if comp_ctx_lengths=None. ##
+##       - The first number in this list is the context length that will be used during prefilling. ##
+##       - During the decoding process, based on the position_id or cache index it will work with the specific compute-context-length in the list. It will start from a proper compute-context-length in the list based on input prompt length and will gradually increase the compute-context-length if the cache index passes the current compute-context-length. ##
+comp_ctx_lengths = [128, 256, 512]  # None
+
+model_name = "meta-llama/Llama-3.2-1B-Instruct"
+model = QEFFAutoModelForCausalLM.from_pretrained(
+    model_name, continuous_batching=True, comp_ctx_lengths=comp_ctx_lengths
+)
+# model = QEFFAutoModelForCausalLM.from_pretrained(model_name, comp_ctx_lengths=comp_ctx_lengths)
+
+# model compilation for either continuous or static batching. For continuous batching full_batch_size is needed.
+model.compile(
+    prefill_seq_len=128,
+    ctx_len=512,
+    num_cores=16,
+    num_devices=1,
+    full_batch_size=4,
+    mxfp6_matmul=True,
+    mxint8_kv_cache=True,
+)
+# model.compile(prefill_seq_len=128, ctx_len=512, num_cores=16, num_devices=1,batch_size=4,mxfp6_matmul=True,mxint8_kv_cache=True)
+
+# Create tokenizer and run model.generate and passes the input prompts to it. It also receives comp_ctx_lengths list which will be used during the decoding process to apply the best and most efficient compute context length.
+tokenizer = AutoTokenizer.from_pretrained(model_name)
+model.generate(
+    prompts=[
+        "What are some healthy foods to include in a balanced diet?",
+        "What is a nutritious meal that can keep you energized throughout the day?",
+        "What are some fun and relaxing activities to do over the weekend?",
+        "What's your favorite hobby?",
+    ],
+    tokenizer=tokenizer,
+)
diff --git a/tests/transformers/test_compute_context_length.py b/tests/transformers/test_compute_context_length.py
new file mode 100644
index 000000000..a68f68141
--- /dev/null
+++ b/tests/transformers/test_compute_context_length.py
@@ -0,0 +1,176 @@
+# -----------------------------------------------------------------------------
+#
+# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# ----------------------------------------------------------------------------
+
+import copy
+import os
+from time import perf_counter
+
+import onnx
+import pytest
+from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
+
+from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM
+
+configs = [
+    # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params
+    ("gpt2", 256, 2, 4, 128, 512, 127, {}),
+    ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}),
+    ("falcon", 256, 2, 4, 128, 512, 127, {}),
+    ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}),
+    ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
+    ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
+    ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
+    ("mpt", 256, 2, 4, 128, 512, 127, {}),
+    ("phi", 256, 2, 4, 128, 512, 127, {}),
+    ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}),
+    ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
+    ("starcoder2", 256, 2, 4, 128, 512, 127, {}),
+    ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}),
+]
+
+configs = [
+    AutoConfig.for_model(
+        model_name,
+        max_position_embeddings=max_position_embeddings,
+        num_hidden_layers=num_hidden_layers,
+        num_attention_heads=num_attention_heads,
+        hidden_size=hidden_size,
+        intermediate_size=intermediate_size,
+        vocab_size=vocab_size,
+        **additional_params,
+    )
+    for (
+        model_name,
+        max_position_embeddings,
+        num_hidden_layers,
+        num_attention_heads,
+        hidden_size,
+        intermediate_size,
+        vocab_size,
+        additional_params,
+    ) in configs
+]
+config_ids = [x.model_type for x in configs]
+
+model_kwargs = {"attn_implementation": "eager"}
+
+
+@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+def test_causal_lm_unsupported(cb):
+    model = AutoModelForCausalLM.from_config(AutoConfig.for_model("opt"))
+    with pytest.warns():
+        QEFFAutoModelForCausalLM(model, cb)
+
+
+@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+@pytest.mark.parametrize("config", configs, ids=config_ids)
+def test_causal_lm_init(config, cb):
+    model = AutoModelForCausalLM.from_config(config, **model_kwargs)
+    qeff_model = QEFFAutoModelForCausalLM(model, cb)
+    with pytest.raises(TypeError):
+        QEFFAutoModelForCausalLM(AutoModel.from_config(config, **model_kwargs), cb)
+    assert qeff_model.model.__class__.__name__.startswith("QEff")
+
+
+@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+@pytest.mark.parametrize("config", configs, ids=config_ids)
+def test_causal_lm_pretrained(config, cb, tmp_path):
+    model = AutoModelForCausalLM.from_config(config, **model_kwargs)
+    model.save_pretrained(tmp_path)
+
+    qeff_model = QEFFAutoModelForCausalLM.from_pretrained(tmp_path, cb)
+    assert qeff_model.model.__class__.__name__.startswith("QEff")
+
+
+@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+@pytest.mark.parametrize("config", configs, ids=config_ids)
+def test_causal_lm_hash(config, cb):
+    hash_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash
+    hash_0_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash
+
+    assert hash_0_0 == hash_0_1
+
+    cfg1 = copy.deepcopy(config)
+    cfg1.num_hidden_layers -= 1
+    hash_1_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg1, **model_kwargs), cb).model_hash
+    cfg2 = copy.deepcopy(config)
+    cfg2.num_hidden_layers -= 1
+    hash_1_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg2, **model_kwargs), cb).model_hash
+    assert hash_1_0 == hash_1_1
+
+    assert hash_0_0 != hash_1_0
+
+    if cb:
+        hash_0_no_cb = QEFFAutoModelForCausalLM(
+            AutoModelForCausalLM.from_config(config, **model_kwargs), False
+        ).model_hash
+        assert hash_0_0 != hash_0_no_cb
+
+
+@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+@pytest.mark.parametrize("config", configs, ids=config_ids)
+def test_causal_lm_export(config, cb, tmp_path):
+    model = AutoModelForCausalLM.from_config(config, **model_kwargs)
+    qeff_model = QEFFAutoModelForCausalLM(model, cb)
+    comp_ctx_lengths = [512, 1024, 2048]
+    qeff_model.export(comp_ctx_lengths, tmp_path)
+    model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash)
+    assert model_path.is_dir()
+    assert qeff_model.onnx_path.is_file()
+    assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",)
+
+    # Check if the KV-cache inputs and outputs are created
+    onnx_model = onnx.load(qeff_model.onnx_path, load_external_data=False)
+    retained_output_names = {
+        x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState")
+    }
+    retained_output_names.issubset({x.name for x in onnx_model.graph.input})
+
+    # Check if there is no re-export
+    start = perf_counter()
+    qeff_model.export(tmp_path)
+    end = perf_counter()
+    export_time = end - start
+    assert export_time < 2.0
+
+
+@pytest.fixture
+def tmp_cache(tmp_path, monkeypatch):
+    monkeypatch.setattr("QEfficient.base.modeling_qeff.QEFF_HOME", tmp_path)
+    yield tmp_path
+
+
+@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
+@pytest.mark.parametrize("config", configs, ids=config_ids)
+def test_causal_lm_compile(config, cb, tmp_cache):
+    model = AutoModelForCausalLM.from_config(config, **model_kwargs)
+    comp_ctx_lengths = [8, 12, 16]
+    qeff_model = QEFFAutoModelForCausalLM(model, cb, comp_ctx_lengths=comp_ctx_lengths)
+    compile_params = {"prefill_seq_len": 8, "ctx_len": 16}
+    if cb:
+        compile_params["full_batch_size"] = 32
+        compile_params["batch_size"] = 8
+    qeff_model.compile(**compile_params)
+    model_path = tmp_cache / (qeff_model.model_name + "-" + qeff_model.model_hash)
+
+    # Check if ONNX is exported properly
+    assert model_path.is_dir()
+    assert qeff_model.onnx_path.is_file()
+    assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",)
+
+    # Check if QPC is compiled properly
+    assert qeff_model.qpc_path.is_dir()
+    assert (qeff_model.qpc_path / "programqpc.bin").is_file()
+    assert qeff_model.qpc_path.relative_to(tmp_cache).parts[0] == qeff_model.model_name + "-" + qeff_model.model_hash
+
+    # Check if there is no re-compilation
+    start = perf_counter()
+    qeff_model.compile(**compile_params)
+    end = perf_counter()
+    compile_time = end - start
+    assert compile_time < 2.0
+    assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json"))