Skip to content

Commit 0c9915b

Browse files
committed
code reuse
1 parent 38bb077 commit 0c9915b

File tree

2 files changed

+87
-81
lines changed

2 files changed

+87
-81
lines changed

hyperengine/examples/4_1_word2vec_embedding.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,10 @@ def __init__(self, **params):
2020
self._num_skips = params.get('num_skips', 2)
2121
self._skip_window = params.get('skip_window', 1)
2222

23-
self._step = 0
24-
self._index = 0
25-
self._epochs_completed = 0
26-
self._just_completed = False
27-
2823
self._vocabulary = None
2924
self._dictionary = None
3025
self._data = None
3126

32-
@property
33-
def size(self):
34-
return len(self._data)
35-
3627
@property
3728
def vocabulary(self):
3829
return self._vocabulary
@@ -41,35 +32,14 @@ def vocabulary(self):
4132
def vocab_size(self):
4233
return self._vocab_size
4334

44-
@property
45-
def step(self):
46-
return self._step
47-
48-
@property
49-
def index(self):
50-
return self._index
51-
52-
@property
53-
def epochs_completed(self):
54-
return self._epochs_completed
55-
56-
@property
57-
def just_completed(self):
58-
return self._just_completed
59-
60-
def reset_counters(self):
61-
self._step = 0
62-
self._index = 0
63-
self._epochs_completed = 0
64-
self._just_completed = False
65-
6635
def build(self):
6736
hype.util.debug('Building the data provider')
6837
words = get_text8('temp-text8/data')
6938
self._vocabulary = [('UNK', None)] + Counter(words).most_common(self._vocab_size - 1)
7039
self._vocabulary = np.array([word for word, _ in self._vocabulary])
7140
self._dictionary = {word: code for code, word in enumerate(self._vocabulary)}
7241
self._data = np.array([self._dictionary.get(word, 0) for word in words])
42+
self._size = len(self._data)
7343

