Skip to content

Commit 4ffea82

Browse files
committed
updates
1 parent 17a4832 commit 4ffea82

File tree

2 files changed

+287
-140
lines changed

2 files changed

+287
-140
lines changed

convert_to_ngff.py

Lines changed: 99 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dask import array as da, bag, delayed
1010
from dask.highlevelgraph import HighLevelGraph
1111
import dask
12+
from dask_image import ndfilters
1213
from dask_cuda import LocalCUDACluster
1314
from rmm.allocators.cupy import rmm_cupy_allocator
1415
import rmm
@@ -17,6 +18,8 @@
1718
from pathlib import Path
1819
import glob, zarr
1920
from zarr_parallel_processing.multiscales import Multimeta
21+
from zarr_parallel_processing import utils
22+
2023
from typing import Callable, Any
2124
from collections import defaultdict
2225

@@ -95,7 +98,8 @@ def write_single_region(region: da.Array,
9598
def write_regions_sequential(
9699
image_regions: tuple,
97100
region_slices: tuple,
98-
dataset: zarr.Array
101+
dataset: zarr.Array,
102+
**kwargs
99103
):
100104
executor = get_reusable_executor(max_workers=n_jobs,
101105
kill_workers=True,
@@ -131,109 +135,22 @@ def write_regions(
131135
)
132136
return dataset
133137

134-
def deconvolve_block(img, psf=None, iterations=20):
135-
# Pad PSF with zeros to match image shape
136-
pad_l, pad_r = np.divmod(np.array(img.shape) - np.array(psf.shape), 2)
137-
pad_r += pad_l
138-
psf = np.pad(psf, tuple(zip(pad_l, pad_r)), 'constant', constant_values=0)
139-
# Recenter PSF at the origin
140-
# Needed to ensure PSF doesn't introduce an offset when
141-
# convolving with image
142-
for i in range(psf.ndim):
143-
psf = np.roll(psf, psf.shape[i] // 2, axis=i)
144-
# Convolution requires FFT of the PSF
145-
psf = np.fft.rfftn(psf)
146-
# Perform deconvolution in-place on a copy of the image
147-
# (avoids changing the original)
148-
img_decon = np.copy(img)
149-
for _ in range(iterations):
150-
ratio = img / np.fft.irfftn(np.fft.rfftn(img_decon) * psf)
151-
img_decon *= np.fft.irfftn((np.fft.rfftn(ratio).conj() * psf).conj())
152-
return img_decon
153-
154-
155-
import numpy as np
156-
157-
158-
def gaussian_psf(shape, mean, cov):
159-
"""
160-
Computes an n-dimensional Gaussian function over a grid defined by the given shape.
161-
162-
Parameters:
163-
shape (tuple of int): Shape of the n-dimensional grid (e.g., (height, width, depth)).
164-
mean (float or list-like): Scalar or array-like representing the mean of the Gaussian.
165-
If scalar, it will be applied to all dimensions.
166-
cov (float or list-like): Scalar, 1D array, or 2D array representing the covariance.
167-
- If scalar, creates an isotropic Gaussian.
168-
- If 1D, creates a diagonal covariance matrix.
169-
- If 2D, used directly as the covariance matrix.
170-
171-
Returns:
172-
np.ndarray: An n-dimensional array containing the Gaussian function values.
173-
"""
174-
n = len(shape)
175-
if np.isscalar(mean):
176-
mean = np.full(n, mean)
177-
else:
178-
mean = np.asarray(mean)
179-
if mean.shape[0] != n:
180-
raise ValueError(f"Mean must match the number of dimensions ({n}).")
181-
if np.isscalar(cov):
182-
cov = np.eye(n) * cov
183-
elif np.ndim(cov) == 1:
184-
if len(cov) != n:
185-
raise ValueError(f"Covariance vector length must match the number of dimensions ({n}).")
186-
cov = np.diag(cov)
187-
elif np.ndim(cov) == 2:
188-
cov = np.asarray(cov)
189-
if cov.shape != (n, n):
190-
raise ValueError(f"Covariance matrix must be ({n}, {n}).")
191-
else:
192-
raise ValueError("Covariance must be a scalar, 1D array, or 2D matrix.")
193-
grids = np.meshgrid(*[np.arange(s) for s in shape], indexing='ij')
194-
coords = np.stack(grids, axis=-1) # Shape: (*shape, n)
195-
flat_coords = coords.reshape(-1, n)
196-
det_cov = np.linalg.det(cov)
197-
inv_cov = np.linalg.inv(cov)
198-
if det_cov <= 0:
199-
raise ValueError("Covariance matrix must be positive definite.")
200-
norm_factor = 1 / (np.sqrt((2 * np.pi) ** n * det_cov))
201-
diff = flat_coords - mean
202-
exponent = -0.5 * np.sum(diff @ inv_cov * diff, axis=1)
203-
gaussian_values = norm_factor * np.exp(exponent)
204-
return gaussian_values.reshape(shape)
205-
206-
207-
208-
def richardson_lucy(img: da.Array,
209-
psf: da.Array,
210-
iterations: int = 20,
211-
backend: str = 'cupy'
212-
):
213-
if backend == 'cupy':
214-
img = img.map_blocks(cp.asarray)
215-
psf = psf.map_blocks(cp.asarray)
216-
deconvolved = img.map_overlap(
217-
deconvolve_block,
218-
psf = psf,
219-
iterations = iterations,
220-
meta = img._meta,
221-
depth = tuple(np.array(psf.shape) // 2),
222-
boundary = "periodic"
223-
)
224-
if backend == 'cupy':
225-
deconvolved = deconvolved.map_blocks(cp.asnumpy)
226-
return deconvolved
227138

228139

229140

230-
def to_ngff(arr: da.Array,
231-
output_path: str | Path,
232-
region_shape: tuple = None,
233-
scale: tuple = None,
234-
units: tuple = None,
235-
client: Client = None
236-
) -> zarr.Group:
141+
# def threshold_local(img: da.Array)
142+
143+
144+
def process_and_save_to_ngff(arr: da.Array,
145+
output_path: str | Path,
146+
region_shape: tuple = None,
147+
scale: tuple = None,
148+
units: tuple = None,
149+
client: Client = None,
150+
parallelize_over_regions = True,
151+
func: Callable = utils.otsu,
152+
**func_params
153+
) -> zarr.Group:
237154

238155
region_slices = get_regions(arr.shape, region_shape, as_slices = True)
239156

@@ -256,64 +173,106 @@ def to_ngff(arr: da.Array,
256173
meta.to_ngff(gr)
257174

258175
image_regions = [arr[region_slice] for region_slice in region_slices]
176+
# processed_regions = image_regions
177+
processed_regions = [func(reg, **func_params) for reg in image_regions]
178+
259179
if client is not None:
260180
client.scatter(region_slices)
261181
client.scatter(image_regions)
262182

263-
write_regions(image_regions = image_regions,
264-
region_slices = region_slices,
265-
dataset = dataset,
266-
client = client)
183+
if not parallelize_over_regions:
184+
write_regions(image_regions = processed_regions,
185+
region_slices = region_slices,
186+
dataset = dataset,
187+
client = client)
188+
else:
189+
write_regions_sequential(image_regions = processed_regions,
190+
region_slices = region_slices,
191+
dataset = dataset,
192+
client = client)
267193
return gr
268194

269195

270196

271-
if __name__ == '__main__':
197+
# if __name__ == '__main__':
272198

273-
chunks = (1, 1, 96, 128, 128)
274-
region_shape = (128, 2, 96, 128, 128)
275-
scale = (600, 1, 2, 0.406, 0.406)
276-
units = ('s', 'Channel', 'µm', 'µm', 'µm')
277-
psf = gaussian_psf((1, 1, 12, 16, 16), (1, 1, 6, 8, 8), (1, 1, 12, 16, 16))
278-
psf = da.from_array(psf, chunks = chunks)
199+
chunks = (1, 1, 48, 128, 128)
200+
region_shape = (1, 1, 91, 554, 928)
201+
scale = (600, 1, 2, 0.406, 0.406)
202+
units = ('s', 'Channel', 'µm', 'µm', 'µm')
203+
# psf = gaussian_psf((1, 1, 12, 16, 16), (1, 1, 6, 8, 8), (1, 1, 12, 16, 16))
204+
# psf = da.from_array(psf, chunks = chunks)
279205

280-
n_jobs = 4
281-
threads_per_worker = 1
282-
memory_limit = '3GB'
206+
block_size = (1, 1, 5, 9, 9)
283207

284-
input_tiff_path_mg = f"/home/oezdemir/data/original/franziska/crop/mG_View1/*"
285-
input_tiff_path_h2b = f"/home/oezdemir/data/original/franziska/crop/H2B_View1/*"
208+
n_jobs = 4
209+
threads_per_worker = 2
210+
memory_limit = '8GB'
286211

287-
output_zarr_path = f"/home/oezdemir/data/original/franziska/concat.zarr"
212+
input_tiff_path_mg = f"/home/oezdemir/data/original/franziska/crop/mG_View1/*"
213+
input_tiff_path_h2b = f"/home/oezdemir/data/original/franziska/crop/H2B_View1/*"
288214

289-
t0 = time.time()
215+
output_zarr_path = f"/home/oezdemir/data/original/franziska/concat.zarr"
290216

291-
paths_mg = sorted(glob.glob(input_tiff_path_mg))
292-
paths_h2b = sorted(glob.glob(input_tiff_path_h2b))
217+
t0 = time.time()
293218

294-
with LocalCluster(n_workers=n_jobs, threads_per_worker=threads_per_worker, memory_limit=memory_limit) as cluster:
295-
cluster.scale(n_jobs)
296-
with Client(cluster) as client:
219+
paths_mg = sorted(glob.glob(input_tiff_path_mg))
220+
paths_h2b = sorted(glob.glob(input_tiff_path_h2b))
297221

298-
### Read image collections
299-
imgs_mg = [read_image(path) for path in paths_mg]
300-
imgs_h2b = [read_image(path) for path in paths_h2b]
301222

302-
### Concatenate collections into a single dask array
303-
mg_merged = da.concatenate(imgs_mg, axis = 0) # concatenate along the time dimension
304-
h2b_merged = da.concatenate(imgs_h2b, axis = 0) # concatenate along the time dimension
305-
imgs_merged = da.concatenate((mg_merged, h2b_merged), axis = 1) # concatenate along the channel dimension
306223

307-
### Process merged images
224+
# imgs_mg = [read_image(path) for path in paths_mg]
225+
# imgs_h2b = [read_image(path) for path in paths_h2b]
226+
#
227+
# ### Concatenate collections into a single dask array
228+
# mg_merged = da.concatenate(imgs_mg, axis=0) # concatenate along the time dimension
229+
# h2b_merged = da.concatenate(imgs_h2b, axis=0) # concatenate along the time dimension
230+
# imgs_merged = da.concatenate((mg_merged, h2b_merged), axis=1) # concatenate along the channel dimension
308231

309-
###
310-
to_ngff(imgs_merged,
311-
output_path = output_zarr_path,
312-
region_shape = region_shape,
313-
scale = scale,
314-
units = units,
315-
client = client
316-
)
232+
# processed_img = da.concatenate([otsu(img, return_thresholded=True) for img in imgs_merged], axis=0)
233+
234+
235+
with LocalCluster(processes=True,
236+
nanny=True,
237+
n_workers=n_jobs,
238+
threads_per_worker=threads_per_worker,
239+
memory_limit=memory_limit) as cluster:
240+
cluster.scale(n_jobs)
241+
with Client(cluster,
242+
heartbeat_interval="120s",
243+
timeout="600s",
244+
) as client:
245+
246+
### Read image collections
247+
imgs_mg = [read_image(path) for path in paths_mg]
248+
imgs_h2b = [read_image(path) for path in paths_h2b]
249+
250+
### Concatenate collections into a single dask array
251+
mg_merged = da.concatenate(imgs_mg, axis = 0) # concatenate along the time dimension
252+
h2b_merged = da.concatenate(imgs_h2b, axis = 0) # concatenate along the time dimension
253+
imgs_merged = da.concatenate((mg_merged, h2b_merged), axis = 1) # concatenate along the channel dimension
254+
255+
### Process merged images
256+
processed_img = imgs_merged
257+
# processed_img = ndfilters.threshold_local(imgs_merged, block_size=block_size, method='mean')
258+
# processed_img = ndfilters.gaussian_filter(imgs_merged, sigma = (0.4, 0.4, 1, 1, 1))
259+
# filtered = ndfilters.uniform_filter(imgs_merged, size = block_size)
260+
# processed_img = imgs_merged > filtered
261+
# processed_mg = da.concatenate([utils.mean_threshold(img, return_thresholded=True) for img in imgs_mg], axis = 0)
262+
# processed_h2b = da.concatenate([utils.mean_threshold(img, return_thresholded=True) for img in imgs_h2b], axis = 0)
263+
# processed_mg = da.concatenate([utils.otsu(img, bincount = 9, return_thresholded=True) for img in imgs_mg], axis = 0)
264+
# processed_h2b = da.concatenate([utils.otsu(img, bincount = 9, return_thresholded=True) for img in imgs_h2b], axis = 0)
265+
# processed_img = da.concatenate((processed_mg, processed_h2b), axis = 1) # concatenate along the channel dimension
266+
267+
process_and_save_to_ngff(processed_img,
268+
output_path = output_zarr_path,
269+
region_shape = region_shape,
270+
scale = scale,
271+
units = units,
272+
client = client,
273+
parallelize_over_regions=False,
274+
func = utils.otsu,
275+
)
317276

318277

319278

0 commit comments

Comments
 (0)