-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
95 lines (78 loc) · 2.62 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import logging
import torch
import numpy as np
def gather_metrics(metrics):
_metrics = sorted(metrics, key=lambda x: x[0])
metrics = [x[1] for x in _metrics]
assert isinstance(metrics, list)
res = {k: [] for k in metrics[0].keys()}
for m in metrics:
to_del = []
for k in res:
if k not in m:
to_del.append(k)
for k in to_del:
del res[k]
for m in metrics:
for k in res.keys():
res[k].append(m[k])
res_stats = {}
for k, v in res.items():
res_stats[k] = {"mean": np.mean(v), "std": np.std(v)}
logging.info("Raw metrics: {}".format(res))
return res_stats, metrics
def set_seed(seed, get_state=False, set_torch=True, set_numpy=True):
if get_state:
if set_torch:
torch_state = torch.get_rng_state()
else:
torch_state = None
if set_numpy:
numpy_state = np.random.get_state()
else:
numpy_state = None
else:
torch_state, numpy_state = None, None
if set_torch:
torch.manual_seed(seed)
if set_numpy:
np.random.seed(seed)
if get_state:
return (torch_state, numpy_state)
def restore_state(states, set_torch=True, set_numpy=True):
if set_torch:
torch.set_rng_state(states[0])
if set_numpy:
np.random.set_state(states[1])
class seed_scope:
def __init__(self, seed, set_torch=True, set_numpy=True):
self.seed = seed
self.set_torch = set_torch
self.set_numpy = set_numpy
def __enter__(self):
self.prev_state = set_seed(self.seed, get_state=True,
set_numpy=self.set_numpy,
set_torch=self.set_torch)
def __exit__(self, type, value, traceback):
restore_state(self.prev_state,
set_torch=self.set_torch,
set_numpy=self.set_numpy)
class RandomSplitter:
def __init__(self, splits, num, seed=None):
assert np.isclose(np.sum(splits), 1)
self.num = num
idx = list(range(num))
with seed_scope(seed=seed, set_torch=False):
np.random.shuffle(idx)
cnt = 0
self.idx_splits = []
for s in splits[:-1]:
length = int(s * num)
self.idx_splits.append(idx[cnt: cnt+length])
cnt += length
self.idx_splits.append(idx[cnt:])
def split(self, x, split_id=None):
if split_id is not None:
return x[self.idx_splits[split_id]]
else:
return [x[_idx] for _idx in self.idx_splits]