diff --git a/QEfficient/generation/cloud_infer.py b/QEfficient/generation/cloud_infer.py index 71cd61188..a8a7e4b24 100644 --- a/QEfficient/generation/cloud_infer.py +++ b/QEfficient/generation/cloud_infer.py @@ -47,7 +47,7 @@ def __init__( qpc_path: Union[Path, str], device_ids: Optional[List[int]] = None, activate: bool = True, - enable_debug_logs: bool = False, + enable_debug_logs: bool = True, ): """ Initialise for QAIC inference Session diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 2dd485a5e..2c86d9ab5 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -316,7 +316,7 @@ def cloud_ai_100_exec_kv( prompts_txt_file_path: Optional[str] = None, device_id: Optional[List[int]] = None, generation_len: Optional[int] = None, - enable_debug_logs: bool = False, + enable_debug_logs: bool = True, stream: bool = True, write_io_dir: Optional[str] = None, automation=False, @@ -408,7 +408,7 @@ def __init__( full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, device_id: Optional[List[int]] = None, - enable_debug_logs: bool = False, + enable_debug_logs: bool = True, write_io_dir: Optional[str] = None, is_tlm: Optional[int] = None, ) -> None: @@ -902,7 +902,7 @@ def __init__( full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, device_id: Optional[List[int]] = None, - enable_debug_logs: bool = False, + enable_debug_logs: bool = True, write_io_dir: Optional[str] = None, is_tlm: bool = False, ) -> None: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index d1fc61cee..fe467b7d7 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1017,7 +1017,7 @@ def cloud_ai_100_generate( self, inputs: torch.Tensor, device_ids: List[int], - enable_debug_logs: bool = False, + enable_debug_logs: bool = True, generation_len: int = None, streamer: Optional[TextStreamer] = None, ) -> np.ndarray: diff --git a/examples/cpp_execution/text_inference_using_cpp.py b/examples/cpp_execution/text_inference_using_cpp.py index eadf5e601..5a2764d7c 100644 --- a/examples/cpp_execution/text_inference_using_cpp.py +++ b/examples/cpp_execution/text_inference_using_cpp.py @@ -146,7 +146,7 @@ def cloud_ai_100_exec_kv_cpp( prompts_txt_file_path: Optional[str] = None, device_id: Optional[List[int]] = None, generation_len: Optional[int] = None, - enable_debug_logs: bool = False, + enable_debug_logs: bool = True, stream: bool = True, full_batch_size: Optional[int] = None, ): diff --git a/scripts/replicate_kv_head/README.md b/scripts/replicate_kv_head/README.md index a88b194ff..9a1ac9c1e 100644 --- a/scripts/replicate_kv_head/README.md +++ b/scripts/replicate_kv_head/README.md @@ -30,4 +30,6 @@ Replace `` with your actual token. ### Arguments - **--model_name**: Model card name to use (default: “meta-llama/Meta-Llama-3-8B-Instruct”). - **--prompt**: Prompt to use for the model (default: “My name is”). -- **--repeat**: Factor to repeat key-value heads (default: 2). \ No newline at end of file +- **--repeat**: Factor to repeat key-value heads (default: 2). +- **--num_attention_heads**: Number of attentin heads (default: None). This is optional param, if not given explicitly the will be read from config.json. +- **--hidden_size**: Hidden size (default: None). This is optional param, if not given explicitly the will be read from config.json. \ No newline at end of file diff --git a/scripts/replicate_kv_head/replicate_kv_heads.py b/scripts/replicate_kv_head/replicate_kv_heads.py index 6edc29771..c8273d602 100644 --- a/scripts/replicate_kv_head/replicate_kv_heads.py +++ b/scripts/replicate_kv_head/replicate_kv_heads.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import argparse +from typing import Optional import torch from transformers import AutoModelForCausalLM, AutoTokenizer @@ -70,16 +71,42 @@ def duplicate_weights_for_linear_layer( ) -def main(args): +def replicate_kv_heads( + model_name: str = "meta-llama/Meta-Llama-3-8B-Instruct", + prompt: str = "My name is", + repeat: int = 2, + full_batch_size: Optional[int] = None, + num_hidden_layers: Optional[int] = None, + num_attention_heads: Optional[int] = None, + hidden_size: Optional[int] = None, +): + """ + Replicate the KV heads. The script performs the following steps: + 1. Runs inference with the original model. + 2. Replicates the KV heads. + 3. Runs inference on the modified model to validate the changes. + 4. Exports the modified model to ONNX format. + + ``Mandatory`` Args: + :model_name (str): Model card name to use, default value as meta-llama/Meta-Llama-3-8B-Instruct. + :prompt (str): Prompt to use for the model, default value as My name is + :repeat (int): Factor to repeat key-value heads. + ``Optional`` Args: + :full_batch_size (int): Set full batch size to enable continuous batching mode, default is None. + :num_hidden_layers (int): Number of hidden layers to use, default is None. + :num_attention_heads (int): Number of attention heads, if not passed explicitly then will be picked from config.json. + :hidden_size (int): Hidden size to use, if not passed explicitly then will be picked from config.json. + + """ # Load the model and tokenizer - model_name = args.model_name model_base_name = model_name.split("/")[-1] # Replace quantizers for loading Quantized AWQ/GPTQ models on CPU. replace_transformers_quantizers() # Prepare kwargs for model loading model_kwargs = {"attn_implementation": "eager"} - if args.num_hidden_layers: - model_kwargs["num_hidden_layers"] = args.num_hidden_layers + + if num_hidden_layers: + model_kwargs["num_hidden_layers"] = num_hidden_layers pretrained_model_name_or_path = login_and_download_hf_lm(model_name) model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_kwargs) @@ -87,7 +114,7 @@ def main(args): # Undo the effect of replace_transformers_quantizers undo_transformers_quantizers() tokenizer = AutoTokenizer.from_pretrained(model_name) - inputs = tokenizer(args.prompt, return_tensors="pt") + inputs = tokenizer(prompt, return_tensors="pt") # Generate original outputs and tokens with torch.inference_mode(): @@ -95,7 +122,6 @@ def main(args): orig_tokens = model.generate(**inputs, max_new_tokens=10, num_beams=1, do_sample=False) # Modify the number of key-value heads - repeat = args.repeat orig_kv_heads = model.config.num_key_value_heads new_kv_heads = repeat * orig_kv_heads model.config.num_key_value_heads = new_kv_heads @@ -103,13 +129,20 @@ def main(args): print("Original KV heads:", orig_kv_heads) print("Modified KV heads:", new_kv_heads) + # Check if hidden size and number of attention heads are explicitly passed as arguments or not + if num_attention_heads is None: + num_attention_heads = model.config.num_attention_heads + + if hidden_size is None: + hidden_size = model.config.hidden_size + # Update the model's attention layers with new key-value heads for block in model.model.layers: attn = block.self_attn attn.num_key_value_heads = new_kv_heads - attn.num_key_value_groups = block.self_attn.num_heads // new_kv_heads - duplicate_weights_for_linear_layer(attn.k_proj, orig_kv_heads, repeat, attn.head_dim, attn.hidden_size) - duplicate_weights_for_linear_layer(attn.v_proj, orig_kv_heads, repeat, attn.head_dim, attn.hidden_size) + attn.num_key_value_groups = num_attention_heads // new_kv_heads + duplicate_weights_for_linear_layer(attn.k_proj, orig_kv_heads, repeat, attn.head_dim, hidden_size) + duplicate_weights_for_linear_layer(attn.v_proj, orig_kv_heads, repeat, attn.head_dim, hidden_size) # Generate modified outputs and tokens with torch.inference_mode(): @@ -126,13 +159,13 @@ def main(args): ) # Export the modified model - q_model = QEFFAutoModelForCausalLM(model, continuous_batching=(True if args.full_batch_size else False)) + q_model = QEFFAutoModelForCausalLM(model, continuous_batching=(True if full_batch_size else False)) export( model_name, q_model, tokenizer=tokenizer, onnx_dir_path=f"{model_base_name}-{new_kv_heads}kvheads", - full_batch_size=(args.full_batch_size if args.full_batch_size else None), + full_batch_size=(full_batch_size if full_batch_size else None), ) @@ -162,6 +195,29 @@ def main(args): default=None, help="Number of hidden layers to use, default is None", ) + parser.add_argument( + "--num_attention_heads", + "--num-attention-heads", + type=int, + default=None, + help="Number of attention heads, if not passed explicitly then will be picked from config.json", + ) + parser.add_argument( + "--hidden_size", + "--hidden-size", + type=int, + default=None, + help="Hidden size to use, if not passed explicitly then will be picked from config.json", + ) args = parser.parse_args() - main(args) + + replicate_kv_heads( + model_name=args.model_name, + prompt=args.prompt, + repeat=args.repeat, + full_batch_size=args.full_batch_size, + num_hidden_layers=args.num_hidden_layers, + num_attention_heads=args.num_attention_heads, + hidden_size=args.hidden_size, + )