Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 31 additions & 31 deletions src/RegisterWorkerAperturesMismatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import RegisterWorkerShell: worker, init!, close!, load_mm_package, workertid

export AperturesMismatch, monitor, monitor!, worker

mutable struct AperturesMismatch{A<:AbstractArray,T,K,N} <: AbstractWorker
mutable struct AperturesMismatch{A <: AbstractArray, T, K, N} <: AbstractWorker
fixed::A
nodes::NTuple{N,K}
maxshift::NTuple{N,Int}
nodes::NTuple{N, K}
maxshift::NTuple{N, Int}
thresh::T
preprocess # likely of type PreprocessSNF, but could be a function
normalization::Symbol
Expand All @@ -23,8 +23,8 @@ mutable struct AperturesMismatch{A<:AbstractArray,T,K,N} <: AbstractWorker
Qs
mmis
tid::Int
dev::Union{Nothing,Int}
cuda_objects::Dict{Symbol,Any}
dev::Union{Nothing, Int}
cuda_objects::Dict{Symbol, Any}
end

workertid(w::AperturesMismatch) = w.tid
Expand All @@ -33,14 +33,14 @@ function load_mm_package(dev)
if dev !== nothing
eval(:(using CUDA, RegisterMismatchCuda))
end
nothing
return nothing
end

function init!(algorithm::AperturesMismatch)
if algorithm.dev !== nothing
cuda_init!(algorithm)
end
nothing
return nothing
end

function cuda_init!(algorithm)
Expand All @@ -57,12 +57,12 @@ function cuda_init!(algorithm)
end
fixed = algorithm.fixed
T = cudatype(eltype(fixed))
d_fixed = CuArray{T}(sdata(fixed))
d_fixed = CuArray{T}(sdata(fixed))
algorithm.cuda_objects[:d_fixed] = d_fixed
algorithm.cuda_objects[:d_moving] = similar(d_fixed)
gridsize = map(length, algorithm.nodes)
aperture_width = default_aperture_width(algorithm.fixed, gridsize)
algorithm.cuda_objects[:cms] = CMStorage{T}(undef, aperture_width, algorithm.maxshift)
return algorithm.cuda_objects[:cms] = CMStorage{T}(undef, aperture_width, algorithm.maxshift)
end

function close!(algorithm::AperturesMismatch)
Expand All @@ -71,7 +71,7 @@ function close!(algorithm::AperturesMismatch)
activate(old_active_context)
end
end
nothing
return nothing
end

"""
Expand Down Expand Up @@ -113,20 +113,20 @@ pre-processing function, but see also `PreprocessSNF`.
```

"""
function AperturesMismatch(fixed, nodes::NTuple{N,K}, maxshift::NTuple{N,<:Integer}, preprocess=identity; normalization=:pixels, thresh_fac=(0.5)^ndims(fixed), thresh=nothing, correctbias::Bool=true, tid=1, dev=nothing) where {K,N}
function AperturesMismatch(fixed, nodes::NTuple{N, K}, maxshift::NTuple{N, <:Integer}, preprocess = identity; normalization = :pixels, thresh_fac = (0.5)^ndims(fixed), thresh = nothing, correctbias::Bool = true, tid = 1, dev = nothing) where {K, N}
gridsize = map(length, nodes)
nimages(fixed) == 1 || error("Register to a single image")
if isnothing(thresh)
thresh = (thresh_fac/prod(gridsize)) * (normalization==:pixels ? length(fixed) : sumabs2(fixed))
thresh = (thresh_fac / prod(gridsize)) * (normalization == :pixels ? length(fixed) : sumabs2(fixed))
end
T = eltype(fixed) <: AbstractFloat ? eltype(fixed) : Float32
# T = Float64 # Ipopt requires Float64
Es = ArrayDecl(Array{T,N}, gridsize)
cs = ArrayDecl(Array{SVector{N,T},N}, gridsize)
Qs = ArrayDecl(Array{similar_type(SMatrix, T, Size(N,N)),N}, gridsize)
mmsize = map(x->2x+1, maxshift)
mmis = ArrayDecl(Array{NumDenom{T},2*N}, (mmsize...,gridsize...))
AperturesMismatch{typeof(fixed),T,K,N}(fixed, nodes, maxshift, T(thresh), preprocess, normalization, correctbias, Es, cs, Qs, mmis, tid, dev, Dict{Symbol,Any}())
Es = ArrayDecl(Array{T, N}, gridsize)
cs = ArrayDecl(Array{SVector{N, T}, N}, gridsize)
Qs = ArrayDecl(Array{similar_type(SMatrix, T, Size(N, N)), N}, gridsize)
mmsize = map(x -> 2x + 1, maxshift)
mmis = ArrayDecl(Array{NumDenom{T}, 2 * N}, (mmsize..., gridsize...))
return AperturesMismatch{typeof(fixed), T, K, N}(fixed, nodes, maxshift, T(thresh), preprocess, normalization, correctbias, Es, cs, Qs, mmis, tid, dev, Dict{Symbol, Any}())
end

