Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support float8 dtype storage and deepseek v3 with fp8 inference. #9906

Open
wants to merge 18 commits into
base: develop
Choose a base branch
from
Open
21 changes: 15 additions & 6 deletions paddlenlp/utils/safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@
import mmap
from collections import OrderedDict

import ml_dtypes
import numpy as np

__all__ = [
"fast_safe_open",
"fast_load_file",
]

np.bfloat16 = ml_dtypes.bfloat16
np.float8_e5m2 = ml_dtypes.float8_e5m2
np.float8_e4m3fn = ml_dtypes.float8_e4m3fn


MAX_HEADER_SIZE = 100 * 1000 * 1000

Expand All @@ -49,16 +54,16 @@
"BOOL": np.bool_,
"U8": np.uint8,
"I8": np.int8,
"F8_E5M2": 1, # no fp8
"F8_E4M3": 1, # no fp8
"F8_E5M2": np.float8_e5m2, # no fp8
"F8_E4M3": np.float8_e4m3fn, # no fp8
"I16": np.int16,
"U16": np.uint16,
"I32": np.int32,
"U32": np.uint32,
"I64": np.int64,
"U64": np.uint64,
"F16": np.float16,
"BF16": 2, # no bf16
"BF16": np.bfloat16, # no bf16
"F32": np.float32,
"F64": np.float64,
}
Expand Down Expand Up @@ -238,9 +243,13 @@ def __getitem__(self, index):
return tensor.reshape(target_shape)

def get(self, *args, **kwargs):
tensor = np.empty(shape=self.shape, dtype=self.dtype)
self.bufferfile.seek(self.start_offset)
self.bufferfile.readinto(memoryview(tensor))
# tensor = np.empty(shape=self.shape, dtype=self.dtype)
# self.bufferfile.seek(self.start_offset)
# self.bufferfile.readinto(memoryview(tensor))
# int fix for empty shape []
nbytes = int(np.prod(self.shape)) * np.dtype(self.dtype).itemsize
buffer = self.bufferfile.read(nbytes)
tensor = np.frombuffer(buffer, dtype=self.dtype).reshape(self.shape)
return tensor

@property
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ regex
numpy<=1.26.4
tiktoken
tokenizers>=0.21,<0.22
ml_dtypes
omegaconf
126 changes: 113 additions & 13 deletions tests/transformers/test_safetensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,118 @@
import unittest

import numpy as np
import paddle
from safetensors.numpy import load_file, save_file

from paddlenlp.utils.safetensors import fast_load_file, fast_safe_open

from ..testing_utils import skip_platform

paddle.set_device("cpu")


def enhanced_to_tensor(tensor):
if tensor.dtype == np.bfloat16:
return paddle.to_tensor(tensor.view(np.uint16))
if tensor.dtype == np.float8_e5m2:
t = paddle.to_tensor(tensor.view(np.int8))
new_t = paddle.empty(t.shape, dtype=paddle.float8_e5m2)
new_t.get_tensor()._share_data_with(t.get_tensor())
return new_t
if tensor.dtype == np.float8_e4m3fn:
t = paddle.to_tensor(tensor.view(np.int8))
new_t = paddle.empty(t.shape, dtype=paddle.float8_e4m3fn)
new_t.get_tensor()._share_data_with(t.get_tensor())
return new_t
# return paddle.to_tensor(tensor.view(np.int8), dtype=paddle.float8_e4m3fn)
return paddle.to_tensor(tensor)


