Skip to content

Commit 8c48120

Browse files
authored
[Nonlinear] add amontoison's IntDisjointSet type and remove DataStructures.jl (#2885)
1 parent 034a527 commit 8c48120

File tree

6 files changed

+86
-27
lines changed

6 files changed

+86
-27
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ version = "1.46.0"
66
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
77
CodecBzip2 = "523fee87-0ab8-5b00-afb7-3ecf72e48cfd"
88
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"
9-
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
109
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1110
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
1211
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -23,7 +22,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2322
BenchmarkTools = "1"
2423
CodecBzip2 = "0.6, 0.7, 0.8"
2524
CodecZlib = "0.6, 0.7"
26-
DataStructures = "0.18, 0.19"
2725
ForwardDiff = "0.10, 1"
2826
JSON3 = "1"
2927
JSONSchema = "1"

src/FileFormats/MPS/MPS.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module MPS
88

99
import ..FileFormats
1010
import MathOptInterface as MOI
11-
import DataStructures: OrderedDict
11+
import OrderedCollections: OrderedDict
1212

1313
const IndicatorLessThanTrue{T} =
1414
MOI.Indicator{MOI.ACTIVATE_ON_ONE,MOI.LessThan{T}}

src/Nonlinear/ReverseAD/Coloring/Coloring.jl

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
module Coloring
88

9-
import DataStructures
10-
9+
include("IntDisjointSet.jl")
1110
include("topological_sort.jl")
1211

1312
"""
@@ -154,7 +153,7 @@ function _prevent_cycle(
154153
forbiddenColors,
155154
color,
156155
)
157-
er = DataStructures.find_root!(S, e_idx2)
156+
er = _find_root!(S, e_idx2)
158157
@inbounds first = firstVisitToTree[er]
159158
p = first.source # but this depends on the order?
160159
q = first.target
@@ -172,29 +171,18 @@ function _grow_star(v, w, e_idx, firstNeighbor, color, S)
172171
@inbounds if p != v
173172
firstNeighbor[color[w]] = _Edge(e_idx, v, w)
174173
else
175-
union!(S, e_idx, e.index)
174+
_root_union!(S, e_idx, e.index)
176175
end
177176
return
178177
end
179178

180-
function _merge_trees(eg, eg1, S)
181-
e1 = DataStructures.find_root!(S, eg)
182-
e2 = DataStructures.find_root!(S, eg1)
183-
if e1 != e2
184-
union!(S, eg, eg1)
179+
function _merge_trees(S::_IntDisjointSet, eg::Int, eg1::Int)
180+
if _find_root!(S, eg) != _find_root!(S, eg1)
181+
_root_union!(S, eg, eg1)
185182
end
186183
return
187184
end
188185

189-
# Work-around a deprecation in [email protected]
190-
function _IntDisjointSet(n)
191-
@static if isdefined(DataStructures, :IntDisjointSet)
192-
return DataStructures.IntDisjointSet(n)
193-
else
194-
return DataStructures.IntDisjointSets(n) # COV_EXCL_LINE
195-
end
196-
end
197-
198186
"""
199187
acyclic_coloring(g::UndirectedGraph)
200188
@@ -214,7 +202,6 @@ function acyclic_coloring(g::UndirectedGraph)
214202
firstNeighbor = _Edge[]
215203
firstVisitToTree = fill(_Edge(0, 0, 0), _num_edges(g))
216204
color = fill(0, _num_vertices(g))
217-
# disjoint set forest of edges in the graph
218205
S = _IntDisjointSet(_num_edges(g))
219206
@inbounds for v in 1:_num_vertices(g)
220207
n_neighbor = _num_neighbors(v, g)
@@ -293,7 +280,7 @@ function acyclic_coloring(g::UndirectedGraph)
293280
continue
294281
end
295282
if color[x] == color[v]
296-
_merge_trees(e_idx, e2_idx, S)
283+
_merge_trees(S, e_idx, e2_idx)
297284
end
298285
end
299286
end
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) 2017: Miles Lubin and contributors
2+
# Copyright (c) 2017: Google Inc.
3+
# Copyright (c) 2024: Guillaume Dalle and Alexis Montoison
4+
#
5+
# Use of this source code is governed by an MIT-style license that can be found
6+
# in the LICENSE.md file or at https://opensource.org/licenses/MIT.
7+
8+
# The code in this file was taken from
9+
# https://github.com/gdalle/SparseMatrixColorings.jl/blob/main/src/Forest.jl
10+
#
11+
# It was copied at the suggestion of Alexis in his JuMP-dev 2025 talk.
12+
#
13+
# @odow made minor changes to match MOI coding styles.
14+
#
15+
# x-ref https://github.com/gdalle/SparseMatrixColorings.jl/pull/190
16+
17+
mutable struct _IntDisjointSet
18+
# current number of distinct trees in the S
19+
number_of_trees::Int
20+
# vector storing the index of a parent in the tree for each edge, used in
21+
# union-find operations
22+
parents::Vector{Int}
23+
# vector approximating the depth of each tree to optimize path compression
24+
ranks::Vector{Int}
25+
26+
_IntDisjointSet(n::Integer) = new(n, collect(1:n), zeros(Int, n))
27+
end
28+
29+
function _find_root!(S::_IntDisjointSet, x::Integer)
30+
p = S.parents[x]
31+
if S.parents[p] != p
32+
S.parents[x] = p = _find_root!(S, p)
33+
end
34+
return p
35+
end
36+
37+
function _root_union!(S::_IntDisjointSet, x::Int, y::Int)
38+
rank1, rank2 = S.ranks[x], S.ranks[y]
39+
if rank1 < rank2
40+
x, y = y, x
41+
elseif rank1 == rank2
42+
S.ranks[x] += 1
43+
end
44+
S.parents[y] = x
45+
S.number_of_trees -= 1
46+
return
47+
end