function worker(algorithm::AperturesMismatch, img, tindex, mon)
Expand All @@ -136,16 +136,16 @@ function worker(algorithm::AperturesMismatch, img, tindex, mon)
use_cuda = algorithm.dev !== nothing
if use_cuda
device!(CuDevice(algorithm.dev))
d_fixed = algorithm.cuda_objects[:d_fixed]
d_fixed = algorithm.cuda_objects[:d_fixed]
d_moving = algorithm.cuda_objects[:d_moving]
cms = algorithm.cuda_objects[:cms]
cms = algorithm.cuda_objects[:cms]
copyto!(d_moving, moving)
cs = coords_spatial(img)
aperture_centers = aperture_grid(map(d->size(img,d),cs), gridsize)
aperture_centers = aperture_grid(map(d -> size(img, d), cs), gridsize)
mms = allocate_mmarrays(eltype(cms), gridsize, algorithm.maxshift)
mismatch_apertures!(mms, d_fixed, d_moving, aperture_centers, cms; normalization=algorithm.normalization)
mismatch_apertures!(mms, d_fixed, d_moving, aperture_centers, cms; normalization = algorithm.normalization)
else
mms = mismatch_apertures(algorithm.fixed, moving, gridsize, algorithm.maxshift; normalization=algorithm.normalization)
mms = mismatch_apertures(algorithm.fixed, moving, gridsize, algorithm.maxshift; normalization = algorithm.normalization)
end
# displaymismatch(mms, thresh=10)
if algorithm.correctbias
Expand All @@ -154,11 +154,11 @@ function worker(algorithm::AperturesMismatch, img, tindex, mon)
T = eltype(algorithm.Es)
N = length(gridsize)
Es = zeros(T, gridsize...)
cs = Array{SVector{N,T}}(undef, gridsize...)
Qs = Array{similar_type(SMatrix, T, Size(N,N))}(undef, gridsize...)
cs = Array{SVector{N, T}}(undef, gridsize...)
Qs = Array{similar_type(SMatrix, T, Size(N, N))}(undef, gridsize...)
thresh = algorithm.thresh
for i = 1:length(mms)
Es[i], cs[i], Qs[i] = qfit(mms[i], thresh; opt=false)
for i in 1:length(mms)
Es[i], cs[i], Qs[i] = qfit(mms[i], thresh; opt = false)
end
monitor!(mon, :Es, Es)
monitor!(mon, :cs, cs)
Expand All @@ -169,20 +169,20 @@ function worker(algorithm::AperturesMismatch, img, tindex, mon)
gridsize = size(mmis)
coefs1 = first(mmis).data.coefs
result = Array{eltype(coefs1)}(undef, size(coefs1)..., gridsize...)
_copy_mm!(result, mmis, ntuple(_->Colon(), N), CartesianIndices(gridsize))
_copy_mm!(result, mmis, ntuple(_ -> Colon(), N), CartesianIndices(gridsize))
monitor!(mon, :mmis, result)
end
mon
return mon
end

