Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
name = "RegisterWorkerAperturesMismatch"
uuid = "30e56b64-2659-11e9-2fbf-0524297743d8"
authors = ["Tim Holy <tim.holy@gmail.com>"]
version = "1.0.0"
authors = ["Tim Holy <tim.holy@gmail.com>"]

[deps]
CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
RegisterCore = "67712758-55e7-5c3c-8e85-dda1d7758434"
RegisterDeformation = "c19381b7-cf49-59d7-881c-50dfbd227eaf"
RegisterDriver = "935ac36e-2656-11e9-1e3b-cbaa636797af"
RegisterFit = "36121b08-3789-5198-aff2-59a3443d9b59"
RegisterMismatch = "3c0dd727-6833-5f5d-a1e8-c0d421935c74"
RegisterMismatchCommon = "abb2e897-52bf-5d28-a379-6ca321e3b878"
Expand All @@ -22,8 +21,10 @@ SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[compat]
CoordinateTransformations = "0.5, 0.6"
Aqua = "0.8"
CUDA = "3, 4, 5"
CoordinateTransformations = "0.5, 0.6"
ExplicitImports = "1"
ImageCore = "0.8.1, 0.9, 0.10"
Interpolations = "0.12, 0.13, 0.14, 0.15"
RegisterCore = "1"
Expand All @@ -36,19 +37,23 @@ RegisterMismatchCuda = "1"
RegisterOptimize = "1"
RegisterPenalty = "1"
RegisterWorkerShell = "1"
SharedArrays = "1"
StaticArrays = "0.12, 1"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
ImageAxes = "2803e5a7-5153-5ecf-9a86-9b4c37f5f5ac"
ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1"
JLD = "4138dd39-2aa7-5051-a626-17a0bb65d9c8"
PaddedViews = "5432bcbf-9aad-5242-b902-cca2824c8663"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
RegisterDriver = "935ac36e-2656-11e9-1e3b-cbaa636797af"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"

[targets]
test = ["Test", "AxisArrays", "ImageAxes", "Distributed", "ImageMagick", "JLD", "PaddedViews", "Pkg", "TestImages"]
test = ["Aqua", "ExplicitImports", "RegisterDriver", "Test", "AxisArrays", "ImageAxes", "Distributed", "ImageMagick", "JLD", "PaddedViews", "Pkg", "TestImages"]
81 changes: 73 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,75 @@
# RegisterWorkerAperturesMismatch.jl

