Skip to content

Commit c67ac46

Browse files
add conversion module
1 parent 477ccc8 commit c67ac46

File tree

1 file changed

+322
-0
lines changed

1 file changed

+322
-0
lines changed

convert_to_ngff.py

Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
import os, dataclasses, numcodecs, abc, time, dask
2+
from aicsimageio import AICSImage
3+
from aicsimageio.metadata.utils import OME
4+
import numpy as np, cupy as cp
5+
from pathlib import Path
6+
import zarr
7+
from typing import (Union, Iterable)
8+
import warnings
9+
from dask import array as da, bag, delayed
10+
from dask.highlevelgraph import HighLevelGraph
11+
import dask
12+
from dask_cuda import LocalCUDACluster
13+
from rmm.allocators.cupy import rmm_cupy_allocator
14+
import rmm
15+
16+
import itertools
17+
from pathlib import Path
18+
import glob, zarr
19+
from ome_zarr_io.ngff.multiscales import Pyramid, Multimeta
20+
from ome_zarr_io.base.readers import ImageReader
21+
from typing import Callable, Any
22+
from collections import defaultdict
23+
24+
from distributed import LocalCluster, Client
25+
from joblib import delayed as jdel, Parallel, parallel_config
26+
from joblib.externals.loky import get_reusable_executor
27+
get_reusable_executor().shutdown()
28+
29+
30+
def get_regions(array_shape,
31+
region_shape,
32+
as_slices = False
33+
):
34+
assert len(array_shape) == len(region_shape)
35+
steps = []
36+
for i in range(len(region_shape)):
37+
size = array_shape[i]
38+
inc = region_shape[i]
39+
seq = np.arange(0, size, inc)
40+
if size > seq[-1]:
41+
seq = np.append(seq, size)
42+
increments = tuple([(seq[i], seq[i+1]) for i in range(len(seq) - 1)])
43+
tuples = tuple(tuple(item) for item in increments)
44+
if as_slices:
45+
slcs = tuple([slice(*item) for item in tuples])
46+
steps.append(slcs)
47+
else:
48+
steps.append(tuples)
49+
out = list(itertools.product(*steps))
50+
return out
51+
52+
def read_image(file_path: Path | str):
53+
img = AICSImage(file_path)
54+
return img.get_image_dask_data()
55+
56+
def create_zarr_array(directory: Path | str | zarr.Group,
57+
array_name: str,
58+
shape: tuple,
59+
chunks: tuple,
60+
dtype: Any,
61+
overwrite: bool = False,
62+
) -> zarr.Array:
63+
chunks = tuple(np.minimum(shape, chunks))
64+
if not isinstance(directory, zarr.Group):
65+
path = os.path.join(directory, array_name)
66+
dataset = zarr.create(shape=shape,
67+
chunks=chunks,
68+
dtype=dtype,
69+
store=path,
70+
dimension_separator='/',
71+
overwrite=overwrite
72+
)
73+
else:
74+
_ = directory.create(name = array_name,
75+
shape = shape,
76+
chunks = chunks,
77+
dtype = dtype,
78+
dimension_separator='/',
79+
overwrite=overwrite
80+
)
81+
dataset = directory[array_name]
82+
return dataset
83+
84+
def write_single_region(region: da.Array,
85+
dataset: Path | str | zarr.Array,
86+
region_slice: slice = None
87+
):
88+
da.to_zarr(region,
89+
url = dataset,
90+
region = region_slice,
91+
compute=True,
92+
return_stored=True
93+
)
94+
return dataset
95+
96+
def write_regions_sequential(
97+
image_regions: tuple,
98+
region_slices: tuple,
99+
dataset: zarr.Array
100+
):
101+
executor = get_reusable_executor(max_workers=n_jobs,
102+
kill_workers=True,
103+
context='loky')
104+
for region_slice, image_region in zip(region_slices, image_regions):
105+
executor.submit(write_single_region,
106+
region=image_region,
107+
dataset=dataset,
108+
region_slice=region_slice
109+
)
110+
return dataset
111+
112+
def write_regions(
113+
image_regions: tuple,
114+
region_slices: tuple,
115+
dataset: zarr.Array,
116+
client: Client = None
117+
) -> zarr.Array:
118+
if client is None:
119+
n_jobs = 4
120+
else:
121+
n_jobs = client.cluster.workers.__len__()
122+
client.cluster.scale(n_jobs)
123+
client.scatter(image_regions)
124+
client.scatter(region_slices)
125+
with parallel_config(backend = 'loky', n_jobs = n_jobs):
126+
with Parallel() as parallel:
127+
parallel(jdel(write_single_region)(region = image_region,
128+
region_slice = region_slice,
129+
dataset = dataset)
130+
for image_region, region_slice in
131+
zip(image_regions, region_slices)
132+
)
133+
return dataset
134+
135+
def deconvolve_block(img, psf=None, iterations=20):
136+
# Pad PSF with zeros to match image shape
137+
pad_l, pad_r = np.divmod(np.array(img.shape) - np.array(psf.shape), 2)
138+
pad_r += pad_l
139+
psf = np.pad(psf, tuple(zip(pad_l, pad_r)), 'constant', constant_values=0)
140+
# Recenter PSF at the origin
141+
# Needed to ensure PSF doesn't introduce an offset when
142+
# convolving with image
143+
for i in range(psf.ndim):
144+
psf = np.roll(psf, psf.shape[i] // 2, axis=i)
145+
# Convolution requires FFT of the PSF
146+
psf = np.fft.rfftn(psf)
147+
# Perform deconvolution in-place on a copy of the image
148+
# (avoids changing the original)
149+
img_decon = np.copy(img)
150+
for _ in range(iterations):
151+
ratio = img / np.fft.irfftn(np.fft.rfftn(img_decon) * psf)
152+
img_decon *= np.fft.irfftn((np.fft.rfftn(ratio).conj() * psf).conj())
153+
return img_decon
154+
155+
156+
import numpy as np
157+
158+
159+
def gaussian_psf(shape, mean, cov):
160+
"""
161+
Computes an n-dimensional Gaussian function over a grid defined by the given shape.
162+
163+
Parameters:
164+
shape (tuple of int): Shape of the n-dimensional grid (e.g., (height, width, depth)).
165+
mean (float or list-like): Scalar or array-like representing the mean of the Gaussian.
166+
If scalar, it will be applied to all dimensions.
167+
cov (float or list-like): Scalar, 1D array, or 2D array representing the covariance.
168+
- If scalar, creates an isotropic Gaussian.
169+
- If 1D, creates a diagonal covariance matrix.
170+
- If 2D, used directly as the covariance matrix.
171+
172+
Returns:
173+
np.ndarray: An n-dimensional array containing the Gaussian function values.
174+
"""
175+
n = len(shape)
176+
if np.isscalar(mean):
177+
mean = np.full(n, mean)
178+
else:
179+
mean = np.asarray(mean)
180+
if mean.shape[0] != n:
181+
raise ValueError(f"Mean must match the number of dimensions ({n}).")
182+
if np.isscalar(cov):
183+
cov = np.eye(n) * cov
184+
elif np.ndim(cov) == 1:
185+
if len(cov) != n:
186+
raise ValueError(f"Covariance vector length must match the number of dimensions ({n}).")
187+
cov = np.diag(cov)
188+
elif np.ndim(cov) == 2:
189+
cov = np.asarray(cov)
190+
if cov.shape != (n, n):
191+
raise ValueError(f"Covariance matrix must be ({n}, {n}).")
192+
else:
193+
raise ValueError("Covariance must be a scalar, 1D array, or 2D matrix.")
194+
grids = np.meshgrid(*[np.arange(s) for s in shape], indexing='ij')
195+
coords = np.stack(grids, axis=-1) # Shape: (*shape, n)
196+
flat_coords = coords.reshape(-1, n)
197+
det_cov = np.linalg.det(cov)
198+
inv_cov = np.linalg.inv(cov)
199+
if det_cov <= 0:
200+
raise ValueError("Covariance matrix must be positive definite.")
201+
norm_factor = 1 / (np.sqrt((2 * np.pi) ** n * det_cov))
202+
diff = flat_coords - mean
203+
exponent = -0.5 * np.sum(diff @ inv_cov * diff, axis=1)
204+
gaussian_values = norm_factor * np.exp(exponent)
205+
return gaussian_values.reshape(shape)
206+
207+
208+
209+
def richardson_lucy(img: da.Array,
210+
psf: da.Array,
211+
iterations: int = 20,
212+
backend: str = 'cupy'
213+
):
214+
if backend == 'cupy':
215+
img = img.map_blocks(cp.asarray)
216+
psf = psf.map_blocks(cp.asarray)
217+
deconvolved = img.map_overlap(
218+
deconvolve_block,
219+
psf = psf,
220+
iterations = iterations,
221+
meta = img._meta,
222+
depth = tuple(np.array(psf.shape) // 2),
223+
boundary = "periodic"
224+
)
225+
if backend == 'cupy':
226+
deconvolved = deconvolved.map_blocks(cp.asnumpy)
227+
return deconvolved
228+
229+
230+
231+
def to_ngff(arr: da.Array,
232+
output_path: str | Path,
233+
region_shape: tuple = None,
234+
scale: tuple = None,
235+
units: tuple = None,
236+
client: Client = None
237+
) -> zarr.Group:
238+
239+
region_slices = get_regions(arr.shape, region_shape, as_slices = True)
240+
241+
gr = zarr.open_group(output_path, mode='a')
242+
dataset = create_zarr_array(gr,
243+
array_name = '0',
244+
shape = arr.shape,
245+
chunks = chunks,
246+
dtype = arr.dtype,
247+
overwrite = True
248+
)
249+
250+
meta = Multimeta()
251+
meta.parse_axes(axis_order='tczyx',
252+
unit_list = units
253+
)
254+
meta.add_dataset(path = '0',
255+
scale = scale
256+
)
257+
meta.to_ngff(gr)
258+
259+
image_regions = [arr[region_slice] for region_slice in region_slices]
260+
if client is not None:
261+
client.scatter(region_slices)
262+
client.scatter(image_regions)
263+
264+
write_regions(image_regions = image_regions,
265+
region_slices = region_slices,
266+
dataset = dataset,
267+
client = client)
268+
return gr
269+
270+
271+
272+
if __name__ == '__main__':
273+
274+
chunks = (1, 1, 96, 128, 128)
275+
region_shape = (128, 2, 96, 128, 128)
276+
scale = (600, 1, 2, 0.406, 0.406)
277+
units = ('s', 'Channel', 'µm', 'µm', 'µm')
278+
psf = gaussian_psf((1, 1, 12, 16, 16), (1, 1, 6, 8, 8), (1, 1, 12, 16, 16))
279+
psf = da.from_array(psf, chunks = chunks)
280+
281+
n_jobs = 4
282+
threads_per_worker = 1
283+
memory_limit = '3GB'
284+
285+
input_tiff_path_mg = f"/home/oezdemir/data/original/franziska/crop/mG_View1/*"
286+
input_tiff_path_h2b = f"/home/oezdemir/data/original/franziska/crop/H2B_View1/*"
287+
288+
output_zarr_path = f"/home/oezdemir/data/original/franziska/concat.zarr"
289+
290+
t0 = time.time()
291+
292+
paths_mg = sorted(glob.glob(input_tiff_path_mg))
293+
paths_h2b = sorted(glob.glob(input_tiff_path_h2b))
294+
295+
with LocalCluster(n_workers=n_jobs, threads_per_worker=threads_per_worker, memory_limit=memory_limit) as cluster:
296+
cluster.scale(n_jobs)
297+
with Client(cluster) as client:
298+
299+
### Read image collections
300+
imgs_mg = [read_image(path) for path in paths_mg]
301+
imgs_h2b = [read_image(path) for path in paths_h2b]
302+
303+
### Concatenate collections into a single dask array
304+
mg_merged = da.concatenate(imgs_mg, axis = 0) # concatenate along the time dimension
305+
h2b_merged = da.concatenate(imgs_h2b, axis = 0) # concatenate along the time dimension
306+
imgs_merged = da.concatenate((mg_merged, h2b_merged), axis = 1) # concatenate along the channel dimension
307+
308+
### Process merged images
309+
310+
###
311+
to_ngff(imgs_merged,
312+
output_path = output_zarr_path,
313+
region_shape = region_shape,
314+
scale = scale,
315+
units = units,
316+
client = client
317+
)
318+
319+
320+
321+
322+

0 commit comments

Comments
 (0)