function _copy_mm!(dest, src, colons, R)
for (I, mm) in zip(R, src)
dest[colons..., I] = mm.data.coefs
end
dest
return dest
end

cudatype(::Type{T}) where {T<:Union{Float32,Float64}} = T
cudatype(::Type{T}) where {T <: Union{Float32, Float64}} = T
cudatype(::Any) = Float32

end # module
134 changes: 68 additions & 66 deletions test/aperturedmm.jl
Original file line number Diff line number Diff line change
@@ -1,75 +1,77 @@
workdir = mktempdir()
@testset "apertured mismatch registration" begin
workdir = mktempdir()

### Apertured registration
# Create the data
fixed = testimage("cameraman")
gridsize = (5,5)
ntimes = 4
shift_amplitude = 5
u_dfm = shift_amplitude*randn(2, gridsize..., ntimes)
img = AxisArray(SharedArray{Float64}((size(fixed)..., ntimes)), :y, :x, :time)
nodes = map(d->range(1, stop=size(fixed,d), length=gridsize[d]), (1,2))
tax = timeaxis(img)
for i = 1:ntimes
ϕ_dfm = GridDeformation(u_dfm[:,:,:,i], nodes)
img[tax(i)] = warp(fixed, ϕ_dfm)
end
# Perform the registration
fn = joinpath(workdir, "apertured.jld")
maxshift = (3*shift_amplitude, 3*shift_amplitude)
algorithms = AperturesMismatch[AperturesMismatch(fixed, nodes, maxshift; tid=p) for p in wtids]
prepare_mm_package(algorithms)
mons = monitor(algorithms, (:Es, :cs, :Qs, :mmis))
driver(fn, algorithms, img, mons)

# With preprocessing
fn_pp = joinpath(workdir, "apertured_pp.jld")
pp = PreprocessSNF(0.1, [2,2], [10,10])
algorithms = AperturesMismatch[AperturesMismatch(pp(fixed), nodes, maxshift, pp; tid=p) for p in wtids]
prepare_mm_package(algorithms)
mons = monitor(algorithms, (:Es, :cs, :Qs, :mmis))
driver(fn_pp, algorithms, img, mons)
### Apertured registration
# Create the data
fixed = testimage("cameraman")
gridsize = (5, 5)
ntimes = 4
shift_amplitude = 5
u_dfm = shift_amplitude * randn(2, gridsize..., ntimes)
img = AxisArray(SharedArray{Float64}((size(fixed)..., ntimes)), :y, :x, :time)
nodes = map(d -> range(1, stop = size(fixed, d), length = gridsize[d]), (1, 2))
tax = timeaxis(img)
for i in 1:ntimes
ϕ_dfm = GridDeformation(u_dfm[:, :, :, i], nodes)
img[tax(i)] = warp(fixed, ϕ_dfm)
end
# Perform the registration
fn = joinpath(workdir, "apertured.jld")
maxshift = (3 * shift_amplitude, 3 * shift_amplitude)
algorithms = AperturesMismatch[AperturesMismatch(fixed, nodes, maxshift; tid = p) for p in wtids]
prepare_mm_package(algorithms)
mons = monitor(algorithms, (:Es, :cs, :Qs, :mmis))
driver(fn, algorithms, img, mons)

# using CUDA
if !(haskey(ENV,"CI")&&(ENV["CI"]=="true"))
fn_cuda = joinpath(workdir, "apertured_cuda.jld")
algorithm = AperturesMismatch(pp(fixed), nodes, maxshift, pp; dev=0)
prepare_mm_package(algorithm)
# With preprocessing
fn_pp = joinpath(workdir, "apertured_pp.jld")
pp = PreprocessSNF(0.1, [2, 2], [10, 10])
algorithms = AperturesMismatch[AperturesMismatch(pp(fixed), nodes, maxshift, pp; tid = p) for p in wtids]
prepare_mm_package(algorithms)
mons = monitor(algorithms, (:Es, :cs, :Qs, :mmis))
driver(fn_cuda, algorithms, img, mons)
end
driver(fn_pp, algorithms, img, mons)

