Skip to content

Commit d30c609

Browse files
authored
utils: safetensors: dont slice data on torch level (Comfy-Org#12266)
Torch has alignment enforcement when viewing with data type changes but only relative to itself. Do all tensor constructions straight off the memory-view individually so pytorch doesnt see an alignment problem. The is needed for handling misaligned safetensors weights, which are reasonably common in third party models. This limits usage of this safetensors loader to GPU compute only as CPUs kernnel are very likely to bus error. But it works for dynamic_vram, where we really dont want to take a deep copy and we always use GPU copy_ which disentangles the misalignment.
1 parent 5087f1d commit d30c609

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

comfy/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,22 +82,26 @@ def scalar(*args, **kwargs):
8282
def load_safetensors(ckpt):
8383
f = open(ckpt, "rb")
8484
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
85+
mv = memoryview(mapping)
8586

8687
header_size = struct.unpack("<Q", mapping[:8])[0]
8788
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
8889

89-
with warnings.catch_warnings():
90-
#We are working with read-only RAM by design
91-
warnings.filterwarnings("ignore", message="The given buffer is not writable")
92-
data_area = torch.frombuffer(mapping, dtype=torch.uint8)[8 + header_size:]
90+
mv = mv[8 + header_size:]
9391

9492
sd = {}
9593
for name, info in header.items():
9694
if name == "__metadata__":
9795
continue
9896

9997
start, end = info["data_offsets"]
100-
sd[name] = data_area[start:end].view(_TYPES[info["dtype"]]).view(info["shape"])
98+
if start == end:
99+
sd[name] = torch.empty(info["shape"], dtype =_TYPES[info["dtype"]])
100+
else:
101+
with warnings.catch_warnings():
102+
#We are working with read-only RAM by design
103+
warnings.filterwarnings("ignore", message="The given buffer is not writable")
104+
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
101105

102106
return sd, header.get("__metadata__", {}),
103107

0 commit comments

Comments
 (0)