Skip to content

Support of Grok1 Model #373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions QEfficient/base/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from transformers import AutoConfig

from QEfficient.base.modeling_qeff import QEFFBaseModel
from QEfficient.transformers.modeling_utils import MODEL_CLASS_MAPPING
from QEfficient.transformers.modeling_utils import EXTERNAL_MODEL_CLASS_MAPPING, MODEL_CLASS_MAPPING
from QEfficient.utils import login_and_download_hf_lm


Expand All @@ -40,9 +40,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) ->
"""
Downloads HuggingFace model if already doesn't exist locally, returns QEFFAutoModel object based on type of model.
"""
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)

class_name = MODEL_CLASS_MAPPING.get(config.__class__.__name__, None)
class_name = (
MODEL_CLASS_MAPPING.get(config.__class__.__name__, None)
or EXTERNAL_MODEL_CLASS_MAPPING[config.__class__.__name__]
)
if class_name:
module = __import__("QEfficient.transformers.models.modeling_auto")
model_class = getattr(module, class_name)
Expand Down
6 changes: 5 additions & 1 deletion QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def mutate(cls, original_module: nn.Module, parent_module: nn.Module):
raise NotImplementedError("Please implement your own method by inheriting this class")


class ModuleMethodMapperTransform(PytorchTransform):
class ExternalModuleMapperTransform(PytorchTransform):
"""
Serves as base class for any transform that want to map a particular method of a class to a new method implementation.
"""
Expand All @@ -109,6 +109,10 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
):
for orig_method_name, mapped_method in repl_method_map.items():
setattr(module, orig_method_name, MethodType(mapped_method, module))

if hasattr(module, "__qeff_init__"):
module.__qeff_init__()

transformed = True

return model, transformed
Expand Down
9 changes: 9 additions & 0 deletions QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def main(
allow_mxint8_mdp_io: bool = False,
enable_qnn: Optional[bool] = False,
qnn_config: Optional[str] = None,
trust_remote_code: Optional[bool] = False,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -140,6 +141,7 @@ def main(
:allow_mxint8_mdp_io (bool): Allows MXINT8 compression of MDP IO traffic. ``Defaults to False.``
:enable_qnn (bool): Enables QNN Compilation. ``Defaults to False.``
:qnn_config (str): Path of QNN Config parameters file. ``Defaults to None.``
:trust_remote_code (bool): Trust remote code execution. ``Defaults to False.``
:kwargs: Pass any compiler option as input. Any flag that is supported by `qaic-exec` can be passed. Params are converted to flags as below:
-allocator_dealloc_delay=1 -> -allocator-dealloc-delay=1
-qpc_crc=True -> -qpc-crc
Expand All @@ -164,6 +166,7 @@ def main(
hf_token=hf_token,
full_batch_size=full_batch_size,
local_model_dir=local_model_dir,
trust_remote_code=trust_remote_code,
)

image_path = kwargs.pop("image_path", None)
Expand Down Expand Up @@ -264,6 +267,12 @@ def main(
action="store_true",
help="Compress constant MatMul weights to MXFP6 E2M3, default is no compression",
)
parser.add_argument(
"--trust_remote_code",
action="store_true",
default=False,
help="Enable trusting remote code when loading models. Default is False; set to True by passing this flag.",
)
parser.add_argument(
"--mxint8",
"--mxint8_kv_cache",
Expand Down
2 changes: 2 additions & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ def build_model_class_mapping(auto_model_class, qeff_class_name):
}


EXTERNAL_MODEL_CLASS_MAPPING = {"Grok1Config": "QEFFAutoModelForCausalLM"}

MODEL_CLASS_MAPPING = {
**build_model_class_mapping(mapping.AutoModelForCausalLM, "QEFFAutoModelForCausalLM"),
**build_model_class_mapping(mapping.AutoModelForImageTextToText, "QEFFAutoModelForImageTextToText"),
Expand Down
7 changes: 7 additions & 0 deletions QEfficient/transformers/models/grok_1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

Loading
Loading