Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c2c7bd8

Browse files
authoredApr 13, 2020
Merge pull request #46 from Arkoniak/elkan_algorithm
Full Elkan implementation and refactoring of MLJ Interface
2 parents a44f676 + 9f8f5b4 commit c2c7bd8

14 files changed

+514
-443
lines changed
 

‎Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ParallelKMeans"
22
uuid = "42b8e9d4-006b-409a-8472-7f34b3fb58af"
33
authors = ["Bernard Brenyah", "Andrey Oskin"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

‎docs/src/index.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ParallelKMeans.jl Package
1+
# [ParallelKMeans.jl Package](https://github.com/PyDataBlog/ParallelKMeans.jl)
22

33
```@contents
44
Depth = 4
@@ -59,7 +59,7 @@ git checkout experimental
5959

6060
- [X] Implementation of [Hamerly implementation](https://www.researchgate.net/publication/220906984_Making_k-means_Even_Faster).
6161
- [X] Interface for inclusion in Alan Turing Institute's [MLJModels](https://github.com/alan-turing-institute/MLJModels.jl#who-is-this-repo-for).
62-
- [ ] Full Implementation of Triangle inequality based on [Elkan - 2003 Using the Triangle Inequality to Accelerate K-Means"](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf).
62+
- [X] Full Implementation of Triangle inequality based on [Elkan - 2003 Using the Triangle Inequality to Accelerate K-Means"](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf).
6363
- [ ] Implementation of [Geometric methods to accelerate k-means algorithm](http://cs.baylor.edu/~hamerly/papers/sdm2016_rysavy_hamerly.pdf).
6464
- [ ] Native support for tabular data inputs outside of MLJModels' interface.
6565
- [ ] Refactoring and finalizaiton of API desgin.
@@ -177,6 +177,7 @@ ________________________________________________________________________________
177177

178178
- 0.1.0 Initial release
179179
- 0.1.1 Added interface for MLJ
180+
- 0.1.2 Added Elkan algorithm
180181

181182
## Contributing
182183

‎src/ParallelKMeans.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
module ParallelKMeans
22

33
using StatsBase
4-
using MLJModelInterface
4+
import MLJModelInterface
55
import Base.Threads: @spawn
66
import Distances
77

8+
const MMI = MLJModelInterface
9+
810
include("seeding.jl")
911
include("kmeans.jl")
1012
include("lloyd.jl")
11-
include("light_elkan.jl")
1213
include("hamerly.jl")
14+
include("elkan.jl")
1315
include("mlj_interface.jl")
1416

1517
export kmeans
16-
export Lloyd, LightElkan, Hamerly
18+
export Lloyd, Hamerly, Elkan
1719

1820
end # module

‎src/elkan.jl

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
"""
2+
Elkan()
3+
4+
Elkan algorithm implementation, based on "Charles Elkan. 2003.
5+
Using the triangle inequality to accelerate k-means.
6+
In Proceedings of the Twentieth International Conference on
7+
International Conference on Machine Learning (ICML’03). AAAI Press, 147–153."
8+
9+
This algorithm provides much faster convergence than Lloyd algorithm especially
10+
for high dimensional data.
11+
It can be used directly in `kmeans` function
12+
13+
```julia
14+
X = rand(30, 100_000) # 100_000 random points in 30 dimensions
15+
16+
kmeans(Elkan(), X, 3) # 3 clusters, Elkan algorithm
17+
```
18+
"""
19+
struct Elkan <: AbstractKMeansAlg end
20+
21+
function kmeans!(alg::Elkan, containers, X, k;
22+
n_threads = Threads.nthreads(),
23+
k_init = "k-means++", max_iters = 300,
24+
tol = 1e-6, verbose = false, init = nothing)
25+
nrow, ncol = size(X)
26+
centroids = init == nothing ? smart_init(X, k, n_threads, init=k_init).centroids : deepcopy(init)
27+
28+
update_containers(alg, containers, centroids, n_threads)
29+
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X)
30+
31+
converged = false
32+
niters = 0
33+
J_previous = 0.0
34+
35+
# Update centroids & labels with closest members until convergence
36+
while niters < max_iters
37+
niters += 1
38+
# Core iteration
39+
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X)
40+
41+
# Collect distributed containers (such as centroids_new, centroids_cnt)
42+
# in paper it is step 4
43+
collect_containers(alg, containers, n_threads)
44+
45+
J = sum(containers.ub)
46+
47+
# auxiliary calculation, in paper it's d(c, m(c))
48+
calculate_centroids_movement(alg, containers, centroids)
49+
50+
# lower and ounds update, in paper it's steps 5 and 6
51+
@parallelize n_threads ncol chunk_update_bounds(alg, containers, centroids)
52+
53+
# Step 7, final assignment of new centroids
54+
centroids .= containers.centroids_new[end]
55+
56+
if verbose
57+
# Show progress and terminate if J stopped decreasing.
58+
println("Iteration $niters: Jclust = $J")
59+
end
60+
61+
# Check for convergence
62+
if (niters > 1) & (abs(J - J_previous) < (tol * J))
63+
converged = true
64+
break
65+
end
66+
67+
# Step 1 in original paper, calulation of distance d(c, c')
68+
update_containers(alg, containers, centroids, n_threads)
69+
J_previous = J
70+
end
71+
72+
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids)
73+
totalcost = sum(containers.sum_of_squares)
74+
75+
# Terminate algorithm with the assumption that K-means has converged
76+
if verbose & converged
77+
println("Successfully terminated with convergence.")
78+
end
79+
80+
# TODO empty placeholder vectors should be calculated
81+
# TODO Float64 type definitions is too restrictive, should be relaxed
82+
# especially during GPU related development
83+
return KmeansResult(centroids, containers.labels, Float64[], Int[], Float64[], totalcost, niters, converged)
84+
end
85+
86+
function create_containers(::Elkan, k, nrow, ncol, n_threads)
87+
lng = n_threads + 1
88+
centroids_new = Vector{Array{Float64,2}}(undef, lng)
89+
centroids_cnt = Vector{Vector{Int}}(undef, lng)
90+
91+
for i = 1:lng
92+
centroids_new[i] = zeros(nrow, k)
93+
centroids_cnt[i] = zeros(k)
94+
end
95+
96+
centroids_dist = Matrix{Float64}(undef, k, k)
97+
98+
# lower bounds
99+
lb = Matrix{Float64}(undef, k, ncol)
100+
101+
# upper bounds
102+
ub = Vector{Float64}(undef, ncol)
103+
104+
# r(x) in original paper, shows whether point distance should be updated
105+
stale = ones(Bool, ncol)
106+
107+
# distance that centroid moved
108+
p = Vector{Float64}(undef, k)
109+
110+
labels = zeros(Int, ncol)
111+
112+
# total_sum_calculation
113+
sum_of_squares = Vector{Float64}(undef, n_threads)
114+
115+
return (
116+
centroids_new = centroids_new,
117+
centroids_cnt = centroids_cnt,
118+
labels = labels,
119+
centroids_dist = centroids_dist,
120+
lb = lb,
121+
ub = ub,
122+
stale = stale,
123+
p = p,
124+
sum_of_squares = sum_of_squares
125+
)
126+
end
127+
128+
function chunk_initialize(::Elkan, containers, centroids, X, r, idx)
129+
ub = containers.ub
130+
lb = containers.lb
131+
centroids_dist = containers.centroids_dist
132+
labels = containers.labels
133+
centroids_new = containers.centroids_new[idx]
134+
centroids_cnt = containers.centroids_cnt[idx]
135+
136+
@inbounds for i in r
137+
min_dist = distance(X, centroids, i, 1)
138+
label = 1
139+
lb[label, i] = min_dist
140+
for j in 2:size(centroids, 2)
141+
# triangular inequality
142+
if centroids_dist[j, label] > min_dist
143+
lb[j, i] = min_dist
144+
else
145+
dist = distance(X, centroids, i, j)
146+
label = dist < min_dist ? j : label
147+
min_dist = dist < min_dist ? dist : min_dist
148+
lb[j, i] = dist
149+
end
150+
end
151+
ub[i] = min_dist
152+
labels[i] = label
153+
centroids_cnt[label] += 1
154+
for j in axes(X, 1)
155+
centroids_new[j, label] += X[j, i]
156+
end
157+
end
158+
end
159+
160+
function update_containers(::Elkan, containers, centroids, n_threads)
161+
# unpack containers for easier manipulations
162+
centroids_dist = containers.centroids_dist
163+
164+
k = size(centroids_dist, 1) # number of clusters
165+
@inbounds for j in axes(centroids_dist, 2)
166+
min_dist = Inf
167+
for i in j + 1:k
168+
d = distance(centroids, centroids, i, j)
169+
centroids_dist[i, j] = d
170+
centroids_dist[j, i] = d
171+
min_dist = min_dist < d ? min_dist : d
172+
end
173+
for i in 1:j - 1
174+
min_dist = min_dist < centroids_dist[j, i] ? min_dist : centroids_dist[j, i]
175+
end
176+
centroids_dist[j, j] = min_dist
177+
end
178+
179+
# TODO: oh, one should be careful here. inequality holds for eucledian metrics
180+
# not square eucledian. So, for Lp norm it should be something like
181+
# centroids_dist = 0.5^p. Should check one more time original paper
182+
centroids_dist .*= 0.25
183+
184+
return centroids_dist
185+
end
186+
187+
function chunk_update_centroids(::Elkan, containers, centroids, X, r, idx)
188+
# unpack
189+
ub = containers.ub
190+
lb = containers.lb
191+
centroids_dist = containers.centroids_dist
192+
labels = containers.labels
193+
stale = containers.stale
194+
centroids_new = containers.centroids_new[idx]
195+
centroids_cnt = containers.centroids_cnt[idx]
196+
197+
@inbounds for i in r
198+
label_old = labels[i]
199+
label = label_old
200+
min_dist = ub[i]
201+
# tighten the loop, exclude points that very close to center
202+
min_dist <= centroids_dist[label, label] && continue
203+
for j in axes(centroids, 2)
204+
# tighten the loop once more, exclude far away centers
205+
j == label && continue
206+
min_dist <= lb[j, i] && continue
207+
min_dist <= centroids_dist[j, label] && continue
208+
209+
# one calculation per iteration is enough
210+
if stale[i]
211+
min_dist = distance(X, centroids, i, label)
212+
lb[label, i] = min_dist
213+
ub[i] = min_dist
214+
stale[i] = false
215+
end
216+
217+
if (min_dist > lb[j, i]) | (min_dist > centroids_dist[j, label])
218+
dist = distance(X, centroids, i, j)
219+
lb[j, i] = dist
220+
if dist < min_dist
221+
min_dist = dist
222+
label = j
223+
end
224+
end
225+
end
226+
227+
if label != label_old
228+
labels[i] = label
229+
centroids_cnt[label_old] -= 1
230+
centroids_cnt[label] += 1
231+
for j in axes(X, 1)
232+
centroids_new[j, label_old] -= X[j, i]
233+
centroids_new[j, label] += X[j, i]
234+
end
235+
end
236+
end
237+
end
238+
239+
function collect_containers(alg::Elkan, containers, n_threads)
240+
if n_threads == 1
241+
@inbounds containers.centroids_new[end] .= containers.centroids_new[1] ./ containers.centroids_cnt[1]'
242+
else
243+
@inbounds containers.centroids_new[end] .= containers.centroids_new[1]
244+
@inbounds containers.centroids_cnt[end] .= containers.centroids_cnt[1]
245+
@inbounds for i in 2:n_threads
246+
containers.centroids_new[end] .+= containers.centroids_new[i]
247+
containers.centroids_cnt[end] .+= containers.centroids_cnt[i]
248+
end
249+
250+
@inbounds containers.centroids_new[end] .= containers.centroids_new[end] ./ containers.centroids_cnt[end]'
251+
end
252+
end
253+
254+
function calculate_centroids_movement(alg::Elkan, containers, centroids)
255+
p = containers.p
256+
centroids_new = containers.centroids_new[end]
257+
258+
for i in axes(centroids, 2)
259+
p[i] = distance(centroids, centroids_new, i, i)
260+
end
261+
end
262+
263+
264+
function chunk_update_bounds(alg, containers, centroids, r, idx)
265+
p = containers.p
266+
lb = containers.lb
267+
ub = containers.ub
268+
stale = containers.stale
269+
labels = containers.labels
270+
271+
@inbounds for i in r
272+
for j in axes(centroids, 2)
273+
lb[j, i] = lb[j, i] > p[j] ? lb[j, i] + p[j] - 2*sqrt(abs(lb[j, i]*p[j])) : 0.0
274+
end
275+
stale[i] = true
276+
ub[i] += p[labels[i]] + 2*sqrt(abs(ub[i]*p[labels[i]]))
277+
end
278+
end

‎src/hamerly.jl

Lines changed: 30 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,14 @@ kmeans(Hamerly(), X, 3) # 3 clusters, Hamerly algorithm
1818
struct Hamerly <: AbstractKMeansAlg end
1919

2020

21-
function kmeans(alg::Hamerly, design_matrix, k;
21+
function kmeans!(alg::Hamerly, containers, X, k;
2222
n_threads = Threads.nthreads(),
2323
k_init = "k-means++", max_iters = 300,
2424
tol = 1e-6, verbose = false, init = nothing)
25-
nrow, ncol = size(design_matrix)
26-
containers = create_containers(alg, k, nrow, ncol, n_threads)
25+
nrow, ncol = size(X)
26+
centroids = init == nothing ? smart_init(X, k, n_threads, init=k_init).centroids : deepcopy(init)
2727

28-
return kmeans!(alg, containers, design_matrix, k, n_threads = n_threads,
29-
k_init = k_init, max_iters = max_iters, tol = tol,
30-
verbose = verbose, init = init)
31-
end
32-
33-
34-
function kmeans!(alg::Hamerly, containers, design_matrix, k;
35-
n_threads = Threads.nthreads(),
36-
k_init = "k-means++", max_iters = 300,
37-
tol = 1e-6, verbose = false, init = nothing)
38-
nrow, ncol = size(design_matrix)
39-
centroids = init == nothing ? smart_init(design_matrix, k, n_threads, init=k_init).centroids : deepcopy(init)
40-
41-
@parallelize n_threads ncol chunk_initialize!(alg, containers, centroids, design_matrix)
28+
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X)
4229

4330
converged = false
4431
niters = 0
@@ -48,15 +35,15 @@ function kmeans!(alg::Hamerly, containers, design_matrix, k;
4835
# Update centroids & labels with closest members until convergence
4936
while niters < max_iters
5037
niters += 1
51-
update_containers!(containers, alg, centroids, n_threads)
52-
@parallelize n_threads ncol chunk_update_centroids!(centroids, containers, alg, design_matrix)
38+
update_containers(alg, containers, centroids, n_threads)
39+
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X)
5340
collect_containers(alg, containers, n_threads)
5441

5542
J = sum(containers.ub)
56-
move_centers!(centroids, containers, alg)
43+
move_centers(alg, containers, centroids)
5744

5845
r1, r2, pr1, pr2 = double_argmax(p)
59-
@parallelize n_threads ncol chunk_update_bounds!(containers, r1, r2, pr1, pr2)
46+
@parallelize n_threads ncol chunk_update_bounds(alg, containers, r1, r2, pr1, pr2)
6047

6148
if verbose
6249
# Show progress and terminate if J stops decreasing as specified by the tolerance level.
@@ -70,10 +57,9 @@ function kmeans!(alg::Hamerly, containers, design_matrix, k;
7057
end
7158

7259
J_previous = J
73-
7460
end
7561

76-
@parallelize n_threads ncol sum_of_squares(containers, design_matrix, containers.labels, centroids)
62+
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids)
7763
totalcost = sum(containers.sum_of_squares)
7864

7965
# Terminate algorithm with the assumption that K-means has converged
@@ -144,29 +130,29 @@ function create_containers(alg::Hamerly, k, nrow, ncol, n_threads)
144130
end
145131

146132
"""
147-
chunk_initialize!(alg::Hamerly, containers, centroids, design_matrix, r, idx)
133+
chunk_initialize(alg::Hamerly, containers, centroids, design_matrix, r, idx)
148134
149135
Initial calulation of all bounds and points labeling.
150136
"""
151-
function chunk_initialize!(alg::Hamerly, containers, centroids, design_matrix, r, idx)
137+
function chunk_initialize(alg::Hamerly, containers, centroids, X, r, idx)
152138
centroids_cnt = containers.centroids_cnt[idx]
153139
centroids_new = containers.centroids_new[idx]
154140

155141
@inbounds for i in r
156-
label = point_all_centers!(containers, centroids, design_matrix, i)
142+
label = point_all_centers!(containers, centroids, X, i)
157143
centroids_cnt[label] += 1
158-
for j in axes(design_matrix, 1)
159-
centroids_new[j, label] += design_matrix[j, i]
144+
for j in axes(X, 1)
145+
centroids_new[j, label] += X[j, i]
160146
end
161147
end
162148
end
163149

164150
"""
165-
update_containers!(containers, ::Hamerly, centroids, n_threads)
151+
update_containers(::Hamerly, containers, centroids, n_threads)
166152
167153
Calculates minimum distances from centers to each other.
168154
"""
169-
function update_containers!(containers, ::Hamerly, centroids, n_threads)
155+
function update_containers(::Hamerly, containers, centroids, n_threads)
170156
s = containers.s
171157
s .= Inf
172158
@inbounds for i in axes(centroids, 2)
@@ -180,13 +166,13 @@ function update_containers!(containers, ::Hamerly, centroids, n_threads)
180166
end
181167

182168
"""
183-
chunk_update_centroids!(centroids, containers, alg::Hamerly, design_matrix, r, idx)
169+
chunk_update_centroids(::Hamerly, containers, centroids, X, r, idx)
184170
185171
Detailed description of this function can be found in the original paper. It iterates through
186172
all points and tries to skip some calculation using known upper and lower bounds of distances
187173
from point to centers. If it fails to skip than it fall back to generic `point_all_centers!` function.
188174
"""
189-
function chunk_update_centroids!(centroids, containers, alg::Hamerly, design_matrix, r, idx)
175+
function chunk_update_centroids(alg::Hamerly, containers, centroids, X, r, idx)
190176

191177
# unpack containers for easier manipulations
192178
centroids_new = containers.centroids_new[idx]
@@ -203,17 +189,17 @@ function chunk_update_centroids!(centroids, containers, alg::Hamerly, design_mat
203189
if ub[i] > m
204190
# tighten upper bound
205191
label = labels[i]
206-
ub[i] = distance(design_matrix, centroids, i, label)
192+
ub[i] = distance(X, centroids, i, label)
207193
# second bound test
208194
if ub[i] > m
209-
label_new = point_all_centers!(containers, centroids, design_matrix, i)
195+
label_new = point_all_centers!(containers, centroids, X, i)
210196
if label != label_new
211197
labels[i] = label_new
212198
centroids_cnt[label_new] += 1
213199
centroids_cnt[label] -= 1
214-
for j in axes(design_matrix, 1)
215-
centroids_new[j, label_new] += design_matrix[j, i]
216-
centroids_new[j, label] -= design_matrix[j, i]
200+
for j in axes(X, 1)
201+
centroids_new[j, label_new] += X[j, i]
202+
centroids_new[j, label] -= X[j, i]
217203
end
218204
end
219205
end
@@ -222,11 +208,11 @@ function chunk_update_centroids!(centroids, containers, alg::Hamerly, design_mat
222208
end
223209

224210
"""
225-
point_all_centers!(containers, centroids, design_matrix, i)
211+
point_all_centers!(containers, centroids, X, i)
226212
227213
Calculates new labels and upper and lower bounds for all points.
228214
"""
229-
function point_all_centers!(containers, centroids, design_matrix, i)
215+
function point_all_centers!(containers, centroids, X, i)
230216
ub = containers.ub
231217
lb = containers.lb
232218
labels = containers.labels
@@ -235,7 +221,7 @@ function point_all_centers!(containers, centroids, design_matrix, i)
235221
min_distance2 = Inf
236222
label = 1
237223
@inbounds for k in axes(centroids, 2)
238-
dist = distance(design_matrix, centroids, i, k)
224+
dist = distance(X, centroids, i, k)
239225
if min_distance > dist
240226
label = k
241227
min_distance2 = min_distance
@@ -253,12 +239,12 @@ function point_all_centers!(containers, centroids, design_matrix, i)
253239
end
254240

255241
"""
256-
move_centers!(centroids, containers, ::Hamerly)
242+
move_centers(::Hamerly, containers, centroids)
257243
258244
Calculates new positions of centers and distance they have moved. Results are stored
259245
in `centroids` and `p` respectively.
260246
"""
261-
function move_centers!(centroids, containers, ::Hamerly)
247+
function move_centers(::Hamerly, containers, centroids)
262248
centroids_new = containers.centroids_new[end]
263249
p = containers.p
264250

@@ -273,11 +259,11 @@ function move_centers!(centroids, containers, ::Hamerly)
273259
end
274260

275261
"""
276-
chunk_update_bounds!(containers, r1, r2, pr1, pr2, r, idx)
262+
chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
277263
278264
Updates upper and lower bounds of point distance to the centers, with regard to the centers movement.
279265
"""
280-
function chunk_update_bounds!(containers, r1, r2, pr1, pr2, r, idx)
266+
function chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
281267
p = containers.p
282268
ub = containers.ub
283269
lb = containers.lb

‎src/kmeans.jl

Lines changed: 1 addition & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ Allocationless calculation of square eucledean distance between vectors X1[:, i1
9393
"""
9494
function distance(X1, X2, i1, i2)
9595
d = 0.0
96+
# TODO: break of the loop if d is larger than threshold (known minimum disatnce)
9697
@inbounds for i in axes(X1, 1)
9798
d += (X1[i, i1] - X2[i, i2])^2
9899
end
@@ -108,18 +109,6 @@ design matrix(x), centroids (centre), and the number of desired groups (k).
108109
109110
A Float type representing the computed metric is returned.
110111
"""
111-
function sum_of_squares(x, labels, centre)
112-
s = 0.0
113-
114-
@inbounds for j in axes(x, 2)
115-
for i in axes(x, 1)
116-
s += (x[i, j] - centre[i, labels[j]])^2
117-
end
118-
end
119-
120-
return s
121-
end
122-
123112
function sum_of_squares(containers, x, labels, centre, r, idx)
124113
s = 0.0
125114

@@ -170,100 +159,3 @@ function kmeans(alg, design_matrix, k;
170159
k_init = k_init, max_iters = max_iters, tol = tol,
171160
verbose = verbose, init = init)
172161
end
173-
174-
175-
"""
176-
Kmeans!(alg::AbstractKMeansAlg, containers, design_matrix, k; n_threads = nthreads(), k_init="k-means++", max_iters=300, tol=1e-6, verbose=false)
177-
178-
Mutable version of `kmeans` function. Definition of arguments and results can be
179-
found in `kmeans`.
180-
181-
Argument `containers` represent algorithm specific containers, such as labels, intermidiate
182-
centroids and so on, which are used during calculations.
183-
"""
184-
function kmeans!(alg, containers, design_matrix, k;
185-
n_threads = Threads.nthreads(),
186-
k_init = "k-means++", max_iters = 300,
187-
tol = 1e-6, verbose = false, init = nothing)
188-
nrow, ncol = size(design_matrix)
189-
centroids = init == nothing ? smart_init(design_matrix, k, n_threads, init=k_init).centroids : deepcopy(init)
190-
191-
converged = false
192-
niters = 0
193-
J_previous = 0.0
194-
195-
# Update centroids & labels with closest members until convergence
196-
197-
while niters < max_iters
198-
niters += 1
199-
200-
update_containers!(containers, alg, centroids, n_threads)
201-
J = update_centroids!(centroids, containers, alg, design_matrix, n_threads)
202-
203-
if verbose
204-
# Show progress and terminate if J stopped decreasing.
205-
println("Iteration $niters: Jclust = $J")
206-
end
207-
208-
# Check for convergence
209-
if (niters > 1) & (abs(J - J_previous) < (tol * J))
210-
converged = true
211-
break
212-
end
213-
214-
J_previous = J
215-
216-
end
217-
218-
totalcost = sum_of_squares(design_matrix, containers.labels, centroids)
219-
220-
# Terminate algorithm with the assumption that K-means has converged
221-
if verbose & converged
222-
println("Successfully terminated with convergence.")
223-
end
224-
225-
# TODO empty placeholder vectors should be calculated
226-
# TODO Float64 type definitions is too restrictive, should be relaxed
227-
# especially during GPU related development
228-
return KmeansResult(centroids, containers.labels, Float64[], Int[], Float64[], totalcost, niters, converged)
229-
end
230-
231-
"""
232-
update_centroids!(centroids, containers, alg, design_matrix, n_threads)
233-
234-
Internal function, used to update centroids by utilizing one of `alg`. It works as
235-
a wrapper of internal `chunk_update_centroids!` function, splitting incoming
236-
`design_matrix` in chunks and combining results together.
237-
"""
238-
function update_centroids!(centroids, containers, alg, design_matrix, n_threads)
239-
ncol = size(design_matrix, 2)
240-
241-
if n_threads == 1
242-
r = axes(design_matrix, 2)
243-
J = chunk_update_centroids!(centroids, containers, alg, design_matrix, r, 1)
244-
245-
centroids .= containers.new_centroids[1] ./ containers.centroids_cnt[1]'
246-
else
247-
ranges = splitter(ncol, n_threads)
248-
249-
waiting_list = Vector{Task}(undef, n_threads - 1)
250-
251-
for i in 1:length(ranges) - 1
252-
waiting_list[i] = @spawn chunk_update_centroids!(centroids, containers,
253-
alg, design_matrix, ranges[i], i + 1)
254-
end
255-
256-
J = chunk_update_centroids!(centroids, containers, alg, design_matrix, ranges[end], 1)
257-
258-
J += sum(fetch.(waiting_list))
259-
260-
for i in 1:length(ranges) - 1
261-
containers.new_centroids[1] .+= containers.new_centroids[i + 1]
262-
containers.centroids_cnt[1] .+= containers.centroids_cnt[i + 1]
263-
end
264-
265-
centroids .= containers.new_centroids[1] ./ containers.centroids_cnt[1]'
266-
end
267-
268-
return J
269-
end

‎src/light_elkan.jl

Lines changed: 0 additions & 150 deletions
This file was deleted.

‎src/lloyd.jl

Lines changed: 94 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,61 @@ Basic algorithm for k-means calculation.
55
"""
66
struct Lloyd <: AbstractKMeansAlg end
77

8+
"""
9+
Kmeans!(alg::AbstractKMeansAlg, containers, design_matrix, k; n_threads = nthreads(), k_init="k-means++", max_iters=300, tol=1e-6, verbose=true)
10+
11+
Mutable version of `kmeans` function. Definition of arguments and results can be
12+
found in `kmeans`.
13+
14+
Argument `containers` represent algorithm specific containers, such as labels, intermidiate
15+
centroids and so on, which are used during calculations.
16+
"""
17+
function kmeans!(alg::Lloyd, containers, X, k;
18+
n_threads = Threads.nthreads(),
19+
k_init = "k-means++", max_iters = 300,
20+
tol = 1e-6, verbose = false, init = nothing)
21+
nrow, ncol = size(X)
22+
centroids = init == nothing ? smart_init(X, k, n_threads, init=k_init).centroids : deepcopy(init)
23+
24+
converged = false
25+
niters = 1
26+
J_previous = 0.0
27+
28+
# Update centroids & labels with closest members until convergence
29+
while niters <= max_iters
30+
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X)
31+
collect_containers(alg, containers, centroids, n_threads)
32+
J = sum(containers.J)
33+
34+
if verbose
35+
# Show progress and terminate if J stopped decreasing.
36+
println("Iteration $niters: Jclust = $J")
37+
end
38+
39+
# Check for convergence
40+
if (niters > 1) & (abs(J - J_previous) < (tol * J))
41+
converged = true
42+
break
43+
end
44+
45+
J_previous = J
46+
niters += 1
47+
end
48+
49+
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids)
50+
totalcost = sum(containers.sum_of_squares)
51+
52+
# Terminate algorithm with the assumption that K-means has converged
53+
if verbose & converged
54+
println("Successfully terminated with convergence.")
55+
end
56+
57+
# TODO empty placeholder vectors should be calculated
58+
# TODO Float64 type definitions is too restrictive, should be relaxed
59+
# especially during GPU related development
60+
return KmeansResult(centroids, containers.labels, Float64[], Int[], Float64[], totalcost, niters, converged)
61+
end
62+
863
kmeans(design_matrix, k;
964
n_threads = Threads.nthreads(),
1065
k_init = "k-means++", max_iters = 300, tol = 1e-6,
@@ -17,56 +72,70 @@ kmeans(design_matrix, k;
1772
1873
Internal function for the creation of all necessary intermidiate structures.
1974
20-
- `new_centroids` - container which holds new positions of centroids
75+
- `centroids_new` - container which holds new positions of centroids
2176
- `centroids_cnt` - container which holds number of points for each centroid
2277
- `labels` - vector which holds labels of corresponding points
2378
"""
2479
function create_containers(::Lloyd, k, nrow, ncol, n_threads)
25-
new_centroids = Vector{Array{Float64, 2}}(undef, n_threads)
26-
centroids_cnt = Vector{Vector{Int}}(undef, n_threads)
80+
lng = n_threads + 1
81+
centroids_new = Vector{Array{Float64,2}}(undef, lng)
82+
centroids_cnt = Vector{Vector{Int}}(undef, lng)
2783

28-
for i in 1:n_threads
29-
new_centroids[i] = Array{Float64, 2}(undef, nrow, k)
84+
for i in 1:lng
85+
centroids_new[i] = Array{Float64, 2}(undef, nrow, k)
3086
centroids_cnt[i] = Vector{Int}(undef, k)
3187
end
3288

3389
labels = Vector{Int}(undef, ncol)
3490

35-
return (new_centroids = new_centroids, centroids_cnt = centroids_cnt,
36-
labels = labels)
37-
end
91+
J = Vector{Float64}(undef, n_threads)
3892

39-
update_containers!(containers, ::Lloyd, centroids, n_threads) = nothing
93+
# total_sum_calculation
94+
sum_of_squares = Vector{Float64}(undef, n_threads)
4095

41-
function chunk_update_centroids!(centroids, containers, ::Lloyd,
42-
design_matrix, r, idx)
96+
return (centroids_new = centroids_new, centroids_cnt = centroids_cnt,
97+
labels = labels, J = J, sum_of_squares = sum_of_squares)
98+
end
4399

100+
function chunk_update_centroids(::Lloyd, containers, centroids, X, r, idx)
44101
# unpack containers for easier manipulations
45-
new_centroids = containers.new_centroids[idx]
102+
centroids_new = containers.centroids_new[idx]
46103
centroids_cnt = containers.centroids_cnt[idx]
47104
labels = containers.labels
48105

49-
new_centroids .= 0.0
106+
centroids_new .= 0.0
50107
centroids_cnt .= 0
51108
J = 0.0
52109
@inbounds for i in r
53-
min_distance = Inf
110+
min_dist = distance(X, centroids, i, 1)
54111
label = 1
55-
for k in axes(centroids, 2)
56-
distance = 0.0
57-
for j in axes(design_matrix, 1)
58-
distance += (design_matrix[j, i] - centroids[j, k])^2
59-
end
60-
label = min_distance > distance ? k : label
61-
min_distance = min_distance > distance ? distance : min_distance
112+
for j in 2:size(centroids, 2)
113+
dist = distance(X, centroids, i, j)
114+
label = dist < min_dist ? j : label
115+
min_dist = dist < min_dist ? dist : min_dist
62116
end
63117
labels[i] = label
64118
centroids_cnt[label] += 1
65-
for j in axes(design_matrix, 1)
66-
new_centroids[j, label] += design_matrix[j, i]
119+
for j in axes(X, 1)
120+
centroids_new[j, label] += X[j, i]
67121
end
68-
J += min_distance
122+
J += min_dist
69123
end
70124

71-
return J
125+
containers.J[idx] = J
126+
end
127+
128+
function collect_containers(alg::Lloyd, containers, centroids, n_threads)
129+
if n_threads == 1
130+
@inbounds centroids .= containers.centroids_new[1] ./ containers.centroids_cnt[1]'
131+
else
132+
@inbounds containers.centroids_new[end] .= containers.centroids_new[1]
133+
@inbounds containers.centroids_cnt[end] .= containers.centroids_cnt[1]
134+
@inbounds for i in 2:n_threads
135+
containers.centroids_new[end] .+= containers.centroids_new[i]
136+
containers.centroids_cnt[end] .+= containers.centroids_cnt[i]
137+
end
138+
139+
@inbounds centroids .= containers.centroids_new[end] ./ containers.centroids_cnt[end]'
140+
end
72141
end

‎src/mlj_interface.jl

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all ava
55
# availalbe variants for reference
66
const MLJDICT = Dict(:Lloyd => Lloyd(),
77
:Hamerly => Hamerly(),
8-
:LightElkan => LightElkan())
8+
:Elkan => Elkan())
99

1010
####
1111
#### MODEL DEFINITION
1212
####
1313

14-
mutable struct KMeans <: MLJModelInterface.Unsupervised
14+
mutable struct KMeans <: MMI.Unsupervised
1515
algo::Symbol
1616
k_init::String
1717
k::Int
@@ -24,49 +24,56 @@ mutable struct KMeans <: MLJModelInterface.Unsupervised
2424
end
2525

2626

27-
function KMeans(; algo=:Lloyd, k_init="k-means++",
27+
function KMeans(; algo=:Hamerly, k_init="k-means++",
2828
k=3, tol=1e-6, max_iters=300, copy=true,
2929
threads=Threads.nthreads(), verbosity=0, init=nothing)
3030

3131
model = KMeans(algo, k_init, k, tol, max_iters, copy, threads, verbosity, init)
32-
message = MLJModelInterface.clean!(model)
32+
message = MMI.clean!(model)
3333
isempty(message) || @warn message
3434
return model
3535
end
3636

3737

38-
function MLJModelInterface.clean!(m::KMeans)
39-
warning = ""
38+
function MMI.clean!(m::KMeans)
39+
warning = String[]
4040

4141
if !(m.algo keys(MLJDICT))
42-
warning *= "Unsupported KMeans variant, Defauting to KMeans++ seeding algorithm."
43-
m.algo = :Lloyd
42+
push!(warning, "Unsupported KMeans variant. Defaulting to Hamerly algorithm.")
43+
m.algo = :Hamerly
44+
end
4445

45-
elseif m.k_init != "k-means++"
46-
warning *= "Only `k-means++` or random seeding algorithms are supported. Defaulting to random seeding."
47-
m.k_init = "random"
46+
if !(m.k_init ["k-means++", "random"])
47+
push!(warning, "Only \"k-means++\" or \"random\" seeding algorithms are supported. Defaulting to k-means++ seeding.")
48+
m.k_init = "kmeans++"
49+
end
4850

49-
elseif m.k < 1
50-
warning *= "Number of clusters must be greater than 0. Defaulting to 3 clusters."
51+
if m.k < 1
52+
push!(warning, "Number of clusters must be greater than 0. Defaulting to 3 clusters.")
5153
m.k = 3
54+
end
5255

53-
elseif !(m.tol < 1.0)
54-
warning *= "Tolerance level must be less than 1. Defaulting to tol of 1e-6."
56+
if !(m.tol < 1.0)
57+
push!(warning, "Tolerance level must be less than 1. Defaulting to tol of 1e-6.")
5558
m.tol = 1e-6
59+
end
5660

57-
elseif !(m.max_iters > 0)
58-
warning *= "Number of permitted iterations must be greater than 0. Defaulting to 300 iterations."
61+
if !(m.max_iters > 0)
62+
push!(warning, "Number of permitted iterations must be greater than 0. Defaulting to 300 iterations.")
5963
m.max_iters = 300
64+
end
6065

61-
elseif !(m.threads > 0)
62-
warning *= "Number of threads must be at least 1. Defaulting to all threads available."
66+
if !(m.threads > 0)
67+
push!(warning, "Number of threads must be at least 1. Defaulting to all threads available.")
6368
m.threads = Threads.nthreads()
69+
end
6470

65-
elseif !(m.verbosity (0, 1))
66-
warning *= "Verbosity must be either 0 (no info) or 1 (info requested). Defaulting to 0."
67-
m.verbosity = 0
71+
if !(m.verbosity (0, 1))
72+
push!(warning, "Verbosity must be either 0 (no info) or 1 (info requested). Defaulting to 1.")
73+
m.verbosity = 1
6874
end
69-
return warning
75+
76+
return join(warning, "\n")
7077
end
7178

7279

@@ -78,14 +85,14 @@ end
7885
7986
See also the [package documentation](https://pydatablog.github.io/ParallelKMeans.jl/stable).
8087
"""
81-
function MLJModelInterface.fit(m::KMeans, X)
88+
function MMI.fit(m::KMeans, X)
8289
# convert tabular input data into the matrix model expects. Column assumed as features so input data is permuted
8390
if !m.copy
8491
# permutes dimensions of input table without copying and pass to model
85-
DMatrix = convert(Array{Float64, 2}, MLJModelInterface.matrix(X)')
92+
DMatrix = convert(Array{Float64, 2}, MMI.matrix(X)')
8693
else
8794
# permutes dimensions of input table as a column major matrix from a copy of the data
88-
DMatrix = convert(Array{Float64, 2}, MLJModelInterface.matrix(X, transpose=true))
95+
DMatrix = convert(Array{Float64, 2}, MMI.matrix(X, transpose=true))
8996
end
9097

9198
# lookup available algorithms
@@ -106,7 +113,7 @@ function MLJModelInterface.fit(m::KMeans, X)
106113
end
107114

108115

109-
function MLJModelInterface.fitted_params(model::KMeans, fitresult)
116+
function MMI.fitted_params(model::KMeans, fitresult)
110117
# extract what's relevant from `fitresult`
111118
results, _, _ = fitresult # unpack fitresult
112119
centers = results.centers
@@ -124,15 +131,15 @@ end
124131
#### PREDICT FUNCTION
125132
####
126133

127-
function MLJModelInterface.transform(m::KMeans, fitresult, Xnew)
134+
function MMI.transform(m::KMeans, fitresult, Xnew)
128135
# make predictions/assignments using the learned centroids
129136

130137
if !m.copy
131138
# permutes dimensions of input table without copying and pass to model
132-
DMatrix = convert(Array{Float64, 2}, MLJModelInterface.matrix(Xnew)')
139+
DMatrix = convert(Array{Float64, 2}, MMI.matrix(Xnew)')
133140
else
134141
# permutes dimensions of input table as a column major matrix from a copy of the data
135-
DMatrix = convert(Array{Float64, 2}, MLJModelInterface.matrix(Xnew, transpose=true))
142+
DMatrix = convert(Array{Float64, 2}, MMI.matrix(Xnew, transpose=true))
136143
end
137144

138145
# TODO: Warn users if fitresult is from a `non-converged` fit?
@@ -147,7 +154,7 @@ function MLJModelInterface.transform(m::KMeans, fitresult, Xnew)
147154
centroids = results.centers
148155
distances = Distances.pairwise(Distances.SqEuclidean(), DMatrix, centroids; dims=2)
149156
preds = argmin.(eachrow(distances))
150-
return MLJModelInterface.table(reshape(preds, :, 1), prototype=Xnew)
157+
return MMI.table(reshape(preds, :, 1), prototype=Xnew)
151158
end
152159

153160

@@ -156,7 +163,7 @@ end
156163
####
157164

158165
# TODO 4: metadata for the package and for each of the model interfaces
159-
metadata_pkg.(KMeans,
166+
MMI.metadata_pkg.(KMeans,
160167
name = "ParallelKMeans",
161168
uuid = "42b8e9d4-006b-409a-8472-7f34b3fb58af",
162169
url = "https://github.com/PyDataBlog/ParallelKMeans.jl",
@@ -166,9 +173,9 @@ metadata_pkg.(KMeans,
166173

167174

168175
# Metadata for ParaKMeans model interface
169-
metadata_model(KMeans,
170-
input = MLJModelInterface.Table(MLJModelInterface.Continuous),
171-
output = MLJModelInterface.Table(MLJModelInterface.Count),
176+
MMI.metadata_model(KMeans,
177+
input = MMI.Table(MMI.Continuous),
178+
output = MMI.Table(MMI.Count),
172179
weights = false,
173180
descr = ParallelKMeans_Desc,
174181
path = "ParallelKMeans.KMeans")

‎test/test02_kmeans.jl renamed to ‎test/test02_lloyd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module TestKMeans
1+
module TestLloyd
22

33
using ParallelKMeans
44
using Test

‎test/test04_elkan.jl

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,47 @@
11
module TestElkan
22

33
using ParallelKMeans
4-
using ParallelKMeans: update_containers!
54
using Test
65
using Random
76

8-
@testset "centroid distances" begin
9-
containers = (centroids_dist = Matrix{Float64}(undef, 3, 3), )
10-
centroids = [1.0 2.0 4.0; 2.0 1.0 3.0]
11-
update_containers!(containers, LightElkan(), centroids, 1)
12-
centroids_dist = containers.centroids_dist
13-
@test centroids_dist[1, 2] == centroids_dist[2, 1]
14-
@test centroids_dist[1, 3] == centroids_dist[3, 1]
15-
@test centroids_dist[2, 3] == centroids_dist[3, 2]
16-
@test centroids_dist[1, 2] == 0.5
17-
@test centroids_dist[1, 3] == 2.5
18-
@test centroids_dist[2, 3] == 2.0
19-
@test centroids_dist[1, 1] == 0.5
20-
@test centroids_dist[2, 2] == 0.5
21-
@test centroids_dist[3, 3] == 2.0
22-
end
23-
24-
@testset "basic kmeans" begin
7+
@testset "basic kmeans elkan" begin
258
X = [1. 2. 4.;]
26-
res = kmeans(LightElkan(), X, 1; n_threads = 1, tol = 1e-6, verbose = false)
9+
res = kmeans(Elkan(), X, 1; n_threads = 1, tol = 1e-6, verbose = false)
2710
@test res.assignments == [1, 1, 1]
2811
@test res.centers[1] 2.3333333333333335
2912
@test res.totalcost 4.666666666666666
3013
@test res.converged
3114

32-
res = kmeans(LightElkan(), X, 2; n_threads = 1, init = [1.0 4.0], tol = 1e-6, verbose = false)
15+
res = kmeans(Elkan(), X, 2; n_threads = 1, init = [1.0 4.0], tol = 1e-6, verbose = false)
3316
@test res.assignments == [1, 1, 2]
3417
@test res.centers [1.5 4.0]
3518
@test res.totalcost 0.5
3619
@test res.converged
3720
end
3821

39-
@testset "no convergence yield last result" begin
22+
@testset "elkan no convergence yield last result" begin
4023
X = [1. 2. 4.;]
41-
res = kmeans(LightElkan(), X, 2; n_threads = 1, init = [1.0 4.0], tol = 1e-6, max_iters = 1, verbose = false)
24+
res = kmeans(Elkan(), X, 2; n_threads = 1, init = [1.0 4.0], tol = 1e-6, max_iters = 1, verbose = false)
4225
@test !res.converged
4326
@test res.totalcost 0.5
4427
end
4528

46-
@testset "singlethread linear separation" begin
29+
@testset "elkan singlethread linear separation" begin
4730
Random.seed!(2020)
4831

4932
X = rand(3, 100)
50-
res = kmeans(LightElkan(), X, 3; n_threads = 1, tol = 1e-6, verbose = false)
33+
res = kmeans(Elkan(), X, 3; n_threads = 1, tol = 1e-10, max_iters = 10, verbose = false)
5134

5235
@test res.totalcost 14.16198704459199
53-
@test res.converged
54-
@test res.iterations == 11
36+
@test !res.converged
37+
@test res.iterations == 10
5538
end
5639

57-
@testset "multithread linear separation quasi two threads" begin
40+
@testset "elkan multithread linear separation quasi two threads" begin
5841
Random.seed!(2020)
5942

6043
X = rand(3, 100)
61-
res = kmeans(LightElkan(), X, 3; n_threads = 2, tol = 1e-6, verbose = false)
44+
res = kmeans(Elkan(), X, 3; n_threads = 2, tol = 1e-6, verbose = false)
6245

6346
@test res.totalcost 14.16198704459199
6447
@test res.converged

‎test/test05_hamerly.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module TestHamerly
22

33
using ParallelKMeans
4-
using ParallelKMeans: chunk_initialize!, double_argmax
4+
using ParallelKMeans: chunk_initialize, double_argmax
55
using Test
66
using Random
77

@@ -11,7 +11,7 @@ using Random
1111
nrow, ncol = size(X)
1212
containers = ParallelKMeans.create_containers(Hamerly(), 3, nrow, ncol, 1)
1313

14-
ParallelKMeans.chunk_initialize!(Hamerly(), containers, centroids, X, 1:ncol, 1)
14+
ParallelKMeans.chunk_initialize(Hamerly(), containers, centroids, X, 1:ncol, 1)
1515
@test containers.lb == [18.0, 20.0, 5.0, 5.0]
1616
@test containers.ub == [0.0, 2.0, 0.0, 0.0]
1717
end

‎test/test06_verbose.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Test
66
using Suppressor
77

88

9-
@testset "Testing verbosity of implementation" begin
9+
@testset "LLoyd: Testing verbosity of implementation" begin
1010
Random.seed!(2020)
1111
X = rand(4, 150)
1212
Random.seed!(2020)
@@ -15,5 +15,22 @@ using Suppressor
1515
@test r == "Iteration 1: Jclust = 46.534795844478815\n"
1616
end
1717

18-
end # module
18+
@testset "Hamerly: Testing verbosity of implementation" begin
19+
Random.seed!(2020)
20+
X = rand(4, 150)
21+
Random.seed!(2020)
22+
# Capture output and compare
23+
r = @capture_out kmeans(Hamerly(), X, 3; n_threads=1, max_iters=1, verbose=true)
24+
@test r == "Iteration 1: Jclust = 46.534795844478815\n"
25+
end
1926

27+
@testset "Elkan: Testing verbosity of implementation" begin
28+
Random.seed!(2020)
29+
X = rand(4, 150)
30+
Random.seed!(2020)
31+
# Capture output and compare
32+
r = @capture_out kmeans(Elkan(), X, 3; n_threads=1, max_iters=1, verbose=true)
33+
@test r == "Iteration 1: Jclust = 46.534795844478815\n"
34+
end
35+
36+
end # module

‎test/test07_mlj_interface.jl

Lines changed: 27 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using MLJBase
1111
@testset "Test struct construction" begin
1212
model = KMeans()
1313

14-
@test model.algo == :Lloyd
14+
@test model.algo == :Hamerly
1515
@test model.init == nothing
1616
@test model.k == 3
1717
@test model.k_init == "k-means++"
@@ -24,13 +24,13 @@ end
2424

2525

2626
@testset "Test bad struct warings" begin
27-
@test_logs (:warn, "Unsupported KMeans variant, Defauting to KMeans++ seeding algorithm.") ParallelKMeans.KMeans(algo=:Fake)
28-
@test_logs (:warn, "Only `k-means++` or random seeding algorithms are supported. Defaulting to random seeding.") ParallelKMeans.KMeans(k_init="abc")
27+
@test_logs (:warn, "Unsupported KMeans variant. Defaulting to Hamerly algorithm.") ParallelKMeans.KMeans(algo=:Fake)
28+
@test_logs (:warn, "Only \"k-means++\" or \"random\" seeding algorithms are supported. Defaulting to k-means++ seeding.") ParallelKMeans.KMeans(k_init="abc")
2929
@test_logs (:warn, "Number of clusters must be greater than 0. Defaulting to 3 clusters.") ParallelKMeans.KMeans(k=0)
3030
@test_logs (:warn, "Tolerance level must be less than 1. Defaulting to tol of 1e-6.") ParallelKMeans.KMeans(tol=2)
3131
@test_logs (:warn, "Number of permitted iterations must be greater than 0. Defaulting to 300 iterations.") ParallelKMeans.KMeans(max_iters=0)
3232
@test_logs (:warn, "Number of threads must be at least 1. Defaulting to all threads available.") ParallelKMeans.KMeans(threads=0)
33-
@test_logs (:warn, "Verbosity must be either 0 (no info) or 1 (info requested). Defaulting to 0.") ParallelKMeans.KMeans(verbosity=100)
33+
@test_logs (:warn, "Verbosity must be either 0 (no info) or 1 (info requested). Defaulting to 1.") ParallelKMeans.KMeans(verbosity=100)
3434
end
3535

3636

@@ -47,75 +47,62 @@ end
4747
@testset "Test Lloyd model fitting" begin
4848
Random.seed!(2020)
4949
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
50-
model = KMeans(k=2)
51-
results = fit(model, X)
52-
53-
@test results[2] == nothing
54-
@test results[end].converged == true
55-
@test results[end].totalcost == 16
56-
end
57-
50+
X_test = table([10 1])
5851

59-
@testset "Test Hamerly model fitting" begin
60-
Random.seed!(2020)
61-
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
62-
model = KMeans(algo=:Hamerly, k=2)
52+
model = KMeans(algo = :Lloyd, k=2)
6353
results = fit(model, X)
6454

6555
@test results[2] == nothing
6656
@test results[end].converged == true
6757
@test results[end].totalcost == 16
68-
end
69-
70-
71-
@testset "Test Lloyd fitted params" begin
72-
Random.seed!(2020)
73-
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
74-
model = KMeans(k=2)
75-
results = fit(model, X)
7658

7759
params = fitted_params(model, results)
7860
@test params.converged == true
7961
@test params.totalcost == 16
62+
63+
# Use trained model to cluster new data X_test
64+
preds = transform(model, results, X_test)
65+
@test preds[:x1][1] == 2
8066
end
8167

8268

83-
@testset "Test Hamerly fitted params" begin
69+
@testset "Test Hamerly model fitting" begin
8470
Random.seed!(2020)
8571
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
72+
X_test = table([10 1])
73+
8674
model = KMeans(algo=:Hamerly, k=2)
8775
results = fit(model, X)
8876

77+
@test results[2] == nothing
78+
@test results[end].converged == true
79+
@test results[end].totalcost == 16
80+
8981
params = fitted_params(model, results)
9082
@test params.converged == true
9183
@test params.totalcost == 16
92-
end
93-
94-
95-
@testset "Test Lloyd transform" begin
96-
Random.seed!(2020)
97-
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
98-
X_test = table([10 1])
99-
100-
# Train model using training data X
101-
model = KMeans(k=2)
102-
results = fit(model, X)
10384

10485
# Use trained model to cluster new data X_test
10586
preds = transform(model, results, X_test)
10687
@test preds[:x1][1] == 2
10788
end
10889

109-
110-
@testset "Test Hamerly transform" begin
90+
@testset "Test Elkan model fitting" begin
11191
Random.seed!(2020)
11292
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
11393
X_test = table([10 1])
11494

115-
# Train model using training data X
116-
model = KMeans(algo=:Hamerly, k=2)
95+
model = KMeans(algo=:Elkan, k=2)
11796
results = fit(model, X)
11897

98+
@test results[2] == nothing
99+
@test results[end].converged == true
100+
@test results[end].totalcost == 16
101+
102+
params = fitted_params(model, results)
103+
@test params.converged == true
104+
@test params.totalcost == 16
105+
119106
# Use trained model to cluster new data X_test
120107
preds = transform(model, results, X_test)
121108
@test preds[:x1][1] == 2
@@ -133,4 +120,3 @@ end
133120
end
134121

135122
end # module
136-

0 commit comments

Comments
 (0)
Please sign in to comment.