Skip to content

Commit 997f747

Browse files
committed
Updated MLJ Interface with predict function
1 parent 461c3d6 commit 997f747

File tree

2 files changed

+71
-45
lines changed

2 files changed

+71
-45
lines changed

src/mlj_interface.jl

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,20 @@ function MMI.fit(m::KMeans, verbosity::Int, X)
9494

9595
# fit model and get results
9696
verbose = verbosity > 0 # Display fitting operations if verbosity > 0
97-
fitresult = ParallelKMeans.kmeans(algo, DMatrix, m.k;
97+
result = ParallelKMeans.kmeans(algo, DMatrix, m.k;
9898
n_threads = m.threads, k_init=m.k_init,
9999
max_iters=m.max_iters, tol=m.tol, init=m.init,
100100
verbose=verbose)
101101

102+
cluster_labels = MMI.categorical(1:m.k)
103+
fitresult = (result.centers, cluster_labels, result.converged)
102104
cache = nothing
103-
report = (cluster_centers=fitresult.centers, iterations=fitresult.iterations,
104-
converged=fitresult.converged, totalcost=fitresult.totalcost,
105-
labels=fitresult.assignments)
105+
106+
report = (cluster_centers=result.centers, iterations=result.iterations,
107+
converged=result.converged, totalcost=result.totalcost,
108+
assignments=result.assignments, labels=cluster_labels)
109+
110+
106111
"""
107112
# TODO: warn users about non convergence
108113
if verbose & (!fitresult.converged)
@@ -114,16 +119,8 @@ end
114119

115120

116121
function MMI.fitted_params(model::KMeans, fitresult)
117-
# extract what's relevant from `fitresult`
118-
results, _, _ = fitresult # unpack fitresult
119-
centers = results.centers
120-
converged = results.converged
121-
iters = results.iterations
122-
totalcost = results.totalcost
123-
124-
# then return as a NamedTuple
125-
return (cluster_centers = centers, totalcost = totalcost,
126-
iterations = iters, converged = converged)
122+
# Centroids
123+
return (cluster_centers = fitresult[1], )
127124
end
128125

129126

@@ -143,21 +140,37 @@ function MMI.transform(m::KMeans, fitresult, Xnew)
143140
end
144141

145142
# Warn users if fitresult is from a `non-converged` fit
146-
if !fitresult[end].converged
143+
if !(fitresult[end])
147144
@warn "Failed to converge. Using last assignments to make transformations."
148145
end
149146

150-
# results from fitted model
151-
results = fitresult[1]
152-
153147
# use centroid matrix to assign clusters for new data
154-
centroids = results.centers
148+
centroids = fitresult[1]
155149
distances = Distances.pairwise(Distances.SqEuclidean(), DMatrix, centroids; dims=2)
156-
preds = argmin.(eachrow(distances))
157-
return MMI.table(reshape(preds, :, 1), prototype=Xnew)
150+
#preds = argmin.(eachrow(distances))
151+
return MMI.table(distances, prototype=Xnew)
158152
end
159153

160154

155+
function MMI.predict(m::KMeans, fitresult, Xnew)
156+
locations, cluster_labels, _ = fitresult
157+
158+
Xarray = MMI.matrix(Xnew)
159+
(n, p), k = size(Xarray), m.k
160+
161+
pred = zeros(Int, n)
162+
@inbounds for i 1:n
163+
minv = Inf
164+
for j 1:k
165+
curv = Distances.evaluate(Distances.SqEuclidean(), view(Xarray, i, :), view(locations, :, j))
166+
P = curv < minv
167+
pred[i] = j * P + pred[i] * !P # if P is true --> j
168+
minv = curv * P + minv * !P # if P is true --> curvalue
169+
end
170+
end
171+
return cluster_labels[pred]
172+
end
173+
161174
####
162175
#### METADATA
163176
####
@@ -176,6 +189,7 @@ MMI.metadata_pkg.(KMeans,
176189
MMI.metadata_model(KMeans,
177190
input = MMI.Table(MMI.Continuous),
178191
output = MMI.Table(MMI.Continuous),
192+
target = AbstractArray{<:MMI.Multiclass},
179193
weights = false,
180194
descr = ParallelKMeans_Desc,
181195
path = "ParallelKMeans.KMeans")

test/test07_mlj_interface.jl

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,23 @@ end
4848
X_test = table([10 1])
4949

5050
model = KMeans(algo = :Lloyd, k=2)
51-
results = fit(model, 0, X)
51+
results, cache, report = fit(model, 0, X)
5252

53-
@test results[2] == nothing
54-
@test results[end].converged == true
55-
@test results[end].totalcost == 16
53+
@test cache == nothing
54+
@test report.converged == true
55+
@test report.totalcost == 16
5656

5757
params = fitted_params(model, results)
58-
@test params.converged == true
59-
@test params.totalcost == 16
58+
@test params.cluster_centers == [1.0 10.0; 2.0 2.0]
6059

6160
# Use trained model to cluster new data X_test
6261
preds = transform(model, results, X_test)
63-
@test preds[:x1][1] == 2
62+
@test preds[:x1][1] == 82.0
63+
@test preds[:x2][1] == 1.0
64+
65+
# Make predictions on new data X_test with fitted params
66+
yhat = predict(model, results, X_test)
67+
@test yhat[1] == 2
6468
end
6569

6670

@@ -69,20 +73,24 @@ end
6973
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
7074
X_test = table([10 1])
7175

72-
model = KMeans(algo=:Hamerly, k=2)
73-
results = fit(model, 0, X)
76+
model = KMeans(algo = :Hamerly, k=2)
77+
results, cache, report = fit(model, 0, X)
7478

75-
@test results[2] == nothing
76-
@test results[end].converged == true
77-
@test results[end].totalcost == 16
79+
@test cache == nothing
80+
@test report.converged == true
81+
@test report.totalcost == 16
7882

7983
params = fitted_params(model, results)
80-
@test params.converged == true
81-
@test params.totalcost == 16
84+
@test params.cluster_centers == [1.0 10.0; 2.0 2.0]
8285

8386
# Use trained model to cluster new data X_test
8487
preds = transform(model, results, X_test)
85-
@test preds[:x1][1] == 2
88+
@test preds[:x1][1] == 82.0
89+
@test preds[:x2][1] == 1.0
90+
91+
# Make predictions on new data X_test with fitted params
92+
yhat = predict(model, results, X_test)
93+
@test yhat[1] == 2
8694
end
8795

8896

@@ -91,20 +99,24 @@ end
9199
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
92100
X_test = table([10 1])
93101

94-
model = KMeans(algo=:Elkan, k=2)
95-
results = fit(model, 0, X)
102+
model = KMeans(algo = :Elkan, k=2)
103+
results, cache, report = fit(model, 0, X)
96104

97-
@test results[2] == nothing
98-
@test results[end].converged == true
99-
@test results[end].totalcost == 16
105+
@test cache == nothing
106+
@test report.converged == true
107+
@test report.totalcost == 16
100108

101109
params = fitted_params(model, results)
102-
@test params.converged == true
103-
@test params.totalcost == 16
110+
@test params.cluster_centers == [1.0 10.0; 2.0 2.0]
104111

105112
# Use trained model to cluster new data X_test
106113
preds = transform(model, results, X_test)
107-
@test preds[:x1][1] == 2
114+
@test preds[:x1][1] == 82.0
115+
@test preds[:x2][1] == 1.0
116+
117+
# Make predictions on new data X_test with fitted params
118+
yhat = predict(model, results, X_test)
119+
@test yhat[1] == 2
108120
end
109121

110122

@@ -114,7 +126,7 @@ end
114126
X_test = table([10 1])
115127

116128
model = KMeans(k=2, max_iters=1)
117-
results = fit(model, 0, X)
129+
results, cache, report = fit(model, 0, X)
118130

119131
@test_logs (:warn, "Failed to converge. Using last assignments to make transformations.") transform(model, results, X_test)
120132
end

0 commit comments

Comments
 (0)