This package is similar to [RegisterWorkerApertures](https://github.com/HolyLab/RegisterWorkerApertures.jl),
whose documentation you should consult for an overview.
This package differs in that it is targeted at "whole experiment" rather than
"stack-by-stack" registration.
It writes the mismatch data to disk, and then [RegisterOptimize](https://github.com/HolyLab/RegisterOptimize.jl) is used to optimize a time-series of deformations.

However, because the temporal regularization enforces smoothness, and many real-world
data sets have discontinuous movements, this approach is not currently recommended.
[![CI](https://github.com/HolyLab/RegisterWorkerAperturesMismatch.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/HolyLab/RegisterWorkerAperturesMismatch.jl/actions/workflows/CI.yml)
[![codecov](https://codecov.io/gh/HolyLab/RegisterWorkerAperturesMismatch.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/HolyLab/RegisterWorkerAperturesMismatch.jl)
[![Aqua QA](https://juliatesting.github.io/Aqua.jl/dev/assets/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
[![version](https://img.shields.io/github/v/release/HolyLab/RegisterWorkerAperturesMismatch.jl)](https://github.com/HolyLab/RegisterWorkerAperturesMismatch.jl/releases)

Worker for aperture-based (blocked) image registration using mismatch polynomials.
The image domain is divided into a grid of apertures; a local shift is estimated
for each aperture by fitting a quadratic to the cross-correlation mismatch array.
Results are written to disk for downstream optimization with
[RegisterOptimize](https://github.com/HolyLab/RegisterOptimize.jl).

This package is similar to
[RegisterWorkerApertures](https://github.com/HolyLab/RegisterWorkerApertures.jl)
(see its documentation for a broader overview), but targets **whole-experiment**
rather than stack-by-stack registration.
Because the temporal regularization in `RegisterOptimize` enforces smoothness,
this approach is **not currently recommended** for data sets with discontinuous
movements.

CUDA-accelerated computation is supported via the `dev` keyword argument.

## Installation

This package is registered in the
[HolyLabRegistry](https://github.com/HolyLab/HolyLabRegistry).
Add the registry once, then install normally:

```julia
using Pkg
pkg"registry add https://github.com/HolyLab/HolyLabRegistry.git"
Pkg.add("RegisterWorkerAperturesMismatch")
```

## Usage

```julia
using RegisterWorkerAperturesMismatch

# Define a reference image and aperture grid
fixed = Float32.(reshape(1:64*64, 64, 64)) ./ (64f0 * 64f0)
nodes = (range(1, 64, length=4), range(1, 64, length=4)) # 4×4 aperture grid
maxshift = (5, 5)

# Build the worker and monitoring dict
alg = AperturesMismatch(fixed, nodes, maxshift)
mon = monitor(alg, (:Es, :cs, :Qs))

# Register a moving image
moving = fixed .+ 0.01f0
mon = worker(alg, moving, 1, mon)

# Inspect results
size(mon[:Es]) # (4, 4) — per-aperture mismatch energy
size(mon[:cs]) # (4, 4) — per-aperture shift estimates (SVector{2})
size(mon[:Qs]) # (4, 4) — per-aperture curvature matrices (SMatrix{2,2})
```

For a preprocessing function (applied to both `fixed` and each `moving` frame
before computing the mismatch):

```julia
pp = img -> img ./ (maximum(img) + eps(eltype(img)))
fixed = pp(raw_fixed)
alg = AperturesMismatch(fixed, nodes, maxshift, pp)
```

For CUDA-accelerated registration, pass the device index:

```julia
alg = AperturesMismatch(fixed, nodes, maxshift; dev=0)
```

See the docstrings (`?AperturesMismatch`, `?worker`) for the full keyword reference.
149 changes: 107 additions & 42 deletions src/RegisterWorkerAperturesMismatch.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,42 @@
"""
RegisterWorkerAperturesMismatch

Worker for aperture-based (blocked) image registration using mismatch polynomials.

Divides the image domain into a grid of apertures and estimates a local shift for
each aperture by fitting a quadratic to the mismatch array. Implements the
[`AbstractWorker`](@ref RegisterWorkerShell.AbstractWorker) interface from
`RegisterWorkerShell`; the primary entry points are [`AperturesMismatch`](@ref)
(constructor) and [`worker`](@ref) (single-frame registration).

CUDA-accelerated computation is supported when `dev` is set to a device index.
"""
module RegisterWorkerAperturesMismatch

using ImageCore, CoordinateTransformations, Interpolations, StaticArrays, SharedArrays
using RegisterCore, RegisterDeformation, RegisterFit, RegisterPenalty, RegisterOptimize
using RegisterMismatch, RegisterMismatchCommon
using CoordinateTransformations: CoordinateTransformations
using ImageCore: ImageCore, coords_spatial, nimages
using Interpolations: Interpolations
using RegisterCore: RegisterCore, NumDenom, maxshift
using RegisterDeformation: RegisterDeformation
using RegisterFit: RegisterFit, qfit
using RegisterMismatch: RegisterMismatch, CMStorage, mismatch_apertures!
using RegisterMismatchCommon: RegisterMismatchCommon, allocate_mmarrays, aperture_grid,
correctbias, correctbias!, default_aperture_width, mismatch_apertures
using RegisterOptimize: RegisterOptimize
using RegisterPenalty: RegisterPenalty, interpolate_mm!
using RegisterWorkerShell: RegisterWorkerShell, AbstractWorker, ArrayDecl, getindex_t,
monitor, monitor!
using SharedArrays: SharedArrays, sdata
using StaticArrays: StaticArrays, SMatrix, SVector, Size, similar_type
# Note: RegisterMismatchCuda is loaded dynamically below when dev !== nothing
using RegisterWorkerShell # , RegisterDriver

import RegisterWorkerShell: worker, init!, close!, load_mm_package, workertid

export AperturesMismatch, monitor, monitor!, worker

mutable struct AperturesMismatch{A <: AbstractArray, T, K, N} <: AbstractWorker
struct AperturesMismatch{A <: AbstractArray, T, N} <: AbstractWorker
fixed::A
nodes::NTuple{N, K}
nodes::NTuple{N}
maxshift::NTuple{N, Int}
thresh::T
preprocess # likely of type PreprocessSNF, but could be a function
Expand Down Expand Up @@ -75,45 +99,70 @@ function close!(algorithm::AperturesMismatch)
end

"""

`alg = AperturesMismatch(fixed, nodes, maxshift, [preprocess=identity];
kwargs...)` creates a worker-object for performing "apertured"
(blocked) registration. `fixed` is the reference image, `nodes`
specifies the grid of apertures, `maxshift` represents the largest
shift (in pixels) that will be evaluated, and `preprocess` allows you
to apply a transformation (e.g., filtering) to the `moving` images
before registration; `fixed` should already have any such
transformations applied.

Registration is performed by calling `driver`. You should monitor
`Es`, `cs`, `Qs`, and `mmis`.

## Example

Suppose your images are somewhat noisy, in which case a bit of
smoothing might help considerably. Here we'll illustrate the use of a
pre-processing function, but see also `PreprocessSNF`.

AperturesMismatch(fixed, nodes, maxshift, preprocess=identity;
normalization=:pixels, thresh_fac=0.5^ndims(fixed),
thresh=nothing, correctbias=true, tid=1, dev=nothing)

Create a worker object for aperture-based (blocked) image registration.

`fixed` is the reference image. `nodes` is an `N`-tuple of ranges or
vectors specifying the aperture grid along each spatial dimension.
`maxshift` is an `N`-tuple of integers giving the maximum shift (in
pixels) to evaluate along each dimension. `preprocess` is an optional
function applied to each `moving` image before registration; `fixed`
should already have the same transformation applied.

# Keyword arguments

- `normalization`: `:pixels` (default) normalizes mismatch by the number
of pixels in each aperture; `:intensity` normalizes by image intensity.
- `thresh_fac`: sets the default threshold as `thresh_fac / prod(gridsize)`
times the image norm. Ignored when `thresh` is supplied explicitly.
- `thresh`: minimum mismatch energy required to fit a quadratic. Apertures
below threshold are skipped. Defaults to a value derived from `thresh_fac`.
- `correctbias`: if `true` (default), applies bias correction to the
mismatch arrays before fitting.
- `tid`: worker thread id (default `1`).
- `dev`: CUDA device index (`Int`). If `nothing` (default), runs on CPU.

The returned object is an `AbstractWorker` subtype. Use [`monitor`](@ref)
to create a monitoring dict, then [`worker`](@ref) (or `driver`) to run
registration. The key monitored quantities are:
- `:Es` — per-aperture mismatch energy
- `:cs` — per-aperture shift estimates
- `:Qs` — per-aperture quadratic curvature matrices
- `:mmis` — interpolated mismatch arrays (optional, expensive to store)

# Examples

Basic usage with a 4×4 aperture grid:

```julia
fixed = rand(Float32, 64, 64)
nodes = (range(1, 64, length=4), range(1, 64, length=4))
maxshift = (5, 5)
alg = AperturesMismatch(fixed, nodes, maxshift)
mon = monitor(alg, (:Es, :cs, :Qs))
moving = rand(Float32, 64, 64)
mon = worker(alg, moving, 1, mon)
size(mon[:Es]) # (4, 4)
```
# Raw images are fixed0 and moving0, both two-dimensional
pp = img -> imfilter_gaussian(img, [3, 3])
fixed = pp(fixed0)
# We'll use a 5x7 grid of apertures
nodes = (linspace(1, size(fixed,1), 5), linspace(1, size(fixed,2), 7))
# Allow shifts of up to 30 pixels in any direction
maxshift = (30,30)

# Create the algorithm-object
alg = AperturesMismatch(fixed, nodes, maxshift, pp)

mon = monitor(alg, (:Es, :cs, :Qs, :mmis))

# Run the algorithm
mon = driver(algorithm, moving0, mon)
With a preprocessing function applied to both `fixed` and each `moving` image:

```julia
fixed0 = rand(Float32, 64, 64)
pp = img -> img ./ (maximum(img) + eps(Float32))
fixed = pp(fixed0)
nodes = (range(1, 64, length=5), range(1, 64, length=7))
alg = AperturesMismatch(fixed, nodes, (10, 10), pp)
mon = monitor(alg, (:Es, :cs, :Qs))
moving0 = rand(Float32, 64, 64)
mon = worker(alg, moving0, 1, mon)
size(mon[:cs]) # (5, 7)
```

"""
function AperturesMismatch(fixed, nodes::NTuple{N, K}, maxshift::NTuple{N, <:Integer}, preprocess = identity; normalization = :pixels, thresh_fac = (0.5)^ndims(fixed), thresh = nothing, correctbias::Bool = true, tid = 1, dev = nothing) where {K, N}
function AperturesMismatch(fixed, nodes::NTuple{N}, maxshift::NTuple{N, <:Integer}, preprocess = identity; normalization = :pixels, thresh_fac = (0.5)^ndims(fixed), thresh = nothing, correctbias::Bool = true, tid = 1, dev = nothing) where {N}
gridsize = map(length, nodes)
nimages(fixed) == 1 || error("Register to a single image")
if isnothing(thresh)
Expand All @@ -126,9 +175,25 @@ function AperturesMismatch(fixed, nodes::NTuple{N, K}, maxshift::NTuple{N, <:Int
Qs = ArrayDecl(Array{similar_type(SMatrix, T, Size(N, N)), N}, gridsize)
mmsize = map(x -> 2x + 1, maxshift)
mmis = ArrayDecl(Array{NumDenom{T}, 2 * N}, (mmsize..., gridsize...))
return AperturesMismatch{typeof(fixed), T, K, N}(fixed, nodes, maxshift, T(thresh), preprocess, normalization, correctbias, Es, cs, Qs, mmis, tid, dev, Dict{Symbol, Any}())
return AperturesMismatch{typeof(fixed), T, N}(fixed, nodes, maxshift, T(thresh), preprocess, normalization, correctbias, Es, cs, Qs, mmis, tid, dev, Dict{Symbol, Any}())
end

"""
worker(algorithm::AperturesMismatch, img, tindex, mon) -> mon

Perform aperture-based mismatch registration for a single image frame.

`img` is the source of moving images. If `img` has a time axis, `tindex`
selects the frame; otherwise `img` itself is used directly. `mon` is the
monitoring dict returned by `monitor`; any keys present in `mon` (`:Es`,
`:cs`, `:Qs`, `:mmis`) are updated with the results of this call.

Returns `mon` with updated fields:
- `:Es` — per-aperture mismatch energy (scalar per aperture)
- `:cs` — per-aperture shift estimates (`SVector{N}` per aperture)
- `:Qs` — per-aperture quadratic curvature matrices (`SMatrix{N,N}` per aperture)
- `:mmis` — interpolated mismatch arrays (present only if `:mmis` key was in `mon`)
"""
function worker(algorithm::AperturesMismatch, img, tindex, mon)
moving0 = getindex_t(img, tindex)
moving = algorithm.preprocess(moving0)
Expand Down
5 changes: 3 additions & 2 deletions test/aperturedmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@
fn_cuda = joinpath(workdir, "apertured_cuda.jld")
algorithm = AperturesMismatch(pp(fixed), nodes, maxshift, pp; dev = 0)
prepare_mm_package(algorithm)
mons = monitor(algorithms, (:Es, :cs, :Qs, :mmis))
driver(fn_cuda, algorithms, img, mons)
mons = monitor([algorithm], (:Es, :cs, :Qs, :mmis))
# RegisterMismatchCuda performs some scalar GPU indexing; allow it explicitly
CUDA.@allowscalar driver(fn_cuda, [algorithm], img, mons)
end

fns = [fn, fn_pp]
Expand Down
18 changes: 17 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
using SharedArrays, JLD, Test
using SharedArrays, JLD, Test, CUDA
using ImageCore, ImageAxes, TestImages, StaticArrays, PaddedViews
using AxisArrays: AxisArray
using RegisterCore, RegisterOptimize, RegisterDeformation, RegisterPenalty
using RegisterMismatchCommon
using RegisterWorkerAperturesMismatch, RegisterDriver
using Aqua
using ExplicitImports

@testset "Aqua" begin
Aqua.test_all(RegisterWorkerAperturesMismatch;
stale_deps=(; ignore=[:CUDA, :RegisterMismatchCuda]),
deps_compat=(; check_extras=false),
piracies=(; treat_as_own=[RegisterWorkerAperturesMismatch.load_mm_package]),
# AtomixCUDAExt declares __precompile__(false), which is disallowed during
# precompilation on Julia 1.10, causing a spurious persistent_tasks failure.
persistent_tasks=VERSION >= v"1.11")
end

@testset "ExplicitImports" begin
test_explicit_imports(RegisterWorkerAperturesMismatch)
end

nt = 3 # number of time points
wtids = threadids()
Expand Down
Loading