-
Notifications
You must be signed in to change notification settings - Fork 75
Introducing @reduce for group level reduction #379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
fbc0ced
addedwarp and block reduce
brabreda 899bc98
Merge branch 'JuliaGPU:release-0.8' into release-0.8
brabreda a34ac1a
added op to groupreduce
brabreda 66163f4
add reduction macro
brabreda ab3d6d1
add reduction macro
brabreda db024ed
added reduce file
brabreda File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
struct Config{ | ||
THREADS_PER_WARP, # size of warp | ||
THREADS_PER_BLOCK # size of blocks | ||
} | ||
end | ||
|
||
@inline function Base.getproperty(conf::Type{Config{ THREADS_PER_WARP, THREADS_PER_BLOCK}}, sym::Symbol) where { THREADS_PER_WARP, THREADS_PER_BLOCK} | ||
if sym == :threads_per_warp | ||
THREADS_PER_WARP | ||
elseif sym == :threads_per_block | ||
THREADS_PER_BLOCK | ||
else | ||
# fallback for nothing | ||
getfield(conf, sym) | ||
end | ||
end | ||
|
||
# TODO: make variable block size possible | ||
# TODO: figure out where to place this | ||
# reduction functionality for a group | ||
@inline function __reduce(__ctx__ , op, val, neutral, ::Type{T}) where {T} | ||
threads = KernelAbstractions.@groupsize()[1] | ||
threadIdx = KernelAbstractions.@index(Local) | ||
|
||
# shared mem for a complete reduction | ||
shared = KernelAbstractions.@localmem(T, 1024) | ||
@inbounds shared[threadIdx] = val | ||
|
||
# perform the reduction | ||
d = 1 | ||
while d < threads | ||
KernelAbstractions.@synchronize() | ||
index = 2 * d * (threadIdx-1) + 1 | ||
@inbounds if index <= threads | ||
other_val = if index + d <= threads | ||
shared[index+d] | ||
else | ||
neutral | ||
end | ||
shared[index] = op(shared[index], other_val) | ||
end | ||
d *= 2 | ||
end | ||
|
||
# load the final value on the first thread | ||
if threadIdx == 1 | ||
val = @inbounds shared[threadIdx] | ||
end | ||
|
||
# every thread will return the reduced value of the group | ||
return val | ||
end |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.