Skip to content

Commit 22d8047

Browse files
committed
...
1 parent eb517b0 commit 22d8047

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,7 @@ python3 tfcheckpoint2pytorch.py --checkpoint w2l_plus_large_mp.tar.gz --onnx w2l
6464
--ignoreattr Toutput_types --ignoreattr output_shapes --ignoreattr output_types --ignoreattr predicate --ignoreattr f --ignoreattr dtypes \
6565
--input_name 'IteratorGetNext:0' --input_shape -1 -1 64 --input_dtype half \
6666
--output_name 'ForwardPass/fully_connected_ctc_decoder/logits:0'
67-
67+
68+
# download slot-attention_object_discovery.pt from https://console.cloud.google.com/storage/browser/gresearch/slot-attention
69+
python3 tfcheckpoint2pytorch.py --checkpoint slot-attention_object_discovery.zip -o slot-attention_object_discovery.pt
6870
```

tfcheckpoint2pytorch.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@
66
import tempfile
77
import shutil
88
import tarfile
9+
import zipfile
910
import collections
1011

1112
import numpy as np
1213
import tensorflow
13-
from tensorflow.python import pywrap_tensorflow
14+
15+
try:
16+
from tensorflow.python.pywrap_tensorflow import NewCheckpointReader
17+
except:
18+
from tensorflow.compat.v1.train import NewCheckpointReader
19+
1420
from tensorflow.python.framework import meta_graph
1521
from tensorflow.core.framework import types_pb2
1622

@@ -35,10 +41,15 @@
3541
tarfile.open(args.checkpoint).extractall(checkpoint_dir)
3642
files = [os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir)]
3743
checkpoint_dir = files[0] if os.path.isdir(files[0]) else checkpoint_dir
44+
elif args.checkpoint.endswith('.zip'):
45+
checkpoint_dir = args.tmp
46+
zipfile.ZipFile(args.checkpoint).extractall(checkpoint_dir)
47+
files = [os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir)]
48+
checkpoint_dir = files[0] if os.path.isdir(files[0]) else checkpoint_dir
3849
else:
3950
checkpoint_dir = args.checkpoint
4051

41-
reader = pywrap_tensorflow.NewCheckpointReader(tensorflow.train.latest_checkpoint(checkpoint_dir))
52+
reader = NewCheckpointReader(tensorflow.train.latest_checkpoint(checkpoint_dir))
4253
blobs = {k : reader.get_tensor(k) for k in reader.get_variable_to_shape_map()}
4354

4455
if args.output_path.endswith('.json'):
@@ -52,7 +63,7 @@
5263
(np.savez if args.output_path[-1] == 'z' else numpy.save)(args.output_path, **blobs)
5364
elif args.output_path.endswith('.pt'):
5465
import torch
55-
torch.save({k : torch.from_numpy(blob) for k, blob in blobs.items()}, args.output_path)
66+
torch.save({k : (torch.as_tensor(blob) if not np.isscalar(blob) else torch.tensor(blob) ) if isinstance(blob, np.ndarray) else blob for k, blob in blobs.items()}, args.output_path)
5667

5768
if args.onnx or args.tensorboard or args.graph:
5869
meta_graph_file = glob.glob(os.path.join(checkpoint_dir, '*.meta'))[0]

0 commit comments

Comments
 (0)