Skip to content

Commit b4fd26c

Browse files
committed
Extend threaded macro to use shared memory
1 parent a2bf168 commit b4fd26c

File tree

6 files changed

+600
-48
lines changed

6 files changed

+600
-48
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.6.7"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
99
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
10+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1011

1112
[weakdeps]
1213
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -17,9 +18,10 @@ ClimaCommsCUDAExt = "CUDA"
1718
ClimaCommsMPIExt = "MPI"
1819

1920
[compat]
20-
CUDA = "3, 4, 5"
2121
Adapt = "3, 4"
22+
CUDA = "3, 4, 5"
2223
Logging = "1.9.4"
2324
LoggingExtras = "1.1.0"
2425
MPI = "0.20.18"
26+
StaticArrays = "1.9"
2527
julia = "1.9"

docs/Manifest.toml

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
julia_version = "1.11.0"
3+
julia_version = "1.11.4"
44
manifest_format = "2.0"
55
project_hash = "d60839f726bd9115791d1a0807a21b61938765a9"
66

@@ -19,13 +19,11 @@ deps = ["LinearAlgebra", "Requires"]
1919
git-tree-sha1 = "50c3c56a52972d78e8be9fd135bfb91c9574c140"
2020
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
2121
version = "4.1.1"
22+
weakdeps = ["StaticArrays"]
2223

2324
[deps.Adapt.extensions]
2425
AdaptStaticArraysExt = "StaticArrays"
2526

26-
[deps.Adapt.weakdeps]
27-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
28-
2927
[[deps.ArgTools]]
3028
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
3129
version = "1.1.2"
@@ -39,10 +37,10 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
3937
version = "1.11.0"
4038

4139
[[deps.ClimaComms]]
42-
deps = ["Adapt", "Logging", "LoggingExtras"]
40+
deps = ["Adapt", "Logging", "LoggingExtras", "StaticArrays"]
4341
path = ".."
4442
uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
45-
version = "0.6.5"
43+
version = "0.6.7"
4644

4745
[deps.ClimaComms.extensions]
4846
ClimaCommsCUDAExt = "CUDA"
@@ -361,6 +359,25 @@ version = "1.11.0"
361359
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
362360
version = "1.11.0"
363361

362+
[[deps.StaticArrays]]
363+
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
364+
git-tree-sha1 = "0feb6b9031bd5c51f9072393eb5ab3efd31bf9e4"
365+
uuid = "90137ffa-7385-5640-81b9-e52037218182"
366+
version = "1.9.13"
367+
368+
[deps.StaticArrays.extensions]
369+
StaticArraysChainRulesCoreExt = "ChainRulesCore"
370+
StaticArraysStatisticsExt = "Statistics"
371+
372+
[deps.StaticArrays.weakdeps]
373+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
374+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
375+
376+
[[deps.StaticArraysCore]]
377+
git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682"
378+
uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
379+
version = "1.4.3"
380+
364381
[[deps.StyledStrings]]
365382
uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b"
366383
version = "1.11.0"

