9
9
from dask import array as da , bag , delayed
10
10
from dask .highlevelgraph import HighLevelGraph
11
11
import dask
12
+ from dask_image import ndfilters
12
13
from dask_cuda import LocalCUDACluster
13
14
from rmm .allocators .cupy import rmm_cupy_allocator
14
15
import rmm
17
18
from pathlib import Path
18
19
import glob , zarr
19
20
from zarr_parallel_processing .multiscales import Multimeta
21
+ from zarr_parallel_processing import utils
22
+
20
23
from typing import Callable , Any
21
24
from collections import defaultdict
22
25
@@ -95,7 +98,8 @@ def write_single_region(region: da.Array,
95
98
def write_regions_sequential (
96
99
image_regions : tuple ,
97
100
region_slices : tuple ,
98
- dataset : zarr .Array
101
+ dataset : zarr .Array ,
102
+ ** kwargs
99
103
):
100
104
executor = get_reusable_executor (max_workers = n_jobs ,
101
105
kill_workers = True ,
@@ -131,109 +135,22 @@ def write_regions(
131
135
)
132
136
return dataset
133
137
134
- def deconvolve_block (img , psf = None , iterations = 20 ):
135
- # Pad PSF with zeros to match image shape
136
- pad_l , pad_r = np .divmod (np .array (img .shape ) - np .array (psf .shape ), 2 )
137
- pad_r += pad_l
138
- psf = np .pad (psf , tuple (zip (pad_l , pad_r )), 'constant' , constant_values = 0 )
139
- # Recenter PSF at the origin
140
- # Needed to ensure PSF doesn't introduce an offset when
141
- # convolving with image
142
- for i in range (psf .ndim ):
143
- psf = np .roll (psf , psf .shape [i ] // 2 , axis = i )
144
- # Convolution requires FFT of the PSF
145
- psf = np .fft .rfftn (psf )
146
- # Perform deconvolution in-place on a copy of the image
147
- # (avoids changing the original)
148
- img_decon = np .copy (img )
149
- for _ in range (iterations ):
150
- ratio = img / np .fft .irfftn (np .fft .rfftn (img_decon ) * psf )
151
- img_decon *= np .fft .irfftn ((np .fft .rfftn (ratio ).conj () * psf ).conj ())
152
- return img_decon
153
-
154
-
155
- import numpy as np
156
-
157
-
158
- def gaussian_psf (shape , mean , cov ):
159
- """
160
- Computes an n-dimensional Gaussian function over a grid defined by the given shape.
161
-
162
- Parameters:
163
- shape (tuple of int): Shape of the n-dimensional grid (e.g., (height, width, depth)).
164
- mean (float or list-like): Scalar or array-like representing the mean of the Gaussian.
165
- If scalar, it will be applied to all dimensions.
166
- cov (float or list-like): Scalar, 1D array, or 2D array representing the covariance.
167
- - If scalar, creates an isotropic Gaussian.
168
- - If 1D, creates a diagonal covariance matrix.
169
- - If 2D, used directly as the covariance matrix.
170
-
171
- Returns:
172
- np.ndarray: An n-dimensional array containing the Gaussian function values.
173
- """
174
- n = len (shape )
175
- if np .isscalar (mean ):
176
- mean = np .full (n , mean )
177
- else :
178
- mean = np .asarray (mean )
179
- if mean .shape [0 ] != n :
180
- raise ValueError (f"Mean must match the number of dimensions ({ n } )." )
181
- if np .isscalar (cov ):
182
- cov = np .eye (n ) * cov
183
- elif np .ndim (cov ) == 1 :
184
- if len (cov ) != n :
185
- raise ValueError (f"Covariance vector length must match the number of dimensions ({ n } )." )
186
- cov = np .diag (cov )
187
- elif np .ndim (cov ) == 2 :
188
- cov = np .asarray (cov )
189
- if cov .shape != (n , n ):
190
- raise ValueError (f"Covariance matrix must be ({ n } , { n } )." )
191
- else :
192
- raise ValueError ("Covariance must be a scalar, 1D array, or 2D matrix." )
193
- grids = np .meshgrid (* [np .arange (s ) for s in shape ], indexing = 'ij' )
194
- coords = np .stack (grids , axis = - 1 ) # Shape: (*shape, n)
195
- flat_coords = coords .reshape (- 1 , n )
196
- det_cov = np .linalg .det (cov )
197
- inv_cov = np .linalg .inv (cov )
198
- if det_cov <= 0 :
199
- raise ValueError ("Covariance matrix must be positive definite." )
200
- norm_factor = 1 / (np .sqrt ((2 * np .pi ) ** n * det_cov ))
201
- diff = flat_coords - mean
202
- exponent = - 0.5 * np .sum (diff @ inv_cov * diff , axis = 1 )
203
- gaussian_values = norm_factor * np .exp (exponent )
204
- return gaussian_values .reshape (shape )
205
-
206
-
207
-
208
- def richardson_lucy (img : da .Array ,
209
- psf : da .Array ,
210
- iterations : int = 20 ,
211
- backend : str = 'cupy'
212
- ):
213
- if backend == 'cupy' :
214
- img = img .map_blocks (cp .asarray )
215
- psf = psf .map_blocks (cp .asarray )
216
- deconvolved = img .map_overlap (
217
- deconvolve_block ,
218
- psf = psf ,
219
- iterations = iterations ,
220
- meta = img ._meta ,
221
- depth = tuple (np .array (psf .shape ) // 2 ),
222
- boundary = "periodic"
223
- )
224
- if backend == 'cupy' :
225
- deconvolved = deconvolved .map_blocks (cp .asnumpy )
226
- return deconvolved
227
138
228
139
229
140
230
- def to_ngff (arr : da .Array ,
231
- output_path : str | Path ,
232
- region_shape : tuple = None ,
233
- scale : tuple = None ,
234
- units : tuple = None ,
235
- client : Client = None
236
- ) -> zarr .Group :
141
+ # def threshold_local(img: da.Array)
142
+
143
+
144
+ def process_and_save_to_ngff (arr : da .Array ,
145
+ output_path : str | Path ,
146
+ region_shape : tuple = None ,
147
+ scale : tuple = None ,
148
+ units : tuple = None ,
149
+ client : Client = None ,
150
+ parallelize_over_regions = True ,
151
+ func : Callable = utils .otsu ,
152
+ ** func_params
153
+ ) -> zarr .Group :
237
154
238
155
region_slices = get_regions (arr .shape , region_shape , as_slices = True )
239
156
@@ -256,64 +173,106 @@ def to_ngff(arr: da.Array,
256
173
meta .to_ngff (gr )
257
174
258
175
image_regions = [arr [region_slice ] for region_slice in region_slices ]
176
+ # processed_regions = image_regions
177
+ processed_regions = [func (reg , ** func_params ) for reg in image_regions ]
178
+
259
179
if client is not None :
260
180
client .scatter (region_slices )
261
181
client .scatter (image_regions )
262
182
263
- write_regions (image_regions = image_regions ,
264
- region_slices = region_slices ,
265
- dataset = dataset ,
266
- client = client )
183
+ if not parallelize_over_regions :
184
+ write_regions (image_regions = processed_regions ,
185
+ region_slices = region_slices ,
186
+ dataset = dataset ,
187
+ client = client )
188
+ else :
189
+ write_regions_sequential (image_regions = processed_regions ,
190
+ region_slices = region_slices ,
191
+ dataset = dataset ,
192
+ client = client )
267
193
return gr
268
194
269
195
270
196
271
- if __name__ == '__main__' :
197
+ # if __name__ == '__main__':
272
198
273
- chunks = (1 , 1 , 96 , 128 , 128 )
274
- region_shape = (128 , 2 , 96 , 128 , 128 )
275
- scale = (600 , 1 , 2 , 0.406 , 0.406 )
276
- units = ('s' , 'Channel' , 'µm' , 'µm' , 'µm' )
277
- psf = gaussian_psf ((1 , 1 , 12 , 16 , 16 ), (1 , 1 , 6 , 8 , 8 ), (1 , 1 , 12 , 16 , 16 ))
278
- psf = da .from_array (psf , chunks = chunks )
199
+ chunks = (1 , 1 , 48 , 128 , 128 )
200
+ region_shape = (1 , 1 , 91 , 554 , 928 )
201
+ scale = (600 , 1 , 2 , 0.406 , 0.406 )
202
+ units = ('s' , 'Channel' , 'µm' , 'µm' , 'µm' )
203
+ # psf = gaussian_psf((1, 1, 12, 16, 16), (1, 1, 6, 8, 8), (1, 1, 12, 16, 16))
204
+ # psf = da.from_array(psf, chunks = chunks)
279
205
280
- n_jobs = 4
281
- threads_per_worker = 1
282
- memory_limit = '3GB'
206
+ block_size = (1 , 1 , 5 , 9 , 9 )
283
207
284
- input_tiff_path_mg = f"/home/oezdemir/data/original/franziska/crop/mG_View1/*"
285
- input_tiff_path_h2b = f"/home/oezdemir/data/original/franziska/crop/H2B_View1/*"
208
+ n_jobs = 4
209
+ threads_per_worker = 2
210
+ memory_limit = '8GB'
286
211
287
- output_zarr_path = f"/home/oezdemir/data/original/franziska/concat.zarr"
212
+ input_tiff_path_mg = f"/home/oezdemir/data/original/franziska/crop/mG_View1/*"
213
+ input_tiff_path_h2b = f"/home/oezdemir/data/original/franziska/crop/H2B_View1/*"
288
214
289
- t0 = time . time ()
215
+ output_zarr_path = f"/home/oezdemir/data/original/franziska/concat.zarr"
290
216
291
- paths_mg = sorted (glob .glob (input_tiff_path_mg ))
292
- paths_h2b = sorted (glob .glob (input_tiff_path_h2b ))
217
+ t0 = time .time ()
293
218
294
- with LocalCluster (n_workers = n_jobs , threads_per_worker = threads_per_worker , memory_limit = memory_limit ) as cluster :
295
- cluster .scale (n_jobs )
296
- with Client (cluster ) as client :
219
+ paths_mg = sorted (glob .glob (input_tiff_path_mg ))
220
+ paths_h2b = sorted (glob .glob (input_tiff_path_h2b ))
297
221
298
- ### Read image collections
299
- imgs_mg = [read_image (path ) for path in paths_mg ]
300
- imgs_h2b = [read_image (path ) for path in paths_h2b ]
301
222
302
- ### Concatenate collections into a single dask array
303
- mg_merged = da .concatenate (imgs_mg , axis = 0 ) # concatenate along the time dimension
304
- h2b_merged = da .concatenate (imgs_h2b , axis = 0 ) # concatenate along the time dimension
305
- imgs_merged = da .concatenate ((mg_merged , h2b_merged ), axis = 1 ) # concatenate along the channel dimension
306
223
307
- ### Process merged images
224
+ # imgs_mg = [read_image(path) for path in paths_mg]
225
+ # imgs_h2b = [read_image(path) for path in paths_h2b]
226
+ #
227
+ # ### Concatenate collections into a single dask array
228
+ # mg_merged = da.concatenate(imgs_mg, axis=0) # concatenate along the time dimension
229
+ # h2b_merged = da.concatenate(imgs_h2b, axis=0) # concatenate along the time dimension
230
+ # imgs_merged = da.concatenate((mg_merged, h2b_merged), axis=1) # concatenate along the channel dimension
308
231
309
- ###
310
- to_ngff (imgs_merged ,
311
- output_path = output_zarr_path ,
312
- region_shape = region_shape ,
313
- scale = scale ,
314
- units = units ,
315
- client = client
316
- )
232
+ # processed_img = da.concatenate([otsu(img, return_thresholded=True) for img in imgs_merged], axis=0)
233
+
234
+
235
+ with LocalCluster (processes = True ,
236
+ nanny = True ,
237
+ n_workers = n_jobs ,
238
+ threads_per_worker = threads_per_worker ,
239
+ memory_limit = memory_limit ) as cluster :
240
+ cluster .scale (n_jobs )
241
+ with Client (cluster ,
242
+ heartbeat_interval = "120s" ,
243
+ timeout = "600s" ,
244
+ ) as client :
245
+
246
+ ### Read image collections
247
+ imgs_mg = [read_image (path ) for path in paths_mg ]
248
+ imgs_h2b = [read_image (path ) for path in paths_h2b ]
249
+
250
+ ### Concatenate collections into a single dask array
251
+ mg_merged = da .concatenate (imgs_mg , axis = 0 ) # concatenate along the time dimension
252
+ h2b_merged = da .concatenate (imgs_h2b , axis = 0 ) # concatenate along the time dimension
253
+ imgs_merged = da .concatenate ((mg_merged , h2b_merged ), axis = 1 ) # concatenate along the channel dimension
254
+
255
+ ### Process merged images
256
+ processed_img = imgs_merged
257
+ # processed_img = ndfilters.threshold_local(imgs_merged, block_size=block_size, method='mean')
258
+ # processed_img = ndfilters.gaussian_filter(imgs_merged, sigma = (0.4, 0.4, 1, 1, 1))
259
+ # filtered = ndfilters.uniform_filter(imgs_merged, size = block_size)
260
+ # processed_img = imgs_merged > filtered
261
+ # processed_mg = da.concatenate([utils.mean_threshold(img, return_thresholded=True) for img in imgs_mg], axis = 0)
262
+ # processed_h2b = da.concatenate([utils.mean_threshold(img, return_thresholded=True) for img in imgs_h2b], axis = 0)
263
+ # processed_mg = da.concatenate([utils.otsu(img, bincount = 9, return_thresholded=True) for img in imgs_mg], axis = 0)
264
+ # processed_h2b = da.concatenate([utils.otsu(img, bincount = 9, return_thresholded=True) for img in imgs_h2b], axis = 0)
265
+ # processed_img = da.concatenate((processed_mg, processed_h2b), axis = 1) # concatenate along the channel dimension
266
+
267
+ process_and_save_to_ngff (processed_img ,
268
+ output_path = output_zarr_path ,
269
+ region_shape = region_shape ,
270
+ scale = scale ,
271
+ units = units ,
272
+ client = client ,
273
+ parallelize_over_regions = False ,
274
+ func = utils .otsu ,
275
+ )
317
276
318
277
319
278
0 commit comments