Skip to content

Commit 01fa766

Browse files
authored
MLJ interface (#140)
1 parent f477a5a commit 01fa766

File tree

19 files changed

+511
-85
lines changed

19 files changed

+511
-85
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# v0.16.4
2+
3+
Add [MLJ.jl](https://github.com/alan-turing-institute/MLJ.jl) support.
4+
15
# v0.16.3
26

37
New function: `midlife`.

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Ripserer"
22
uuid = "aa79e827-bd0b-42a8-9f10-2b302677a641"
33
authors = ["mtsch <[email protected]>"]
4-
version = "0.16.3"
4+
version = "0.16.4"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -11,6 +11,7 @@ Future = "9fa8497b-333b-5362-9e8d-4d0656e87820"
1111
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
1212
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
1415
MiniQhull = "978d7f02-9e05-4691-894f-ae31a51d76ca"
1516
PersistenceDiagrams = "90b4794c-894b-4756-a0f8-5efeb5ddf7ae"
1617
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
@@ -25,8 +26,9 @@ DataStructures = "0.17, 0.18"
2526
Distances = "0.8, 0.9, 0.10"
2627
IterTools = "1"
2728
LightGraphs = "1.3.3"
29+
MLJModelInterface = "^0.3.5"
2830
MiniQhull = "0.2"
29-
PersistenceDiagrams = "^0.8.2"
31+
PersistenceDiagrams = "0.9"
3032
ProgressMeter = "1"
3133
RecipesBase = "1"
3234
StaticArrays = "0.12, 1"
@@ -36,10 +38,11 @@ julia = "1"
3638
[extras]
3739
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3840
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
41+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
3942
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4043
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
4144
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
4245
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4346

4447
[targets]
45-
test = ["Aqua", "Documenter", "Random", "SafeTestsets", "Suppressor", "Test"]
48+
test = ["Aqua", "Documenter", "MLJBase", "Random", "SafeTestsets", "Suppressor", "Test"]

docs/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
GLMNet = "8d5ece8b-de18-5317-b113-243142960cc6"
55
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
6+
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
67
ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1"
78
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
89
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
10+
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
11+
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
912
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
1013
PersistenceDiagrams = "90b4794c-894b-4756-a0f8-5efeb5ddf7ae"
1114
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

docs/src/api.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,20 @@ Ripserer.Chain
114114
Mod
115115
```
116116

117+
## MLJ.jl Interface
118+
119+
```@docs
120+
Ripserer.RipsPersistentHomology
121+
```
122+
123+
```@docs
124+
Ripserer.AlphaPersistentHomology
125+
```
126+
127+
```@docs
128+
Ripserer.CubicalPersistentHomology
129+
```
130+
117131
## Experimental Features
118132

119133
```@docs

docs/src/examples/cubical.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ curve_plot = plot(curve; legend=false, title="Curve")
2424
# will use a small, 240×240 pixel version of the image. Ripserer should have no problems
2525
# with processing larger images, but this will work well enough for this tutorial.
2626

27-
blackhole_image = load(joinpath(
28-
@__DIR__, "../assets/data/240px-Black_hole_-_Messier_87_crop_max_res.jpg"
29-
))
27+
blackhole_image = load(
28+
joinpath(@__DIR__, "../assets/data/240px-Black_hole_-_Messier_87_crop_max_res.jpg")
29+
)
3030
blackhole_plot = plot(blackhole_image; title="Black Hole")
3131

3232
# To use the image with Ripserer, we have to convert it to grayscale.

docs/src/examples/malaria.jl

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
# # Image Classification With Cubical Filtrations and Persistence Images
1+
# # Image Classification With Cubical Persistent Homology
22

33
# In this example, we will show how to use Ripserer in an image classification
44
# context. Persistent homology is not a predictive algorithm, but it can be used to extract
55
# useful features from data.
66

7+
# ## Setting up
8+
79
using Ripserer
810
using PersistenceDiagrams
911
using Images # also required: ImageIO to read .png files
@@ -104,6 +106,8 @@ plot(plot(dim_1[end]; persistence=true), heatmap(image_1(dim_1[end]); aspect_rat
104106

105107
persims = [[vec(image_0(dim_0[i])); vec(image_1(dim_1[i]))] for i in 1:length(diagrams)]
106108

109+
# ## Fitting A Model
110+
107111
# Now it's time to fit our model. We will use
108112
# [GLMNet.jl](https://github.com/JuliaStats/GLMNet.jl) to fit a regularized linear model.
109113

@@ -137,7 +141,7 @@ nothing; # hide
137141

138142
# Get the classification accuracy.
139143

140-
accuracy = count(predictions .== test_y) / length(test_y)
144+
count(predictions .== test_y) / length(test_y)
141145

142146
# Not half bad considering we haven't touched the images and we left pretty much all
143147
# settings on default.
@@ -158,3 +162,52 @@ plot(
158162

159163
# These correspond to the area we identified at the beginning. Also note that in this case,
160164
# the classifier does not care about ``H_1`` at all.
165+
166+
# ## Using MLJ
167+
168+
# Another, more straightforward way to execute a similar pipeline is to use Ripserer's
169+
# [MLJ.jl](https://github.com/alan-turing-institute/MLJ.jl) integration. We will use a
170+
# random forest classifier for this example.
171+
172+
# We start by loading MLJ and the classifier. Not that
173+
# [MLJDecisionTreeInterface.jl](https://github.com/bensadeghi/DecisionTree.jl) needs to be
174+
# installed for this to work.
175+
176+
using MLJ
177+
tree = @load RandomForestClassifier pkg = "DecisionTree" verbosity = 0
178+
179+
# We create a pipeline of `CubicalPersistentHomology` followed by the classifier. In this
180+
# case, `CubicalPersistentHomology` takes care of both the homology computation and the
181+
# conversion to persistence images.
182+
183+
pipe = @pipeline(CubicalPersistentHomology(), tree)
184+
185+
# We train the pipeline the same way you would fit any other MLJ model. Remember, we need to
186+
# use grayscale versions of images stored in `inputs`.
187+
188+
classes = coerce(classes, Binary)
189+
train, test = partition(eachindex(classes), 0.7; shuffle=true, rng=1337)
190+
mach = machine(pipe, inputs, classes)
191+
fit!(mach; rows=train)
192+
193+
# Next, we predict the classes on the test data and print out the classification accuracy.
194+
195+
yhat = predict_mode(mach, inputs[test])
196+
accuracy(yhat, classes[test])
197+
198+
# The result is quite a bit worse than before. We can try mitigating that by using a
199+
# different vectorizer.
200+
201+
pipe.cubical_persistent_homology.vectorizer = PersistenceCurveVectorizer()
202+
mach = machine(pipe, inputs, classes)
203+
fit!(mach; rows=train)
204+
205+
yhat = predict_mode(mach, inputs[test])
206+
accuracy(yhat, classes[test])
207+
208+
# The result could be improved further by choosing a different model and
209+
# vectorizer. However, this is just a short introduction. Please see the [MLJ.jl
210+
# documentation](https://alan-turing-institute.github.io/MLJ.jl/dev/) for more information
211+
# on model tuning and selection, and the [PersistenceDiagrams.jl
212+
# documentation](https://mtsch.github.io/PersistenceDiagrams.jl/dev/mlj/) for a list of
213+
# vectorizers and their options.

docs/src/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Ripserer and its companion package
4848
* Various persistence diagram vectorization functions, implemented with persistence images
4949
and persistence curves.
5050
* Easy extensibility through a documented API.
51+
* Integration with [MLJ.jl](https://github.com/alan-turing-institute/MLJ.jl).
5152
* Experimental shortest representative cycle computation.
5253
* Experimental sparse circular coordinate computation.
5354

src/Ripserer.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ using RecipesBase
2525
using StaticArrays
2626
using TupleTools
2727

28+
import MLJModelInterface
29+
2830
# This functionality is imported to avoid having to deal with name clashes. There is no
2931
# piracy involved here.
3032
import LightGraphs: vertices, edges, nv, adjacency_matrix
@@ -51,7 +53,10 @@ export Mod,
5153
ripserer,
5254
reconstruct_cycle,
5355
Partition,
54-
CircularCoordinates
56+
CircularCoordinates,
57+
RipsPersistentHomology,
58+
AlphaPersistentHomology,
59+
CubicalPersistentHomology
5560

5661
include("base/primefield.jl")
5762
include("base/abstractcell.jl")
@@ -77,5 +82,6 @@ include("filtrations/edgecollapse.jl")
7782

7883
include("extra/cycles.jl")
7984
include("extra/circularcoordinates.jl")
85+
include("extra/mlj.jl")
8086

8187
end

src/extra/cycles.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,13 @@ function reconstruct_cycle(
135135
distances=distance_matrix(filtration),
136136
) where {T}
137137
if !hasproperty(interval, :representative)
138-
throw(ArgumentError("interval has no representative! Run `ripserer` with `reps=true`"))
138+
throw(
139+
ArgumentError("interval has no representative! Run `ripserer` with `reps=true`")
140+
)
139141
elseif !(eltype(interval.representative) <: AbstractChainElement{<:AbstractCell{1}})
140-
throw(ArgumentError("cycles can only be reconstructed for 1-dimensional intervals."))
142+
throw(
143+
ArgumentError("cycles can only be reconstructed for 1-dimensional intervals.")
144+
)
141145
elseif !(birth(interval) _birth_or_value(r) < death(interval))
142146
return simplex_type(filtration, 1)[]
143147
else

0 commit comments

Comments
 (0)