-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathonlineaverage.cpp
79 lines (67 loc) · 1.54 KB
/
onlineaverage.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#include <cassert>
#include <algorithm>
//TODO: This requires 2x memory than it absolutely needs to...
class OnlineAverage {
public:
int dim;
double* posSum;
double* pos;
double lblSum = 0;
double lbl = 0;
size_t count = 0;
OnlineAverage(int dim_) : dim(dim_), posSum(new double[dim_]()), pos(new double[dim_]()) {}
~OnlineAverage() {
delete[] pos;
delete[] posSum;
}
void add(double* pos_, double lbl_) {
count++;
for (int i = 0; i < dim; i++) {
posSum[i] += pos_[i];
}
lblSum += lbl_;
updateAverage();
}
void add(OnlineAverage* o) {
count += o->count;
for (int i = 0; i < dim; i++) {
posSum[i] += o->posSum[i];
}
lblSum += o->lblSum;
updateAverage();
}
void remove(OnlineAverage* o) {
count -= o->count;
for (int i = 0; i < dim; i++) {
posSum[i] -= o->posSum[i];
}
lblSum -= o->lblSum;
updateAverage();
}
void remove(double* pos_, double lbl_) {
assert(count > 0);
count--;
for (int i = 0; i < dim; i++) {
posSum[i] -= pos_[i];
}
lblSum -= lbl_;
updateAverage();
}
void reset() {
count = lblSum = lbl = 0;
std::fill_n(pos, dim, 0);
std::fill_n(posSum, dim, 0);
}
double* position() {
return pos;
}
double label() {
return lbl;
}
void updateAverage() {
for (int i = 0; i < dim; i++) {
pos[i] = posSum[i]/count;
}
lbl = lblSum/count;
}
};