diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index b7e42fc09b0..63b1fb5d472 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -194,11 +194,11 @@ def enable_flash_mla(self): return False def get_quant_config(self, name: Optional[str] = None) -> QuantConfig: - if name is None or self.per_layer_quant_configs is None: + if name is None or self.quant_config_dict is None: return self.quant_config - if name in self.per_layer_quant_configs: - return self.per_layer_quant_configs[name] + if name in self.quant_config_dict: + return self.quant_config_dict[name] raise ValueError(f'quant config of {name} is not found') diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 25452b5eaa3..671db6d0adb 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -39,6 +39,7 @@ from tqdm import tqdm from transformers import PretrainedConfig +import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils from tensorrt_llm._ipc_utils import can_access_peer from tensorrt_llm._utils import get_sm_version from tensorrt_llm.functional import PositionEmbeddingType @@ -144,6 +145,44 @@ def __init__(self, model, is_draft_model: bool = False): def load_weights(self, weights: Dict): + def requantize_weight_with_new_scale(weight, weight_scale, old_scale_2, + new_scale_2, device): + """ + Dequantize FP4 weights and requantize with a new scale. + + Args: + weight: FP4 quantized weight tensor 2D [,] + weight_scale: FP8 per-block scaling factors + old_scale_2: original global scale (amax/(448*6)) + new_scale_2: new global scale (amax/(448*6)) + device: target device for computation + + Returns: + (requantized_weight, new_weight_scale) + """ + # Remember original dtype of weight_scale + original_scale_dtype = weight_scale.dtype + original_scale_shape = weight_scale.shape + + # Dequantize + dequant_shape = (weight.shape[0], weight.shape[1] * 2) + weight_dequant = torch.ops.tensorrt_llm.e2m1_and_ufp8sf_scale_to_float_v2( + weight.contiguous(), + weight_scale.flatten().view( + fp4_utils.float4_sf_dtype).contiguous(), old_scale_2, 16, 1, + True).to(dtype=torch.bfloat16).reshape(dequant_shape) + + # Requantize using the new_scale_2 + weight_requant, weight_scale_requant = torch.ops.trtllm.fp4_quantize( + weight_dequant.to(device), + 1.0 / new_scale_2.to(device), + 16, # scaling_vector_size + False) + + # Ensure the returned scale has the same dtype as the input scale + return weight_requant.cpu(), weight_scale_requant.reshape( + original_scale_shape).view(original_scale_dtype).cpu() + def rename_moe_weight(weights: Dict, rename_rules: Dict): result = {} for key, value in weights.items(): @@ -200,6 +239,91 @@ def load_kv_b_proj_and_k_b_proj_trans(module_name: str, return kv_b_proj, k_nope_weight_trans + def load_kv_b_proj_and_k_b_proj_trans_for_fp8_per_tensor( + module_name: str) -> torch.Tensor: + """ + Load kv_b_proj and k_b_proj_trans for FP8 per-tensor quantization. + Similar to load_kv_b_proj_and_k_b_proj_trans but for FP8 weights. + Returns: + kv_b_proj: concatenated weight for context phase + k_b_proj_trans: transposed k weight [num_heads, kv_lora_rank, qk_nope_head_dim] + """ + weight_name = "weight" + local_qk_nope_head_dim = qk_nope_head_dim + local_v_head_dim = v_head_dim + local_kv_lora_rank = kv_lora_rank + + kv_b_proj = weights[f"{module_name}.{weight_name}"][:].unflatten( + 0, + [ + num_heads, + local_qk_nope_head_dim + local_v_head_dim, + ], + ) + + if not self.model_config.mapping.enable_attention_dp: + kv_b_proj = split_matrix_tp(kv_b_proj, tp_size, tp_rank, 0) + k_nope_weight, v_weight = kv_b_proj.split( + [local_qk_nope_head_dim, local_v_head_dim], + dim=1, + ) + weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size + local_num_heads = num_heads // weight_divisor + + # Transpose k_nope_weight: [num_heads, qk_nope_head_dim, kv_lora_rank] + # -> [num_heads, kv_lora_rank, qk_nope_head_dim] + k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous() + + # Concatenate for context phase + kv_b_proj = torch.concat([ + k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim, + local_kv_lora_rank), + v_weight.reshape(local_num_heads * local_v_head_dim, + local_kv_lora_rank) + ], + dim=0) + + return kv_b_proj, k_nope_weight_trans + + def load_kv_b_proj_and_k_b_proj_for_nvfp4(module_name: str, + weight_name: str) -> torch.Tensor: + if weight_name == "weight": + local_qk_nope_head_dim = qk_nope_head_dim + local_v_head_dim = v_head_dim + local_kv_lora_rank = kv_lora_rank // 2 + elif weight_name == "weight_scale": + local_qk_nope_head_dim = qk_nope_head_dim + local_v_head_dim = v_head_dim + local_kv_lora_rank = kv_lora_rank // 16 + + kv_b_proj = weights[f"{module_name}.{weight_name}"][:].unflatten( + 0, + [ + num_heads, + local_qk_nope_head_dim + local_v_head_dim, + ], + ) + + if not self.model_config.mapping.enable_attention_dp: + kv_b_proj = split_matrix_tp(kv_b_proj, tp_size, tp_rank, 0) + k_nope_weight, v_weight = kv_b_proj.split( + [local_qk_nope_head_dim, local_v_head_dim], + dim=1, + ) + + weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size + local_num_heads = num_heads // weight_divisor + + kv_b_proj = torch.concat([ + k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim, + local_kv_lora_rank), + v_weight.reshape(local_num_heads * local_v_head_dim, + local_kv_lora_rank) + ], + dim=0) + + return kv_b_proj, k_nope_weight + def load_kv_b_proj_and_k_b_proj_trans_dequant( module_name: str) -> torch.Tensor: weight_name = "weight" @@ -260,6 +384,24 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, return k_b_proj, v_b_proj + def split_kv_b_proj_for_nvfp4_or_fp8_per_tensor(kv_b_proj: torch.Tensor) -> torch.Tensor: + local_qk_nope_head_dim = qk_nope_head_dim + local_v_head_dim = v_head_dim + + weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size + local_num_heads = num_heads // weight_divisor + + k_b_proj, v_b_proj = kv_b_proj.split([ + local_num_heads * local_qk_nope_head_dim, + local_num_heads * local_v_head_dim + ], + dim=0) + k_b_proj = k_b_proj.view( + [local_num_heads, local_qk_nope_head_dim, -1]) + v_b_proj = v_b_proj.view([local_num_heads, local_v_head_dim, -1]) + + return k_b_proj, v_b_proj + is_lite = self.config.q_lora_rank is None num_heads = self.config.num_attention_heads qk_nope_head_dim = self.config.qk_nope_head_dim @@ -288,42 +430,170 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, self.config.num_hidden_layers) name = '.'.join(names) if names[-1] == "kv_b_proj": - # TODO: remove weight_dequant after enabling fp8_bmm - dequant_kv_b_proj = self.model_config.quant_config.is_module_excluded_from_quantization( - names[-1]) - if dequant_kv_b_proj: - kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans_dequant( - name) + global_quant_config = self.model_config.get_quant_config() + if global_quant_config.quant_algo == QuantAlgo.MIXED_PRECISION: + # Check if this layer has a specific quant config + if self.model_config.quant_config_dict and name in self.model_config.quant_config_dict: + layer_quant_config = self.model_config.get_quant_config(name) + layer_quant_mode = layer_quant_config.layer_quant_mode + else: + # Layer is not quantized, use global quant mode (will go to else branch) + layer_quant_mode = global_quant_config.layer_quant_mode else: - kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans( - name, is_scale=False) - module.weight.data.copy_( - kv_b_proj.reshape(module.weight.shape)) - - attn_module = all_named_modules[parent_module_name] - _, v_b_proj = split_kv_b_proj(module.weight.data, - is_scale=False) - attn_module.v_b_proj = nn.Parameter(v_b_proj, - requires_grad=False) - - attn_module.k_b_proj_trans.data.copy_( - k_b_proj_trans.reshape( - attn_module.k_b_proj_trans.shape)) - - if getattr(module, "weight_scale", - None) is not None and not dequant_kv_b_proj: - kv_b_proj_scale, k_b_proj_trans_scale = load_kv_b_proj_and_k_b_proj_trans( - name, is_scale=True) - module.weight_scale.copy_( + layer_quant_mode = global_quant_config.layer_quant_mode + + nvfp4_kv_b_proj = layer_quant_mode.has_nvfp4() and weights[f"{name}.weight"].dtype == fp4_utils.float4_e2m1x2 + fp8_per_tensor_kv_b_proj = layer_quant_mode.has_fp8_qdq() and weights[f"{name}.weight"].dtype == torch.float8_e4m3fn + + + if nvfp4_kv_b_proj: + ########### input_scale + module.input_scale.data.copy_( + 1.0 / weights[f"{name}.input_scale"]) + E2M1_MAX = 6.0 + module.inv_input_scale.data.copy_(module.input_scale / + E2M1_MAX) + + ########### alpha + alpha = weights[f"{name}.input_scale"].float( + ) * weights[f"{name}.weight_scale_2"].float() + module.alpha.data.copy_(alpha) + module.scalar_alpha = alpha.item() + + ########### weights: kv_b_proj and k_b_proj and v_b_proj + # will transpose and copy k_b_proj later + # will copy v_b_proj later + kv_b_proj, k_b_proj = load_kv_b_proj_and_k_b_proj_for_nvfp4( + name, "weight") + _, v_b_proj = split_kv_b_proj_for_nvfp4_or_fp8_per_tensor(module.weight.data) + module.weight.data.copy_( + kv_b_proj.reshape(module.weight.shape)) + + ########### weight_scale: kv_b_proj_scale and k_b_proj_scale and v_b_proj_scale + # load and copy kv_b_proj_scale to module, because it is used in context phrase + kv_b_proj_scale, k_b_proj_scale = load_kv_b_proj_and_k_b_proj_for_nvfp4( + name, "weight_scale") + + _, v_b_proj_scale = split_kv_b_proj_for_nvfp4_or_fp8_per_tensor( + kv_b_proj_scale) + kv_b_proj_scale = torch.ops.trtllm.block_scale_interleave( + kv_b_proj_scale.view(fp4_utils.float4_sf_dtype)) + + module.weight_scale.data.copy_( kv_b_proj_scale.reshape(module.weight_scale.shape)) - attn_module.k_b_proj_trans_scale.copy_( - k_b_proj_trans_scale.reshape( - attn_module.k_b_proj_trans_scale.shape)) - _, v_b_proj_scale = split_kv_b_proj( - module.weight_scale.data, is_scale=True) - attn_module.v_b_proj_scale = nn.Parameter( - v_b_proj_scale, requires_grad=False) + ########### k_b_proj_trans and v_b_proj + k_b_proj_dequant_shape = (k_b_proj.shape[0], + k_b_proj.shape[1], + k_b_proj.shape[2] * 2) + v_b_proj_dequant_shape = (v_b_proj.shape[0], + v_b_proj.shape[1], + v_b_proj.shape[2] * 2) + + # dequantize and transpose k_b_proj + k_b_proj_trans = torch.ops.tensorrt_llm.e2m1_and_ufp8sf_scale_to_float_v2( + k_b_proj.reshape(-1, + k_b_proj.shape[-1]).contiguous(), + k_b_proj_scale.flatten().view( + fp4_utils.float4_sf_dtype).contiguous(), + weights[f"{name}.weight_scale_2"], + 16, + 1, + False, + ).to(dtype=torch.bfloat16).reshape( + k_b_proj_dequant_shape).transpose(2, 1) + + # dequantize v_b_proj + v_b_proj_original_device = v_b_proj.device + v_b_proj = torch.ops.tensorrt_llm.e2m1_and_ufp8sf_scale_to_float_v2( + v_b_proj.reshape( + -1, v_b_proj.shape[-1]).contiguous().cpu(), + v_b_proj_scale.flatten().view( + fp4_utils.float4_sf_dtype).contiguous(), + weights[f"{name}.weight_scale_2"], + 16, + 1, + False, + ).to(dtype=torch.bfloat16, + device=v_b_proj_original_device).reshape( + v_b_proj_dequant_shape) + + # copy BF16 k_b_proj_trans and v_b_proj to attn_module + attn_module = all_named_modules[parent_module_name] + attn_module.k_b_proj_trans.data.copy_( + k_b_proj_trans.reshape( + attn_module.k_b_proj_trans.shape)) + attn_module.v_b_proj = nn.Parameter(v_b_proj, + requires_grad=False) + elif fp8_per_tensor_kv_b_proj: + ##### for fp8 per tensor scaling + # Load weights for kv_b_proj + kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans_for_fp8_per_tensor(name) + module.weight.data.copy_( + kv_b_proj.reshape(module.weight.shape)) + + # Load weights for k_b_proj_trans + attn_module = all_named_modules[parent_module_name] + attn_module.k_b_proj_trans.data.copy_( + k_b_proj_trans.reshape( + attn_module.k_b_proj_trans.shape)) + + # Load weights for v_b_proj + _, v_b_proj = split_kv_b_proj_for_nvfp4_or_fp8_per_tensor(module.weight.data) + attn_module.v_b_proj = nn.Parameter(v_b_proj, + requires_grad=False) + + # Load weight_scale for kv_b_proj, k_b_proj_trans and v_b_proj + if f"{name}.weight_scale" in weights: + weight_scale = weights[f"{name}.weight_scale"] + module.weight_scale.data.copy_(weight_scale) + attn_module.k_b_proj_trans_scale.data.copy_(module.weight_scale) + attn_module.v_b_proj_scale.data.copy_(module.weight_scale) + + # Assume input_scale = 1.0 for kv_b_proj k_b_proj_trans and v_b_proj + module.input_scale.data.fill_(1.0) + module.inv_input_scale.data.fill_(1.0) + attn_module.k_b_proj_trans_input_scale.data.fill_(1.0) + attn_module.v_b_proj_input_scale.data.fill_(1.0) + + else: + ##### for fp8 block scaling + # TODO: remove weight_dequant after enabling fp8_bmm + dequant_kv_b_proj = self.model_config.quant_config.is_module_excluded_from_quantization( + names[-1]) + if dequant_kv_b_proj: + kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans_dequant( + name) + else: + kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans( + name, is_scale=False) + module.weight.data.copy_( + kv_b_proj.reshape(module.weight.shape)) + + attn_module = all_named_modules[parent_module_name] + _, v_b_proj = split_kv_b_proj(module.weight.data, + is_scale=False) + attn_module.v_b_proj = nn.Parameter(v_b_proj, + requires_grad=False) + + attn_module.k_b_proj_trans.data.copy_( + k_b_proj_trans.reshape( + attn_module.k_b_proj_trans.shape)) + + if getattr(module, "weight_scale", + None) is not None and not dequant_kv_b_proj: + kv_b_proj_scale, k_b_proj_trans_scale = load_kv_b_proj_and_k_b_proj_trans( + name, is_scale=True) + module.weight_scale.copy_( + kv_b_proj_scale.reshape(module.weight_scale.shape)) + attn_module.k_b_proj_trans_scale.copy_( + k_b_proj_trans_scale.reshape( + attn_module.k_b_proj_trans_scale.shape)) + + _, v_b_proj_scale = split_kv_b_proj( + module.weight_scale.data, is_scale=True) + attn_module.v_b_proj_scale = nn.Parameter( + v_b_proj_scale, requires_grad=False) if attn_module.k_b_proj_trans_dequant is not None: attn_module.k_b_proj_trans_dequant.data.copy_( @@ -346,26 +616,142 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, ).view(*attn_module.v_b_proj_dequant.shape).to( attn_module.v_b_proj_dequant.dtype)) elif names[-1] == "kv_a_proj_with_mqa": - fused_a = weights[ - f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:] + global_quant_config = self.model_config.get_quant_config() + if global_quant_config.quant_algo == QuantAlgo.MIXED_PRECISION: + # Check if this layer has a specific quant config + if self.model_config.quant_config_dict and name in self.model_config.quant_config_dict: + layer_quant_config = self.model_config.get_quant_config(name) + layer_quant_mode = layer_quant_config.layer_quant_mode + else: + # Layer is not quantized, use global quant mode (will go to else branch) + layer_quant_mode = global_quant_config.layer_quant_mode + else: + layer_quant_mode = global_quant_config.layer_quant_mode + + # Check if both q_a_proj and kv_a_proj_with_mqa are NVFP4 (or only kv_a_proj_with_mqa for lite mode) + kv_a_is_nvfp4 = weights[f"{name}.weight"].dtype == fp4_utils.float4_e2m1x2 if not is_lite: - q_a_proj = weights[ - f"{'.'.join(names[:-1])}.q_a_proj.weight"][:] - fused_a = torch.cat([q_a_proj, fused_a], dim=0) + q_a_is_nvfp4 = weights[f"{'.'.join(names[:-1])}.q_a_proj.weight"].dtype == fp4_utils.float4_e2m1x2 + nvfp4_fused_a = layer_quant_mode.has_nvfp4() and kv_a_is_nvfp4 and q_a_is_nvfp4 + else: + nvfp4_fused_a = layer_quant_mode.has_nvfp4() and kv_a_is_nvfp4 + + if nvfp4_fused_a: + ########### input_scale + kv_a_proj_with_mqa_input_scale = weights[ + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.input_scale"] + if not is_lite: + q_a_proj_input_scale = weights[ + f"{'.'.join(names[:-1])}.q_a_proj.input_scale"] + assert kv_a_proj_with_mqa_input_scale == q_a_proj_input_scale, "kv_a_proj_with_mqa.input_scale and q_a_proj.input_scale should be the same" + # modelopt ckpt stores amax/(448*6), convert to (448*6)/amax + shared_input_scale = kv_a_proj_with_mqa_input_scale + module.input_scale.data.copy_(1.0 / shared_input_scale) + E2M1_MAX = 6.0 + module.inv_input_scale.data.copy_(module.input_scale / + E2M1_MAX) + ########### weight_scale_2 + need_requant_kv_a_proj_with_mqa = False + need_requant_q_a_proj = False + kv_a_proj_with_mqa_scale_2 = weights[ + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_2"] + shared_weight_scale_2 = kv_a_proj_with_mqa_scale_2 + if not is_lite: + q_a_proj_scale_2 = weights[ + f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_2"] + if kv_a_proj_with_mqa_scale_2 < q_a_proj_scale_2: + shared_weight_scale_2 = q_a_proj_scale_2 + need_requant_kv_a_proj_with_mqa = True + elif q_a_proj_scale_2 < kv_a_proj_with_mqa_scale_2: + need_requant_q_a_proj = True + + ########### alpha + alpha = shared_input_scale.float( + ) * shared_weight_scale_2.float() + module.alpha.data.copy_(alpha) + module.scalar_alpha = alpha.item() + + ########### weights + kv_a_proj_with_mqa = weights[ + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:] - if f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv" in weights: - fused_a_scale = weights[ - f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv"] + if not is_lite: + q_a_proj = weights[ + f"{'.'.join(names[:-1])}.q_a_proj.weight"][:] + + ########### weight_scale + kv_a_proj_with_mqa_scale = weights[ + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale"][:] + kv_a_proj_with_mqa_scale = torch.ops.trtllm.block_scale_interleave( + kv_a_proj_with_mqa_scale.view( + fp4_utils.float4_sf_dtype)) if not is_lite: q_a_proj_scale = weights[ - f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_inv"][:] - fused_a_scale = torch.cat( - [q_a_proj_scale, fused_a_scale], dim=0) + f"{'.'.join(names[:-1])}.q_a_proj.weight_scale"][:] + q_a_proj_scale = torch.ops.trtllm.block_scale_interleave( + q_a_proj_scale.view(fp4_utils.float4_sf_dtype)) + + ########### requantize + if need_requant_kv_a_proj_with_mqa: + # requant kv_a_proj_with_mqa + kv_a_proj_with_mqa, kv_a_proj_with_mqa_scale = requantize_weight_with_new_scale( + kv_a_proj_with_mqa, + kv_a_proj_with_mqa_scale, + kv_a_proj_with_mqa_scale_2, + shared_weight_scale_2, + device=module.weight.device, + ) + if need_requant_q_a_proj: + # requant q_a_proj + q_a_proj, q_a_proj_scale = requantize_weight_with_new_scale( + q_a_proj, + q_a_proj_scale, + q_a_proj_scale_2, + shared_weight_scale_2, + device=module.weight.device) + + ########### fuse and load weights + if not is_lite: + fused_a = torch.cat([q_a_proj, kv_a_proj_with_mqa], + dim=0) + else: + fused_a = kv_a_proj_with_mqa + + # For DeepseekV32 with fuse_a_indexer_k_weight=True: kv_a_proj_with_mqa is oversized + # to include indexer weights, which is filled in post_load_weights. + module.weight.data[0:fused_a.shape[0]].copy_(fused_a) - module.weight_scale.data.copy_(fused_a_scale) - # For DeepseekV32 with fuse_a_indexer_k_weight=True: kv_a_proj_with_mqa is oversized - # to include indexer weights, which is filled in post_load_weights. - module.weight.data[0:fused_a.shape[0]].copy_(fused_a) + ########### fuse weight_scale + if not is_lite: + fused_a_scale = torch.cat( + [q_a_proj_scale, kv_a_proj_with_mqa_scale], + dim=0) + else: + fused_a_scale = kv_a_proj_with_mqa_scale + # For DeepseekV32 with fuse_a_indexer_k_weight=True: kv_a_proj_with_mqa is oversized + # to include indexer weights, which is filled in post_load_weights. + module.weight_scale[0:fused_a_scale.shape[0]].data.copy_(fused_a_scale) + else: + fused_a = weights[ + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:] + if not is_lite: + q_a_proj = weights[ + f"{'.'.join(names[:-1])}.q_a_proj.weight"][:] + fused_a = torch.cat([q_a_proj, fused_a], dim=0) + + if f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv" in weights: + fused_a_scale = weights[ + f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight_scale_inv"] + if not is_lite: + q_a_proj_scale = weights[ + f"{'.'.join(names[:-1])}.q_a_proj.weight_scale_inv"][:] + fused_a_scale = torch.cat( + [q_a_proj_scale, fused_a_scale], dim=0) + + module.weight_scale.data.copy_(fused_a_scale) + # For DeepseekV32 with fuse_a_indexer_k_weight=True: kv_a_proj_with_mqa is oversized + # to include indexer weights, which is filled in post_load_weights. + module.weight.data[0:fused_a.shape[0]].copy_(fused_a) elif names[-1] in params_map: module_weights = [] for new_name in params_map[names[-1]]: @@ -1386,6 +1772,7 @@ def __init__(self, dtype=config.torch_dtype, skip_create_weights_in_init=model_config. skip_create_weights_in_init, + quant_config=model_config.get_quant_config(), ) else: self.eh_proj = Linear( @@ -1398,6 +1785,7 @@ def __init__(self, reduce_output=True, skip_create_weights_in_init=model_config. skip_create_weights_in_init, + quant_config=model_config.get_quant_config(), ) self.shared_head = DeepseekV3MTPHead(model_config) @@ -1432,7 +1820,7 @@ def norm_hidden(): tp_rank = self.model_config.mapping.tp_rank if tp_size > 1 and not (self.model_config.mapping.enable_attention_dp): - hidden_states = torch.chunk(hidden_states, tp_size, dim=-1)[tp_rank] + hidden_states = torch.chunk(hidden_states, tp_size,dim=-1)[tp_rank].contiguous() hidden_states = self.eh_proj(hidden_states) # Input layer norm diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 7b5c5e429c0..4db2cc8f9e9 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -629,6 +629,20 @@ def mla_custom_op_inplace( latent_cache_gen=latent_cache_gen) +# Global flag to track if auto_deploy ops are loaded +_AUTO_DEPLOY_OPS_LOADED = False + +def _ensure_auto_deploy_ops_loaded(): + """Lazy import auto_deploy custom ops to avoid circular import. Only imports once.""" + global _AUTO_DEPLOY_OPS_LOADED + if not _AUTO_DEPLOY_OPS_LOADED: + try: + import tensorrt_llm._torch.auto_deploy.custom_ops.quant # noqa: F401 + _AUTO_DEPLOY_OPS_LOADED = True + except ImportError: + pass # auto_deploy ops not available + + def fp8_block_scaling_bmm_out( mat1: torch.Tensor, mat2_fp8: torch.Tensor, @@ -964,11 +978,15 @@ def create_weights(self): # k_b_proj_trans's dtype must be consistent with self.kv_b_proj, # which can be modified after __init__ - has_fp8_block_scales = ( + kv_b_proj_has_fp8_block_scales = ( self.kv_b_proj.quant_config and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales()) - mla_weight_dtype = torch.float8_e4m3fn if has_fp8_block_scales else self.dtype + kv_b_proj_has_fp8_per_tensor_scales = ( + self.kv_b_proj.quant_config + and self.kv_b_proj.quant_config.quant_mode.has_fp8_qdq()) + + mla_weight_dtype = torch.float8_e4m3fn if kv_b_proj_has_fp8_block_scales or kv_b_proj_has_fp8_per_tensor_scales else self.dtype self.k_b_proj_trans = nn.Parameter( torch.empty( (self.num_heads_tp, self.kv_lora_rank, self.qk_nope_head_dim), @@ -979,7 +997,10 @@ def create_weights(self): self.k_b_proj_trans_dequant = None self.v_b_proj_dequant = None - if has_fp8_block_scales: + self.k_b_proj_trans_input_scale = None + self.v_b_proj_input_scale = None + + if kv_b_proj_has_fp8_block_scales: self.k_b_proj_trans_scale = nn.Parameter( torch.empty( ( @@ -1022,6 +1043,25 @@ def create_weights(self): ), requires_grad=False, ) + elif kv_b_proj_has_fp8_per_tensor_scales: + # weight scales + self.k_b_proj_trans_scale = nn.Parameter( + torch.tensor(1.0, dtype=torch.float32), + requires_grad=False, + ) + self.v_b_proj_scale = nn.Parameter( + torch.tensor(1.0, dtype=torch.float32), + requires_grad=False, + ) + # input scales + self.k_b_proj_trans_input_scale = nn.Parameter( + torch.tensor(1.0, dtype=torch.float32), + requires_grad=False, + ) + self.v_b_proj_input_scale = nn.Parameter( + torch.tensor(1.0, dtype=torch.float32), + requires_grad=False, + ) else: self.k_b_proj_trans_scale = None self.v_b_proj_scale = None @@ -1709,6 +1749,14 @@ def forward_absorption_generation( device=q.device, ) + kv_b_proj_has_fp8_block_scales = ( + self.kv_b_proj.quant_config + and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales()) + + kv_b_proj_has_fp8_per_tensor_scales = ( + self.kv_b_proj.quant_config + and self.kv_b_proj.quant_config.quant_mode.has_fp8_qdq()) + rope_stream = self.aux_stream if not has_fp8_kv_cache else None if self.k_b_proj_trans.dtype == torch.bfloat16: # [num_heads, num_tokens, self.qk_nope_head_dim] @@ -1731,7 +1779,7 @@ def forward_absorption_generation( rope_stream, ) - elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn: + elif kv_b_proj_has_fp8_block_scales: # [num_heads, num_tokens, self.kv_lora_rank] q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) @@ -1751,6 +1799,30 @@ def forward_absorption_generation( self.ln_events[1], rope_stream, ) + elif kv_b_proj_has_fp8_per_tensor_scales: + # [num_heads, num_tokens, self.kv_lora_rank] + q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) + + # Ensure auto_deploy ops are loaded + _ensure_auto_deploy_ops_loaded() + + maybe_execute_in_parallel( + lambda: q_nope_out.copy_( + torch.ops.auto_deploy.torch_quant_fp8_bmm( + q_nope.transpose(0, 1), # [num_heads, num_tokens, qk_nope_head_dim] + self.k_b_proj_trans.transpose(1, 2), # [num_heads, qk_nope_head_dim, kv_lora_rank] + input_scale=self.k_b_proj_trans_input_scale, + weight_scale=self.k_b_proj_trans_scale, + ) + ), + lambda: self.mqa.mla_rope_generation( + fused_q, q_pe, latent_cache, attn_metadata, cu_q_seqlens, + cu_kv_seqlens, fmha_scheduler_counter, mla_bmm1_scale, + mla_bmm2_scale, quant_q_buffer), + self.ln_events[0], + self.ln_events[1], + rope_stream, + ) else: raise NotImplementedError( f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.") @@ -1804,7 +1876,7 @@ def forward_absorption_generation( torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1), self.v_b_proj.transpose(1, 2), attn_output.transpose(0, 1)) - elif self.v_b_proj.dtype == torch.float8_e4m3fn: + elif kv_b_proj_has_fp8_block_scales: fp8_block_scaling_bmm_out( attn_out_latent, self.v_b_proj, @@ -1812,6 +1884,19 @@ def forward_absorption_generation( attn_output.transpose(0, 1), self.v_b_proj_dequant, ) + elif kv_b_proj_has_fp8_per_tensor_scales: + # Ensure auto_deploy ops are loaded + _ensure_auto_deploy_ops_loaded() + + # Call FP8 per-tensor BMM and copy result to pre-allocated buffer + attn_output.transpose(0, 1).copy_( + torch.ops.auto_deploy.torch_quant_fp8_bmm( + attn_out_latent.transpose(0, 1), # [num_heads, seq, kv_lora_rank] + self.v_b_proj.transpose(1, 2), # [num_heads, kv_lora_rank, v_head_dim] + input_scale=self.v_b_proj_input_scale, + weight_scale=self.v_b_proj_scale, + ) + ) else: raise NotImplementedError( f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.") @@ -1844,6 +1929,14 @@ def forward_absorption_context( device=q.device, ) + kv_b_proj_has_fp8_block_scales = ( + self.kv_b_proj.quant_config + and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales()) + + kv_b_proj_has_fp8_per_tensor_scales = ( + self.kv_b_proj.quant_config + and self.kv_b_proj.quant_config.quant_mode.has_fp8_qdq()) + if self.k_b_proj_trans.dtype == torch.bfloat16: # [num_heads, num_tokens, self.qk_nope_head_dim] q_nope_t = q_nope.transpose(0, 1) @@ -1856,7 +1949,7 @@ def forward_absorption_context( torch.ops.trtllm.bmm_out(q_nope_t, self.k_b_proj_trans.transpose(1, 2), q_nope_out) - elif self.k_b_proj_trans.dtype == torch.float8_e4m3fn: + elif kv_b_proj_has_fp8_block_scales: # [num_heads, num_tokens, self.kv_lora_rank] q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) @@ -1867,6 +1960,22 @@ def forward_absorption_context( q_nope_out, self.k_b_proj_trans_dequant, ) + elif kv_b_proj_has_fp8_per_tensor_scales: + # [num_heads, num_tokens, self.kv_lora_rank] + q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1) + + # Ensure auto_deploy ops are loaded + _ensure_auto_deploy_ops_loaded() + + # Call FP8 per-tensor BMM and copy result to pre-allocated buffer + q_nope_out.copy_( + torch.ops.auto_deploy.torch_quant_fp8_bmm( + q_nope.transpose(0, 1), # [num_heads, num_tokens, qk_nope_head_dim] + self.k_b_proj_trans.transpose(1, 2), # [num_heads, qk_nope_head_dim, kv_lora_rank] + input_scale=self.k_b_proj_trans_input_scale, + weight_scale=self.k_b_proj_trans_scale, + ) + ) else: raise NotImplementedError( f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.") @@ -1914,7 +2023,7 @@ def forward_absorption_context( torch.ops.trtllm.bmm_out(attn_out_latent.transpose(0, 1), self.v_b_proj.transpose(1, 2), attn_output.transpose(0, 1)) - elif self.v_b_proj.dtype == torch.float8_e4m3fn: + elif kv_b_proj_has_fp8_block_scales: fp8_block_scaling_bmm_out( attn_out_latent, self.v_b_proj, @@ -1922,6 +2031,19 @@ def forward_absorption_context( attn_output.transpose(0, 1), self.v_b_proj_dequant, ) + elif kv_b_proj_has_fp8_per_tensor_scales: + # Ensure auto_deploy ops are loaded + _ensure_auto_deploy_ops_loaded() + + # Call FP8 per-tensor BMM and copy result to pre-allocated buffer + attn_output.transpose(0, 1).copy_( + torch.ops.auto_deploy.torch_quant_fp8_bmm( + attn_out_latent.transpose(0, 1), # [num_heads, seq, kv_lora_rank] + self.v_b_proj.transpose(1, 2), # [num_heads, kv_lora_rank, v_head_dim] + input_scale=self.v_b_proj_input_scale, + weight_scale=self.v_b_proj_scale, + ) + ) else: raise NotImplementedError( f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.")