docs/src/apis.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,18 @@ ClimaComms.device
2828
ClimaComms.device_functional
2929
ClimaComms.array_type
3030
ClimaComms.allowscalar
31-
ClimaComms.@threaded
3231
ClimaComms.@time
3332
ClimaComms.@elapsed
3433
ClimaComms.@assert
3534
ClimaComms.@sync
3635
ClimaComms.@cuda_sync
3736
Adapt.adapt_structure(::Type{<:AbstractArray}, ::ClimaComms.AbstractDevice)
37+
ClimaComms.@threaded
38+
ClimaComms.@interdependent
39+
ClimaComms.InterdependentIteratorData
40+
ClimaComms.@sync_interdependent
41+
ClimaComms.synchronize_gpu_threads
42+
ClimaComms.static_shared_memory_array
3843
```
3944

4045
## Contexts

ext/ClimaCommsCUDAExt.jl

Lines changed: 165 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import CUDA
55
import Adapt
66
import ClimaComms
77
import ClimaComms: CUDADevice, threaded
8+
import ClimaComms: OneInterdependentItem, MultipleInterdependentItems
89

910
function ClimaComms._assign_device(::CUDADevice, rank_number)
1011
CUDA.device!(rank_number % CUDA.ndevices())
@@ -50,17 +51,29 @@ ClimaComms.elapsed(f::F, ::CUDADevice, args...; kwargs...) where {F} =
5051
ClimaComms.assert(::CUDADevice, cond::C, text::T) where {C, T} =
5152
isnothing(text) ? (CUDA.@cuassert cond()) : (CUDA.@cuassert cond() text())
5253

53-
# TODO: Generalize all of the following code to multi-dimensional thread blocks
54-
# and multiple iterators.
54+
ClimaComms.synchronize_gpu_threads(::CUDADevice) = CUDA.sync_threads()
5555

56-
# The number of threads in the kernel being executed by the calling thread.
57-
threads_in_kernel() = CUDA.blockDim().x * CUDA.gridDim().x
56+
ClimaComms.static_shared_memory_array(
57+
::CUDADevice,
58+
::Type{T},
59+
dims...,
60+
) where {T} = CUDA.CuStaticSharedArray(T, dims)
61+
62+
# Number of blocks in kernel being executed and index of calling thread's block.
63+
blocks_in_kernel() = CUDA.gridDim().x
64+
block_idx_in_kernel() = CUDA.blockIdx().x
5865

59-
# The index of the calling thread, which is between 1 and threads_in_kernel().
60-
thread_index() =
66+
# Number of threads in each block of kernel being executed and index of calling
67+
# thread within its block.
68+
threads_in_block() = CUDA.blockDim().x
69+
thread_idx_in_block() = CUDA.threadIdx().x
70+
71+
# Total number of threads in kernel being executed and index of calling thread.
72+
threads_in_kernel() = CUDA.blockDim().x * CUDA.gridDim().x
73+
thread_idx_in_kernel() =
6174
(CUDA.blockIdx().x - 1) * CUDA.blockDim().x + CUDA.threadIdx().x
6275

63-
# The maximum number of blocks that can fit on the GPU used for this kernel.
76+
# Maximum number of blocks that can fit on the GPU used for this kernel.
6477
grid_size_limit(kernel) = CUDA.attribute(
6578
CUDA.device(kernel.fun.mod.ctx),
6679
CUDA.DEVICE_ATTRIBUTE_MAX_GRID_DIM_X,
@@ -73,22 +86,22 @@ block_size_limit(max_threads_in_block::Int, _) = max_threads_in_block
7386
block_size_limit(::Val{:auto}, kernel) =
7487
CUDA.launch_configuration(kernel.fun).threads
7588

76-
function threaded(f::F, ::CUDADevice, ::Val, itr; block_size) where {F}
89+
function threaded(f::F, device::CUDADevice, ::Val, itr; block_size) where {F}
7790
length(itr) > 0 || return nothing
7891
Base.require_one_based_indexing(itr)
7992

80-
function call_f_once_from_thread()
81-
item_index = thread_index()
82-
item_index <= length(itr) && @inbounds f(itr[item_index])
93+
function thread_function()
94+
itr_index = thread_idx_in_kernel()
95+
itr_index <= length(itr) && @inbounds f(itr[itr_index])
8396
return nothing
8497
end
85-
kernel = CUDA.@cuda launch=false call_f_once_from_thread()
98+
kernel = CUDA.@cuda launch=false thread_function()
8699
max_blocks = grid_size_limit(kernel)
87100
max_threads_in_block = block_size_limit(block_size, kernel)
88101

89102
# If there are too many items, coarsen by the smallest possible amount.
90103
length(itr) <= max_blocks * max_threads_in_block ||
91-
return threaded(f, CUDADevice(), 1, itr)
104+
return threaded(f, device, 1, itr; block_size)
92105

93106
threads_in_block = min(max_threads_in_block, length(itr))
94107
blocks = cld(length(itr), threads_in_block)
@@ -102,17 +115,18 @@ function threaded(
102115
itr;
103116
block_size,
104117
) where {F}
105-
min_items_in_thread > 0 || throw(ArgumentError("`coarsen` is not positive"))
118+
min_items_in_thread > 0 ||
119+
throw(ArgumentError("integer `coarsen` value must be positive"))
106120
length(itr) > 0 || return nothing
107121
Base.require_one_based_indexing(itr)
108122

109123
# Maximize memory coalescing with a "grid-stride loop"; for reference, see
110124
# https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops
111-
call_f_multiple_times_from_thread() =
112-
for item_index in thread_index():threads_in_kernel():length(itr)
113-
@inbounds f(itr[item_index])
125+
coarsened_thread_function() =
126+
for itr_index in thread_idx_in_kernel():threads_in_kernel():length(itr)
127+
@inbounds f(itr[itr_index])
114128
end
115-
kernel = CUDA.@cuda launch=false call_f_multiple_times_from_thread()
129+
kernel = CUDA.@cuda launch=false coarsened_thread_function()
116130
max_blocks = grid_size_limit(kernel)
117131
max_threads_in_block = block_size_limit(block_size, kernel)
118132

@@ -129,4 +143,137 @@ function threaded(
129143
CUDA.@sync kernel(; blocks, threads = threads_in_block)
130144
end
131145

146+
function threaded(
147+
f::F,
148+
device::CUDADevice,
149+
::Union{Val, NTuple{2, Val}},
150+
independent_itr,
151+
interdependent_itr;
152+
block_size,
153+
) where {F}
154+
length(independent_itr) > 0 || return nothing
155+
length(interdependent_itr) > 0 || return nothing
156+
Base.require_one_based_indexing(independent_itr)
157+
Base.require_one_based_indexing(interdependent_itr)
158+
159+
function two_itr_thread_function()
160+
block_index = block_idx_in_kernel()
161+
thread_index = thread_idx_in_block()
162+
(
163+
block_index <= length(independent_itr) &&
164+
thread_index <= length(interdependent_itr)
165+
) && @inbounds f(
166+
independent_itr[block_index],
167+
OneInterdependentItem(interdependent_itr[thread_index], device),
168+
)
169+
return nothing
170+
end
171+
kernel = CUDA.@cuda launch=false two_itr_thread_function()
172+
max_blocks = grid_size_limit(kernel)
173+
max_threads_in_block = block_size_limit(block_size, kernel)
174+
175+
# If there are too many items, coarsen by the smallest possible amount.
176+
(
177+
length(independent_itr) <= max_blocks &&
178+
length(interdependent_itr) <= max_threads_in_block
179+
) || return threaded(
180+
f,
181+
device,
182+
(1, 1),
183+
independent_itr,
184+
interdependent_itr;
185+
block_size,
186+
)
187+
188+
blocks = length(independent_itr)
189+
threads_in_block = length(interdependent_itr)
190+
CUDA.@sync kernel(; blocks, threads = threads_in_block)
191+
end
192+
193+
# Use a default coarsen value of 1 for either iterator when a value is needed.
194+
threaded(
195+
f::F,
196+
device::CUDADevice,
197+
min_independent_items_in_thread::Int,
198+
independent_itr,
199+
interdependent_itr;
200+
block_size,
201+
) where {F} = threaded(
202+
f,
203+
device,
204+
(min_independent_items_in_thread, 1),
205+
independent_itr,
206+
interdependent_itr;
207+
block_size,
208+
)
209+
threaded(
210+
f::F,
211+
device::CUDADevice,
212+
min_items_in_thread::Tuple{Val, Int},
213+
independent_itr,
214+
interdependent_itr;
215+
block_size,
216+
) where {F} = threaded(
217+
f,
218+
device,
219+
(1, min_items_in_thread[2]),
220+
independent_itr,
221+
interdependent_itr;
222+
block_size,
223+
)
224+
225+
function threaded(
226+
f::F,
227+
device::CUDADevice,
228+
min_items_in_thread::NTuple{2, Int},
229+
independent_itr,
230+
interdependent_itr;
231+
block_size,
232+
) where {F}
233+
(min_items_in_thread[1] > 0 && min_items_in_thread[2] > 0) ||
234+
throw(ArgumentError("all integer `coarsen` values must be positive"))
235+
length(independent_itr) > 0 || return nothing
236+
length(interdependent_itr) > 0 || return nothing
237+
Base.require_one_based_indexing(independent_itr)
238+
Base.require_one_based_indexing(interdependent_itr)
239+
240+
# Maximize memory coalescing with a "grid-stride loop" (reference is above).
241+
function coarsened_two_itr_thread_function()
242+
independent_itr_indices =
243+
block_idx_in_kernel():blocks_in_kernel():length(independent_itr)
244+
interdependent_itr_indices =
245+
thread_idx_in_block():threads_in_block():length(interdependent_itr)
246+
for independent_itr_index in independent_itr_indices
247+
@inbounds f(
248+
independent_itr[independent_itr_index],
249+
MultipleInterdependentItems(
250+
interdependent_itr,
251+
interdependent_itr_indices,
252+
device,
253+
),
254+
)
255+
end
256+
end
257+
kernel = CUDA.@cuda launch=false coarsened_two_itr_thread_function()
258+
max_blocks = grid_size_limit(kernel)
259+
max_threads_in_block = block_size_limit(block_size, kernel)
260+
261+
# If there are too many items to use the specified coarsening, increase it
262+
# by the smallest possible amount.
263+
max_required_blocks = cld(length(independent_itr), min_items_in_thread[1])
264+
max_required_threads_in_block =
265+
cld(length(interdependent_itr), min_items_in_thread[2])
266+
items_in_thread = (
267+
max_required_blocks <= max_blocks ? min_items_in_thread[1] :
268+
cld(length(independent_itr), max_blocks),
269+
max_required_threads_in_block <= max_threads_in_block ?
270+
min_items_in_thread[2] :
271+
cld(length(interdependent_itr), max_threads_in_block),
272+
)
273+
274+
blocks = cld(length(independent_itr), items_in_thread[1])
275+
threads_in_block = cld(length(interdependent_itr), items_in_thread[2])
276+
CUDA.@sync kernel(; blocks, threads = threads_in_block)
277+
end
278+
132279
end

0 commit comments

Comments
 (0)