Skip to content

Commit 528dce7

Browse files
authored
Implement python deserialize for flat_tensor (#11779)
### Summary Implement the deserialize function in the Python FlatTensorSerializer class. This allows for loading flatbuffer data from a Python environment. ### Test plan I've added an additional test in `extension/flat_tensor/test/test_serialize.py` to cover the deserialize path.
1 parent b677429 commit 528dce7

File tree

2 files changed

+94
-4
lines changed

2 files changed

+94
-4
lines changed

extension/flat_tensor/serialize/serialize.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919

2020
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
2121
from executorch.exir._serialize._program import _insert_flatbuffer_header
22-
from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer
22+
from executorch.exir._serialize.data_serializer import (
23+
DataEntry,
24+
DataPayload,
25+
DataSerializer,
26+
)
2327

2428
from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
2529

@@ -34,6 +38,9 @@
3438
# endian.
3539
_HEADER_BYTEORDER: Literal["little"] = "little"
3640

41+
# Current version. Keep in sync with c++ version number in serialize.
42+
_FLAT_TENSOR_VERSION: int = 0
43+
3744

3845
def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
3946
"""Serializes a FlatTensor to a flatbuffer and returns the serialized data."""
@@ -320,7 +327,7 @@ def serialize(
320327
# Create FlatTensor, which describes of the contents of the file and
321328
# points to all the data segments. It will be serialized to flatbuffer.
322329
flat_tensor = FlatTensor(
323-
version=0, # Keep in sync with c++ version number in serialize.h
330+
version=_FLAT_TENSOR_VERSION,
324331
segments=data_segments,
325332
named_data=named_data,
326333
)
@@ -383,4 +390,49 @@ def deserialize(self, blob: Cord) -> DataPayload:
383390
"""
384391
Deserializes a flat_tensor blob into a list of tensor metadata and tensors.
385392
"""
386-
raise NotImplementedError("deserialize_data")
393+
394+
data = bytes(blob)
395+
396+
# Read header. Verify that it's valid.
397+
header = FlatTensorHeader.from_bytes(data[8:])
398+
if not header.is_valid():
399+
raise RuntimeError(
400+
"Flat tensor header is invalid. File is likely incorrect format or corrupt."
401+
)
402+
403+
# Deserialize the flat tensor data, which contains the data offsets and tensor metadata.
404+
flat_tensor_bytes = data[0 : header.flatbuffer_offset + header.flatbuffer_size]
405+
flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes)
406+
407+
# Verify that this is a supported version.
408+
if flat_tensor.version != _FLAT_TENSOR_VERSION:
409+
raise NotImplementedError(
410+
f"Flat tensor files reports unsupported version {flat_tensor.version}. Expected {_FLAT_TENSOR_VERSION}."
411+
)
412+
413+
# Extract the buffers.
414+
buffers = [
415+
data[
416+
header.segment_base_offset
417+
+ segment.offset : header.segment_base_offset
418+
+ segment.offset
419+
+ segment.size
420+
]
421+
for segment in flat_tensor.segments
422+
]
423+
424+
payload = DataPayload(
425+
buffers=buffers,
426+
named_data={},
427+
)
428+
429+
# Read the named data entries.
430+
for named_data in flat_tensor.named_data:
431+
entry = DataEntry(
432+
buffer_index=named_data.segment_index,
433+
alignment=1,
434+
tensor_layout=named_data.tensor_layout,
435+
)
436+
payload.named_data[named_data.key] = entry
437+
438+
return payload

extension/flat_tensor/test/test_serialize.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,19 @@
66

77
# pyre-unsafe
88

9+
import dataclasses
910
import math
1011
import unittest
1112

1213
from typing import List, Optional
1314

15+
from executorch.exir._serialize._cord import Cord
16+
1417
from executorch.exir._serialize.data_serializer import (
1518
DataEntry,
1619
DataPayload,
1720
DataSerializer,
1821
)
19-
2022
from executorch.exir._serialize.padding import aligned_size
2123

2224
from executorch.exir.schema import ScalarType
@@ -223,3 +225,39 @@ def test_serialize(self) -> None:
223225
)
224226

225227
self.assertEqual(segments[2].offset + segments[2].size, len(segment_data))
228+
229+
def test_round_trip(self) -> None:
230+
# Serialize and then deserialize the test payload. Make sure it's reconstructed
231+
# properly.
232+
config = FlatTensorConfig()
233+
serializer: DataSerializer = FlatTensorSerializer(config)
234+
235+
# Round trip the data.
236+
serialized_data = bytes(serializer.serialize(TEST_DATA_PAYLOAD))
237+
deserialized_payload = serializer.deserialize(Cord(serialized_data))
238+
239+
# Validate the deserialized payload. Since alignment isn't serialized, we need to
240+
# do this somewhat manually.
241+
for i in range(len(deserialized_payload.buffers)):
242+
self.assertEqual(
243+
TEST_DATA_PAYLOAD.buffers[i],
244+
deserialized_payload.buffers[i],
245+
f"Buffer at index {i} does not match.",
246+
)
247+
248+
self.assertEqual(
249+
TEST_DATA_PAYLOAD.named_data.keys(), deserialized_payload.named_data.keys()
250+
)
251+
252+
SKIP_FIELDS = {"alignment"} # Fields to ignore in comparison.
253+
for key in TEST_DATA_PAYLOAD.named_data.keys():
254+
reference = TEST_DATA_PAYLOAD.named_data[key]
255+
actual = deserialized_payload.named_data[key]
256+
257+
for field in dataclasses.fields(reference):
258+
if field.name not in SKIP_FIELDS:
259+
self.assertEqual(
260+
getattr(reference, field.name),
261+
getattr(actual, field.name),
262+
f"Named data record {key}.{field.name} does not match.",
263+
)

0 commit comments

Comments
 (0)