Skip to content

Commit ace8c16

Browse files
authored
Implement generic_mul! for Julia 1.12 (#639)
1 parent a2ca0a6 commit ace8c16

File tree

2 files changed

+55
-4
lines changed

2 files changed

+55
-4
lines changed

src/host/linalg.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,21 @@ end
417417
function LinearAlgebra.generic_matmatmul_wrapper!(C::AbstractGPUMatrix{T}, tA::AbstractChar, tB::AbstractChar, A::AbstractGPUVecOrMat{T}, B::AbstractGPUVecOrMat{T}, alpha::Number, beta::Number, val::LinearAlgebra.BlasFlag.SyrkHerkGemm) where {T}
418418
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta)
419419
end
420+
# Julia 1.12 introduced generic_mul! for scalar * array operations
421+
function LinearAlgebra.generic_mul!(C::AbstractGPUVecOrMat, X::AbstractGPUVecOrMat, s::Number, alpha::Number, beta::Number)
422+
if length(C) != length(X)
423+
throw(DimensionMismatch(lazy"first array has length $(length(C)) which does not match the length of the second, $(length(X))."))
424+
end
425+
@. C = X * s * alpha + C * beta
426+
return C
427+
end
428+
function LinearAlgebra.generic_mul!(C::AbstractGPUVecOrMat, s::Number, X::AbstractGPUVecOrMat, alpha::Number, beta::Number)
429+
if length(C) != length(X)
430+
throw(DimensionMismatch(lazy"first array has length $(length(C)) which does not match the length of the second, $(length(X))."))
431+
end
432+
@. C = s * X * alpha + C * beta
433+
return C
434+
end
420435
end
421436

422437
function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R}
@@ -730,7 +745,7 @@ function LinearAlgebra.rotate!(x::AbstractGPUArray, y::AbstractGPUArray, c::Numb
730745
@inbounds xi = x[i]
731746
@inbounds yi = y[i]
732747
@inbounds x[i] = s*yi + c *xi
733-
@inbounds y[i] = c*yi - conj(s)*xi
748+
@inbounds y[i] = c*yi - conj(s)*xi
734749
end
735750
rotate_kernel!(get_backend(x))(x, y, c, s; ndrange = size(x))
736751
return x, y

test/testsuite/linalg.jl

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,17 +256,17 @@
256256
B = Diagonal(b)
257257
A = Diagonal(a)
258258
mul!(C, A, B)
259-
@test collect(C.diag) collect(A.diag) .* collect(B.diag)
259+
@test collect(C.diag) collect(A.diag) .* collect(B.diag)
260260
a = AT(diagm(rand(elty, n)))
261261
b = AT(diagm(rand(elty, n)))
262262
C = Diagonal(d)
263263
mul!(C, a, b)
264-
@test collect(C) Diagonal(collect(a) * collect(b))
264+
@test collect(C) Diagonal(collect(a) * collect(b))
265265
a = transpose(AT(diagm(rand(elty, n))))
266266
b = adjoint(AT(diagm(rand(elty, n))))
267267
C = Diagonal(d)
268268
mul!(C, a, b)
269-
@test collect(C) Diagonal(collect(a) * collect(b))
269+
@test collect(C) Diagonal(collect(a) * collect(b))
270270
end
271271
end
272272

@@ -303,6 +303,42 @@
303303
end
304304
end
305305

306+
@testset "mul! + UniformScaling" begin
307+
for elty in (Float32, ComplexF32)
308+
n = 128
309+
s = rand(elty)
310+
I_s = UniformScaling(s)
311+
312+
# Test vector operations
313+
a = AT(rand(elty, n))
314+
b = AT(rand(elty, n))
315+
b_copy = copy(b)
316+
317+
# Test mul!(a, I*s, b) - should compute a = s * b
318+
mul!(a, I_s, b)
319+
@test collect(a) s .* collect(b_copy)
320+
321+
# Test mul!(a, b, s) - should compute a = b * s
322+
a = AT(rand(elty, n))
323+
mul!(a, b, s)
324+
@test collect(a) collect(b_copy) .* s
325+
326+
# Test matrix operations
327+
A = AT(rand(elty, n, n))
328+
B = AT(rand(elty, n, n))
329+
B_copy = copy(B)
330+
331+
# Test mul!(A, I*s, B)
332+
mul!(A, I_s, B)
333+
@test collect(A) s .* collect(B_copy)
334+
335+
# Test mul!(A, B, s)
336+
A = AT(rand(elty, n, n))
337+
mul!(A, B, s)
338+
@test collect(A) collect(B_copy) .* s
339+
end
340+
end
341+
306342
@testset "lmul! and rmul!" for (a,b) in [((3,4),(4,3)), ((3,), (1,3)), ((1,3), (3))], T in eltypes
307343
@test compare(rmul!, AT, rand(T, a), Ref(rand(T)))
308344
@test compare(lmul!, AT, Ref(rand(T)), rand(T, b))

0 commit comments

Comments
 (0)