fns = [fn, fn_pp]
if (@isdefined fn_cuda)&&isfile(fn_cuda)
push!(fns, fn_cuda)
end
for fname in fns
jldopen(fname, "r") do file
dEs, dcs, dQs, dmmis = file["Es"], file["cs"], file["Qs"], file["mmis"]
for d in (dEs, dcs, dQs, dmmis)
@test eltype(d) == Float32
# using CUDA
if !(haskey(ENV, "CI") && (ENV["CI"] == "true"))
fn_cuda = joinpath(workdir, "apertured_cuda.jld")
algorithm = AperturesMismatch(pp(fixed), nodes, maxshift, pp; dev = 0)
prepare_mm_package(algorithm)
mons = monitor(algorithms, (:Es, :cs, :Qs, :mmis))
driver(fn_cuda, algorithms, img, mons)
end

fns = [fn, fn_pp]
if (@isdefined fn_cuda) && isfile(fn_cuda)
push!(fns, fn_cuda)
end
for fname in fns
jldopen(fname, "r") do file
dEs, dcs, dQs, dmmis = file["Es"], file["cs"], file["Qs"], file["mmis"]
for d in (dEs, dcs, dQs, dmmis)
@test eltype(d) == Float32
end
@test size(dEs) == (gridsize..., ntimes)
@test size(dcs) == (2, gridsize..., ntimes)
@test size(dQs) == (2, 2, gridsize..., ntimes)
innersize = map(x -> 2x + 1, maxshift)
@test size(dmmis) == (2, innersize..., gridsize..., ntimes)
end
@test size(dEs) == (gridsize..., ntimes)
@test size(dcs) == (2, gridsize..., ntimes)
@test size(dQs) == (2, 2, gridsize..., ntimes)
innersize = map(x->2x+1, maxshift)
@test size(dmmis) == (2, innersize..., gridsize..., ntimes)
end
end

cs, Qs, mmis = jldopen(fn, mmaparrays=true) do file
read(file, "cs"), read(file, "Qs"), read(file, "mmis")
end
ϕs, mismatch = fixed_λ(Float64.(collect(cs)), Float64.(collect(Qs)), nodes, AffinePenalty(nodes, 0.001), Float64.(collect(mmis)); λt=1e-5)
for t = 1:nimages(img)
moving = img[tax(t)]
warped = warp(moving, ϕs[t])
r_m = ratio(mismatch0(fixed, moving), 0)
r_w = ratio(mismatch0(fixed, warped), 0)
@test r_w < r_m
end
cs, Qs, mmis = jldopen(fn, mmaparrays = true) do file
read(file, "cs"), read(file, "Qs"), read(file, "mmis")
end
ϕs, mismatch = fixed_λ(Float64.(collect(cs)), Float64.(collect(Qs)), nodes, AffinePenalty(nodes, 0.001), Float64.(collect(mmis)); λt = 1.0e-5)
for t in 1:nimages(img)
moving = img[tax(t)]
warped = warp(moving, ϕs[t])
r_m = ratio(mismatch0(fixed, moving), 0)
r_w = ratio(mismatch0(fixed, warped), 0)
@test r_w < r_m
end

cs = Qs = mmis = 0 # since we're mmapping, we'd better free these before deleting files
GC.gc()
cs = Qs = mmis = 0 # since we're mmapping, we'd better free these before deleting files
GC.gc()

rm(workdir, recursive=true)
rm(workdir, recursive = true)
end
2 changes: 1 addition & 1 deletion test/internals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ end

@testset "AperturesMismatch constructor argument validation" begin
fixed = rand(Float32, 20, 20)
nodes = (range(1, stop=20, length=4), range(1, stop=20, length=4))
nodes = (range(1, stop = 20, length = 4), range(1, stop = 20, length = 4))
# Vector maxshift should not match the NTuple{N,<:Integer} signature
@test_throws MethodError AperturesMismatch(fixed, nodes, [3, 3])
end
Loading
Loading