test/FileFormats/MPS/MPS.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using Test
1010

1111
import MathOptInterface as MOI
1212
import MathOptInterface.FileFormats: MPS
13-
import DataStructures: OrderedDict
13+
import OrderedCollections: OrderedDict
1414

1515
function runtests()
1616
for name in names(@__MODULE__; all = true)

test/Nonlinear/ReverseAD.jl

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ import LinearAlgebra
1111
import MathOptInterface as MOI
1212
import SparseArrays
1313

14-
const Nonlinear = MOI.Nonlinear
15-
const ReverseAD = Nonlinear.ReverseAD
16-
const Coloring = ReverseAD.Coloring
14+
import MathOptInterface.Nonlinear
15+
import MathOptInterface.Nonlinear.ReverseAD
16+
import MathOptInterface.Nonlinear.ReverseAD.Coloring
1717

1818
function runtests()
1919
for name in names(@__MODULE__; all = true)
@@ -1421,6 +1421,33 @@ function test_hessian_reinterpret_unsafe()
14211421
return
14221422
end
14231423

1424+
function test_IntDisjointSet()
1425+
for case in [
1426+
[(1, 2) => [1, 1, 3], (1, 3) => [1, 1, 1]],
1427+
[(1, 2) => [1, 1, 3], (3, 1) => [1, 1, 1]],
1428+
[(2, 1) => [2, 2, 3], (1, 3) => [2, 2, 2]],
1429+
[(2, 1) => [2, 2, 3], (3, 1) => [3, 2, 3]],
1430+
[(1, 3) => [1, 2, 1], (2, 3) => [1, 2, 2]],
1431+
[(1, 3) => [1, 2, 1], (3, 2) => [1, 1, 1]],
1432+
[(3, 1) => [3, 2, 3], (2, 3) => [3, 3, 3]],
1433+
[(3, 1) => [3, 2, 3], (3, 2) => [3, 3, 3]],
1434+
[(2, 3) => [1, 2, 2], (1, 3) => [1, 2, 1]],
1435+
[(2, 3) => [1, 2, 2], (3, 1) => [2, 2, 2]],
1436+
[(3, 2) => [1, 3, 3], (1, 3) => [3, 3, 3]],
1437+
[(3, 2) => [1, 3, 3], (3, 1) => [3, 3, 3]],
1438+
]
1439+
S = Coloring._IntDisjointSet(3)
1440+
@test Coloring._find_root!.((S,), [1, 2, 3]) == [1, 2, 3]
1441+
@test S.number_of_trees == 3
1442+
for (i, (union, result)) in enumerate(case)
1443+
Coloring._root_union!(S, union[1], union[2])
1444+
@test Coloring._find_root!.((S,), [1, 2, 3]) == result
1445+
@test S.number_of_trees == 3 - i
1446+
end
1447+
end
1448+
return
1449+
end
1450+
14241451
end # module
14251452

14261453
TestReverseAD.runtests()

0 commit comments

Comments
 (0)