Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
41dbbf5
[QEff]: Add gpt_oss
vbaddi Aug 6, 2025
eb31daa
nit: update modeling and make transform uniform
vbaddi Aug 7, 2025
8dcc3ad
apirunner change
ochougul Aug 7, 2025
ce7e719
added test along with simplified Hybridcache
ochougul Aug 7, 2025
dedf20a
added test assert
ochougul Aug 7, 2025
e35bfde
nit: update test gpt file
vbaddi Aug 8, 2025
7f6c4f6
nit: update modeling with new decode moe forward
vbaddi Aug 11, 2025
908d649
nit: seperate gate, up projections for MoE
vbaddi Aug 20, 2025
e427267
nit: remove test file and add sample test in config
Oct 15, 2025
5432338
Enable CB for GptOssModel
mamtsing Nov 3, 2025
cb8145f
Fix tests
mamtsing Nov 4, 2025
58c1740
Address review comments
mamtsing Nov 4, 2025
dec0616
prefill only changes for gpt-oss
ochougul Nov 4, 2025
c57e208
fixed mapping
ochougul Nov 5, 2025
7f8416f
added test
ochougul Nov 6, 2025
08ccd20
added test
ochougul Nov 6, 2025
9c8dcae
made example not ugly
ochougul Nov 6, 2025
50db73f
fixed tests
ochougul Nov 6, 2025
00eab98
fixed tests
ochougul Nov 6, 2025
4a43f0b
added new test and fixed failing tests
ochougul Nov 7, 2025
128b2c9
fixed tests
ochougul Nov 10, 2025
ea320ed
fixed kv cache shape
ochougul Nov 10, 2025
fbb85c0
fixed self.onnx_path issue in modeling_qeff
ochougul Nov 11, 2025
2c2abf2
fix formatting error
mamtsing Nov 11, 2025
dbe2495
Merge branch 'main' into prefill+decode_gpt_oss
quic-mamta Nov 11, 2025
fba1ac0
added ffn blocking and num blocks env variables
ochougul Nov 13, 2025
37f3681
include num_ffn_blocks in hash
ochougul Nov 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@
# -----------------------------------------------------------------------------

import os
import warnings

import QEfficient.utils.model_registery # noqa: F401
from QEfficient.utils import custom_format_warning
from QEfficient.utils.logging_utils import logger

# ----------------------------------------------------------------------------- #
# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
# hf_transfer is imported (will happen on line 15 via leading imports)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# DO NOT ADD ANY CODE ABOVE THIS LINE
# Please contact maintainers if you must edit this file above this line.
# ----------------------------------------------------------------------------- #
# Placeholder for all non-transformer models registered in QEfficient
import warnings # noqa: I001

import QEfficient.utils.model_registery # noqa: F401
from QEfficient.utils import custom_format_warning
from QEfficient.utils.logging_utils import logger


# custom warning for the better logging experience
Expand Down
41 changes: 34 additions & 7 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
self.model = model
self.hash_params = create_model_params(self, **kwargs)
self.prefill_enabled = False
self.prefill_onnx_path: Optional[str] = None
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
self.qpc_session: Optional[QAICInferenceSession] = None
Expand Down Expand Up @@ -179,6 +181,7 @@ def _export(
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
export_dir: Optional[str] = None,
offload_pt_weights: bool = True,
prefill_only: Optional[bool] = False,
) -> str:
"""
Export the PyTorch model to ONNX and apply ONNX transforms
Expand Down Expand Up @@ -207,7 +210,10 @@ def _export(

# Return early if ONNX already exists
if onnx_path.is_file():
self.onnx_path = onnx_path
if prefill_only:
self.prefill_onnx_path = onnx_path
else:
self.onnx_path = onnx_path
return onnx_path

# check if the model is in meta state or weights are offloaded
Expand Down Expand Up @@ -283,10 +289,29 @@ def _export(

finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)

self.onnx_path = onnx_path
if prefill_only:
self.prefill_onnx_path = onnx_path
else:
self.onnx_path = onnx_path
return onnx_path

def get_onnx_path(
self,
prefill_only: Optional[bool] = False,
specializations: Optional[List[Dict[str, int]]] = None,
offload_pt_weights: Optional[bool] = True,
):
kwargs = {"offload_pt_weights": offload_pt_weights}
if prefill_only:
if self.prefill_onnx_path is None:
kwargs.update({"prefill_only": prefill_only, "prefill_seq_len": specializations[0].get("seq_len")})
self.export(**kwargs)
return self.prefill_onnx_path
else:
if self.onnx_path is None:
self.export(**kwargs)
return self.onnx_path

@dump_qconfig
def _compile(
self,
Expand All @@ -300,6 +325,8 @@ def _compile(
num_speculative_tokens: Optional[int] = None,
enable_qnn: Optional[bool] = False,
qnn_config: Optional[str] = None,
prefill_only: Optional[str] = None,
offload_pt_weights: Optional[bool] = True,
**compiler_options,
) -> str:
"""
Expand All @@ -325,10 +352,9 @@ def _compile(

For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
"""
if onnx_path is None and self.onnx_path is None:
self.export()

onnx_path = Path(onnx_path or self.onnx_path)
onnx_path = Path(
onnx_path if onnx_path else self.get_onnx_path(prefill_only, specializations, offload_pt_weights)
)
compile_dir = Path(compile_dir or onnx_path.parent)
qpc_path = compile_dir / "qpc"
if not onnx_path.is_file():
Expand Down Expand Up @@ -390,6 +416,7 @@ def _compile(
"mdp_ts_num_devices": mdp_ts_num_devices,
"mdp_ts_json": mdp_ts_json,
"num_speculative_tokens": num_speculative_tokens,
"prefill_only": prefill_only,
}
compile_hash = hash_dict_params(compile_hash_params)

Expand Down
3 changes: 2 additions & 1 deletion QEfficient/peft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
return obj

def export(self, export_dir: Optional[str] = None) -> str:
def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Export the model with the active adapter to ONNX format.

Expand Down Expand Up @@ -286,6 +286,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights
onnx_transform_kwargs={"adapter_name": self.model.active_adapter},
export_dir=export_dir,
**kwargs,
)

def compile(
Expand Down
3 changes: 2 additions & 1 deletion QEfficient/peft/lora/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def _init_adapter_model(self):
# load_weight to model
self._load_adapter_weights_to_model()

def export(self, export_dir: Optional[str] = None) -> str:
def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``.

Expand Down Expand Up @@ -387,6 +387,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
output_names,
dynamic_axes,
export_dir=export_dir,
**kwargs,
)

def generate(
Expand Down
31 changes: 31 additions & 0 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,37 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache

def write_only(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
k_out, v_out = key_states, value_states
else:
position_ids = cache_kwargs.get("position_ids")
is_sliding_layer = cache_kwargs.get("is_sliding")
_, _, ctx_len, _ = self.key_cache[layer_idx].shape
if is_sliding_layer:
kv_position_ids = torch.arange(ctx_len, dtype=torch.int64).reshape(1, -1)
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
)
else:
kv_position_ids = position_ids

self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
)
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
return k_out, v_out

def update(
self,
key_states: torch.Tensor,
Expand Down
3 changes: 3 additions & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@
# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc.
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}

# This is for supporting different modelling classes specially written for prefill-only model
SPECIALIZED_PREFILL_ONLY_MODEL_ARCH = {"gpt_oss"}

# Define a transformers layers to QEff layers dictionary
# While onboarding new models make sure to add the new layer maps to this dictionary.
TransformersToQEffModulesDict: Dict[Type[nn.Module], Type[nn.Module]] = {
Expand Down
Loading