|
6 | 6 | import tempfile
|
7 | 7 | import shutil
|
8 | 8 | import tarfile
|
| 9 | +import zipfile |
9 | 10 | import collections
|
10 | 11 |
|
11 | 12 | import numpy as np
|
12 | 13 | import tensorflow
|
13 |
| -from tensorflow.python import pywrap_tensorflow |
| 14 | + |
| 15 | +try: |
| 16 | + from tensorflow.python.pywrap_tensorflow import NewCheckpointReader |
| 17 | +except: |
| 18 | + from tensorflow.compat.v1.train import NewCheckpointReader |
| 19 | + |
14 | 20 | from tensorflow.python.framework import meta_graph
|
15 | 21 | from tensorflow.core.framework import types_pb2
|
16 | 22 |
|
|
35 | 41 | tarfile.open(args.checkpoint).extractall(checkpoint_dir)
|
36 | 42 | files = [os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir)]
|
37 | 43 | checkpoint_dir = files[0] if os.path.isdir(files[0]) else checkpoint_dir
|
| 44 | +elif args.checkpoint.endswith('.zip'): |
| 45 | + checkpoint_dir = args.tmp |
| 46 | + zipfile.ZipFile(args.checkpoint).extractall(checkpoint_dir) |
| 47 | + files = [os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir)] |
| 48 | + checkpoint_dir = files[0] if os.path.isdir(files[0]) else checkpoint_dir |
38 | 49 | else:
|
39 | 50 | checkpoint_dir = args.checkpoint
|
40 | 51 |
|
41 |
| -reader = pywrap_tensorflow.NewCheckpointReader(tensorflow.train.latest_checkpoint(checkpoint_dir)) |
| 52 | +reader = NewCheckpointReader(tensorflow.train.latest_checkpoint(checkpoint_dir)) |
42 | 53 | blobs = {k : reader.get_tensor(k) for k in reader.get_variable_to_shape_map()}
|
43 | 54 |
|
44 | 55 | if args.output_path.endswith('.json'):
|
|
52 | 63 | (np.savez if args.output_path[-1] == 'z' else numpy.save)(args.output_path, **blobs)
|
53 | 64 | elif args.output_path.endswith('.pt'):
|
54 | 65 | import torch
|
55 |
| - torch.save({k : torch.from_numpy(blob) for k, blob in blobs.items()}, args.output_path) |
| 66 | + torch.save({k : (torch.as_tensor(blob) if not np.isscalar(blob) else torch.tensor(blob) ) if isinstance(blob, np.ndarray) else blob for k, blob in blobs.items()}, args.output_path) |
56 | 67 |
|
57 | 68 | if args.onnx or args.tensorboard or args.graph:
|
58 | 69 | meta_graph_file = glob.glob(os.path.join(checkpoint_dir, '*.meta'))[0]
|
|
0 commit comments