|
19 | 19 |
|
20 | 20 | from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
|
21 | 21 | from executorch.exir._serialize._program import _insert_flatbuffer_header
|
22 |
| -from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer |
| 22 | +from executorch.exir._serialize.data_serializer import ( |
| 23 | + DataEntry, |
| 24 | + DataPayload, |
| 25 | + DataSerializer, |
| 26 | +) |
23 | 27 |
|
24 | 28 | from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required
|
25 | 29 |
|
|
34 | 38 | # endian.
|
35 | 39 | _HEADER_BYTEORDER: Literal["little"] = "little"
|
36 | 40 |
|
| 41 | +# Current version. Keep in sync with c++ version number in serialize. |
| 42 | +_FLAT_TENSOR_VERSION: int = 0 |
| 43 | + |
37 | 44 |
|
38 | 45 | def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
|
39 | 46 | """Serializes a FlatTensor to a flatbuffer and returns the serialized data."""
|
@@ -320,7 +327,7 @@ def serialize(
|
320 | 327 | # Create FlatTensor, which describes of the contents of the file and
|
321 | 328 | # points to all the data segments. It will be serialized to flatbuffer.
|
322 | 329 | flat_tensor = FlatTensor(
|
323 |
| - version=0, # Keep in sync with c++ version number in serialize.h |
| 330 | + version=_FLAT_TENSOR_VERSION, |
324 | 331 | segments=data_segments,
|
325 | 332 | named_data=named_data,
|
326 | 333 | )
|
@@ -383,4 +390,49 @@ def deserialize(self, blob: Cord) -> DataPayload:
|
383 | 390 | """
|
384 | 391 | Deserializes a flat_tensor blob into a list of tensor metadata and tensors.
|
385 | 392 | """
|
386 |
| - raise NotImplementedError("deserialize_data") |
| 393 | + |
| 394 | + data = bytes(blob) |
| 395 | + |
| 396 | + # Read header. Verify that it's valid. |
| 397 | + header = FlatTensorHeader.from_bytes(data[8:]) |
| 398 | + if not header.is_valid(): |
| 399 | + raise RuntimeError( |
| 400 | + "Flat tensor header is invalid. File is likely incorrect format or corrupt." |
| 401 | + ) |
| 402 | + |
| 403 | + # Deserialize the flat tensor data, which contains the data offsets and tensor metadata. |
| 404 | + flat_tensor_bytes = data[0 : header.flatbuffer_offset + header.flatbuffer_size] |
| 405 | + flat_tensor = _deserialize_to_flat_tensor(flat_tensor_bytes) |
| 406 | + |
| 407 | + # Verify that this is a supported version. |
| 408 | + if flat_tensor.version != _FLAT_TENSOR_VERSION: |
| 409 | + raise NotImplementedError( |
| 410 | + f"Flat tensor files reports unsupported version {flat_tensor.version}. Expected {_FLAT_TENSOR_VERSION}." |
| 411 | + ) |
| 412 | + |
| 413 | + # Extract the buffers. |
| 414 | + buffers = [ |
| 415 | + data[ |
| 416 | + header.segment_base_offset |
| 417 | + + segment.offset : header.segment_base_offset |
| 418 | + + segment.offset |
| 419 | + + segment.size |
| 420 | + ] |
| 421 | + for segment in flat_tensor.segments |
| 422 | + ] |
| 423 | + |
| 424 | + payload = DataPayload( |
| 425 | + buffers=buffers, |
| 426 | + named_data={}, |
| 427 | + ) |
| 428 | + |
| 429 | + # Read the named data entries. |
| 430 | + for named_data in flat_tensor.named_data: |
| 431 | + entry = DataEntry( |
| 432 | + buffer_index=named_data.segment_index, |
| 433 | + alignment=1, |
| 434 | + tensor_layout=named_data.tensor_layout, |
| 435 | + ) |
| 436 | + payload.named_data[named_data.key] = entry |
| 437 | + |
| 438 | + return payload |
0 commit comments