Skip to content

Commit 88dba5d

Browse files
authored
setindex! with BandIndex (#1259)
Constant-propagation of the band index would allow eliminating the branches in `setindex!` for structured matrices. These parallel the similar `getindex` definitions that already exist.
1 parent 8e6dcfb commit 88dba5d

File tree

6 files changed

+98
-1
lines changed

6 files changed

+98
-1
lines changed

src/bidiag.jl

+13
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,19 @@ end
178178
return A
179179
end
180180

181+
@inline function setindex!(A::Bidiagonal, x, b::BandIndex)
182+
@boundscheck checkbounds(A, b)
183+
if b.band == 0
184+
@inbounds A.dv[b.index] = x
185+
elseif b.band (-1,1) && b.band == _offdiagind(A.uplo)
186+
@inbounds A.ev[b.index] = x
187+
elseif !iszero(x)
188+
throw(ArgumentError(LazyString(lazy"cannot set entry $(to_indices(A, (b,))) off the ",
189+
A.uplo == 'U' ? "upper" : "lower", " bidiagonal band to a nonzero value ", x)))
190+
end
191+
return A
192+
end
193+
181194
Base._reverse(A::Bidiagonal, dims) = reverse!(Matrix(A); dims)
182195
Base._reverse(A::Bidiagonal, ::Colon) = Bidiagonal(reverse(A.dv), reverse(A.ev), A.uplo == 'U' ? :L : :U)
183196

src/diagonal.jl

+10-1
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ zeroslike(::Type{M}, sz::Tuple{Integer, Vararg{Integer}}) where {M<:AbstractMatr
218218
r
219219
end
220220

221-
function setindex!(D::Diagonal, v, i::Int, j::Int)
221+
@inline function setindex!(D::Diagonal, v, i::Int, j::Int)
222222
@boundscheck checkbounds(D, i, j)
223223
if i == j
224224
@inbounds D.diag[i] = v
@@ -228,6 +228,15 @@ function setindex!(D::Diagonal, v, i::Int, j::Int)
228228
return D
229229
end
230230

231+
@inline function setindex!(D::Diagonal, v, b::BandIndex)
232+
@boundscheck checkbounds(D, b)
233+
if b.band == 0
234+
@inbounds D.diag[b.index] = v
235+
elseif !iszero(v)
236+
throw(ArgumentError(lazy"cannot set off-diagonal entry $(to_indices(D, (b,))) to a nonzero value ($v)"))
237+
end
238+
return D
239+
end
231240

232241
## structured matrix methods ##
233242
function Base.replace_in_print_matrix(A::Diagonal,i::Integer,j::Integer,s::AbstractString)

src/tridiag.jl

+26
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,17 @@ Base._reverse!(A::SymTridiagonal, dims::Colon) = (reverse!(A.dv); reverse!(A.ev)
508508
return A
509509
end
510510

511+
@inline function setindex!(A::SymTridiagonal, x, b::BandIndex)
512+
@boundscheck checkbounds(A, b)
513+
if b.band == 0
514+
issymmetric(x) || throw(ArgumentError("cannot set a diagonal entry of a SymTridiagonal to an asymmetric value"))
515+
@inbounds A.dv[b.index] = x
516+
else
517+
throw(ArgumentError(lazy"cannot set off-diagonal entry $(to_indices(A, (b,)))"))
518+
end
519+
return A
520+
end
521+
511522
## Tridiagonal matrices ##
512523
struct Tridiagonal{T,V<:AbstractVector{T}} <: AbstractMatrix{T}
513524
dl::V # sub-diagonal
@@ -775,6 +786,21 @@ end
775786
return A
776787
end
777788

789+
@inline function setindex!(A::Tridiagonal, x, b::BandIndex)
790+
@boundscheck checkbounds(A, b)
791+
if b.band == 0
792+
@inbounds A.d[b.index] = x
793+
elseif b.band == -1
794+
@inbounds A.dl[b.index] = x
795+
elseif b.band == 1
796+
@inbounds A.du[b.index] = x
797+
elseif !iszero(x)
798+
throw(ArgumentError(LazyString(lazy"cannot set entry $(to_indices(A, (b,))) off ",
799+
lazy"the tridiagonal band to a nonzero value ($x)")))
800+
end
801+
return A
802+
end
803+
778804
## structured matrix methods ##
779805
function Base.replace_in_print_matrix(A::Tridiagonal,i::Integer,j::Integer,s::AbstractString)
780806
i==j-1||i==j||i==j+1 ? s : Base.replace_with_centered_mark(s)

test/bidiag.jl

+17
Original file line numberDiff line numberDiff line change
@@ -1205,4 +1205,21 @@ end
12051205
@test rmul!(B, D) == B2
12061206
end
12071207

1208+
@testset "setindex! with BandIndex" begin
1209+
B = Bidiagonal(zeros(3), zeros(2), :U)
1210+
B[LinearAlgebra.BandIndex(0,2)] = 1
1211+
@test B[2,2] == 1
1212+
B[LinearAlgebra.BandIndex(1,1)] = 2
1213+
@test B[1,2] == 2
1214+
@test_throws "cannot set entry $((1,3)) off the upper bidiagonal band" B[LinearAlgebra.BandIndex(2,1)] = 2
1215+
1216+
B = Bidiagonal(zeros(3), zeros(2), :L)
1217+
B[LinearAlgebra.BandIndex(-1,1)] = 2
1218+
@test B[2,1] == 2
1219+
@test_throws "cannot set entry $((3,1)) off the lower bidiagonal band" B[LinearAlgebra.BandIndex(-2,1)] = 2
1220+
1221+
@test_throws BoundsError B[LinearAlgebra.BandIndex(size(B,1),1)]
1222+
@test_throws BoundsError B[LinearAlgebra.BandIndex(0,size(B,1)+1)]
1223+
end
1224+
12081225
end # module TestBidiagonal

test/diagonal.jl

+9
Original file line numberDiff line numberDiff line change
@@ -1489,4 +1489,13 @@ end
14891489
@test !isreal(im*D)
14901490
end
14911491

1492+
@testset "setindex! with BandIndex" begin
1493+
D = Diagonal(zeros(2))
1494+
D[LinearAlgebra.BandIndex(0,2)] = 1
1495+
@test D[2,2] == 1
1496+
@test_throws "cannot set off-diagonal entry $((1,2))" D[LinearAlgebra.BandIndex(1,1)] = 1
1497+
@test_throws BoundsError D[LinearAlgebra.BandIndex(size(D,1),1)]
1498+
@test_throws BoundsError D[LinearAlgebra.BandIndex(0,size(D,1)+1)]
1499+
end
1500+
14921501
end # module TestDiagonal

test/tridiag.jl

+23
Original file line numberDiff line numberDiff line change
@@ -1184,4 +1184,27 @@ end
11841184
@test convert(SymTridiagonal, S) == S
11851185
end
11861186

1187+
@testset "setindex! with BandIndex" begin
1188+
T = Tridiagonal(zeros(3), zeros(4), zeros(3))
1189+
T[LinearAlgebra.BandIndex(0,2)] = 1
1190+
@test T[2,2] == 1
1191+
T[LinearAlgebra.BandIndex(1,2)] = 2
1192+
@test T[2,3] == 2
1193+
T[LinearAlgebra.BandIndex(-1,2)] = 3
1194+
@test T[3,2] == 3
1195+
1196+
@test_throws "cannot set entry $((1,3)) off the tridiagonal band" T[LinearAlgebra.BandIndex(2,1)] = 1
1197+
@test_throws "cannot set entry $((3,1)) off the tridiagonal band" T[LinearAlgebra.BandIndex(-2,1)] = 1
1198+
@test_throws BoundsError T[LinearAlgebra.BandIndex(size(T,1),1)]
1199+
@test_throws BoundsError T[LinearAlgebra.BandIndex(0,size(T,1)+1)]
1200+
1201+
S = SymTridiagonal(zeros(4), zeros(3))
1202+
S[LinearAlgebra.BandIndex(0,2)] = 1
1203+
@test S[2,2] == 1
1204+
1205+
@test_throws "cannot set off-diagonal entry $((1,3))" S[LinearAlgebra.BandIndex(2,1)] = 1
1206+
@test_throws BoundsError S[LinearAlgebra.BandIndex(size(S,1),1)]
1207+
@test_throws BoundsError S[LinearAlgebra.BandIndex(0,size(S,1)+1)]
1208+
end
1209+
11871210
end # module TestTridiagonal

0 commit comments

Comments
 (0)