@@ -11,76 +11,122 @@ class DataProvider(object):
11
11
12
12
13
13
class IterableDataProvider (DataProvider ):
14
+ def __init__ (self ):
15
+ super (IterableDataProvider , self ).__init__ ()
16
+ self ._size = 0
17
+ self ._step = 0
18
+ self ._epochs_completed = 0
19
+ self ._index_in_epoch = 0
20
+ self ._just_completed = False
21
+
14
22
@property
15
23
def size (self ):
16
- raise NotImplementedError
24
+ """
25
+ Data size (number of rows)
26
+ """
27
+ return self ._size
28
+
29
+ @property
30
+ def step (self ):
31
+ """
32
+ The number of batches processed
33
+ """
34
+ return self ._step
17
35
18
36
@property
19
37
def index (self ):
20
- raise NotImplementedError
38
+ """
39
+ Total index of input rows (over all epochs)
40
+ """
41
+ return self ._epochs_completed * self ._size + self ._index_in_epoch
42
+
43
+ @property
44
+ def index_in_epoch (self ):
45
+ """
46
+ The index of input rows in a current epoch
47
+ """
48
+ return self ._index_in_epoch
21
49
22
50
@property
23
51
def epochs_completed (self ):
24
- raise NotImplementedError
52
+ """
53
+ A number of completed epochs
54
+ """
55
+ return self ._epochs_completed
25
56
26
57
@property
27
58
def just_completed (self ):
28
- raise NotImplementedError
59
+ """
60
+ Whether the previous epoch was just completed
61
+ """
62
+ return self ._just_completed
29
63
30
64
def reset_counters (self ):
31
- raise NotImplementedError
65
+ """
66
+ Resets all counters.
67
+ """
68
+ self ._step = 0
69
+ self ._epochs_completed = 0
70
+ self ._index_in_epoch = 0
71
+ self ._just_completed = False
32
72
33
73
def next_batch (self , batch_size ):
74
+ """
75
+ Returns the next `batch_size` examples from this data set.
76
+ """
34
77
raise NotImplementedError
35
78
79
+ def _inc_index (self ):
80
+ index = self ._index_in_epoch + 1
81
+ if index >= self ._size :
82
+ self ._index_in_epoch = 0
83
+ self ._epochs_completed += 1
84
+ self ._just_completed = True
85
+ else :
86
+ self ._index_in_epoch = index
87
+ self ._just_completed = False
88
+
36
89
37
- class DataSet (object ):
90
+ class DataSet (IterableDataProvider ):
38
91
"""
39
- A labeled data set. Both examples and labels are stored as numpy arrays.
92
+ A labeled data set. Both inputs and labels are stored as numpy arrays in memory .
40
93
"""
41
94
42
95
def __init__ (self , x , y ):
96
+ super (DataSet , self ).__init__ ()
97
+
43
98
x = np .array (x )
44
99
y = np .array (y )
45
100
assert x .shape [0 ] == y .shape [0 ]
46
101
47
- self .size = x .shape [0 ]
48
- self .x = x
49
- self .y = y
50
- self .step = 0
51
- self .epochs_completed = 0
52
- self .index_in_epoch = 0
53
- self .just_completed = False
102
+ self ._size = x .shape [0 ]
103
+ self ._x = x
104
+ self ._y = y
54
105
55
106
@property
56
- def index (self ):
57
- return self .epochs_completed * self . size + self . index_in_epoch
107
+ def x (self ):
108
+ return self ._x
58
109
59
- def reset_counters (self ):
60
- self .step = 0
61
- self .epochs_completed = 0
62
- self .index_in_epoch = 0
63
- self .just_completed = False
110
+ @property
111
+ def y (self ):
112
+ return self ._y
64
113
65
114
def next_batch (self , batch_size ):
66
- """
67
- Return the next `batch_size` examples from this data set.
68
- """
69
- if self .just_completed :
70
- permutation = np .arange (self .size )
115
+ if self ._just_completed :
116
+ permutation = np .arange (self ._size )
71
117
np .random .shuffle (permutation )
72
- self .x = self .x [permutation ]
73
- self .y = self .y [permutation ]
74
-
75
- self .step += 1
76
- start = self .index_in_epoch
77
- self .index_in_epoch += batch_size
78
- end = min (self .index_in_epoch , self .size )
79
- if self .index_in_epoch >= self .size :
80
- self .index_in_epoch = 0
81
- self .just_completed = end == self .size
82
- self .epochs_completed += int (self .just_completed )
83
- return self .x [start :end ], self .y [start :end ]
118
+ self ._x = self ._x [permutation ]
119
+ self ._y = self ._y [permutation ]
120
+
121
+ self ._step += 1
122
+ start = self ._index_in_epoch
123
+ self ._index_in_epoch += batch_size
124
+ end = min (self ._index_in_epoch , self ._size )
125
+ if self ._index_in_epoch >= self ._size :
126
+ self ._index_in_epoch = 0
127
+ self ._just_completed = end == self ._size
128
+ self ._epochs_completed += int (self ._just_completed )
129
+ return self ._x [start :end ], self ._y [start :end ]
84
130
85
131
86
132
def merge_data_sets (ds1 , ds2 ):
0 commit comments