-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathsample.py
More file actions
67 lines (46 loc) · 1.71 KB
/
sample.py
File metadata and controls
67 lines (46 loc) · 1.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
def sample_from_generator(elements, probabilities_li, to_sample, niche_only = False):
sampled_li_bin = np.zeros([len(elements)], dtype = float)
probabilities_li = np.asarray(probabilities_li)
while True:
try:
if niche_only == True:
try:
probabilities_li[OTHER_TAGS] = 0.0
if probabilities_li.sum() != 0.0:
probabilities_li = probabilities_li/(1.0*probabilities_li.sum())
else:
probabilities_li = [(1.0/len(elements))]*len(elements)
probabilities_li = np.asarray(probabilities_li)
except Exception as e:
print('Error:', str(e))
sampled_li = np.random.choice(elements, to_sample, p = probabilities_li, replace = False)
break
except:
# print('Error Sampling: Reducing to_sample')
to_sample -= 1
if to_sample == 0:
break
# for idx in range(np.shape(probabilities_li)[0]):
sampled_li_bin[sampled_li] = 1
return np.asarray(sampled_li_bin), np.asarray(sampled_li)
def sample_from_generator_new(elements, probabilities_li, to_sample, num_elements):
sampled_li_bin = np.zeros([num_elements], dtype = float)
probabilities_li = np.asarray(probabilities_li)
if probabilities_li.sum() != 0.0:
probabilities_li = probabilities_li/(1.0*probabilities_li.sum())
else:
probabilities_li = [(1.0/num_elements)]*num_elements
probabilities_li = np.asarray(probabilities_li)
while True:
try:
sampled_li = np.random.choice(elements, to_sample, p = probabilities_li, replace = False)
break
except:
# print('Error Sampling: Reducing to_sample')
to_sample -= 1
if to_sample == 0:
break
# for idx in range(np.shape(probabilities_li)[0]):
sampled_li_bin[sampled_li] = 1
return np.asarray(sampled_li_bin), np.asarray(sampled_li)