Skip to content

Commit cf1f594

Browse files
ezyangfacebook-github-bot
authored andcommitted
Hacky support for meta tensor serialization. (pytorch#62192)
Summary: Pull Request resolved: pytorch#62192 This support is hacky because it doesn't preserve meta tensor storage sharing (e.g., if you serialize a model with shared storage, e.g., a tensor and a view on a tensor, when I deserialize the viewing relationship will be broken and these are just different tensors.) The hack is also durable, in the sense that we will be on the hook for supporting `_rebuild_meta_tensor_no_storage` in perpetuity in the future, even if we change our mind about the serialization format. This unblocks an FB production use case. I didn't add C++ support to minimize blast area of this patch. Signed-off-by: Edward Z. Yang <[email protected]> Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D29910535 Pulled By: ezyang fbshipit-source-id: d98dcdd0108dfc3ae730a071d3c583b6d0281d21
1 parent f0140a8 commit cf1f594

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

test/test_serialization.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,16 @@ def test_pathlike_serialization(self):
741741
torch.save(model, path)
742742
torch.load(path)
743743

744+
def test_meta_serialization(self):
745+
big_model = torch.nn.Conv2d(20000, 320000, kernel_size=3, device='meta')
746+
747+
with BytesIOContext() as f:
748+
torch.save(big_model, f)
749+
f.seek(0)
750+
state = torch.load(f)
751+
752+
self.assertEqual(state.weight.size(), big_model.weight.size())
753+
744754
def run(self, *args, **kwargs):
745755
with serialization_method(use_zip=True):
746756
return super(TestSerialization, self).run(*args, **kwargs)

torch/_tensor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,16 @@ def _reduce_ex_internal(self, proto):
136136
str(self.device),
137137
self.requires_grad)
138138
return (torch._utils._rebuild_mlc_tensor, arg_mlc)
139+
if self.device.type == 'meta':
140+
# NB: This implementation BREAKS storage sharing. Current
141+
# hypothesis is that no one cares for meta tensors.
142+
arg_meta = (
143+
self.dtype,
144+
tuple(self.size()),
145+
self.stride(),
146+
self.requires_grad,
147+
)
148+
return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta)
139149
if self.is_quantized:
140150
# quantizer_params can be different type based on torch attribute
141151
quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]]

torch/_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ def _rebuild_mlc_tensor(data, dtype, device, requires_grad):
185185
return tensor
186186

187187

188+
def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
189+
return torch.empty_strided(size, stride, dtype=dtype, device='meta', requires_grad=requires_grad)
190+
191+
188192
def _rebuild_qtensor(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks):
189193
qscheme = quantizer_params[0]
190194
if qscheme == torch.per_tensor_affine:

0 commit comments

Comments
 (0)