From cdbc2f804de97727dec7dc641861f11c04f52eee Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Mon, 20 Feb 2023 10:37:03 +0100 Subject: [PATCH] Avoid cartesian iteration where possible. --- src/host/broadcast.jl | 47 +++++++++++++++++++++++++++++++++++-------- src/host/math.jl | 2 +- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/src/host/broadcast.jl b/src/host/broadcast.jl index 9a44fb83..0ba3bf96 100644 --- a/src/host/broadcast.jl +++ b/src/host/broadcast.jl @@ -51,22 +51,53 @@ end bc′ = Broadcast.preprocess(dest, bc) # grid-stride kernel - function broadcast_kernel(ctx, dest, bc′, nelem) - i = 0 - while i < nelem - i += 1 - I = @cartesianidx(dest, i) - @inbounds dest[I] = bc′[I] + function broadcast_kernel(ctx, dest, ::Val{Is}, bc′, nelem) where Is + j = 0 + while j < nelem + j += 1 + + i = @linearidx(dest, j) + + # cartesian indexing is slow, so avoid it if possible + if isa(IndexStyle(dest), IndexCartesian) || isa(IndexStyle(bc′), IndexCartesian) + # this performs an integer division, which is expensive. to make it possible + # for the compiler to optimize it away, we put the iterator in the type + # domain so that the indices are available at compile time. note that LLVM + # only seems to replace pow2 divisions (with bitshifts), but other back-ends + # may be smarter and replace arbitrary divisions by bit operations. + # + # also see maleadt/StaticCartesian.jl, which implements this in Julia, + # but does not result in an additional speed-up on tested back-ends. + # + # in addition, we use @inbounds to avoid bounds checks, but we also need to + # inform the compiler about the bounds that we are assuming. this is done + # using the assume intrinsic, and in case of Metal yields a 8x speed-up. + assume(1 <= i <= length(Is)) + I = @inbounds Is[i] + end + + val = if isa(IndexStyle(bc′), IndexCartesian) + @inbounds bc′[I] + else + @inbounds bc′[i] + end + + if isa(IndexStyle(dest), IndexCartesian) + @inbounds dest[I] = val + else + @inbounds dest[i] = val + end end return end elements = length(dest) elements_per_thread = typemax(Int) - heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc′, 1; + Is = CartesianIndices(dest) + heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, Val(Is), bc′, 1; elements, elements_per_thread) config = launch_configuration(backend(dest), heuristic; elements, elements_per_thread) - gpu_call(broadcast_kernel, dest, bc′, config.elements_per_thread; + gpu_call(broadcast_kernel, dest, Val(Is), bc′, config.elements_per_thread; threads=config.threads, blocks=config.blocks) return dest diff --git a/src/host/math.jl b/src/host/math.jl index 8d02c97f..cf455d31 100644 --- a/src/host/math.jl +++ b/src/host/math.jl @@ -2,7 +2,7 @@ function Base.clamp!(A::AnyGPUArray, low, high) gpu_call(A, low, high) do ctx, A, low, high - I = @cartesianidx A + I = @linearidx A A[I] = clamp(A[I], low, high) return end