class EextendDtypeNumpySafe(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extend

def setUp(self):
super().setUp()
self.weight_map = {}
self.tensors = [
([10, 1, 10], "float32"),
([1, 1, 10], "float32"),
([1, 1, 1, 10], "float32"),
([10, 10], "float32"),
([8], "float16"),
([5, 5, 5], "int32"),
]

def get_target_dtype(self, dtype="float32"):
count = 0
weight_map = {}
for shape, _ in self.tensors:
weight_map[f"weight_{count}"] = (np.random.random(shape) * 100).astype(dtype)
count += 1
return weight_map

def get_paddle_target_dtype(self, dtype="float32"):
weight_map = self.get_target_dtype(dtype)
for k, v in list(weight_map.items()):
weight_map[k] = enhanced_to_tensor(v)
return weight_map

@skip_platform("win32", "cygwin")
def test_save_load_file_paddle(self):
with tempfile.TemporaryDirectory() as tmpdirname:
for dtype in ["bfloat16", "float8_e5m2", "float8_e4m3fn"]:
weight_map = self.get_paddle_target_dtype(dtype)
path = os.path.join(tmpdirname, "test.safetensors")
shard = {}
for k in list(weight_map.keys()):
if isinstance(weight_map[k], paddle.Tensor):
shard[k] = weight_map[k].cpu().numpy()
else:
shard[k] = weight_map[k]

save_file(shard, path, metadata={"format": "np"})
sf_load = load_file(path)
fs_sf_load = fast_load_file(path)

for k, v in self.weight_map.items():
paddle.allclose(v, enhanced_to_tensor(sf_load[k]))
paddle.allclose(v, enhanced_to_tensor(fs_sf_load[k]))

@skip_platform("win32", "cygwin")
def test_save_load_file(self):
with tempfile.TemporaryDirectory() as tmpdirname:
for dtype in ["bfloat16", "float8_e4m3fn", "float8_e5m2"]:
weight_map = self.get_target_dtype(dtype)
path = os.path.join(tmpdirname, "test.safetensors")
save_file(weight_map, path, metadata={"format": "np"})
sf_load = load_file(path)
fs_sf_load = fast_load_file(path)
for k, v in self.weight_map.items():
np.testing.assert_equal(v, sf_load[k])
np.testing.assert_equal(v, fs_sf_load[k])

@skip_platform("win32", "cygwin")
def test_dtype_safe_open(self):
with tempfile.TemporaryDirectory() as tmpdirname:
for dtype in ["float32", "int32", "bfloat16", "float8_e4m3fn", "float8_e5m2"]:
weight_map = self.get_target_dtype(dtype)
path = os.path.join(tmpdirname, "test.safetensors")
save_file(weight_map, path, metadata={"format": "np"})

with fast_safe_open(path, framework="np") as f:
for key in f.keys():
safe_slice = f.get_slice(key)
# np.testing.assert_equal(self.weight_map[key][2:1, ...], safe_slice[2:1, ...])
np.testing.assert_equal(weight_map[key][0, ...], safe_slice[0, ...])
np.testing.assert_equal(weight_map[key][0:1, ...], safe_slice[0:1, ...])
np.testing.assert_equal(weight_map[key][..., 2:], safe_slice[..., 2:])
np.testing.assert_equal(weight_map[key][..., 1], safe_slice[..., 1])
np.testing.assert_equal(weight_map[key][:2, ...], safe_slice[:2, ...])
np.testing.assert_equal(weight_map[key][..., :4], safe_slice[..., :4])


class FastSafetensors(unittest.TestCase):
def setUp(self):
super().setUp()
self.weigth_map = {}
self.weight_map = {}
tensors = [
([10, 1, 10], "float32"),
([1, 1, 10], "float32"),
Expand All @@ -38,34 +139,33 @@ def setUp(self):
]
count = 0
for shape, dtype in tensors:
self.weigth_map[f"weight_{count}"] = (np.random.random(shape) * 100).astype(dtype)
self.weight_map[f"weight_{count}"] = (np.random.random(shape) * 100).astype(dtype)
count += 1
print(self.weigth_map)

@skip_platform("win32", "cygwin")
def test_load_file(self):
with tempfile.TemporaryDirectory() as tmpdirname:
path = os.path.join(tmpdirname, "test.safetensors")
save_file(self.weigth_map, path, metadata={"format": "np"})
save_file(self.weight_map, path, metadata={"format": "np"})
sf_load = load_file(path)
fs_sf_load = fast_load_file(path)
for k, v in self.weigth_map.items():
for k, v in self.weight_map.items():
np.testing.assert_equal(v, sf_load[k])
np.testing.assert_equal(v, fs_sf_load[k])

@skip_platform("win32", "cygwin")
def test_safe_open(self):
with tempfile.TemporaryDirectory() as tmpdirname:
path = os.path.join(tmpdirname, "test.safetensors")
save_file(self.weigth_map, path, metadata={"format": "np"})
save_file(self.weight_map, path, metadata={"format": "np"})

with fast_safe_open(path, framework="np") as f:
for key in f.keys():
safe_slice = f.get_slice(key)
# np.testing.assert_equal(self.weigth_map[key][2:1, ...], safe_slice[2:1, ...])
np.testing.assert_equal(self.weigth_map[key][0, ...], safe_slice[0, ...])
np.testing.assert_equal(self.weigth_map[key][0:1, ...], safe_slice[0:1, ...])
np.testing.assert_equal(self.weigth_map[key][..., 2:], safe_slice[..., 2:])
np.testing.assert_equal(self.weigth_map[key][..., 1], safe_slice[..., 1])
np.testing.assert_equal(self.weigth_map[key][:2, ...], safe_slice[:2, ...])
np.testing.assert_equal(self.weigth_map[key][..., :4], safe_slice[..., :4])
# np.testing.assert_equal(self.weight_map[key][2:1, ...], safe_slice[2:1, ...])
np.testing.assert_equal(self.weight_map[key][0, ...], safe_slice[0, ...])
np.testing.assert_equal(self.weight_map[key][0:1, ...], safe_slice[0:1, ...])
np.testing.assert_equal(self.weight_map[key][..., 2:], safe_slice[..., 2:])
np.testing.assert_equal(self.weight_map[key][..., 1], safe_slice[..., 1])
np.testing.assert_equal(self.weight_map[key][:2, ...], safe_slice[:2, ...])
np.testing.assert_equal(self.weight_map[key][..., :4], safe_slice[..., :4])
Loading