Skip to content

Commit fd115f4

Browse files
authored
Fewer MulAddMul branches in Diagonal-triangular mul (#1272)
This reduces TTFX in `Diagonal` - triangular multiplications. ```julia julia> using Random, LinearAlgebra julia> A = rand(4,4); U = UpperTriangular(A); D = Diagonal(A); julia> @time D * U; 0.131110 seconds (239.39 k allocations: 12.162 MiB, 99.95% compilation time) # master 0.102569 seconds (227.44 k allocations: 11.472 MiB, 99.94% compilation time) # this PR ``` If the `Diagonal` is on the right, the TTFX is almost identical, but allocations go down slightly. ```julia julia> @time U * D; 0.125025 seconds (221.82 k allocations: 11.252 MiB, 99.95% compilation time) # master 0.127002 seconds (215.59 k allocations: 10.938 MiB, 12.06% gc time, 99.95% compilation time) # this PR ```
1 parent e30c9c3 commit fd115f4

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/diagonal.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -470,19 +470,19 @@ function __muldiag_nonzeroalpha!(out, D::Diagonal, B::UpperOrLowerTriangular, al
470470
for j in axes(B, 2)
471471
# store the diagonal separately for unit triangular matrices
472472
if isunit
473-
@inbounds @stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[j] * B[j,j], out, (j,j))
473+
@inbounds _modify_nonzeroalpha!(D.diag[j] * B[j,j], out, (j,j), alpha, beta)
474474
end
475475
# The indices of out corresponding to the stored indices of B
476476
rowrange = _rowrange_tri_stored(B, j)
477477
@inbounds @simd for i in rowrange
478-
@stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j))
478+
_modify_nonzeroalpha!(D.diag[i] * B_maybeparent[i,j], out_maybeparent, (i,j), alpha, beta)
479479
end
480480
# Fill the indices of out corresponding to the zeros of B
481481
# we only fill these if out and B don't have matching zeros
482482
if !_has_matching_zeros(out, B)
483483
rowrange = _rowrange_tri_zeros(B, j)
484484
@inbounds @simd for i in rowrange
485-
@stable_muladdmul _modify!(MulAddMul(alpha,beta), D.diag[i] * B[i,j], out, (i,j))
485+
_modify_nonzeroalpha!(D.diag[i] * B[i,j], out, (i,j), alpha, beta)
486486
end
487487
end
488488
end
@@ -511,7 +511,7 @@ function __muldiag_nonzeroalpha!(out, A::UpperOrLowerTriangular, D::Diagonal, al
511511
# we may directly read and write from the parents
512512
out_maybeparent, A_maybeparent = _has_matching_zeros(out, A) ? (parent(out), parent(A)) : (out, A)
513513
for j in axes(A, 2)
514-
dja = @stable_muladdmul MulAddMul(alpha,false)(@inbounds D.diag[j])
514+
dja = @inbounds _djalpha_nonzero(D.diag[j], alpha)
515515
# store the diagonal separately for unit triangular matrices
516516
if isunit
517517
# since alpha is multiplied to the diagonal element of D,
@@ -547,7 +547,7 @@ end
547547
d2 = D2.diag
548548
outd = out.diag
549549
@inbounds @simd for i in eachindex(d1, d2, outd)
550-
@stable_muladdmul _modify!(MulAddMul(alpha,beta), d1[i] * d2[i], outd, i)
550+
_modify_nonzeroalpha!(d1[i] * d2[i], outd, i, alpha, beta)
551551
end
552552
return out
553553
end

0 commit comments

Comments
 (0)