Skip to content

Some minor kernel updates #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 127 additions & 141 deletions src/distanceMetrics/Housdorff/mainHouseDorffKernel/MainLoopIterations.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
this will be invoked after first pass kernel so we already have a working queue and we will work on it
we will finish when mainActiveCounterNext will evaluate 0 at the end of iteration
Expand Down Expand Up @@ -34,149 +33,142 @@ as we will looop here intensively and after each iteration we will need to synch
fp,fn - amount of false positives and false negatives - used in order to be able to get early termination
"""
module MainLoopIterations
using CUDA, ..CUDAGpuUtils, Logging,StaticArrays
using ..HFUtils

function mainLoopIterationsKernel(reducedArrays
,metaData
,metadataDims::Tuple{UInt8,UInt8,UInt8}
,resArrays
,resArraysCounters
,datablockDim
,mainQuesCounterArr
,mainWorkQueueArr
,tailCounter
,tailParts::UInt8
,numberOfThreadBlocks::UInt16
,isOddPass
,iterationNumber
,fp,fn)


#needed to manage cooperative groups functions
using KernelAbstractions, Logging, StaticArrays
using ..CUDAGpuUtils, ..HFUtils

@kernel function mainLoopIterationsKernel(
reducedArrays,
metaData,
metadataDims::NTuple{3, UInt8},
resArrays,
resArraysCounters,
datablockDim,
mainQuesCounterArr,
mainWorkQueueArr,
tailCounter,
tailParts::UInt8,
numberOfThreadBlocks::UInt16,
isOddPass,
iterationNumber,
fp,
fn
)
# needed to manage cooperative groups functions
grid_handle = this_grid()
#storing intermidiate results +2 in order to get the one padding
resShmem = @cuStaticSharedMem(Bool,(34,34,34))
#storing values loaded from analyzed array ...
sourceShmem = @cuStaticSharedMem(Bool,(34,34,34))
#for storing sums for reductions
shmemSum = @cuStaticSharedMem(Float32,35,14) # we need this additional 33th an 34th spots

# storing intermediate results +2 in order to get the one padding
resShmem = @localmem(Bool, (34, 34, 34))
# storing values loaded from analyzed array ...
sourceShmem = @localmem(Bool, (34, 34, 34))
# for storing sums for reductions
shmemSum = @localmem(Float32, (35, 14))

#coordinates of data in main array
#we will use this to establish weather we should mark the data block as empty or full ...
isMaskFull= false
#here we will store in registers data uploaded from mask for later verification wheather we should send it or not
locArr= Int32(0)
# coordinates of data in main array
# we will use this to establish whether we should mark the data block as empty or full ...
isMaskFull = false
# here we will store in registers data uploaded from mask for later verification whether we should send it or not
locArr = Int32(0)
offsetIter = UInt16(0)

#we will store here the data about the blocks that we want to process
toIterWorkQueueShmem = @cuStaticSharedMem(UInt8,32,4)
#indicates where we are in general work queue in given moment if we are iterating over part of work queue owned by this thread block
positionInLocalWorkQueaue = @cuStaticSharedMem(UInt16,1)
#boolean usefull in the iterating over private part of work queue
isAnyBiggerThanZero = @cuStaticSharedMem(Bool,1)
#used for iterating over a tail
tailCounterInShmem= @cuStaticSharedMem(UInt16,1)
workCounterInshmem= @cuStaticSharedMem(UInt16,1)

#loading data to shared memory from global about is it even or odd pass
isOddPassShmem = @cuStaticSharedMem(Bool,1)
iterationNumberShmem = @cuStaticSharedMem(UInt16,1)
#current spot in tail - usefull to save where we need to access the tail to get proper block
currentTailPosition = @cuStaticSharedMem(UInt16,1)
#shared memory variables needed to marks wheather we are already finished with any dilatation step
goldToBeDilatated = @cuStaticSharedMem(Bool,1)
segmToBeDilatated = @cuStaticSharedMem(Bool,1)
#true when we have more than 0 blocks to analyze in next iteration
workCounterBiggerThan0 = @cuStaticSharedMem(Bool,1)

#resetting
@ifXY 1 1 tailCounterInShmem[1]=0
@ifXY 2 1 workCounterInshmem[1]=0


"""
just making easier to invoke it without passing arguments all the time
"""
macro innersingleDataBlockPass(ispassGoldd::Bool,currMatadataBlockX::UInt8,currMatadataBlockY::UInt8 ,currMatadataBlockZ::UInt8 )
singleDataBlockPass(reducedArrays[ispassGoldd*2+1]
,reducedArrays[ispassGoldd*2+2]
,iterationNumberShmem[]
,((currMatadataBlockX-1)*32)+1
,((currMatadataBlockY-1)*32)+1
,((currMatadataBlockZ-1)*32)+1
,isMaskFull ,resShmem ,locArr ,metaData ,metadataDims ,isPassGold
,currMatadataBlockX
,currMatadataBlockY
,currMatadataBlockZ
,mainQuesCounterArr[!isOddPassShmem[]+1]# we negate those as we want to pass the data about new active blocks to queue of next pass
,mainWorkQueueArr[!isOddPassShmem[]+1]
,resArrays[ispassGoldd+1]
,resArraysCounters[ispassGoldd+1] )

loadDataAtTheBegOfDilatationStep(isOddPassShmem,iterationNumberShmem,iterationNumber,positionInMainWorkQueaue,workCounterInshmem,mainQuesCounterArr,isAnyBiggerThanZero,goldToBeDilatated,segmToBeDilatated, resArraysCounters )
sync_threads()
#we check first wheather next dilatation step should be done or not we also establish some shared memory variables to know wheather both passes should continue or just one
# checking weather we already finished so we need to check
# - is amount of results related to gold mask dilatations is equal to false positives
# - is amount of results related to other mask dilatations is equal to false negatives
# - is amount of workQueue that we will want to analyze now is bigger than 0
while(goldToBeDilatated[1] && goldToBeDilatated[1] && workCounterBiggerThan0[1])
while(isAnyBiggerThanZero[])
loadOwnedWorkQueueIntoShmem(mainWorkQueue,mainQuesCounter,toIterWorkQueueShmem,positionInMainWorkQueaue ,numberOfThreadBlocks,tailParts)
#at this point if we have anything in the private part of the work queue we will have it in toIterWorkQueueShmem, in case we will find some 0 inside it means that queue is exhousted and we need to loo into tail
@unroll for i in UInt16(1):32# most outer loop is responsible for z dimension
if(shmemIter[i,1]>0)
#we need to check also wheather given dilatation step is already finished - for example it is possible that it do not make sense to dilatate gold mask but still we need dilate other
if( (goldToBeDilatated[1]&&shmemIter[i,4]==1) || (segmToBeDilatated[1]&&shmemIter[i,4]==0) )
innersingleDataBlockPass(shmemIter[i,4]==1,shmemIter[i,1],shmemIter[i,2] ,shmemIter[i,3] )
# we will store here the data about the blocks that we want to process
toIterWorkQueueShmem = @localmem(UInt8, (32, 4))
# indicates where we are in general work queue in given moment if we are iterating over part of work queue owned by this thread block
positionInLocalWorkQueaue = @localmem(UInt16, 1)

# boolean useful in the iterating over private part of work queue
isAnyBiggerThanZero = @localmem(Bool, 1)
# used for iterating over a tail
tailCounterInShmem = @localmem(UInt16, 1)
workCounterInshmem = @localmem(UInt16, 1)

# loading data to shared memory from global about is it even or odd pass
isOddPassShmem = @localmem(Bool, 1)
iterationNumberShmem = @localmem(UInt16, 1)
# current spot in tail - useful to save where we need to access the tail to get proper block
currentTailPosition = @localmem(UInt16, 1)
# shared memory variables needed to mark whether we are already finished with any dilatation step
goldToBeDilatated = @localmem(Bool, 1)
segmToBeDilatated = @localmem(Bool, 1)
# true when we have more than 0 blocks to analyze in next iteration
workCounterBiggerThan0 = @localmem(Bool, 1)

# resetting
if @index(Local, Linear) == 0
tailCounterInShmem[1] = 0
workCounterInshmem[1] = 0
end

@synchronize

# main loop logic
while goldToBeDilatated[1] && segmToBeDilatated[1] && workCounterBiggerThan0[1]
while isAnyBiggerThanZero[1]
loadOwnedWorkQueueIntoShmem(mainWorkQueueArr, mainQuesCounterArr, toIterWorkQueueShmem, positionInLocalWorkQueaue, numberOfThreadBlocks, tailParts)
# at this point if we have anything in the private part of the work queue we will have it in toIterWorkQueueShmem, in case we will find some 0 inside it means that queue is exhausted and we need to look into tail
@unroll for i in UInt16(1):32 # most outer loop is responsible for z dimension
if toIterWorkQueueShmem[i, 1] > 0
# we need to check also whether given dilatation step is already finished - for example it is possible that it does not make sense to dilate gold mask but still we need to dilate other
if (goldToBeDilatated[1] && toIterWorkQueueShmem[i, 4] == 1) || (segmToBeDilatated[1] && toIterWorkQueueShmem[i, 4] == 0)
innersingleDataBlockPass(toIterWorkQueueShmem[i, 4] == 1, toIterWorkQueueShmem[i, 1], toIterWorkQueueShmem[i, 2], toIterWorkQueueShmem[i, 3])
end
sync_threads()
else #at this point we got some 0 in the shmemIter
if(threadIdxX()==1)# sadly only threads with this id are managing the work queue
isAnyBiggerThanZero[]=false
end#inner if
end#if
end#for
end#while

# in this moment we have nothing left in private part of work queue and we need to check is there sth in tail to process
#first load data of tail counter once

loadDataNeededForTailAnalysisToShmem(currentTailPosition,tailCounter )
sync_threads()
while (currentTailPosition[1]< workCounterInshmem[1])
#below we access tail of working queue in a way that will be atomic
if(threadIdxY()==1 && threadIdxX()==1 )
shmemIter[1]= mainWorkQueueArr[isOddPassShmem[]+1][currentTailPosition[1]]
end
sync_threads()
#we need to check also wheather given dilatation step is already finished - for example it is possible that it do not make sense to dilatate gold mask but still we need dilate other
if( (goldToBeDilatated[1]&&shmemIter[1,4]==1) || (segmToBeDilatated[1]&&shmemIter[1,4]==0) )
innersingleDataBlockPass(shmemIter[1,4]==1,shmemIter[1,1],shmemIter[1,2] ,shmemIter[1,3] )
end
end
loadDataNeededForTailAnalysisToShmem(currentTailPosition,tailCounter )
sync_threads()
loadDataNeededForTailAnalysisToShmem(currentTailPosition, tailCounter)
@synchronize
end
#we set it before sync grid - as we do not need it for this one
if(threadIdxX()==10 && threadIdxY()==10)
#we are just negiting it so it should be the same in all blocks
isOddPassShmem[1]= !isOddPassShmem[1]
# we set it before sync grid - as we do not need it for this one
if @index(Local, Linear) == 0
# we are just negating it so it should be the same in all blocks
isOddPassShmem[1] = !isOddPassShmem[1]
end

#if we are here it means we had covered all blocks that were marked as active and we need to prepare to next dilatation step
sync_grid(grid_handle)
#we clear and add negation to !isOddPassShmem becouse we want to set the previously updated counter to 0
clearBeforeNextDilatation(locArr,resShmem,mainQuesCounterArr[!isOddPassShmem[1]+1])
prepareForNextDilatationStep(iterationNumber,tailCounter,numberOfThreadBlocks,tailParts,mainQuesCounterArr,isOddPassShmem)
loadDataAtTheBegOfDilatationStep(isOddPassShmem,iterationNumberShmem,iterationNumber,positionInMainWorkQueaue,workCounterInshmem,mainQuesCounterArr,isAnyBiggerThanZero,goldToBeDilatated,segmToBeDilatated, resArraysCounters)

@synchronize
# if we are here it means we had covered all blocks that were marked as active and we need to prepare to next dilatation step
sync_grid(grid_handle)
end#while isNotFinished
# we clear and add negation to !isOddPassShmem because we want to set the previously updated counter to 0
clearBeforeNextDilatation(locArr, resShmem, mainQuesCounterArr[!isOddPassShmem[1] + 1])
prepareForNextDilatationStep(iterationNumber, tailCounter, numberOfThreadBlocks, tailParts, mainQuesCounterArr, isOddPassShmem)
end
end


end#mainLoopIterationsKernel
function executeMainLoopIterationsKernel(
reducedArrays,
metaData,
metadataDims,
resArrays,
resArraysCounters,
datablockDim,
mainQuesCounterArr,
mainWorkQueueArr,
tailCounter,
tailParts,
numberOfThreadBlocks,
isOddPass,
iterationNumber,
fp,
fn
)
threads = (32, 32)
blocks = (cld(metadataDims[1], threads[1]), cld(metadataDims[2], threads[2]), cld(metadataDims[3], threads[3]))

kernel = mainLoopIterationsKernel(CPU(), threads, blocks)
kernel(
reducedArrays,
metaData,
metadataDims,
resArrays,
resArraysCounters,
datablockDim,
mainQuesCounterArr,
mainWorkQueueArr,
tailCounter,
tailParts,
numberOfThreadBlocks,
isOddPass,
iterationNumber,
fp,
fn,
ndrange = blocks
)
end

"""
checking weather we already finished so we need to check
Expand Down Expand Up @@ -341,10 +333,4 @@ function singleDataBlockPass(analyzedArr
# now let's check weather block is eligible for futher processing - for this we need sums ...
isActiveForNormalPass(isMaskFull, isMaskEmpty,resShmem,currMatadataBlockX,currMatadataBlockY,currMatadataBlockZ,isPassGold,metaData,mainQuesCounter,mainWorkQueue,resArraysCounter)
end#singleDataBlockPass




end

end # module MainLoopIterations
end
2 changes: 1 addition & 1 deletion src/distanceMetrics/MeansMahalinobis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ covarianceGlobal (just one column but entries exactly the same as above)
mahalanobisResGlobal - global result of Mahalinobis distance
mahalanobisResSliceWise - global result of Mahalinobis distance
"""
@kernel function meansMahalanobisKernel(
@kernel function meansMahalinobisKernel(
goldArr, segmArr, numberToLooFor, loopYdim::UInt32, loopXdim::UInt32, loopZdim::UInt32, arrDims::Tuple{UInt32,UInt32,UInt32},
totalXGold, totalYGold, totalZGold, totalCountGold, totalXSegm, totalYSegm, totalZSegm, totalCountSegm,
countPerZGold, countPerZSegm,
Expand Down
2 changes: 0 additions & 2 deletions test/ICtest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ arrGold = CUDA.ones(UInt8, (dimx,dimy,dimz) )
arrAlgo = CUDA.ones(UInt8, (dimx,dimy,dimz) )
numberToLooFor = UInt8(1)
argsMain, threadsMain, blocksMain,threadsMean,blocksMean,argsMean, totalNumbOfVoxels=InterClassCorrKernel.prepareInterClassCorrKernel(arrGold ,arrAlgo,numberToLooFor)

globalICC=calculateInterclassCorr(arrGold,arrAlgo,argsMain, threadsMain, blocksMain,threadsMean,blocksMean,argsMean, totalNumbOfVoxels)::Float64


Int64(argsMain[1][1]-dimx*dimy*dimz)
argsMain[2][1]==dimx*dimy*dimz
3 changes: 1 addition & 2 deletions test/distanceMetr/MainMahalinobisTest.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
using Revise, Parameters, Logging, Test
using CUDA
# includet("../../src/kernels/kernelEvolutions.jl")
includet("../../src/structs/BasicStructs.jl")
includet("../../src/utils/CUDAGpuUtils.jl")
includet("../../src/utils/IterationUtils.jl")
includet("../../src/utils/ReductionUtils.jl")
includet("../../src/utils/MemoryUtils.jl")
includet("../../src/distanceMetrics/MeansMahalinobis.jl")
includet("../../src/distanceMetrics/Mahalanobis.jl")
using ..BasicPreds, ..CUDAGpuUtils , ..Mahalanobis, ..MeansMahalinobis, ..IterationUtils,..ReductionUtils , ..MemoryUtils
using ..CUDAGpuUtils , ..MeansMahalinobis, ..IterationUtils,..ReductionUtils , ..MemoryUtils
nx=512 ; ny=512 ; nz=317
#first we initialize the metrics on CPU so we will modify them easier
goldBoolCPU= zeros(Float32,nx,ny,nz); #mimicks gold standard mask
Expand Down