Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support float8 dtype storage and deepseek v3 with fp8 inference. #9906

Open
wants to merge 18 commits into
base: develop
Choose a base branch
from
Open
57 changes: 33 additions & 24 deletions paddlenlp/mergekit/merge_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
class MergeModel:
def __init__(self, merge_config):
self.reset_merge_model(merge_config=merge_config)
self.numpy_dtype_map = {"float32": 4, "float16": 2, "uint16": 2}
self.numpy_dtype_map = {"float32": 4, "float16": 2, "uint16": 2, "bfloat16": 2}
self.is_peft = False

def reset_merge_model(self, merge_config=None, merge_param_dict=None):
Expand Down Expand Up @@ -151,7 +151,7 @@
)
for key in local_keys:
# Tensor preprocess
is_bf16 = str(state_dict_list[0][key].dtype) == "uint16"
is_bf16 = str(state_dict_list[0][key].dtype) in ["uint16", "bfloat16"]
tensor_list = [state_dict_list[i].pop(key) for i in range(model_num)]
tensor_mem = int(np.prod(tensor_list[0].shape) * self.numpy_dtype_map[str(tensor_list[0].dtype)]) / (
1024**3
Expand All @@ -165,10 +165,10 @@
tensor_list = [tensor_split[sp] for tensor_split in tensor_split_list]
if is_bf16:
tensor_list = [
paddle.Tensor(tensor, zero_copy=True).astype("float32") for tensor in tensor_list
paddle.Tensor.__call__(tensor, zero_copy=True).astype("float32") for tensor in tensor_list
]
else:
tensor_list = [paddle.Tensor(tensor, zero_copy=True) for tensor in tensor_list]
tensor_list = [paddle.Tensor.__call__(tensor, zero_copy=True) for tensor in tensor_list]

Check warning on line 171 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L171

Added line #L171 was not covered by tests
if self.merge_config.base_model_path is not None:
base_tensor = tensor_list.pop()
tensor_list = [tensor - base_tensor for tensor in tensor_list]
Expand All @@ -184,18 +184,20 @@
if self.merge_config.tensor_type == "pd":
if is_bf16:
tensor_list = [
paddle.Tensor(tensor, zero_copy=True).astype("float32") for tensor in tensor_list
paddle.Tensor.__call__(tensor, zero_copy=True).astype("float32") for tensor in tensor_list
]
else:
tensor_list = [paddle.Tensor(tensor, zero_copy=True) for tensor in tensor_list]
tensor_list = [paddle.Tensor.__call__(tensor, zero_copy=True) for tensor in tensor_list]

Check warning on line 190 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L190

Added line #L190 was not covered by tests
elif self.merge_config.tensor_type == "np" and is_bf16:
tensor_list = [
paddle.Tensor(tensor, zero_copy=True).astype("float32").numpy() for tensor in tensor_list
paddle.Tensor.__call__(tensor, zero_copy=True).astype("float32").numpy()
for tensor in tensor_list
]

if self.merge_config.base_model_path is not None:
base_tensor = tensor_list.pop()
tensor_list = [tensor - base_tensor for tensor in tensor_list]

merge_tensor = self.merge_method.merge(tensor_list)
if self.merge_config.base_model_path is not None:
merge_tensor += base_tensor
Expand All @@ -206,7 +208,9 @@
merge_state_dict[key] = merge_tensor.numpy()
elif self.merge_config.tensor_type == "np" and is_bf16:
# dtype==bfloat16: numpy(float32) -> paddle(float32) -> paddle(bfloat16) -> numpy(uint16)
merge_state_dict[key] = paddle.Tensor(merge_tensor, zero_copy=True).astype("bfloat16").numpy()
merge_state_dict[key] = (
paddle.Tensor.__call__(merge_tensor, zero_copy=True).astype("bfloat16").numpy()
)

# Save safetensor file
save_file(
Expand Down Expand Up @@ -389,7 +393,7 @@
dtype = tensor.dtype
# dtype==bfloat16: numpy(uint16) -> paddle(bfloat16) -> paddle(float32) -> numpy(float32)
if tensor.dtype == np.uint16:
tensor = paddle.Tensor(tensor, zero_copy=True).astype("float32").numpy()
tensor = paddle.Tensor.__call__(tensor, zero_copy=True).astype("float32").numpy()

Check warning on line 396 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L396

Added line #L396 was not covered by tests
tensor_list.append(tensor)
if self.merge_config.base_model_path is not None:
with fast_safe_open(
Expand All @@ -398,14 +402,16 @@
) as w:
base_tensor = w.get_tensor(k)
if base_tensor.dtype == np.uint16:
base_tensor = paddle.Tensor(base_tensor, zero_copy=True).astype("float32").numpy()
base_tensor = paddle.Tensor.__call__(base_tensor, zero_copy=True).astype("float32").numpy()

Check warning on line 405 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L405

Added line #L405 was not covered by tests
tensor_list = [tensor - base_tensor for tensor in tensor_list]
merge_state_dict[k] = self.merge_method.merge(tensor_list)
if self.merge_config.base_model_path is not None:
merge_state_dict[k] += base_tensor
# dtype==bfloat16: numpy(float32) -> paddle(float32) -> paddle(bfloat16) -> numpy(uint16)
if dtype == np.uint16:
merge_state_dict[k] = paddle.Tensor(merge_state_dict[k], zero_copy=True).astype("bfloat16").numpy()
merge_state_dict[k] = (

Check warning on line 412 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L412

Added line #L412 was not covered by tests
paddle.Tensor.__call__(merge_state_dict[k], zero_copy=True).astype("bfloat16").numpy()
)
save_file(
merge_state_dict,
os.path.join(self.merge_config.output_path, shard_file),
Expand All @@ -430,7 +436,7 @@
framework="np",
) as w:
tensor_list.append(w.get_tensor(k))
is_bf16 = str(tensor_list[0].dtype) == "uint16"
is_bf16 = str(tensor_list[0].dtype) in ["uint16", "bfloat16"]

Check warning on line 439 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L439

Added line #L439 was not covered by tests
tensor_mem = int(np.prod(tensor_list[0].shape) * self.numpy_dtype_map[str(tensor_list[0].dtype)]) / (
1024**3
)
Expand All @@ -443,10 +449,10 @@
tensor_list = [tensor_split[sp] for tensor_split in tensor_split_list]
if is_bf16:
tensor_list = [
paddle.Tensor(tensor, zero_copy=True).astype("float32") for tensor in tensor_list
paddle.Tensor.__call__(tensor, zero_copy=True).astype("float32") for tensor in tensor_list
]
else:
tensor_list = [paddle.Tensor(tensor, zero_copy=True) for tensor in tensor_list]
tensor_list = [paddle.Tensor.__call__(tensor, zero_copy=True) for tensor in tensor_list]

Check warning on line 455 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L455

Added line #L455 was not covered by tests
if self.merge_config.base_model_path is not None:
base_tensor = tensor_list.pop()
tensor_list = [tensor - base_tensor for tensor in tensor_list]
Expand All @@ -460,9 +466,11 @@
merge_state_dict[k] = np.concatenate(merge_split, axis=0)
else:
if is_bf16:
tensor_list = [paddle.Tensor(tensor, zero_copy=True).astype("float32") for tensor in tensor_list]
tensor_list = [

Check warning on line 469 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L469

Added line #L469 was not covered by tests
paddle.Tensor.__call__(tensor, zero_copy=True).astype("float32") for tensor in tensor_list
]
else:
tensor_list = [paddle.Tensor(tensor, zero_copy=True) for tensor in tensor_list]
tensor_list = [paddle.Tensor.__call__(tensor, zero_copy=True) for tensor in tensor_list]

Check warning on line 473 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L473

Added line #L473 was not covered by tests
if self.merge_config.base_model_path is not None:
base_tensor = tensor_list.pop()
tensor_list = [tensor - base_tensor for tensor in tensor_list]
Expand Down Expand Up @@ -554,10 +562,10 @@
lora_A_tensor = None
if lora_state_dict is not None and lora_A_key in lora_state_dict.keys():
lora_A_tensor, lora_B_tensor = lora_state_dict.pop(lora_A_key), lora_state_dict.pop(lora_B_key)
is_bf16 = tensor.dtype == np.uint16
tensor = paddle.Tensor(tensor, zero_copy=True)
lora_A_tensor = paddle.Tensor(lora_A_tensor, zero_copy=True)
lora_B_tensor = paddle.Tensor(lora_B_tensor, zero_copy=True)
is_bf16 = str(tensor.dtype) in ["uint16", "bfloat16"]
tensor = paddle.Tensor.__call__(tensor, zero_copy=True)
lora_A_tensor = paddle.Tensor.__call__(lora_A_tensor, zero_copy=True)
lora_B_tensor = paddle.Tensor.__call__(lora_B_tensor, zero_copy=True)

Check warning on line 568 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L565-L568

Added lines #L565 - L568 were not covered by tests
if self.is_cpu and is_bf16:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里替换__call__函数的原因是什么?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle.Tensor 不支持 初始化 FP8 tensor,临时切换接口支持。

tensor = tensor.astype("float32")
lora_A_tensor = lora_A_tensor.astype("float32")
Expand Down Expand Up @@ -685,10 +693,11 @@
if lora_A_key in lora_state_dict.keys():
lora_A_tensor = lora_state_dict[lora_A_key]
lora_B_tensor = lora_state_dict[lora_B_key]
is_bf16 = tensor.dtype == np.uint16
tensor = paddle.Tensor(tensor, zero_copy=True)
lora_A_tensor = paddle.Tensor(lora_A_tensor, zero_copy=True)
lora_B_tensor = paddle.Tensor(lora_B_tensor, zero_copy=True)
is_bf16 = str(tensor.dtype) in ["uint16", "bfloat16"]

Check warning on line 696 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L696

Added line #L696 was not covered by tests

tensor = paddle.Tensor.__call__(tensor, zero_copy=True)
lora_A_tensor = paddle.Tensor.__call__(lora_A_tensor, zero_copy=True)
lora_B_tensor = paddle.Tensor.__call__(lora_B_tensor, zero_copy=True)

Check warning on line 700 in paddlenlp/mergekit/merge_model.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/mergekit/merge_model.py#L698-L700

Added lines #L698 - L700 were not covered by tests
if self.is_cpu and is_bf16:
tensor = tensor.astype("float32")
lora_A_tensor = lora_A_tensor.astype("float32")
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/deepseek_v2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.speculate_model_type = speculate_model_type
self.use_fp8 = False

Check warning on line 227 in paddlenlp/transformers/deepseek_v2/configuration.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/configuration.py#L227

Added line #L227 was not covered by tests

super().__init__(
pad_token_id=pad_token_id,
Expand Down
138 changes: 138 additions & 0 deletions paddlenlp/transformers/deepseek_v2/fp8_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

original_linear = paddle.nn.functional.linear

from typing import Literal, Optional

# from ..linear_utils import RowParallelLinear as PD_RowParallelLinear
from ..linear_utils import ColumnParallelLinear as PD_ColumnParallelLinear
from ..linear_utils import (
ColumnSequenceParallelLinear as PD_ColumnSequenceParallelLinear,
)
from ..linear_utils import Linear as PD_Linear
from ..linear_utils import RowParallelLinear as PD_RowParallelLinear
from ..linear_utils import RowSequenceParallelLinear as PD_RowSequenceParallelLinear

try:
from .kernel import act_quant, fp8_gemm, weight_dequant
except:
pass

Check warning on line 33 in paddlenlp/transformers/deepseek_v2/fp8_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/fp8_linear.py#L32-L33

Added lines #L32 - L33 were not covered by tests


__all__ = [
"Linear",
"ColumnParallelLinear",
"RowParallelLinear",
"ColumnSequenceParallelLinear",
"RowSequenceParallelLinear",
]

gemm_impl: Literal["bf16", "fp8"] = "bf16"
block_size = 128


def fp8_linear(
x: paddle.Tensor, weight: paddle.Tensor, bias: Optional[paddle.Tensor] = None, name=None
) -> paddle.Tensor:
"""
Applies a linear transformation to the incoming data: y = xA^T + b.
This function supports specialized implementations based on quantization
and tensor formats.

Args:
x (paddle.Tensor): The input tensor.
weight (paddle.Tensor): The weight tensor. It may be quantized and
requires dequantization for certain cases.
bias (Optional[paddle.Tensor]): The bias tensor to be added. Default is None.

Returns:
paddle.Tensor: The result of the linear transformation, which may involve
quantization-aware computations depending on the input parameters.

Notes:
- If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
is used for computation.
- If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
"""

if paddle.in_dynamic_mode():
if weight.element_size() > 1:
return original_linear(x, weight, bias)
elif gemm_impl == "bf16":
weight = weight_dequant(weight, weight._scale)
return original_linear(x, weight, bias)

Check warning on line 78 in paddlenlp/transformers/deepseek_v2/fp8_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/fp8_linear.py#L76-L78

Added lines #L76 - L78 were not covered by tests
else:
x, scale = act_quant(x, block_size)
y = fp8_gemm(x, scale, weight, weight._scale)
if bias is not None:
y += bias
return y

Check warning on line 84 in paddlenlp/transformers/deepseek_v2/fp8_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/fp8_linear.py#L80-L84

Added lines #L80 - L84 were not covered by tests
else:
return original_linear(x, weight, bias)

Check warning on line 86 in paddlenlp/transformers/deepseek_v2/fp8_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/fp8_linear.py#L86

Added line #L86 was not covered by tests


paddle.nn.functional.linear = fp8_linear


def register_scale(self):
if self.weight.element_size() == 1:
in_features, out_features = self.weight.shape
scale_out_features = (out_features + self.block_size - 1) // self.block_size
scale_in_features = (in_features + self.block_size - 1) // self.block_size
self.weight_scale_inv = self.create_parameter(

Check warning on line 97 in paddlenlp/transformers/deepseek_v2/fp8_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/fp8_linear.py#L93-L97

Added lines #L93 - L97 were not covered by tests
shape=[scale_in_features, scale_out_features],
attr=self._weight_attr,
dtype="float32",
is_bias=False,
)
self.weight._scale = self.weight_scale_inv

Check warning on line 103 in paddlenlp/transformers/deepseek_v2/fp8_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/fp8_linear.py#L103

Added line #L103 was not covered by tests


class Linear(PD_Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_size = kwargs.get("block_size", 128)
register_scale(self)

Check warning on line 110 in paddlenlp/transformers/deepseek_v2/fp8_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/fp8_linear.py#L108-L110

Added lines #L108 - L110 were not covered by tests


class ColumnParallelLinear(PD_ColumnParallelLinear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_size = kwargs.get("block_size", 128)
register_scale(self)

Check warning on line 117 in paddlenlp/transformers/deepseek_v2/fp8_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/fp8_linear.py#L115-L117

Added lines #L115 - L117 were not covered by tests


class RowParallelLinear(PD_RowParallelLinear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_size = kwargs.get("block_size", 128)
register_scale(self)

Check warning on line 124 in paddlenlp/transformers/deepseek_v2/fp8_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/fp8_linear.py#L122-L124

Added lines #L122 - L124 were not covered by tests


class ColumnSequenceParallelLinear(PD_ColumnSequenceParallelLinear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_size = kwargs.get("block_size", 128)
register_scale(self)

Check warning on line 131 in paddlenlp/transformers/deepseek_v2/fp8_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/fp8_linear.py#L129-L131

Added lines #L129 - L131 were not covered by tests


class RowSequenceParallelLinear(PD_RowSequenceParallelLinear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.block_size = kwargs.get("block_size", 128)
register_scale(self)

Check warning on line 138 in paddlenlp/transformers/deepseek_v2/fp8_linear.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/fp8_linear.py#L136-L138

Added lines #L136 - L138 were not covered by tests
Loading
Loading