@@ -21,15 +21,16 @@ struct Hamerly <: AbstractKMeansAlg end
21
21
function kmeans! (alg:: Hamerly , containers, X, k;
22
22
n_threads = Threads. nthreads (),
23
23
k_init = " k-means++" , max_iters = 300 ,
24
- tol = 1e-6 , verbose = false , init = nothing )
24
+ tol = eltype (X)( 1e-6 ) , verbose = false , init = nothing )
25
25
nrow, ncol = size (X)
26
26
centroids = init == nothing ? smart_init (X, k, n_threads, init= k_init). centroids : deepcopy (init)
27
27
28
28
@parallelize n_threads ncol chunk_initialize (alg, containers, centroids, X)
29
29
30
+ T = eltype (X)
30
31
converged = false
31
32
niters = 0
32
- J_previous = 0.0
33
+ J_previous = zero (T)
33
34
p = containers. p
34
35
35
36
# Update centroids & labels with closest members until convergence
@@ -70,35 +71,36 @@ function kmeans!(alg::Hamerly, containers, X, k;
70
71
# TODO empty placeholder vectors should be calculated
71
72
# TODO Float64 type definitions is too restrictive, should be relaxed
72
73
# especially during GPU related development
73
- return KmeansResult (centroids, containers. labels, Float64 [], Int[], Float64 [], totalcost, niters, converged)
74
+ return KmeansResult (centroids, containers. labels, T [], Int[], T [], totalcost, niters, converged)
74
75
end
75
76
76
- function create_containers (alg:: Hamerly , k, nrow, ncol, n_threads)
77
+ function create_containers (alg:: Hamerly , X, k, nrow, ncol, n_threads)
78
+ T = eltype (X)
77
79
lng = n_threads + 1
78
- centroids_new = Vector {Array{Float64,2 }} (undef, lng)
79
- centroids_cnt = Vector {Vector{Int }} (undef, lng)
80
+ centroids_new = Vector {Matrix{T }} (undef, lng)
81
+ centroids_cnt = Vector {Vector{T }} (undef, lng)
80
82
81
83
for i = 1 : lng
82
- centroids_new[i] = zeros (nrow, k)
83
- centroids_cnt[i] = zeros (k)
84
+ centroids_new[i] = zeros (T, nrow, k)
85
+ centroids_cnt[i] = zeros (T, k)
84
86
end
85
87
86
88
# Upper bound to the closest center
87
- ub = Vector {Float64 } (undef, ncol)
89
+ ub = Vector {T } (undef, ncol)
88
90
89
91
# lower bound to the second closest center
90
- lb = Vector {Float64 } (undef, ncol)
92
+ lb = Vector {T } (undef, ncol)
91
93
92
94
labels = zeros (Int, ncol)
93
95
94
96
# distance that centroid has moved
95
- p = Vector {Float64 } (undef, k)
97
+ p = Vector {T } (undef, k)
96
98
97
99
# distance from the center to the closest other center
98
- s = Vector {Float64 } (undef, k)
100
+ s = Vector {T } (undef, k)
99
101
100
102
# total_sum_calculation
101
- sum_of_squares = Vector {Float64 } (undef, n_threads)
103
+ sum_of_squares = Vector {T } (undef, n_threads)
102
104
103
105
return (
104
106
centroids_new = centroids_new,
@@ -118,12 +120,13 @@ end
118
120
Initial calulation of all bounds and points labeling.
119
121
"""
120
122
function chunk_initialize (alg:: Hamerly , containers, centroids, X, r, idx)
123
+ T = eltype (X)
121
124
centroids_cnt = containers. centroids_cnt[idx]
122
125
centroids_new = containers. centroids_new[idx]
123
126
124
127
@inbounds for i in r
125
128
label = point_all_centers! (containers, centroids, X, i)
126
- centroids_cnt[label] += 1
129
+ centroids_cnt[label] += one (T)
127
130
for j in axes (X, 1 )
128
131
centroids_new[j, label] += X[j, i]
129
132
end
@@ -136,12 +139,13 @@ end
136
139
Calculates minimum distances from centers to each other.
137
140
"""
138
141
function update_containers (:: Hamerly , containers, centroids, n_threads)
142
+ T = eltype (centroids)
139
143
s = containers. s
140
- s .= Inf
144
+ s .= T ( Inf )
141
145
@inbounds for i in axes (centroids, 2 )
142
146
for j in i+ 1 : size (centroids, 2 )
143
147
d = distance (centroids, centroids, i, j)
144
- d = 0.25 * d
148
+ d = T ( 0.25 ) * d
145
149
s[i] = s[i] > d ? d : s[i]
146
150
s[j] = s[j] > d ? d : s[j]
147
151
end
@@ -164,6 +168,7 @@ function chunk_update_centroids(alg::Hamerly, containers, centroids, X, r, idx)
164
168
s = containers. s
165
169
lb = containers. lb
166
170
ub = containers. ub
171
+ T = eltype (X)
167
172
168
173
@inbounds for i in r
169
174
# m ← max(s(a(i))/2, l(i))
@@ -178,8 +183,8 @@ function chunk_update_centroids(alg::Hamerly, containers, centroids, X, r, idx)
178
183
label_new = point_all_centers! (containers, centroids, X, i)
179
184
if label != label_new
180
185
labels[i] = label_new
181
- centroids_cnt[label_new] += 1
182
- centroids_cnt[label] -= 1
186
+ centroids_cnt[label_new] += one (T)
187
+ centroids_cnt[label] -= one (T)
183
188
for j in axes (X, 1 )
184
189
centroids_new[j, label_new] += X[j, i]
185
190
centroids_new[j, label] -= X[j, i]
@@ -199,9 +204,10 @@ function point_all_centers!(containers, centroids, X, i)
199
204
ub = containers. ub
200
205
lb = containers. lb
201
206
labels = containers. labels
207
+ T = eltype (X)
202
208
203
- min_distance = Inf
204
- min_distance2 = Inf
209
+ min_distance = T ( Inf )
210
+ min_distance2 = T ( Inf )
205
211
label = 1
206
212
@inbounds for k in axes (centroids, 2 )
207
213
dist = distance (X, centroids, i, k)
@@ -230,9 +236,10 @@ in `centroids` and `p` respectively.
230
236
function move_centers (:: Hamerly , containers, centroids)
231
237
centroids_new = containers. centroids_new[end ]
232
238
p = containers. p
239
+ T = eltype (centroids)
233
240
234
241
@inbounds for i in axes (centroids, 2 )
235
- d = 0.0
242
+ d = zero (T)
236
243
for j in axes (centroids, 1 )
237
244
d += (centroids[j, i] - centroids_new[j, i])^ 2
238
245
centroids[j, i] = centroids_new[j, i]
@@ -251,6 +258,7 @@ function chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
251
258
ub = containers. ub
252
259
lb = containers. lb
253
260
labels = containers. labels
261
+ T = eltype (containers. ub)
254
262
255
263
# Since bounds are squred distance, `sqrt` is used to make corresponding estimation, unlike
256
264
# the original paper, where usual metric is used.
@@ -270,11 +278,11 @@ function chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
270
278
# The same applies to the lower bounds.
271
279
@inbounds for i in r
272
280
label = labels[i]
273
- ub[i] += 2 * sqrt (abs (ub[i] * p[label])) + p[label]
281
+ ub[i] += T ( 2 ) * sqrt (abs (ub[i] * p[label])) + p[label]
274
282
if r1 == label
275
- lb[i] = lb[i] <= pr2 ? 0.0 : lb[i] + pr2 - 2 * sqrt (abs (pr2* lb[i]))
283
+ lb[i] = lb[i] <= pr2 ? zero (T) : lb[i] + pr2 - T ( 2 ) * sqrt (abs (pr2* lb[i]))
276
284
else
277
- lb[i] = lb[i] <= pr1 ? 0.0 : lb[i] + pr1 - 2 * sqrt (abs (pr1* lb[i]))
285
+ lb[i] = lb[i] <= pr1 ? zero (T) : lb[i] + pr1 - T ( 2 ) * sqrt (abs (pr1* lb[i]))
278
286
end
279
287
end
280
288
end
@@ -284,10 +292,10 @@ end
284
292
285
293
Finds maximum and next after maximum arguments.
286
294
"""
287
- function double_argmax (p)
295
+ function double_argmax (p:: AbstractVector{T} ) where T
288
296
r1, r2 = 1 , 1
289
297
d1 = p[1 ]
290
- d2 = - 1.0
298
+ d2 = T ( - Inf )
291
299
for i in 2 : length (p)
292
300
if p[i] > d1
293
301
r2 = r1
0 commit comments