Skip to content

Commit f0e36ea

Browse files
authored
Simplify logic in Diagonal-Tridiagonal multiplication (#1338)
This improves performance in small `::Diagonal * ::Tridiagonal` multiplications: ```julia julia> D = Diagonal(1:3); T = Tridiagonal(1:2, 1:3, 1:2); C = similar(T); julia> @Btime mul!($C, $D, $T); 43.435 ns (0 allocations: 0 bytes) # master 14.725 ns (0 allocations: 0 bytes) # this PR ``` This also improves TTFX ```julia julia> @time mul!(C, D, T); 0.192970 seconds (451.55 k allocations: 22.211 MiB, 99.99% compilation time) # master 0.129910 seconds (273.38 k allocations: 13.351 MiB, 99.98% compilation time) # this PR ```
1 parent 7f354f4 commit f0e36ea

File tree

1 file changed

+10
-22
lines changed

1 file changed

+10
-22
lines changed

src/bidiag.jl

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,39 +1186,27 @@ function _dibimul!(C, A, B, _add)
11861186
end
11871187
function _dibimul_nonzeroalpha!(C, A, B, _add)
11881188
n = size(A,1)
1189-
if n <= 3
1190-
# For simplicity, use a naive multiplication for small matrices
1191-
# that loops over all elements.
1192-
for I in CartesianIndices(C)
1193-
C[I] += _add(A.diag[I[1]] * B[I[1], I[2]])
1194-
end
1195-
return C
1196-
end
11971189
Ad = A.diag
11981190
Bl = _diag(B, -1)
11991191
Bd = _diag(B, 0)
12001192
Bu = _diag(B, 1)
12011193
@inbounds begin
12021194
# first row of C
1203-
C[1,1] += _add(A[1,1]*B[1,1])
1204-
C[1,2] += _add(A[1,1]*B[1,2])
1205-
# second row of C
1206-
C[2,1] += _add(A[2,2]*B[2,1])
1207-
C[2,2] += _add(A[2,2]*B[2,2])
1208-
C[2,3] += _add(A[2,2]*B[2,3])
1209-
for j in 3:n-2
1195+
C[1,1] += _add(Ad[1]*Bd[1])
1196+
if n >= 2
1197+
C[1,2] += _add(Ad[1]*Bu[1])
1198+
end
1199+
for j in 2:n-1
12101200
Ajj = Ad[j]
12111201
C[j, j-1] += _add(Ajj*Bl[j-1])
12121202
C[j, j ] += _add(Ajj*Bd[j])
12131203
C[j, j+1] += _add(Ajj*Bu[j])
12141204
end
1215-
# row before last of C
1216-
C[n-1,n-2] += _add(A[n-1,n-1]*B[n-1,n-2])
1217-
C[n-1,n-1] += _add(A[n-1,n-1]*B[n-1,n-1])
1218-
C[n-1,n ] += _add(A[n-1,n-1]*B[n-1,n ])
1219-
# last row of C
1220-
C[n,n-1] += _add(A[n,n]*B[n,n-1])
1221-
C[n,n ] += _add(A[n,n]*B[n,n ])
1205+
if n >= 2
1206+
# last row of C
1207+
C[n,n-1] += _add(Ad[n]*Bl[n-1])
1208+
C[n,n ] += _add(Ad[n]*Bd[n])
1209+
end
12221210
end # inbounds
12231211
C
12241212
end

0 commit comments

Comments
 (0)