diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 9f4e233aad7..75a8a472bf7 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -444,16 +444,14 @@ def load_weights_fused_qkv_linear(self, module: Linear, copy_weight(module.weight_scale, max(weight_scale)) - q_weight = q_weight.to(module.dtype) * weight_scale[0] - k_weight = k_weight.to(module.dtype) * weight_scale[1] - v_weight = v_weight.to(module.dtype) * weight_scale[2] + # use in-place multiplication and division to avoid extra memory allocation + q_weight = q_weight.to(module.dtype).mul_(weight_scale[0]) + k_weight = k_weight.to(module.dtype).mul_(weight_scale[1]) + v_weight = v_weight.to(module.dtype).mul_(weight_scale[2]) fused_weight = torch.cat((q_weight, k_weight, v_weight)) - if module.weight_scale.device != fused_weight.device: - module.weight_scale = Parameter( - module.weight_scale.data.to(fused_weight.device)) - fused_weight = (fused_weight / module.weight_scale).to( - torch.float8_e4m3fn) + fused_weight = fused_weight.div_( + module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn) copy_weight(module.weight, fused_weight) # Load k and v scales, used for NVFP4 KV cache @@ -486,14 +484,12 @@ def load_weights_fused_gate_up_linear(self, module: Linear, gate_weight, up_weight = load_weights_fused_gate_up_helper( module, weights) - gate_weight = gate_weight.to(module.dtype) * weight_scale[0] - up_weight = up_weight.to(module.dtype) * weight_scale[1] + # use in-place multiplication and division to avoid extra memory allocation + gate_weight = gate_weight.to(module.dtype).mul_(weight_scale[0]) + up_weight = up_weight.to(module.dtype).mul_(weight_scale[1]) fused_weight = torch.cat((gate_weight, up_weight)) - if module.weight_scale.device != fused_weight.device: - module.weight_scale = Parameter( - module.weight_scale.data.to(fused_weight.device)) - fused_weight = (fused_weight / module.weight_scale).to( - torch.float8_e4m3fn) + fused_weight = fused_weight.div_( + module.weight_scale.to(fused_weight.device)).to(torch.float8_e4m3fn) copy_weight(module.weight, fused_weight)