7444
if hype.util.is_debug_logged():
7545
hype.util.debug('Total words in text: %dM' % (len(words) / 1000000))
@@ -89,7 +59,7 @@ def _generate_batch(self, batch_size, num_skips, skip_window):
8959
span = 2 * skip_window + 1 # [ skip_window target skip_window ]
9060
buffer = deque(maxlen=span)
9161
for _ in range(span):
92-
buffer.append(self._data[self._index])
62+
buffer.append(self._data[self._index_in_epoch])
9363
self._inc_index()
9464
for i in range(batch_size // num_skips):
9565
target = skip_window # target label at the center of the buffer
@@ -100,20 +70,10 @@ def _generate_batch(self, batch_size, num_skips, skip_window):
10070
targets_to_avoid.append(target)
10171
batch[i * num_skips + j] = buffer[skip_window]
10272
labels[i * num_skips + j, 0] = buffer[target]
103-
buffer.append(self._data[self._index])
73+
buffer.append(self._data[self._index_in_epoch])
10474
self._inc_index()
10575
return batch, labels
10676

107-
def _inc_index(self):
108-
next = self._index + 1
109-
if next >= len(self._data):
110-
self._index = 0
111-
self._epochs_completed += 1
112-
self._just_completed = True
113-
else:
114-
self._index = next
115-
self._just_completed = False
116-
11777

11878
def word2vec_model(params):
11979
# Input data.

hyperengine/model/data_set.py

Lines changed: 84 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,76 +11,122 @@ class DataProvider(object):
1111

1212

1313
class IterableDataProvider(DataProvider):
14+
def __init__(self):
15+
super(IterableDataProvider, self).__init__()
16+
self._size = 0
17+
self._step = 0
18+
self._epochs_completed = 0
19+
self._index_in_epoch = 0
20+
self._just_completed = False
21+
1422
@property
1523
def size(self):
16-
raise NotImplementedError
24+
"""
25+
Data size (number of rows)
26+
"""
27+
return self._size
28+
29+
@property
30+
def step(self):
31+
"""
32+
The number of batches processed
33+
"""
34+
return self._step
1735

1836
@property
1937
def index(self):
20-
raise NotImplementedError
38+
"""
39+
Total index of input rows (over all epochs)
40+
"""
41+
return self._epochs_completed * self._size + self._index_in_epoch
42+
43+
@property
44+
def index_in_epoch(self):
45+
"""
46+
The index of input rows in a current epoch
47+
"""
48+
return self._index_in_epoch
2149

2250
@property
2351
def epochs_completed(self):
24-
raise NotImplementedError
52+
"""
53+
A number of completed epochs
54+
"""
55+
return self._epochs_completed
2556

2657
@property
2758
def just_completed(self):
28-
raise NotImplementedError
59+
"""
60+
Whether the previous epoch was just completed
61+
"""
62+
return self._just_completed
2963

3064
def reset_counters(self):
31-
raise NotImplementedError
65+
"""
66+
Resets all counters.
67+
"""
68+
self._step = 0
69+
self._epochs_completed = 0
70+
self._index_in_epoch = 0
71+
self._just_completed = False
3272

3373
def next_batch(self, batch_size):
74+
"""
75+
Returns the next `batch_size` examples from this data set.
76+
"""
3477
raise NotImplementedError
3578

79+
def _inc_index(self):
80+
index = self._index_in_epoch + 1
81+
if index >= self._size:
82+
self._index_in_epoch = 0
83+
self._epochs_completed += 1
84+
self._just_completed = True
85+
else:
86+
self._index_in_epoch = index
87+
self._just_completed = False
88+
3689

37-
class DataSet(object):
90+
class DataSet(IterableDataProvider):
3891
"""
39-
A labeled data set. Both examples and labels are stored as numpy arrays.
92+
A labeled data set. Both inputs and labels are stored as numpy arrays in memory.
4093
"""
4194

4295
def __init__(self, x, y):
96+
super(DataSet, self).__init__()
97+
4398
x = np.array(x)
4499
y = np.array(y)
45100
assert x.shape[0] == y.shape[0]
46101

47-
self.size = x.shape[0]
48-
self.x = x
49-
self.y = y
50-
self.step = 0
51-
self.epochs_completed = 0
52-
self.index_in_epoch = 0
53-
self.just_completed = False
102+
self._size = x.shape[0]
103+
self._x = x
104+
self._y = y
54105

55106
@property
56-
def index(self):
57-
return self.epochs_completed * self.size + self.index_in_epoch
107+
def x(self):
108+
return self._x
58109

59-
def reset_counters(self):
60-
self.step = 0
61-
self.epochs_completed = 0
62-
self.index_in_epoch = 0
63-
self.just_completed = False
110+
@property
111+
def y(self):
112+
return self._y
64113

65114
def next_batch(self, batch_size):
66-
"""
67-
Return the next `batch_size` examples from this data set.
68-
"""
69-
if self.just_completed:
70-
permutation = np.arange(self.size)
115+
if self._just_completed:
116+
permutation = np.arange(self._size)
71117
np.random.shuffle(permutation)
72-
self.x = self.x[permutation]
73-
self.y = self.y[permutation]
74-
75-
self.step += 1
76-
start = self.index_in_epoch
77-
self.index_in_epoch += batch_size
78-
end = min(self.index_in_epoch, self.size)
79-
if self.index_in_epoch >= self.size:
80-
self.index_in_epoch = 0
81-
self.just_completed = end == self.size
82-
self.epochs_completed += int(self.just_completed)
83-
return self.x[start:end], self.y[start:end]
118+
self._x = self._x[permutation]
119+
self._y = self._y[permutation]
120+
121+
self._step += 1
122+
start = self._index_in_epoch
123+
self._index_in_epoch += batch_size
124+
end = min(self._index_in_epoch, self._size)
125+
if self._index_in_epoch >= self._size:
126+
self._index_in_epoch = 0
127+
self._just_completed = end == self._size
128+
self._epochs_completed += int(self._just_completed)
129+
return self._x[start:end], self._y[start:end]
84130

85131

86132
def merge_data_sets(ds1, ds2):

0 commit comments

Comments
 (0)