From d7434cb1d6145b050f34f1282f178875306315cb Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Fri, 28 Feb 2025 21:24:23 -0500 Subject: [PATCH 01/14] update vendored binary_fill_holes --- .../skimage/_vendored/_ndimage_morphology.py | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/python/cucim/src/cucim/skimage/_vendored/_ndimage_morphology.py b/python/cucim/src/cucim/skimage/_vendored/_ndimage_morphology.py index 23d860bd..2a7af8e5 100644 --- a/python/cucim/src/cucim/skimage/_vendored/_ndimage_morphology.py +++ b/python/cucim/src/cucim/skimage/_vendored/_ndimage_morphology.py @@ -808,7 +808,7 @@ def binary_propagation( ) -def _binary_fill_holes_non_iterative(input, output=None): +def _binary_fill_holes_non_iterative(input, structure=None, output=None): """Non-iterative method for hole filling. This algorithm is based on inverting the input and then using `label` to @@ -832,7 +832,9 @@ def _binary_fill_holes_non_iterative(input, output=None): # assign unique labels the background and holes inverse_binary_mask = ~binary_mask - inverse_labels, _ = _measurements.label(inverse_binary_mask) + inverse_labels, _ = _measurements.label( + inverse_binary_mask, structure=structure + ) # After inversion, what was originally the background will now be the # first foreground label encountered. This is ensured due to the @@ -851,6 +853,13 @@ def _binary_fill_holes_non_iterative(input, output=None): if output is None: output = cupy.ascontiguousarray(temp) else: + # handle output argument as in _binary_erosion + if isinstance(output, cupy.ndarray): + if output.dtype.kind == "c": + raise TypeError("Complex output type not supported") + else: + output = bool + output = _util._get_output(output, input) output[:] = temp return output @@ -901,15 +910,13 @@ def binary_fill_holes( filter_all_axes = axes == tuple(range(input.ndim)) if isinstance(origin, int): origin = (origin,) * len(axes) - if structure is None and all(o == 0 for o in origin) and filter_all_axes: + if all(o == 0 for o in origin) and filter_all_axes: return _binary_fill_holes_non_iterative(input, output=output) - else: - if filter_all_axes: - warnings.warn( - "It is recommended to keep the default structure=None and " - "origin=0, so that a faster non-iterative algorithm can be " - "used." - ) + elif filter_all_axes: + warnings.warn( + "It is recommended to keep the default origin=0 so that a faster " + "non-iterative algorithm can be used." + ) mask = cupy.logical_not(input) tmp = cupy.zeros(mask.shape, bool) inplace = isinstance(output, cupy.ndarray) From 18c87ed367eaa2c27b1c9a6881d60e5bef9f2fbc Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Sat, 1 Mar 2025 01:15:08 -0500 Subject: [PATCH 02/14] add batch implementation for various basic regionprops The functions introduced here are not being added to the public API. They will be used behind the scenes from `regionprops_table` to enable orders of magnitude faster computation of region properties for all labels in an image. The basic approach here is to compute a property for all labels in an image from a single CUDA kernel call. This is in contrast to the approach from the `RegionProperties` class which first splits the full image into small sub-images corresponding to each region and then loops over these small sub-images, computing the requested property for each small region in turn. That approach is not amenable to good acceleration on the GPU as individual regions are typically small. Provides batch implementation that computes the following properties for all properties in a single kernel call: - bbox - label_filled (creates version of label_image with all holes filled) - num_pixels - num_pixels_filled - num_perimeter_pixels (number of pixels at perimeter of each labeled region) - num_boundary_pixels (number of pixels touching the image boundary for each region) The following properties are simple transformations of the properties above and have negligable additional cost to compute: - area - area_bbox - area_filled - equivalent_diameter_area - equivalent_spherical_perimeter (as in ITK) - extent - perimeter_on_border_ratio (as in ITK) - slice The following split the label image into a list of sub-images or subsets of coordinates where each element in the list corresponds to a label. The background of the label image has value 0 and is not represented in the sequences. Sequence entry `i` corresponds to label `i + 1`. In most cases, these will not be needed as properties are now computed for all regions at once from the labels image, but they are provided for completeness and to match the scikit-image API. - coords - coords_scaled - image (label mask subimages) - image_convex (convex label mask subimages) - image_intensity (intensity_image subimages) - image_filled (subimages of label mask but with holes filled) - label (sequence of integer label ids) Test cases are added that compare the results of these batch computations to results from scikit-image `regionprops_table`. --- .../measure/_regionprops_gpu_basic_kernels.py | 951 ++++++++++++++++++ .../skimage/measure/_regionprops_gpu_utils.py | 79 ++ .../tests/test_regionprops_gpu_kernels.py | 303 ++++++ 3 files changed, 1333 insertions(+) create mode 100644 python/cucim/src/cucim/skimage/measure/_regionprops_gpu_basic_kernels.py create mode 100644 python/cucim/src/cucim/skimage/measure/_regionprops_gpu_utils.py create mode 100644 python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_basic_kernels.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_basic_kernels.py new file mode 100644 index 00000000..098e1d8b --- /dev/null +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_basic_kernels.py @@ -0,0 +1,951 @@ +import math + +import cupy as cp +import numpy as np + +import cucim.skimage._vendored.ndimage as ndi +from cucim.skimage.measure import label +from cucim.skimage.morphology import convex_hull_image + +from ._regionprops_gpu_utils import ( + _get_count_dtype, + _get_min_integer_dtype, + _includes, + _unravel_loop_index, + _unravel_loop_index_declarations, +) + +__all__ = [ + "equivalent_diameter_area", + "equivalent_diameter_area_2d", + "equivalent_diameter_area_3d", + "regionprops_area", + "regionprops_area_bbox", + "regionprops_bbox_coords", + "regionprops_coords", + "regionprops_extent", + "regionprops_image", + "regionprops_num_pixels", + # extra functions for cuCIM not currently in scikit-image + "regionprops_label_filled", + "regionprops_num_boundary_pixels", + "equivalent_spherical_perimeter", + "regionprops_num_perimeter_pixels", +] + + +def _get_bbox_code(uint_t, ndim, array_size): + """ + Notes + ----- + Local variables created: + + - bbox_min : shape (array_size, ndim) + local minimum coordinates across the local set of labels encountered + - bbox_max : shape (array_size, ndim) + local maximum coordinates across the local set of labels encountered + + Output variables written to: + + - bbox : shape (max_label, 2 * ndim) + """ + + # declaration uses external variable: + # labels_size : total number of pixels in the label image + source_pre = f""" + // bounding box variables + {uint_t} bbox_min[{ndim * array_size}]; + {uint_t} bbox_max[{ndim * array_size}] = {{0}}; + // initialize minimum coordinate to array size + for (size_t ii = 0; ii < {ndim * array_size}; ii++) {{ + bbox_min[ii] = labels_size; + }}\n""" + + # op uses external coordinate array variables: + # in_coord[0]...in_coord[ndim - 1] : coordinates + # coordinates in the labeled image at the current index + source_operation = f""" + bbox_min[{ndim}*offset] = min(in_coord[0], bbox_min[{ndim}*offset]); + bbox_max[{ndim}*offset] = max(in_coord[0] + 1, bbox_max[{ndim}*offset]);""" # noqa: E501 + for d in range(ndim): + source_operation += f""" + bbox_min[{ndim}*offset + {d}] = min(in_coord[{d}], bbox_min[{ndim}*offset + {d}]); + bbox_max[{ndim}*offset + {d}] = max(in_coord[{d}] + 1, bbox_max[{ndim}*offset + {d}]);""" # noqa: E501 + + # post_operation uses external variables: + # ii : index into num_pixels array + # lab : label value that corresponds to location ii + # bbox : output with shape (max_label, 2 * ndim) + source_post = f""" + // bounding box outputs + atomicMin(&bbox[(lab - 1)*{2 * ndim}], bbox_min[{ndim}*ii]); + atomicMax(&bbox[(lab - 1)*{2 * ndim} + {ndim}], bbox_max[{ndim}*ii]);""" # noqa: E501 + for d in range(1, ndim): + source_post += f""" + atomicMin(&bbox[(lab - 1)*{2*ndim} + {d}], bbox_min[{ndim}*ii + {d}]); + atomicMax(&bbox[(lab - 1)*{2*ndim} + {d + ndim}], bbox_max[{ndim}*ii + {d}]);""" # noqa: E501 + return source_pre, source_operation, source_post + + +def _get_num_pixels_code(pixels_per_thread, array_size): + """ + Notes + ----- + Local variables created: + + - num_pixels : shape (array_size, ) + The number of pixels encountered per label value + + Output variables written to: + + - counts : shape (max_label,) + """ + pixel_count_dtype = "int8_t" if pixels_per_thread < 256 else "int16_t" + + source_pre = f""" + // num_pixels variables + {pixel_count_dtype} num_pixels[{array_size}] = {{0}};\n""" + + source_operation = """ + num_pixels[offset] += 1;\n""" + + # post_operation requires external variables: + # ii : index into num_pixels array + # lab : label value that corresponds to location ii + # counts : output with shape (max_label,) + source_post = """ + atomicAdd(&counts[lab - 1], num_pixels[ii]);;""" + return source_pre, source_operation, source_post + + +def _get_coord_sums_code(coord_sum_ctype, ndim, array_size): + """ + Notes + ----- + Local variables created: + + - coord_sum : shape (array_size, ndim) + local sum of coordinates across the local set of labels encountered + + Output variables written to: + + - coord_sums : shape (max_label, 2 * ndim) + """ + + source_pre = f""" + {coord_sum_ctype} coord_sum[{ndim * array_size}] = {{0}};\n""" + + # op uses external coordinate array variables: + # in_coord[0]...in_coord[ndim - 1] : coordinates + # coordinates in the labeled image at the current index + source_operation = f""" + coord_sum[{ndim}*offset] += in_coord[0];""" + for d in range(1, ndim): + source_operation += f""" + coord_sum[{ndim}*offset + {d}] += in_coord[{d}];""" + # post_operation uses external variables: + # ii : index into num_pixels array + # lab : label value that corresponds to location ii + # coord_sums : output with shape (max_label, ndim) + source_post = f""" + // bounding box outputs + atomicAdd(&coord_sums[(lab - 1) * {ndim}], coord_sum[{ndim}*ii]);""" + for d in range(1, ndim): + source_post += f""" + atomicAdd(&coord_sums[(lab - 1) * {ndim} + {d}], + coord_sum[{ndim}*ii + {d}]);""" + return source_pre, source_operation, source_post + + +@cp.memoize(for_each_device=True) +def get_bbox_coords_kernel( + ndim, + int32_coords=True, + int32_count=True, + compute_bbox=True, + compute_num_pixels=False, + compute_coordinate_sums=False, + pixels_per_thread=8, +): + coord_dtype = cp.dtype(cp.uint32 if int32_coords else cp.uint64) + if compute_num_pixels: + count_dtype = cp.dtype(cp.uint32 if int32_count else cp.uint64) + if compute_coordinate_sums: + coord_sum_dtype = cp.dtype(cp.uint64) + coord_sum_ctype = "uint64_t" + + array_size = pixels_per_thread + + if coord_dtype.itemsize <= 4: + uint_t = "unsigned int" + else: + uint_t = "unsigned long long" + + if not (compute_bbox or compute_num_pixels or compute_coordinate_sums): + raise ValueError("no computation requested") + + if compute_bbox: + bbox_pre, bbox_op, bbox_post = _get_bbox_code( + uint_t=uint_t, ndim=ndim, array_size=array_size + ) + if compute_num_pixels: + count_pre, count_op, count_post = _get_num_pixels_code( + pixels_per_thread=pixels_per_thread, array_size=array_size + ) + if compute_coordinate_sums: + coord_sums_pre, coord_sums_op, coord_sums_post = _get_coord_sums_code( + coord_sum_ctype=coord_sum_ctype, ndim=ndim, array_size=array_size + ) + # store only counts for label > 0 (label = 0 is the background) + source = f""" + uint64_t start_index = {pixels_per_thread}*i; + """ + if compute_bbox: + source += bbox_pre + if compute_num_pixels: + source += count_pre + if compute_coordinate_sums: + source += coord_sums_pre + + inner_op = "" + if compute_bbox or compute_coordinate_sums: + source += _unravel_loop_index_declarations( + "labels", ndim, uint_t=uint_t + ) + + inner_op += _unravel_loop_index( + "labels", + ndim=ndim, + uint_t=uint_t, + raveled_index="ii", + omit_declarations=True, + ) + if compute_bbox: + inner_op += bbox_op + if compute_num_pixels: + inner_op += count_op + if compute_coordinate_sums: + inner_op += coord_sums_op + + source += f""" + X encountered_labels[{array_size}] = {{0}}; + X current_label; + X prev_label = labels[start_index]; + int offset = 0; + encountered_labels[0] = prev_label; + uint64_t ii_max = min(start_index + {pixels_per_thread}, labels_size); + for (uint64_t ii = start_index; ii < ii_max; ii++) {{ + current_label = labels[ii]; + if (current_label == 0) {{ continue; }} + if (current_label != prev_label) {{ + offset += 1; + prev_label = current_label; + encountered_labels[offset] = current_label; + }} + {inner_op} + }}""" + source += """ + for (size_t ii = 0; ii <= offset; ii++) { + X lab = encountered_labels[ii]; + if (lab != 0) {""" + + if compute_bbox: + source += bbox_post + if compute_num_pixels: + source += count_post + if compute_coordinate_sums: + source += coord_sums_post + source += """ + } + }\n""" + + # print(source) + inputs = "raw X labels, raw uint64 labels_size" + outputs = [] + name = "cucim_" + if compute_bbox: + outputs.append(f"raw {coord_dtype.name} bbox") + name += f"_bbox{ndim}d" + if compute_num_pixels: + outputs.append(f"raw {count_dtype.name} counts") + name += f"_numpix_dtype{count_dtype.char}" + if compute_coordinate_sums: + outputs.append(f"raw {coord_sum_dtype.name} coord_sums") + name += f"_csums_dtype{coord_sum_dtype.char}" + outputs = ", ".join(outputs) + name += f"_batch{pixels_per_thread}" + return cp.ElementwiseKernel( + inputs, outputs, source, preamble=_includes, name=name + ) + + +def regionprops_num_perimeter_pixels( + label_image, + max_label=None, + pixels_per_thread=16, + props_dict=None, +): + """Determine the number of pixels along the perimeter of each labeled + region. + + This is a n-dimensional implementation so in 3D it is the number of pixels + on the surface of the region. + + Writes "num_perimeter_pixels" to `props_dict` if it is not None. + + Notes + ----- + If the labeled regions have holes, the hole edges will be included in this + measurement. If this is not desired, use regionprops_label_filled to fill + the holes and then pass the filled labels image to this function. + + For more accurate perimeter measurements, use `regionprops_perimeter` or + `regionprops_perimeter_crofton` instead. + """ + if max_label is None: + max_label = int(label_image.max()) + # remove non-boundary pixels + binary_label_mask = label_image > 0 + footprint = ndi.generate_binary_structure(label_image.ndim, connectivity=1) + binary_label_mask_eroded = ndi.binary_erosion(binary_label_mask, footprint) + labeled_edges = label_image * ~binary_label_mask_eroded + + num_perimeter_pixels = regionprops_num_pixels( + labeled_edges, + max_label=max_label, + filled=False, + pixels_per_thread=pixels_per_thread, + props_dict=None, + ) + if props_dict is not None: + props_dict["num_perimeter_pixels"] = num_perimeter_pixels + return num_perimeter_pixels + + +def regionprops_num_pixels( + label_image, + max_label=None, + filled=False, + pixels_per_thread=16, + props_dict=None, +): + """Determine the number of pixels in each labeled region. + + The `filled` flag should be used to indicate if the "label_image" has + already had holes filled via `regionprops_label_filled`. + + if filled: + - writes "num_pixels_filled" to `props_dict` + else: + - writes "num_pixels" to `props_dict` + """ + + if max_label is None: + max_label = int(label_image.max()) + num_counts = max_label + num_pixels_prop_name = "num_pixels_filled" if filled else "num_pixels" + + count_dtype, int32_count = _get_count_dtype(label_image.size) + + pixels_kernel = get_bbox_coords_kernel( + int32_count=int32_count, + ndim=label_image.ndim, + compute_bbox=False, + compute_num_pixels=True, + compute_coordinate_sums=False, + pixels_per_thread=pixels_per_thread, + ) + counts = cp.zeros(num_counts, dtype=count_dtype) + + # make a copy if the labels array is not C-contiguous + if not label_image.flags.c_contiguous: + label_image = cp.ascontiguousarray(label_image) + + pixels_kernel( + label_image, + label_image.size, + counts, + size=math.ceil(label_image.size / pixels_per_thread), + ) + if props_dict is not None: + props_dict[num_pixels_prop_name] = counts + return counts + + +def regionprops_area( + label_image, + spacing=None, + max_label=None, + dtype=cp.float32, + filled=False, + pixels_per_thread=16, + props_dict=None, +): + """Determine the area of each labeled region. + + if filled: + - will reuse "num_pixels_filled" from `props_dict` if present + - will write "area_filled" to `props_dict` + else: + - will reuse "num_pixels" from `props_dict` if present + - will write "area" to `props_dict` + """ + if props_dict is None: + props_dict = {} + num_pixels_prop_name = "num_pixels_filled" if filled else "num_pixels" + area_prop_name = "area_filled" if filled else "area" + # integer atomicAdd is faster than floating point so better to convert + # after counting + if num_pixels_prop_name in props_dict: + num_pixels = props_dict[num_pixels_prop_name] + else: + num_pixels = regionprops_num_pixels( + label_image, + max_label=max_label, + filled=filled, + pixels_per_thread=pixels_per_thread, + props_dict=props_dict, + ) + + area = num_pixels.astype(dtype) + if spacing is not None: + if isinstance(spacing, cp.ndarray): + pixel_area = cp.product(spacing) + else: + pixel_area = math.prod(spacing) + area *= pixel_area + + if props_dict is not None: + props_dict[area_prop_name] = area + return area + + +@cp.fuse() +def equivalent_diameter_area_2d(area): + """2d specialization of equivalent_diameter_area.""" + return cp.sqrt(4.0 * area / cp.pi) + + +@cp.fuse() +def equivalent_diameter_area_3d(area): + """3d specialization of equivalent_diameter_area.""" + return cp.cbrt(6.0 * area / cp.pi) + + +@cp.fuse() +def equivalent_diameter_area_nd(area, ndim): + """3d specialization of equivalent_diameter_area.""" + return cp.pow(2.0 * ndim * area / cp.pi, 1.0 / ndim) + + +def equivalent_diameter_area(area, ndim): + """The formula is equivalent to ITK's HyperSphereRadiusFromVolume. + + This will be equal to 2 * GetEquivalentSphericalRadius() from ITK. + + Can be used to compute the "equivalent_diameter_area" property from + the "area" property. + """ + if ndim < 2: + raise ValueError("ndim must be at least 2") + if ndim == 2: + return equivalent_diameter_area_2d(area) + elif ndim == 3: + return equivalent_diameter_area_3d(area) + return equivalent_diameter_area_nd(area, float(ndim)) + + +@cp.fuse() +def equivalent_spherical_perimeter(area, ndim, diameter): + """Equivalent of ITK's GetEquivalentSphericalPerimeter + + Can be used to compute the "equivalent_spherical_perimeter" property from + the "area" property. + """ + return ndim * area / (0.5 * diameter) + + +def regionprops_bbox_coords( + label_image, + max_label=None, + return_slices=False, + pixels_per_thread=16, + props_dict=None, +): + """Determine bounding box coordinates (and slices) of each labeled region. + + Writes "bbox" to `props_dict` + + if return_slices is True: + - writes "slice" to `props_dict` + + Parameters + ---------- + label_image : cp.ndarray + Image containing labels where 0 is the background and sequential + values > 0 are the labels. + max_label : int or None + The maximum label value present in label_image. Will be computed if not + provided. + return_slices : bool, optional + If True, convert the bounding box coordinates array to a list of slice + tuples. + + Returns + ------- + bbox_coords : cp.ndarray + Raw bounding box coordinates array. The first axis is indexed by + ``label - 1``. The second axis has the minimum coordinate for dimension + ``d`` at index ``2*d`` and the maximum for coordinate at dimension + ``d`` at index ``2*d + 1``. Unlike for `bbox_slices`, the maximum + coordinate in `bbox_coords` is **inclusive** (the region's bounding box + includes both the min and max coordinate). + bbox_slices : list[tuple[slice]] or None + Will be None if return_slices is False. To get a mask corresponding to + the ith label, use + ``mask = label_image[bbox_slices[label - 1]] == label`` to get the + region corresponding to the ith bounding box. + """ + if max_label is None: + max_label = int(label_image.max()) + + int32_coords = max(label_image.shape) < 2**32 + coord_dtype = cp.dtype(cp.uint32 if int32_coords else cp.uint64) + + bbox_kernel = get_bbox_coords_kernel( + ndim=label_image.ndim, + int32_coords=int32_coords, + pixels_per_thread=pixels_per_thread, + ) + + ndim = label_image.ndim + bbox_coords = cp.zeros((max_label, 2 * ndim), dtype=coord_dtype) + + # Initialize value for atomicMin on first ndim coordinates + # The value for atomicMax columns is already 0 as desired. + bbox_coords[:, :ndim] = cp.iinfo(coord_dtype).max + + # make a copy if the inputs are not already C-contiguous + if not label_image.flags.c_contiguous: + label_image = cp.ascontiguousarray(label_image) + + bbox_kernel( + label_image, + label_image.size, + bbox_coords, + size=math.ceil(label_image.size / pixels_per_thread), + ) + if props_dict is not None: + props_dict["bbox"] = bbox_coords + + if return_slices: + bbox_coords_cpu = cp.asnumpy(bbox_coords) + if ndim == 2: + # explicitly writing out the 2d case here for clarity + bbox_slices = [ + ( + slice(int(box[0]), int(box[2])), + slice(int(box[1]), int(box[3])), + ) + for box in bbox_coords_cpu + ] + else: + # general n-dimensional case + bbox_slices = [ + tuple( + slice(int(box[d]), int(box[d + ndim])) for d in range(ndim) + ) + for box in bbox_coords_cpu + ] + if props_dict is not None: + props_dict["slice"] = bbox_slices + else: + bbox_slices = None + + return bbox_coords, bbox_slices + + +@cp.memoize(for_each_device=True) +def get_area_bbox_kernel( + coord_dtype, area_dtype, ndim, compute_coordinate_sums=False +): + coord_dtype = cp.dtype(coord_dtype) + area_dtype = cp.dtype(area_dtype) + uint_t = ( + "unsigned int" if coord_dtype.itemsize <= 4 else "unsigned long long" + ) + + source = f""" + {uint_t} dim_max_offset; + unsigned long long num_pixels_bbox = 1; + """ + for d in range(ndim): + source += f""" + dim_max_offset = i * {2 * ndim} + {d + ndim}; + num_pixels_bbox *= bbox[dim_max_offset] - bbox[dim_max_offset - {ndim}]; + """ # noqa: E501 + source += """ + area_bbox = num_pixels_bbox * pixel_area; + """ + inputs = f"raw {coord_dtype.name} bbox, float64 pixel_area" + outputs = f"{area_dtype.name} area_bbox" + name = f"cucim_area_bbox_{coord_dtype.name}_{area_dtype.name}_{ndim}d" + return cp.ElementwiseKernel( + inputs, outputs, source, preamble=_includes, name=name + ) + + +def regionprops_area_bbox( + bbox, area_dtype=cp.float32, spacing=None, props_dict=None +): + """Determine the area of the bounding box of each labeled region. + + Takes the "bbox" property as input. + + writes "area_bbox" to props_dict. + """ + num_label = bbox.shape[0] + ndim = bbox.shape[1] // 2 + + if spacing is None: + pixel_area = 1.0 + else: + if isinstance(spacing, cp.ndarray): + pixel_area = cp.product(spacing) + else: + pixel_area = math.prod(spacing) + + # make a copy if the inputs are not already C-contiguous + if not bbox.flags.c_contiguous: + bbox = cp.ascontiguousarray(bbox) + + kernel = get_area_bbox_kernel(bbox.dtype, area_dtype, ndim) + area_bbox = cp.empty((num_label,), dtype=area_dtype) + kernel(bbox, pixel_area, area_bbox) + if props_dict is not None: + props_dict["area_bbox"] = area_bbox + return area_bbox + + +def regionprops_extent(area, area_bbox, props_dict=None): + """Compute extent as the ratio of area / area_bbox for each labeled region. + + Takes the "area" and "area_bbox" properties as input. + + writes "extent" to props_dict. + """ + extent = area / area_bbox + if props_dict is not None: + props_dict["extent"] = extent + return extent + + +def regionprops_image( + label_image, + intensity_image=None, + slices=None, + max_label=None, + compute_image=True, + compute_convex=False, + offset_coordinates=True, + props_dict=None, + on_cpu=False, +): + """Return tuples of images of isolated label and/or intensity images. + + Each image incorporates only the bounding box region for a given label. + + This function can also optionally return convex images. + + Length of the tuple(s) is equal to `max_label`. + + reuses "slice" from `props_dict` if it is present + + if compute_image: + - writes "image" to `props_dict` + + if compute_convex: + - writes "image_convex" to `props_dict` + 4 + if intensity_image is not None: + - writes "image_intensity" to `props_dict` + """ + if max_label is None: + max_label = int(label_image.max()) + if props_dict is None: + props_dict = dict() + + if slices is None: + if "slice" not in props_dict: + regionprops_bbox_coords( + label_image, + max_label=max_label, + return_slices=True, + props_dict=props_dict, + ) + slices = props_dict["slice"] + + # mask so there will only be a single label value in each returned slice + masks = tuple( + label_image[sl] == lab for lab, sl in enumerate(slices, start=1) + ) + + if compute_convex: + convex_results = tuple( + convex_hull_image( + m, + offset_coordinates=offset_coordinates, + omit_empty_coords_check=True, + float64_computation=True, + ) + for m in masks + ) + image_convex = convex_results + + if on_cpu: + image_convex = tuple(cp.asnumpy(m) for m in image_convex) + props_dict["image_convex"] = image_convex + else: + image_convex = None + + if on_cpu: + masks = tuple(cp.asnumpy(m) for m in masks) + if intensity_image is not None: + intensity_image = cp.asnumpy(intensity_image) + + props_dict["image"] = masks + + if intensity_image is not None: + if intensity_image.ndim > label_image.ndim: + if intensity_image.ndim != label_image.ndim + 1: + raise ValueError( + "Unexpected intensity_image.ndim. Should be " + "label_image.ndim or label_image.ndim + 1" + ) + imslices = tuple(sl + (slice(None),) for sl in slices) + intensity_images = tuple( + intensity_image[sl] * mask[..., cp.newaxis] + for img, (sl, mask) in enumerate(zip(imslices, masks), start=1) + ) + + else: + intensity_images = tuple( + intensity_image[sl] * mask + for img, (sl, mask) in enumerate(zip(slices, masks), start=1) + ) + if on_cpu: + intensity_images = (cp.asnumpy(img) for img in intensity_images) + props_dict["image_intensity"] = intensity_images + if not compute_image: + return props_dict["image_intensity"] + else: + intensity_images = None + return masks, intensity_images, image_convex + + +def _get_compressed_labels( + labels, max_label, intensity_image=None, sort_labels=True +): + """Produce raveled list of coordinates and label values, excluding any + background pixels. + + Some region properties can be applied to this data format more efficiently, + than for the original labels image. + + Currently being used to compute "coords" and "coords_scaled" properties. + """ + label_dtype = _get_min_integer_dtype(max_label, signed=False) + if labels.dtype != label_dtype: + labels = labels.astype(dtype=label_dtype) + coords_dtype = _get_min_integer_dtype(max(labels.shape), signed=False) + label_coords = cp.nonzero(labels) + if label_coords[0].dtype != coords_dtype: + label_coords = tuple(c.astype(coords_dtype) for c in label_coords) + labels1d = labels[label_coords] + if sort_labels: + sort_indices = cp.argsort(labels1d) + label_coords = tuple(c[sort_indices] for c in label_coords) + labels1d = labels1d[sort_indices] + if intensity_image: + img1d = intensity_image[label_coords] + return label_coords, labels1d, img1d + # max_label = int(labels1d[-1]) + return label_coords, labels1d + + +def regionprops_coords( + label_image, + max_label=None, + spacing=None, + compute_coords=True, + compute_coords_scaled=False, + props_dict=None, +): + """Return tuple(s) of arrays of coordinates for each labeled region. + + Length of the tuple(s) is equal to `max_label`. + + reuses "num_pixels" from `props_dict` if it is present + + writes "coords" to `props_dict` if compute_coords is True + + writes "coords_scaled" to `props_dict` if compute_coords_scaled is True + + Notes + ----- + This is provided only for completeness, but unlike for the RegionProps + class, these are not needed in order to compute any of the other properties. + """ + if max_label is None: + max_label = int(label_image.max()) + if props_dict is None: + props_dict = dict() + + coords_concat, _ = _get_compressed_labels( + label_image, max_label=max_label, sort_labels=True + ) + + if "num_pixels" not in props_dict: + num_pixels = regionprops_num_pixels( + label_image, max_label=max_label, props_dict=props_dict + ) + else: + num_pixels = props_dict["num_pixels"] + + # stack ndim arrays into a single (pixels, ndim) array + coords_concat = cp.stack(coords_concat, axis=-1) + + # scale based on spacing + if compute_coords_scaled: + max_exact_float32_int = 16777216 # 2 ** 24 + max_sz = max(label_image.shape) + float_type = ( + cp.float32 if max_sz < max_exact_float32_int else cp.float64 + ) + coords_concat_scaled = coords_concat.astype(float_type) + if spacing is not None: + scale_factor = cp.asarray(spacing, dtype=float_type).reshape(1, -1) + coords_concat_scaled *= scale_factor + coords_scaled = [] + + if compute_coords: + coords = [] + + # split separate labels out of the concatenated array above + num_pixels_cpu = cp.asnumpy(num_pixels) + slice_start = 0 + slice_stops = np.cumsum(num_pixels_cpu) + for slice_stop in slice_stops: + sl = slice(slice_start, slice_stop) + if compute_coords: + coords.append(coords_concat[sl, :]) + if compute_coords_scaled: + coords_scaled.append(coords_concat_scaled[sl, :]) + slice_start = slice_stop + + if compute_coords: + coords = tuple(coords) + props_dict["coords"] = coords + if not compute_coords_scaled: + return coords + if compute_coords_scaled: + coords_scaled = tuple(coords_scaled) + props_dict["coords_scaled"] = coords_scaled + if not compute_coords: + return coords_scaled + return coords, coords_scaled + + +def _boundary_mask(labels): + """Generate a binary mask corresponding to the pixels touching the image + boundary. + """ + ndim = labels.ndim + slices = [ + slice( + None, + ) + ] * ndim + boundary_mask = cp.zeros(labels.shape, dtype=bool) + for d in range(ndim): + edge_slices1 = slices[:d] + [slice(0, 1)] + slices[d + 1 :] + edge_slices2 = slices[:d] + [slice(-1, None)] + slices[d + 1 :] + boundary_mask[tuple(edge_slices1)] = 1 + boundary_mask[tuple(edge_slices2)] = 1 + slices[d] = slice(1, -1) + return boundary_mask + + +def regionprops_num_boundary_pixels(labels, max_label=None, props_dict=None): + """Determine the number of pixels touching the image boundary for each + labeled region. + + writes "num_boundary_pixels" to props_dict. + """ + if max_label is None: + max_label = int(labels.max()) + + # get mask of edge pixels + boundary_mask = _boundary_mask(labels) + + # include a bin for the background + nbins = max_label + 1 + # exclude background region from edge_counts + edge_counts = cp.bincount(labels[boundary_mask], minlength=nbins)[1:] + if props_dict is not None: + props_dict["num_boundary_pixels"] = edge_counts + return edge_counts + + +def regionprops_label_filled( + labels, + max_label=None, + props_dict=None, +): + """Fill holes in each labeled region. + + writes "label_filled" to props_dict + + Parameters + ---------- + labels : cupy.ndarray + The label image + max_label : the maximum label present in labels + If None, will be determined internally. + props_dict : dict or None + Dictionary to store any measured properties. + + Returns + ------- + label_filled : cupy.ndarray + The label image, but with the holes in each region filled. + """ + if max_label is None: + max_label = int(labels.max()) + + # make sure all background pixels at the boundary have the same label + labels = cp.pad(labels, 1, mode="constant", constant_values=0) + + binary_labels = labels > 0 + + # assign unique labels the background and holes + inverse_binary_mask = ~binary_labels + inverse_labels = label(inverse_binary_mask) + + # After inversion, what was originally the background will now be the + # first foreground label encountered. This is ensured due to the + # single voxel padding done above and the fact that the `label` + # function scans linearly through the array. + background_index = 1 + # set the background back to 0 in the inverse mask so we have a mask + # of just the holes + inverse_binary_mask[inverse_labels == background_index] = 0 + + # add binary holes to the original mask and relabel + binary_holes_filled = cp.logical_or(binary_labels, inverse_binary_mask) + label_filled = label(binary_holes_filled) + + if props_dict is not None: + props_dict["label_filled"] = label_filled + label_filled = label_filled[(slice(1, -1),) * label_filled.ndim] + return cp.ascontiguousarray(label_filled) diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_utils.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_utils.py new file mode 100644 index 00000000..5d5f19ac --- /dev/null +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_utils.py @@ -0,0 +1,79 @@ +import cupy as cp +from packaging.version import parse + +CUPY_GTE_13_3_0 = parse(cp.__version__) >= parse("13.3.0") + +# Need some default includes so uint32_t, uint64_t, etc. are defined + +if CUPY_GTE_13_3_0: + _includes = r""" +#include // provide std:: coverage +""" +else: + _includes = r""" +#include // let Jitify handle this +""" + + +def _get_count_dtype(label_image_size): + """atomicAdd only supports int32, uint32, int64, uint64, float32, float64""" + int32_count = label_image_size < 2**32 + count_dtype = cp.dtype(cp.uint32 if int32_count else cp.uint64) + return count_dtype, int32_count + + +def _get_min_integer_dtype(max_size, signed=False): + # negate to get a signed integer type, but need to also subtract 1, due + # to asymmetric range on positive side, e.g. we want + # max_sz = 127 -> int8 (signed) uint8 (unsigned) + # max_sz = 128 -> int16 (signed) uint8 (unsigned) + func = cp.min_scalar_type + return func(-max_size - 1) if signed else func(max_size) + + +def _unravel_loop_index_declarations(var_name, ndim, uint_t="unsigned int"): + if ndim == 1: + code = f""" + {uint_t} in_coord[1];""" + return code + + code = f""" + // variables for unraveling a linear index to a coordinate array + {uint_t} in_coord[{ndim}]; + {uint_t} temp_floor;""" + for d in range(ndim): + code += f""" + {uint_t} dim{d}_size = {var_name}.shape()[{d}];""" + return code + + +def _unravel_loop_index( + var_name, + ndim, + uint_t="unsigned int", + raveled_index="i", + omit_declarations=False, +): + """ + declare a multi-index array in_coord and unravel the 1D index, i into it. + This code assumes that the array is a C-ordered array. + """ + code = ( + "" + if omit_declarations + else _unravel_loop_index_declarations(var_name, ndim, uint_t) + ) + if ndim == 1: + code = f""" + in_coord[0] = {raveled_index};\n""" + return code + + code += f"{uint_t} temp_idx = {raveled_index};" + for d in range(ndim - 1, 0, -1): + code += f""" + temp_floor = temp_idx / dim{d}_size; + in_coord[{d}] = temp_idx - temp_floor * dim{d}_size; + temp_idx = temp_floor;""" + code += """ + in_coord[0] = temp_idx;""" + return code diff --git a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py new file mode 100644 index 00000000..74d5457c --- /dev/null +++ b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py @@ -0,0 +1,303 @@ +import math +import warnings + +import cupy as cp +import pytest +from cupy.testing import ( + assert_allclose, + assert_array_equal, +) +from scipy.ndimage import find_objects as cpu_find_objects +from skimage import measure as measure_cpu + +from cucim.skimage import data, measure +from cucim.skimage._vendored import ndimage as ndi +from cucim.skimage.measure._regionprops_gpu import ( + equivalent_diameter_area, + regionprops_area, + regionprops_area_bbox, + regionprops_bbox_coords, + regionprops_coords, + regionprops_extent, + regionprops_image, + regionprops_num_pixels, +) + + +def get_labels_nd( + shape, + blob_size_fraction=0.05, + volume_fraction=0.25, + rng=5, + insert_holes=False, + dilate_blobs=False, +): + ndim = len(shape) + blobs_kwargs = dict( + blob_size_fraction=blob_size_fraction, + volume_fraction=volume_fraction, + rng=rng, + ) + blobs = data.binary_blobs(max(shape), n_dim=ndim, **blobs_kwargs) + # crop to rectangular + blobs = blobs[tuple(slice(s) for s in shape)] + + if dilate_blobs: + blobs = ndi.binary_dilation(blobs, 3) + + if insert_holes: + blobs2_kwargs = dict( + blob_size_fraction=blob_size_fraction / 5, + volume_fraction=0.1, + rng=rng, + ) + # create smaller blobs and invert them to create a holes mask to apply + # to the original blobs + temp = data.binary_blobs(max(shape), n_dim=ndim, **blobs2_kwargs) + temp = temp[tuple(slice(s) for s in shape)] + mask = cp.logical_and(blobs > 0, temp == 0) + blobs = blobs * mask + + # binary blobs only creates square outputs + labels = measure.label(blobs) + # print(f"# labels generated = {labels.max()}") + return labels + + +def get_intensity_image(shape, dtype=cp.float32, seed=5, num_channels=None): + npixels = math.prod(shape) + rng = cp.random.default_rng(seed) + dtype = cp.dtype(dtype) + if dtype.kind == "f": + img = cp.arange(npixels, dtype=cp.float32) - npixels // 2 + img = img.reshape(shape) + if dtype == cp.float16: + temp = 100 * rng.standard_normal(img.shape, dtype=cp.float32) + img += temp.astype(cp.float16) + else: + img += 100 * rng.standard_normal(img.shape, dtype=dtype) + else: + iinfo = cp.iinfo(dtype) + imax = min(16384, iinfo.max) + imin = max(0, iinfo.min) + img = rng.integers(imin, imax, shape) + + if num_channels and num_channels > 1: + # generate slightly shifted versions for the additional channels + img = cp.stack((img,) * num_channels, axis=-1) + for c in range(1, num_channels): + img[..., c] = cp.roll(img[..., c], shift=c, axis=0) + return img + + +@pytest.mark.parametrize("precompute_max", [False, True]) +@pytest.mark.parametrize("ndim", [2, 3]) +def test_num_pixels(precompute_max, ndim): + shape = (256, 512) if ndim == 2 else (15, 63, 37) + labels = get_labels_nd(shape) + + max_label = int(cp.max(labels)) if precompute_max else None + num_pixels = regionprops_num_pixels(labels, max_label=max_label) + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), properties=["num_pixels"] + ) + assert_allclose(num_pixels, expected["num_pixels"]) + + +@pytest.mark.parametrize("precompute_max", [False, True]) +@pytest.mark.parametrize("ndim", [2, 3]) +@pytest.mark.parametrize("area_dtype", [cp.float32, cp.float64]) +@pytest.mark.parametrize("spacing", [None, (0.5, 0.35, 0.75)]) +def test_area(precompute_max, ndim, area_dtype, spacing): + shape = (256, 512) if ndim == 2 else (45, 63, 37) + labels = get_labels_nd(shape) + # discard any extra dimensions from spacing + if spacing is not None: + spacing = spacing[:ndim] + + max_label = int(cp.max(labels)) if precompute_max else None + area = regionprops_area( + labels, spacing=spacing, max_label=max_label, dtype=area_dtype + ) + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), + spacing=spacing, + properties=["area", "equivalent_diameter_area"], + ) + assert_allclose(area, expected["area"]) + + ed = equivalent_diameter_area(area, ndim) + assert_allclose( + ed, expected["equivalent_diameter_area"], rtol=1e-5, atol=1e-5 + ) + + +@pytest.mark.parametrize("ndim", [2, 3]) +@pytest.mark.parametrize("area_dtype", [cp.float32, cp.float64]) +@pytest.mark.parametrize("spacing", [None, (0.5, 0.35, 0.75)]) +def test_extent(ndim, area_dtype, spacing): + shape = (512, 512) if ndim == 2 else (64, 64, 64) + labels = get_labels_nd(shape) + # discard any extra dimensions from spacing + if spacing is not None: + spacing = spacing[:ndim] + + # compute area + max_label = int(cp.max(labels)) + area = regionprops_area( + labels, spacing=spacing, max_label=max_label, dtype=area_dtype + ) + + # compute bounding-box area + bbox, slices = regionprops_bbox_coords( + labels, + max_label=max_label, + return_slices=True, + ) + area_bbox = regionprops_area_bbox( + bbox, area_dtype=cp.float32, spacing=spacing + ) + + # compute extents from these + extent = regionprops_extent(area=area, area_bbox=area_bbox) + + # compare to expected result + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), spacing=spacing, properties=["extent"] + ) + assert_allclose(extent, expected["extent"], rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize("precompute_max", [False, True]) +@pytest.mark.parametrize("dtype", [cp.uint32, cp.int64]) +@pytest.mark.parametrize("return_slices", [False, True]) +@pytest.mark.parametrize("ndim", [2, 3]) +def test_bbox_coords_and_area(precompute_max, ndim, dtype, return_slices): + shape = (1024, 512) if ndim == 2 else (80, 64, 48) + labels = get_labels_nd(shape) + + max_label = int(cp.max(labels)) if precompute_max else None + bbox, slices = regionprops_bbox_coords( + labels, + max_label=max_label, + return_slices=return_slices, + ) + assert bbox.dtype == cp.uint32 + if not return_slices: + slices is None + else: + expected_slices = cpu_find_objects(cp.asnumpy(labels)) + assert slices == expected_slices + + spacing = (0.35, 0.75, 0.5)[:ndim] + expected_bbox = measure_cpu.regionprops_table( + cp.asnumpy(labels), spacing=spacing, properties=["bbox", "area_bbox"] + ) + if ndim == 2: + # TODO make ordering of bbox consistent with regionprops bbox? + assert_allclose(bbox[:, 0], expected_bbox["bbox-0"]) + assert_allclose(bbox[:, 1], expected_bbox["bbox-1"]) + assert_allclose(bbox[:, 2], expected_bbox["bbox-2"]) + assert_allclose(bbox[:, 3], expected_bbox["bbox-3"]) + elif ndim == 3: + assert_allclose(bbox[:, 0], expected_bbox["bbox-0"]) + assert_allclose(bbox[:, 1], expected_bbox["bbox-1"]) + assert_allclose(bbox[:, 2], expected_bbox["bbox-2"]) + assert_allclose(bbox[:, 3], expected_bbox["bbox-3"]) + assert_allclose(bbox[:, 4], expected_bbox["bbox-4"]) + assert_allclose(bbox[:, 5], expected_bbox["bbox-5"]) + + # compute area_bbox from bbox array + area_bbox = regionprops_area_bbox( + bbox, area_dtype=cp.float32, spacing=spacing + ) + assert_allclose(area_bbox, expected_bbox["area_bbox"], rtol=1e-5) + + +@pytest.mark.parametrize("ndim", [2, 3]) +@pytest.mark.parametrize("num_channels", [1, 3]) +@pytest.mark.parametrize( + "blob_kwargs", [{}, dict(blob_size_fraction=0.12, volume_fraction=0.3)] +) +def test_image(ndim, num_channels, blob_kwargs): + shape = (256, 512) if ndim == 2 else (64, 64, 80) + + labels = get_labels_nd(shape, **blob_kwargs) + intensity_image = get_intensity_image( + shape, dtype=cp.uint16, num_channels=num_channels + ) + max_label = int(cp.max(labels)) + images, intensity_images, images_convex = regionprops_image( + labels, + max_label=max_label, + intensity_image=intensity_image, + compute_convex=True, + ) + assert len(images) == max_label + assert len(intensity_images) == max_label + assert len(images_convex) == max_label + + # suppress any QHull warnings coming from the scikit-image implementation + warnings.filterwarnings( + "ignore", + message="Failed to get convex hull image", + category=UserWarning, + ) + warnings.filterwarnings( + "ignore", + message="divide by zero", + category=RuntimeWarning, + ) + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), + intensity_image=cp.asnumpy(intensity_image), + properties=["image", "image_intensity", "image_convex"], + ) + warnings.resetwarnings() + + for n in range(max_label): + assert_array_equal(images[n], expected["image"][n]) + assert_array_equal(intensity_images[n], expected["image_intensity"][n]) + # Note if 3d blobs are size 1 on one of the axes, it can cause QHull to + # fail and return a zeros convex image for that label. This has been + # resolved for cuCIM, but not yet for scikit-image. + # The test case with blob_kwargs != {} was chosen as a known good + # setting where such an edge case does NOT occur. + if blob_kwargs: + assert_array_equal(images_convex[n], expected["image_convex"][n]) + else: + # Can't compare to scikit-image in this case + # Just make sure the convex size is not smaller than the original + assert (images_convex[n].sum()) >= (images[n].sum()) + + +@pytest.mark.parametrize("ndim", [2, 3]) +@pytest.mark.parametrize("spacing", [None, (1, 1, 1), (1.5, 0.5, 0.76)]) +def test_coords(ndim, spacing): + shape = (256, 512) if ndim == 2 else (35, 63, 37) + if spacing is not None: + spacing = spacing[:ndim] + labels = get_labels_nd(shape) + max_label = int(cp.max(labels)) + coords, coords_scaled = regionprops_coords( + labels, + max_label=max_label, + spacing=spacing, + compute_coords=True, + compute_coords_scaled=True, + ) + assert len(coords) == max_label + assert len(coords_scaled) == max_label + + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), + spacing=spacing, + properties=["coords", "coords_scaled"], + ) + for n in range(max_label): + # cast to Python int to match dtype from CPU case + assert_array_equal(coords[n].astype(int), expected["coords"][n]) + + assert_allclose( + coords_scaled[n], expected["coords_scaled"][n], rtol=1e-5 + ) From d7e724cfc86399e1e1b3922b31d1f5c1a7fc2e25 Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Sat, 1 Mar 2025 09:55:39 -0500 Subject: [PATCH 03/14] add regionprops_dict interface for computing multiple properties This function operates similarly to `regionprops_table`. In a future commit, once all properties have been supported, it will be used within the existing regionprops_table function so that it will provide much higher performance. --- .../cucim/skimage/measure/_regionprops_gpu.py | 382 ++++++++++++++++++ .../measure/_regionprops_gpu_basic_kernels.py | 35 ++ .../tests/test_regionprops_gpu_kernels.py | 108 +++++ 3 files changed, 525 insertions(+) create mode 100644 python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py new file mode 100644 index 00000000..6c8b2866 --- /dev/null +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py @@ -0,0 +1,382 @@ +import warnings +from copy import copy + +import cupy as cp + +from cucim.skimage.measure._regionprops import ( + COL_DTYPES, + PROPS, +) + +from ._regionprops_gpu_basic_kernels import ( + basic_deps, + equivalent_diameter_area, + equivalent_spherical_perimeter, + regionprops_area, + regionprops_area_bbox, + regionprops_bbox_coords, + regionprops_coords, + regionprops_extent, + regionprops_image, + regionprops_label_filled, + regionprops_num_boundary_pixels, + regionprops_num_perimeter_pixels, + regionprops_num_pixels, +) +from ._regionprops_gpu_utils import _get_min_integer_dtype + +__all__ = [ + "equivalent_diameter_area", + "regionprops_area", + "regionprops_area_bbox", + "regionprops_bbox_coords", + "regionprops_coords", + "regionprops_dict", + "regionprops_extent", + "regionprops_image", + # extra functions for cuCIM not currently in scikit-image + "equivalent_spherical_perimeter", # as in ITK + "regionprops_num_boundary_pixels", + "regionprops_num_perimeter_pixels", + "regionprops_label_filled", +] + + +# Master list of properties currently supported by regionprops_dict for faster +# computation on the GPU. +# +# One caveat is that centroid/moment/inertia_tensor properties currently only +# support 2D and 3D data with moments up to 3rd order. + +# all properties from PROPS have been implemented +PROPS_GPU = copy(PROPS) +# extra properties not currently in scikit-image +PROPS_GPU_EXTRA = { + "num_pixels_filled": "num_pixels_filled", + # a few extra parameters as in ITK + "num_perimeter_pixels": "num_perimeter_pixels", + "num_boundary_pixels": "num_boundary_pixels", + "perimeter_on_border_ratio": "perimeter_on_border_ratio", + "equivalent_spherical_perimeter": "equivalent_spherical_perimeter", +} +PROPS_GPU.update(PROPS_GPU_EXTRA) + +CURRENT_PROPS_GPU = set(PROPS_GPU.values()) + +COL_DTYPES_EXTRA = { + "num_pixels_filled": int, + "num_perimeter_pixels": int, + "num_boundary_pixels": int, + "perimeter_on_border_ratio": float, + "equivalent_spherical_perimeter": float, +} + +# expand column dtypes from _regionprops to include the extra properties +COL_DTYPES_GPU = copy(COL_DTYPES) +COL_DTYPES_GPU.update(COL_DTYPES_EXTRA) + +# Any extra 'property' that is computed on the full labels image and not +# per-region. +GLOBAL_PROPS = {"label_filled"} + +# list of the columns that are stored as a numpy object array when converted +# to tabular format by `regionprops_table` +OBJECT_COLUMNS_GPU = [ + col for col, dtype in COL_DTYPES_GPU.items() if dtype == object +] + + +# `property_deps` is a dictionary where each key is a property and values are +# the other properties that property directly depends on (indirect dependencies +# do not need to be listed as that is handled by traversing a tree structure via +# get_property_dependencies below). +property_deps = dict() +property_deps.update(basic_deps) + +# set of properties that require an intensity_image also be provided +need_intensity_image = {"image_intensity"} + +# set of properties that only supports 2D images +ndim_2_only = set() + + +def get_property_dependencies(dependencies, node): + """Get all direct and indirect dependencies for a specific property""" + visited = set() + result = [] + + def depth_first_search(n): + if n not in visited: + visited.add(n) + if n in dependencies: + for dep in dependencies[n]: + depth_first_search(dep) + # If n is not in dependencies, assume it has no dependencies + result.append(n) + + depth_first_search(node) + return set(result) + + +# precompute full set of direct and indirect dependencies for each property +property_requirements = { + k: get_property_dependencies(property_deps, k) + for k in (CURRENT_PROPS_GPU | GLOBAL_PROPS) +} + + +def regionprops_dict( + label_image, + intensity_image=None, + properties=[], + *, + spacing=None, + max_label=None, + pixels_per_thread=16, +): + """Compute image properties and return them as a pandas-compatible table. + + The table is a dictionary mapping column names to value arrays. See Notes + section below for details. + + .. versionadded:: 0.16 + + Parameters + ---------- + label_image : (M, N[, P]) ndarray + Labeled input image. Labels with value 0 are ignored. + intensity_image : (M, N[, P][, C]) ndarray, optional + Intensity (i.e., input) image with same size as labeled image, plus + optionally an extra dimension for multichannel data. The channel + dimension, if present, must be the last axis. Default is None. + properties : tuple or list of str, optional + Properties that will be included in the resulting dictionary + For a list of available properties, please see :func:`regionprops`. + Users should remember to add "label" to keep track of region + identities. + spacing : tuple of float, shape (ndim,) + The pixel spacing along each axis of the image. + + Extra Parameters + ---------------- + max_label : int or None + The maximum label value. If not provided it will be computed from + `label_image`. + pixels_per_thread : int + A number of properties support computation of multiple adjacent pixels + from each GPU thread. The number of adjacent pixels processed + corresponds to `pixels_per_thread` and can be used as a performance + tuning parameter. + """ + supported_properties = CURRENT_PROPS_GPU | GLOBAL_PROPS + properties = set(properties) + + valid_names = properties & supported_properties + invalid_names = set(properties) - valid_names + valid_names = list(valid_names) + + # Use only the modern names internally, but keep list of mappings back to + # any deprecated names in restore_legacy_names and use that at the end to + # restore the requested deprecated property names. + restore_legacy_names = dict() + for name in invalid_names: + if name in PROPS: + vname = PROPS[name] + if vname in valid_names: + raise ValueError( + f"Property name: {name} is a duplicate of {vname}" + ) + else: + restore_legacy_names[vname] = name + valid_names.append(vname) + else: + raise ValueError(f"Unrecognized property name: {name}") + for v in restore_legacy_names.values(): + invalid_names.discard(v) + # warn if there are any names that did not match a deprecated name + if invalid_names: + warnings.warn( + "The following property names were unrecognized and will not be " + "computed: {invalid_names}" + ) + + requested_props = set(sorted(valid_names)) + + if len(requested_props) == 0: + return {} + + required_props = set() + for prop in requested_props: + required_props.update(property_requirements[prop]) + + ndim = label_image.ndim + if ndim != 2: + invalid_names = requested_props & ndim_2_only + if any(invalid_names): + raise ValueError( + f"{label_image.ndim=}, but the following properties are for " + "2D label images only: {invalid_names}" + ) + if intensity_image is None: + invalid_names = requested_props & need_intensity_image + if any(invalid_names): + raise ValueError( + "No intensity_image provided, but the following requested " + "properties require one: {invalid_names}" + ) + + out = {} + if max_label is None: + max_label = int(label_image.max()) + label_dtype = _get_min_integer_dtype(max_label, signed=False) + # For performance, shrink label's data type to the minimum possible + # unsigned integer type. + if label_image.dtype != label_dtype: + label_image = label_image.astype(label_dtype) + + # create vector of label values + if "label" in required_props: + out["label"] = cp.arange(1, max_label + 1, dtype=label_dtype) + + perf_kwargs = {} + if pixels_per_thread is not None: + perf_kwargs["pixels_per_thread"] = pixels_per_thread + + if "num_pixels" in required_props: + regionprops_num_pixels( + label_image, + max_label=max_label, + filled=False, + **perf_kwargs, + props_dict=out, + ) + + if "area" in required_props: + regionprops_area( + label_image, + spacing=spacing, + max_label=max_label, + dtype=cp.float32, + filled=False, + **perf_kwargs, + props_dict=out, + ) + + if "equivalent_diameter_area" in required_props: + ed = equivalent_diameter_area(out["area"], ndim) + out["equivalent_diameter_area"] = ed + if "equivalent_spherical_perimeter" in required_props: + out[ + "equivalent_spherical_perimeter" + ] = equivalent_spherical_perimeter(out["area"], ndim, ed) + + compute_bbox = "bbox" in required_props + if compute_bbox: + # compute bbox (and slice) + regionprops_bbox_coords( + label_image, + max_label=max_label, + return_slices="slice" in required_props, + **perf_kwargs, + props_dict=out, + ) + + if "area_bbox" in required_props: + regionprops_area_bbox( + out["bbox"], + area_dtype=cp.float32, + spacing=None, + props_dict=out, + ) + + if "extent" in required_props: + out["extent"] = out["area"] / out["area_bbox"] + + if "num_boundary_pixels" in required_props: + regionprops_num_boundary_pixels( + label_image, + max_label=max_label, + props_dict=out, + ) + + if "num_perimeter_pixels" in required_props: + regionprops_num_perimeter_pixels( + label_image, + max_label=max_label, + props_dict=out, + ) + + if "perimeter_on_border_ratio" in required_props: + out["perimeter_on_border_ratio"] = ( + out["num_boundary_pixels"] / out["num_perimeter_pixels"] + ) + + compute_images = "image" in required_props + compute_intensity_images = "image_intensity" in required_props + compute_convex = "image_convex" in required_props + if compute_intensity_images or compute_images or compute_convex: + regionprops_image( + label_image, + intensity_image=intensity_image + if compute_intensity_images + else None, # noqa: E501 + max_label=max_label, + props_dict=out, + compute_image=compute_images, + compute_convex=compute_convex, + offset_coordinates=True, + ) + + compute_coords = "coords" in required_props + compute_coords_scaled = "coords_scaled" in required_props + if compute_coords or compute_coords_scaled: + regionprops_coords( + label_image, + max_label=max_label, + spacing=spacing, + compute_coords=compute_coords, + compute_coords_scaled=compute_coords_scaled, + props_dict=out, + ) + + if "label_filled" in required_props: + regionprops_label_filled( + label_image, + max_label=max_label, + props_dict=out, + ) + if "area_filled" in required_props: + # also handles "num_pixels_filled" + out["area_filled"] = regionprops_area( + out["label_filled"], + max_label=max_label, + filled=True, + props_dict=out, + ) + elif "num_pixels_filled" in required_props: + regionprops_num_pixels( + label_image, + max_label=max_label, + filled=True, + **perf_kwargs, + props_dict=out, + ) + if "image_filled" in required_props: + out["image_filled"], _, _ = regionprops_image( + out["label_filled"], + max_label=max_label, + compute_image=True, + compute_convex=False, + props_dict=None, # omit: using custom "image_filled" key + ) + + # If user had requested properties via their deprecated names, set the + # canonical names for the computed properties to the corresponding + # deprecated one. + for k, v in restore_legacy_names.items(): + out[v] = out.pop(k) + + # only return the properties that were explicitly requested + out = {k: out[k] for k in properties} + + return out diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_basic_kernels.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_basic_kernels.py index 098e1d8b..b3ca745d 100644 --- a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_basic_kernels.py +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_basic_kernels.py @@ -34,6 +34,37 @@ ] +# Store information on which other properties a given property depends on +# This information will be used by `regionprops_dict` to make sure that when +# a particular property is requested any dependent properties are computed +# first. +basic_deps = dict() +basic_deps["area"] = ["num_pixels"] +basic_deps["area_bbox"] = ["bbox"] +basic_deps["area_filled"] = ["label_filled", "num_pixels_filled"] +basic_deps["bbox"] = [] +basic_deps["coords"] = ["num_pixels"] +basic_deps["coords_scaled"] = ["num_pixels"] +basic_deps["equivalent_diameter_area"] = ["area"] +basic_deps["equivalent_spherical_perimeter"] = ["equivalent_diameter_area"] +basic_deps["extent"] = ["area", "area_bbox"] +basic_deps["image"] = [] +basic_deps["image_convex"] = [] +basic_deps["image_filled"] = ["label_filled"] +basic_deps["image_intensity"] = [] +basic_deps["label"] = [] +basic_deps["label_filled"] = [] +basic_deps["num_boundary_pixels"] = [] +basic_deps["num_perimeter_pixels"] = [] +basic_deps["num_pixels"] = [] +basic_deps["num_pixels_filled"] = ["label_filled"] +basic_deps["perimeter_on_border_ratio"] = [ + "num_perimeter_pixels", + "num_boundary_pixels", +] +basic_deps["slice"] = ["bbox"] + + def _get_bbox_code(uint_t, ndim, array_size): """ Notes @@ -386,9 +417,11 @@ def regionprops_area( if filled: - will reuse "num_pixels_filled" from `props_dict` if present - will write "area_filled" to `props_dict` + - will write "num_pixels_filled" to `props_dict` if not already present else: - will reuse "num_pixels" from `props_dict` if present - will write "area" to `props_dict` + - will write "num_pixels" to `props_dict` if not already present """ if props_dict is None: props_dict = {} @@ -406,6 +439,8 @@ def regionprops_area( pixels_per_thread=pixels_per_thread, props_dict=props_dict, ) + if num_pixels_prop_name not in props_dict: + props_dict[num_pixels_prop_name] = num_pixels area = num_pixels.astype(dtype) if spacing is not None: diff --git a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py index 74d5457c..7243d766 100644 --- a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py +++ b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py @@ -1,3 +1,4 @@ +import functools import math import warnings @@ -12,16 +13,19 @@ from cucim.skimage import data, measure from cucim.skimage._vendored import ndimage as ndi +from cucim.skimage.measure._regionprops import PROPS from cucim.skimage.measure._regionprops_gpu import ( equivalent_diameter_area, regionprops_area, regionprops_area_bbox, regionprops_bbox_coords, regionprops_coords, + regionprops_dict, regionprops_extent, regionprops_image, regionprops_num_pixels, ) +from cucim.skimage.measure._regionprops_gpu_basic_kernels import basic_deps def get_labels_nd( @@ -301,3 +305,107 @@ def test_coords(ndim, spacing): assert_allclose( coords_scaled[n], expected["coords_scaled"][n], rtol=1e-5 ) + + +@pytest.mark.parametrize("ndim", [2, 3]) +@pytest.mark.parametrize("spacing", [None, (1.5, 0.5, 0.76)]) +@pytest.mark.parametrize("property_name", list(basic_deps.keys())) +def test_regionprops_dict_single_property(ndim, spacing, property_name): + """Test to verify that any dependencies for a given property are + automatically handled. + """ + shape = (768, 512) if ndim == 2 else (64, 64, 64) + if spacing is not None: + spacing = spacing[:ndim] + labels = get_labels_nd(shape) + if property_name == "image_intensity": + intensity_image = get_intensity_image( + shape, dtype=cp.uint16, num_channels=1 + ) + else: + intensity_image = None + props = regionprops_dict( + labels, intensity_image, properties=[property_name], spacing=spacing + ) + assert property_name in props + # any unrequested dependent properties are not retained in the output dict + assert len(props) == 1 + + +@pytest.mark.parametrize("ndim", [2, 3]) +@pytest.mark.parametrize( + "property_name", + [ + "label", + "image", + "image_convex", + "image_intensity", + "image_filled", + "coords", + "coords_scaled", + ], +) +def test_regionprops_image_and_coords_sequence(ndim, property_name): + shape = (768, 512) if ndim == 2 else (64, 64, 64) + spacing = (1.5, 0.5, 0.76) + if spacing is not None: + spacing = spacing[:ndim] + labels = get_labels_nd(shape) + max_label = int(labels.max()) + if property_name == "image_intensity": + intensity_image = get_intensity_image( + shape, dtype=cp.uint16, num_channels=1 + ) + else: + intensity_image = None + props = regionprops_dict( + labels, + intensity_image, + properties=[property_name], + spacing=spacing, + max_label=max_label, + ) + assert property_name in props + result = props[property_name] + assert len(result) == max_label + + # compute expected result on CPU + labels_cpu = cp.asnumpy(labels) + if intensity_image is not None: + intensity_image_cpu = cp.asnumpy(intensity_image) + else: + intensity_image_cpu = None + expected = measure_cpu.regionprops_table( + labels_cpu, + intensity_image_cpu, + properties=[property_name], + spacing=spacing, + )[property_name] + assert len(expected) == max_label + + # verify + if property_name == "label": + assert_array_equal(expected, result) + else: + if property_name == "coords_scaled": + comparison_func = functools.partial( + assert_allclose, atol=1e-6, rtol=1e-6 + ) + else: + comparison_func = assert_array_equal + for i, (expected_val, val) in enumerate(zip(expected, result)): + comparison_func(expected_val, val) + return + + +@pytest.mark.parametrize( + "property_name", ["Area", "BoundingBoxArea", "Image", "Slice"] +) +def test_regionprops_dict_deprecated_property_names(property_name): + shape = (1024, 1024) + labels = get_labels_nd(shape) + props = regionprops_dict(labels, properties=[property_name]) + # deprecated name is used in the returned results dict + assert property_name in props + # non-deprecated version of the name is not also present + assert PROPS[property_name] not in props From 783a4343aefb3b6a1706421d5d86cf5435823252 Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Sat, 1 Mar 2025 11:06:01 -0500 Subject: [PATCH 04/14] fix shape bug in convex_hull when 2d image has a singletone size on one axis --- python/cucim/src/cucim/skimage/morphology/convex_hull.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cucim/src/cucim/skimage/morphology/convex_hull.py b/python/cucim/src/cucim/skimage/morphology/convex_hull.py index 3896b37a..262af83d 100644 --- a/python/cucim/src/cucim/skimage/morphology/convex_hull.py +++ b/python/cucim/src/cucim/skimage/morphology/convex_hull.py @@ -207,7 +207,7 @@ def convex_hull_image( original_shape = image.shape image = cp.squeeze(image) if image.ndim < 2: - return image + return image.reshape(original_shape) if image.size < cpu_fallback_threshold: # Fallback to pure CPU implementation From b64de6e2dadfb0bfac5c1ccd1d00cfc392ef1588 Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Sat, 1 Mar 2025 11:31:33 -0500 Subject: [PATCH 05/14] add batch kernels for intensity-based regionprops - intensity_mean - intensity_std - intensity_min - intensity_max Both single and multi-channel intensity images are supported --- python/cucim/pyproject.toml | 2 +- .../cucim/skimage/measure/_regionprops_gpu.py | 55 +- .../_regionprops_gpu_intensity_kernels.py | 657 ++++++++++++++++++ .../skimage/measure/_regionprops_gpu_utils.py | 19 + .../tests/test_regionprops_gpu_kernels.py | 149 +++- 5 files changed, 875 insertions(+), 7 deletions(-) create mode 100644 python/cucim/src/cucim/skimage/measure/_regionprops_gpu_intensity_kernels.py diff --git a/python/cucim/pyproject.toml b/python/cucim/pyproject.toml index 1f4ee694..1473cbae 100644 --- a/python/cucim/pyproject.toml +++ b/python/cucim/pyproject.toml @@ -225,7 +225,7 @@ exclude = ''' # codespell --toml python/cucim/pyproject.toml . -i 3 -w skip = "build*,dist,.cache,html,_build,_deps,3rdparty/*,_static,generated,latex,.git,*.ipynb,test_data/input/LICENSE-3rdparty,jitify_testing" # ignore-regex = "" -ignore-words-list = "ans,coo,boun,bui,gool,hart,lond,manuel,nd,paeth,unser,wronly" +ignore-words-list = "ans,coo,boun,bu,bui,gool,hart,lond,manuel,nd,paeth,unser,wronly" quiet-level = 3 # to undo: ./test_data/input/LICENSE-3rdparty diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py index 6c8b2866..349f27a2 100644 --- a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py @@ -23,6 +23,12 @@ regionprops_num_perimeter_pixels, regionprops_num_pixels, ) +from ._regionprops_gpu_intensity_kernels import ( + intensity_deps, + regionprops_intensity_mean, + regionprops_intensity_min_max, + regionprops_intensity_std, +) from ._regionprops_gpu_utils import _get_min_integer_dtype __all__ = [ @@ -34,6 +40,9 @@ "regionprops_dict", "regionprops_extent", "regionprops_image", + "regionprops_intensity_mean", + "regionprops_intensity_min_max", + "regionprops_intensity_std", # extra functions for cuCIM not currently in scikit-image "equivalent_spherical_perimeter", # as in ITK "regionprops_num_boundary_pixels", @@ -92,9 +101,7 @@ # get_property_dependencies below). property_deps = dict() property_deps.update(basic_deps) - -# set of properties that require an intensity_image also be provided -need_intensity_image = {"image_intensity"} +property_deps.update(intensity_deps) # set of properties that only supports 2D images ndim_2_only = set() @@ -124,6 +131,9 @@ def depth_first_search(n): for k in (CURRENT_PROPS_GPU | GLOBAL_PROPS) } +# set of properties that require an intensity_image also be provided +need_intensity_image = set(intensity_deps.keys()) | {"image_intensity"} + def regionprops_dict( label_image, @@ -218,12 +228,15 @@ def regionprops_dict( "2D label images only: {invalid_names}" ) if intensity_image is None: + has_intensity = False invalid_names = requested_props & need_intensity_image if any(invalid_names): raise ValueError( "No intensity_image provided, but the following requested " "properties require one: {invalid_names}" ) + else: + has_intensity = True out = {} if max_label is None: @@ -270,6 +283,42 @@ def regionprops_dict( "equivalent_spherical_perimeter" ] = equivalent_spherical_perimeter(out["area"], ndim, ed) + if has_intensity: + if "intensity_std" in required_props: + # std also computes mean + regionprops_intensity_std( + label_image, + intensity_image, + max_label=max_label, + std_dtype=cp.float64, + sample_std=False, + **perf_kwargs, + props_dict=out, + ) + + elif "intensity_mean" in required_props: + regionprops_intensity_mean( + label_image, + intensity_image, + max_label=max_label, + mean_dtype=cp.float32, + **perf_kwargs, + props_dict=out, + ) + + compute_min = "intensity_min" in required_props + compute_max = "intensity_max" in required_props + if compute_min or compute_max: + regionprops_intensity_min_max( + label_image, + intensity_image, + max_label=max_label, + compute_min=compute_min, + compute_max=compute_max, + **perf_kwargs, + props_dict=out, + ) + compute_bbox = "bbox" in required_props if compute_bbox: # compute bbox (and slice) diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_intensity_kernels.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_intensity_kernels.py new file mode 100644 index 00000000..dcf1a61c --- /dev/null +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_intensity_kernels.py @@ -0,0 +1,657 @@ +import math + +import cupy as cp + +from ._regionprops_gpu_utils import ( + _check_intensity_image_shape, + _get_count_dtype, + _includes, +) + +__all__ = [ + "regionprops_intensity_mean", + "regionprops_intensity_min_max", + "regionprops_intensity_std", +] + +# Store information on which other properties a given property depends on +# This information will be used by `regionprops_dict` to make sure that when +# a particular property is requested any dependent properties are computed +# first. +intensity_deps = dict() +intensity_deps["intensity_min"] = [] +intensity_deps["intensity_max"] = [] +intensity_deps["intensity_mean"] = ["num_pixels"] +intensity_deps["intensity_std"] = ["num_pixels"] + + +def _get_img_sums_code( + c_sum_type, + pixels_per_thread, + array_size, + num_channels=1, + compute_num_pixels=True, + compute_sum=True, + compute_sum_sq=False, +): + """ + Notes + ----- + Local variables created: + + - num_pixels : shape (array_size, ) + The number of pixels encountered per label value + + Output variables written to: + + - counts : shape (max_label,) + """ + pixel_count_dtype = "int8_t" if pixels_per_thread < 256 else "int16_t" + + source_pre = "" + if compute_num_pixels: + source_pre += f""" + {pixel_count_dtype} num_pixels[{array_size}] = {{0}};""" + if compute_sum: + source_pre += f""" + {c_sum_type} img_sums[{array_size * num_channels}] = {{0}};""" + if compute_sum_sq: + source_pre += f""" + {c_sum_type} img_sum_sqs[{array_size * num_channels}] = {{0}};""" + if compute_sum or compute_sum_sq: + source_pre += f""" + {c_sum_type} v = 0;\n""" + + # source_operation requires external variables: + # ii : index into labels array + # offset : index into local region's num_pixels array + # (number of unique labels encountered so far by this thread) + source_operation = "" + if compute_num_pixels: + source_operation += """ + num_pixels[offset] += 1;""" + nc = f"{num_channels}*" if num_channels > 1 else "" + if compute_sum or compute_sum_sq: + for c in range(num_channels): + source_operation += f""" + v = static_cast<{c_sum_type}>(img[{nc}ii + {c}]);""" + if compute_sum: + source_operation += f""" + img_sums[{nc}offset + {c}] += v;""" + if compute_sum_sq: + source_operation += f""" + img_sum_sqs[{nc}offset + {c}] += v * v;\n""" + + # post_operation requires external variables: + # jj : index into num_pixels array + # lab : label value that corresponds to location ii + # num_pixels : output with shape (max_label,) + # sums : output with shape (max_label, num_channels) + # sumsqs : output with shape (max_label, num_channels) + source_post = "" + if compute_num_pixels: + source_post += """ + atomicAdd(&counts[lab - 1], num_pixels[jj]);""" + if compute_sum: + for c in range(num_channels): + source_post += f""" + atomicAdd(&sums[{nc}(lab - 1) + {c}], img_sums[{nc}jj + {c}]);""" + if compute_sum_sq: + for c in range(num_channels): + source_post += f""" + atomicAdd(&sumsqs[{nc}(lab - 1) + {c}], img_sum_sqs[{nc}jj + {c}]);""" # noqa: E501 + return source_pre, source_operation, source_post + + +def _get_intensity_min_max_code( + min_max_dtype, + c_min_max_type, + array_size, + initial_min_val, + initial_max_val, + compute_min=True, + compute_max=True, + num_channels=1, +): + min_max_dtype = cp.dtype(min_max_dtype) + c_type = c_min_max_type + + # Note: CuPy provides atomicMin and atomicMax for float and double in + # cupy/_core/include/atomics.cuh + # The integer variants are part of CUDA itself. + + source_pre = "" + if compute_min: + source_pre += f""" + {c_type} min_vals[{array_size * num_channels}]; + // initialize minimum coordinate to array size + for (size_t ii = 0; ii < {array_size * num_channels}; ii++) {{ + min_vals[ii] = {initial_min_val}; + }}""" + if compute_max: + source_pre += f""" + {c_type} max_vals[{array_size * num_channels}]; + // initialize minimum coordinate to array size + for (size_t ii = 0; ii < {array_size * num_channels}; ii++) {{ + max_vals[ii] = {initial_max_val}; + }}""" + source_pre += f""" + {c_type} v = 0;\n""" + + # source_operation requires external variables: + # ii : index into labels array + # offset : index into local region's num_pixels array + # (number of unique labels encountered so far by this thread) + source_operation = "" + nc = f"{num_channels}*" if num_channels > 1 else "" + if compute_min or compute_max: + for c in range(num_channels): + source_operation += f""" + v = static_cast<{c_type}>(img[{nc}ii + {c}]);""" + if compute_min: + source_operation += f""" + min_vals[{nc}offset + {c}] = min(v, min_vals[{nc}offset + {c}]);""" + if compute_max: + source_operation += f""" + max_vals[{nc}offset + {c}] = max(v, max_vals[{nc}offset + {c}]);\n""" # noqa: E501 + + # post_operation requires external variables: + # jj : offset index into min_vals or max_vals array + # lab : label value that corresponds to location ii + # min_vals : output with shape (max_label, num_channels) + # max_vals : output with shape (max_label, num_channels) + source_post = "" + if compute_min: + for c in range(num_channels): + source_post += f""" + atomicMin(&minimums[{nc}(lab - 1) + {c}], min_vals[{nc}jj + {c}]);""" # noqa: E501 + if compute_max: + for c in range(num_channels): + source_post += f""" + atomicMax(&maximums[{nc}(lab - 1) + {c}], max_vals[{nc}jj + {c}]);""" # noqa: E501 + return source_pre, source_operation, source_post + + +@cp.memoize() +def _get_intensity_img_kernel_dtypes(image_dtype): + """Determine CuPy dtype and C++ type for image sum operations.""" + image_dtype = cp.dtype(image_dtype) + if image_dtype.kind == "f": + # use double for accuracy of mean/std computations + c_sum_type = "double" + dtype = cp.float64 + # atomicMin, atomicMax support 32 and 64-bit float + if image_dtype.itemsize > 4: + min_max_dtype = cp.float64 + c_min_max_type = "double" + else: + min_max_dtype = cp.float32 + c_min_max_type = "float" + elif image_dtype.kind in "bu": + c_sum_type = "uint64_t" + dtype = cp.uint64 + if image_dtype.itemsize > 4: + min_max_dtype = cp.uint64 + c_min_max_type = "uint64_t" + else: + min_max_dtype = cp.uint32 + c_min_max_type = "uint32_t" + elif image_dtype.kind in "i": + c_sum_type = "int64_t" + dtype = cp.int64 + if image_dtype.itemsize > 4: + min_max_dtype = cp.int64 + c_min_max_type = "int64_t" + else: + min_max_dtype = cp.int32 + c_min_max_type = "int32_t" + else: + raise ValueError( + f"Invalid intensity image dtype {image_dtype.name}. " + "Must be an unsigned, integer or floating point type." + ) + return cp.dtype(dtype), c_sum_type, cp.dtype(min_max_dtype), c_min_max_type + + +@cp.memoize() +def _get_intensity_range(image_dtype): + """Determine CuPy dtype and C++ type for image sum operations.""" + image_dtype = cp.dtype(image_dtype) + if image_dtype.kind == "f": + # use double for accuracy of mean/std computations + info = cp.finfo(image_dtype) + elif image_dtype.kind in "bui": + info = cp.iinfo(image_dtype) + else: + raise ValueError( + f"Invalid intensity image dtype {image_dtype.name}. " + "Must be an unsigned, integer or floating point type." + ) + return (info.min, info.max) + + +@cp.memoize(for_each_device=True) +def get_intensity_measure_kernel( + image_dtype=None, + int32_count=True, + num_channels=1, + compute_num_pixels=True, + compute_sum=True, + compute_sum_sq=False, + compute_min=False, + compute_max=False, + pixels_per_thread=8, +): + if compute_num_pixels: + count_dtype = cp.dtype(cp.uint32 if int32_count else cp.uint64) + + ( + sum_dtype, + c_sum_type, + min_max_dtype, + c_min_max_type, + ) = _get_intensity_img_kernel_dtypes(image_dtype) + + array_size = pixels_per_thread + any_sums = compute_num_pixels or compute_sum or compute_sum_sq + + if any_sums: + sums_pre, sums_op, sums_post = _get_img_sums_code( + c_sum_type=c_sum_type, + pixels_per_thread=pixels_per_thread, + array_size=array_size, + num_channels=num_channels, + compute_num_pixels=compute_num_pixels, + compute_sum=compute_sum, + compute_sum_sq=compute_sum_sq, + ) + + any_min_max = compute_min or compute_max + if any_min_max: + if min_max_dtype is None: + raise ValueError("min_max_dtype must be specified") + range_min, range_max = _get_intensity_range(min_max_dtype) + min_max_pre, min_max_op, min_max_post = _get_intensity_min_max_code( + min_max_dtype=min_max_dtype, + c_min_max_type=c_min_max_type, + array_size=array_size, + num_channels=num_channels, + initial_max_val=range_min, + initial_min_val=range_max, + compute_min=compute_min, + compute_max=compute_max, + ) + + if not (any_min_max or any_sums): + raise ValueError("no output values requested") + + # store only counts for label > 0 (label = 0 is the background) + source = f""" + uint64_t start_index = {pixels_per_thread}*i; + """ + + if any_sums: + source += sums_pre + if any_min_max: + source += min_max_pre + + inner_op = "" + if any_sums: + inner_op += sums_op + if any_min_max: + inner_op += min_max_op + + source += f""" + X encountered_labels[{array_size}] = {{0}}; + X current_label; + X prev_label = labels[start_index]; + int offset = 0; + encountered_labels[0] = prev_label; + uint64_t ii_max = min(start_index + {pixels_per_thread}, labels_size); + for (uint64_t ii = start_index; ii < ii_max; ii++) {{ + current_label = labels[ii]; + if (current_label == 0) {{ continue; }} + if (current_label != prev_label) {{ + offset += 1; + prev_label = current_label; + encountered_labels[offset] = current_label; + }} + {inner_op} + }}""" + source += """ + for (size_t jj = 0; jj <= offset; jj++) { + X lab = encountered_labels[jj]; + if (lab != 0) {""" + + if any_sums: + source += sums_post + if any_min_max: + source += min_max_post + source += """ + } + }\n""" + + # print(source) + inputs = "raw X labels, raw uint64 labels_size, raw Y img" + outputs = [] + name = "cucim_" + if compute_num_pixels: + outputs.append(f"raw {count_dtype.name} counts") + name += f"_numpix_{count_dtype.char}" + if compute_sum: + outputs.append(f"raw {sum_dtype.name} sums") + name += "_sum" + if compute_sum_sq: + outputs.append(f"raw {sum_dtype.name} sumsqs") + name += "_sumsq" + if compute_sum or compute_sum_sq: + name += f"_{sum_dtype.char}" + if compute_min: + outputs.append(f"raw {min_max_dtype.name} minimums") + name += "_min" + if compute_max: + outputs.append(f"raw {min_max_dtype.name} maximums") + name += "_max" + if compute_min or compute_max: + name += f"{min_max_dtype.char}" + outputs = ", ".join(outputs) + name += f"_batch{pixels_per_thread}" + return cp.ElementwiseKernel( + inputs, outputs, source, preamble=_includes, name=name + ) + + +def regionprops_intensity_mean( + label_image, + intensity_image, + max_label=None, + mean_dtype=cp.float32, + pixels_per_thread=16, + props_dict=None, +): + """Compute the mean intensity of each region. + + reuses "num_pixels" from `props_dict` if it exists + + writes "intensity_mean" to `props_dict` + writes "num_pixels" to `props_dict` if it was not already present + """ + if props_dict is None: + props_dict = {} + if max_label is None: + max_label = int(label_image.max()) + num_counts = max_label + + num_channels = _check_intensity_image_shape(label_image, intensity_image) + + count_dtype, int32_count = _get_count_dtype(label_image.size) + + image_dtype = intensity_image.dtype + sum_dtype, _, _, _ = _get_intensity_img_kernel_dtypes(image_dtype) + + if "num_pixels" in props_dict: + counts = props_dict["num_pixels"] + if counts.dtype != count_dtype: + counts = counts.astype(count_dtype, copy=False) + compute_num_pixels = False + else: + counts = cp.zeros(num_counts, dtype=count_dtype) + compute_num_pixels = True + + sum_shape = ( + (num_counts,) if num_channels == 1 else (num_counts, num_channels) + ) + sums = cp.zeros(sum_shape, dtype=sum_dtype) + + kernel = get_intensity_measure_kernel( + int32_count=int32_count, + image_dtype=image_dtype, + num_channels=num_channels, + compute_num_pixels=compute_num_pixels, + compute_sum=True, + compute_sum_sq=False, + pixels_per_thread=pixels_per_thread, + ) + + # make a copy if the inputs are not already C-contiguous + if not label_image.flags.c_contiguous: + label_image = cp.ascontiguousarray(label_image) + if not intensity_image.flags.c_contiguous: + intensity_image = cp.ascontiguousarray(intensity_image) + + if compute_num_pixels: + outputs = (counts, sums) + else: + outputs = (sums,) + + kernel( + label_image, + label_image.size, + intensity_image, + *outputs, + size=math.ceil(label_image.size / pixels_per_thread), + ) + + if num_channels > 1: + means = sums / counts[:, cp.newaxis] + else: + means = sums / counts + means = means.astype(mean_dtype, copy=False) + props_dict["intensity_mean"] = means + if "num_pixels" not in props_dict: + props_dict["num_pixels"] = counts + return props_dict + + +@cp.memoize(for_each_device=True) +def get_mean_var_kernel(dtype, sample_std=False): + dtype = cp.dtype(dtype) + + if dtype.kind != "f": + raise ValueError("dtype must be a floating point type") + if dtype == cp.float64: + c_type = "double" + nan_val = "CUDART_NAN" + else: + c_type = "float" + nan_val = "CUDART_NAN_F" + + if sample_std: + source = f""" + if (count == 1) {{ + m = static_cast<{c_type}>(sum); + var = {nan_val}; + }} else {{ + m = static_cast(sum) / count; + var = sqrt( + (static_cast(sumsq) - m * m * count) / (count - 1)); + }}\n""" + else: + source = f""" + if (count == 0) {{ + m = static_cast<{c_type}>(sum); + var = {nan_val}; + }} else if (count == 1) {{ + m = static_cast<{c_type}>(sum); + var = 0.0; + }} else {{ + m = static_cast(sum) / count; + var = sqrt( + (static_cast(sumsq) - m * m * count) / count); + }}\n""" + inputs = "X count, Y sum, Y sumsq" + outputs = "Z m, Z var" + name = f"cucim_sample_std_naive_{dtype.name}" + return cp.ElementwiseKernel( + inputs, outputs, source, preamble=_includes, name=name + ) + + +def regionprops_intensity_std( + label_image, + intensity_image, + sample_std=False, + max_label=None, + std_dtype=cp.float64, + pixels_per_thread=4, + props_dict=None, +): + """Compute the mean and standard deviation of the intensity of each region. + + reuses "num_pixels" from `props_dict` if it exists + + writes "intensity_mean" to `props_dict` + writes "intensity_std" to `props_dict` + writes "num_pixels" to `props_dict` if it was not already present + """ + if props_dict is None: + props_dict = {} + if max_label is None: + max_label = int(label_image.max()) + num_counts = max_label + + num_channels = _check_intensity_image_shape(label_image, intensity_image) + + image_dtype = intensity_image.dtype + sum_dtype, _, _, _ = _get_intensity_img_kernel_dtypes(image_dtype) + + count_dtype, int32_count = _get_count_dtype(label_image.size) + + if "num_pixels" in props_dict: + counts = props_dict["num_pixels"] + if counts.dtype != count_dtype: + counts = counts.astype(count_dtype, copy=False) + compute_num_pixels = False + else: + counts = cp.zeros(num_counts, dtype=count_dtype) + compute_num_pixels = True + + sum_shape = ( + (num_counts,) if num_channels == 1 else (num_counts, num_channels) + ) + sums = cp.zeros(sum_shape, dtype=sum_dtype) + sumsqs = cp.zeros(sum_shape, dtype=sum_dtype) + + # make a copy if the inputs are not already C-contiguous + if not label_image.flags.c_contiguous: + label_image = cp.ascontiguousarray(label_image) + if not intensity_image.flags.c_contiguous: + intensity_image = cp.ascontiguousarray(intensity_image) + + # TODO(grelee): May want to provide an approach with better numerical + # stability (i.e.like the two-pass algorithm or Welford's online algorithm) + kernel = get_intensity_measure_kernel( + int32_count=int32_count, + image_dtype=image_dtype, + num_channels=num_channels, + compute_num_pixels=compute_num_pixels, + compute_sum=True, + compute_sum_sq=True, + pixels_per_thread=pixels_per_thread, + ) + if compute_num_pixels: + outputs = (counts, sums, sumsqs) + else: + outputs = (sums, sumsqs) + kernel( + label_image, + label_image.size, + intensity_image, + *outputs, + size=math.ceil(label_image.size / pixels_per_thread), + ) + + if cp.dtype(std_dtype).kind != "f": + raise ValueError("mean_dtype must be a floating point type") + + # compute means and standard deviations from the counts, sums and + # squared sums (use float64 here since the numerical stability of this + # approach is poor) + means = cp.zeros(sum_shape, dtype=cp.float64) + stds = cp.zeros(sum_shape, dtype=cp.float64) + kernel2 = get_mean_var_kernel(stds.dtype, sample_std=sample_std) + if num_channels > 1: + kernel2(counts[..., cp.newaxis], sums, sumsqs, means, stds) + else: + kernel2(counts, sums, sumsqs, means, stds) + + means = means.astype(std_dtype, copy=False) + stds = stds.astype(std_dtype, copy=False) + props_dict["intensity_std"] = stds + props_dict["intensity_mean"] = means + if "num_pixels" not in props_dict: + props_dict["num_pixels"] = counts + return props_dict + + +def regionprops_intensity_min_max( + label_image, + intensity_image, + max_label=None, + compute_min=True, + compute_max=False, + pixels_per_thread=8, + props_dict=None, +): + """Compute the minimum and maximum intensity of each region. + + writes "intensity_min" to `props_dict` if `compute_min` is True + writes "intensity_max" to `props_dict` if `compute_max` is True + """ + if not (compute_min or compute_max): + raise ValueError("Nothing to compute") + if props_dict is None: + props_dict = {} + + if max_label is None: + max_label = int(label_image.max()) + num_counts = max_label + + num_channels = _check_intensity_image_shape(label_image, intensity_image) + + # use an appropriate data type supported by atomicMin and atomicMax + image_dtype = intensity_image.dtype + _, _, min_max_dtype, _ = _get_intensity_img_kernel_dtypes(image_dtype) + range_min, range_max = _get_intensity_range(image_dtype) + out_shape = ( + (num_counts,) if num_channels == 1 else (num_counts, num_channels) + ) + if compute_min: + minimums = cp.full(out_shape, range_max, dtype=min_max_dtype) + if compute_max: + maximums = cp.full(out_shape, range_min, dtype=min_max_dtype) + + kernel = get_intensity_measure_kernel( + image_dtype=image_dtype, + num_channels=num_channels, + compute_num_pixels=False, + compute_sum=False, + compute_sum_sq=False, + compute_min=compute_min, + compute_max=compute_max, + pixels_per_thread=pixels_per_thread, + ) + + # make a copy if the inputs are not already C-contiguous + if not label_image.flags.c_contiguous: + label_image = cp.ascontiguousarray(label_image) + if not intensity_image.flags.c_contiguous: + intensity_image = cp.ascontiguousarray(intensity_image) + + lab_size = label_image.size + sz = math.ceil(label_image.size / pixels_per_thread) + if compute_min and compute_max: + outputs = (minimums, maximums) + elif compute_min: + outputs = (minimums,) + else: + outputs = (maximums,) + + kernel( + label_image, lab_size, intensity_image, *outputs, size=sz + ) # noqa: E501 + if compute_min: + props_dict["intensity_min"] = minimums + if compute_max: + props_dict["intensity_max"] = maximums + return props_dict diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_utils.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_utils.py index 5d5f19ac..3e331efc 100644 --- a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_utils.py +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_utils.py @@ -1,3 +1,5 @@ +import math + import cupy as cp from packaging.version import parse @@ -31,6 +33,23 @@ def _get_min_integer_dtype(max_size, signed=False): return func(-max_size - 1) if signed else func(max_size) +def _check_intensity_image_shape(label_image, intensity_image): + ndim = label_image.ndim + if intensity_image.shape[:ndim] != label_image.shape: + raise ValueError( + "Initial dimensions of `intensity_image` must match the shape of " + "`label_image`. (`intensity_image` may have additional trailing " + "channels/batch dimensions)" + ) + + num_channels = ( + math.prod(intensity_image.shape[ndim:]) + if intensity_image.ndim > ndim + else 1 + ) + return num_channels + + def _unravel_loop_index_declarations(var_name, ndim, uint_t="unsigned int"): if ndim == 1: code = f""" diff --git a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py index 7243d766..26ad67ca 100644 --- a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py +++ b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py @@ -16,6 +16,7 @@ from cucim.skimage.measure._regionprops import PROPS from cucim.skimage.measure._regionprops_gpu import ( equivalent_diameter_area, + need_intensity_image, regionprops_area, regionprops_area_bbox, regionprops_bbox_coords, @@ -23,9 +24,15 @@ regionprops_dict, regionprops_extent, regionprops_image, + regionprops_intensity_mean, + regionprops_intensity_min_max, + regionprops_intensity_std, regionprops_num_pixels, ) from cucim.skimage.measure._regionprops_gpu_basic_kernels import basic_deps +from cucim.skimage.measure._regionprops_gpu_intensity_kernels import ( + intensity_deps, +) def get_labels_nd( @@ -172,6 +179,140 @@ def test_extent(ndim, area_dtype, spacing): assert_allclose(extent, expected["extent"], rtol=1e-5, atol=1e-5) +@pytest.mark.parametrize("precompute_max", [False, True]) +@pytest.mark.parametrize("ndim", [2, 3]) +@pytest.mark.parametrize("image_dtype", [cp.uint16, cp.uint8, cp.float32]) +@pytest.mark.parametrize("mean_dtype", [cp.float32, cp.float64]) +@pytest.mark.parametrize("num_channels", [1, 4]) +def test_mean_intensity( + precompute_max, ndim, image_dtype, mean_dtype, num_channels +): + shape = (256, 512) if ndim == 2 else (15, 63, 37) + labels = get_labels_nd(shape) + intensity_image = get_intensity_image( + shape, dtype=image_dtype, num_channels=num_channels + ) + + max_label = int(cp.max(labels)) if precompute_max else None + props_dict = regionprops_intensity_mean( + labels, intensity_image, max_label=max_label, mean_dtype=mean_dtype + ) + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), + intensity_image=cp.asnumpy(intensity_image), + properties=["num_pixels", "intensity_mean"], + ) + assert_array_equal(props_dict["num_pixels"], expected["num_pixels"]) + if num_channels == 1: + assert_allclose( + props_dict["intensity_mean"], expected["intensity_mean"], rtol=1e-3 + ) + else: + for c in range(num_channels): + assert_allclose( + props_dict["intensity_mean"][..., c], + expected[f"intensity_mean-{c}"], + rtol=1e-3, + ) + + +@pytest.mark.parametrize("precompute_max", [False, True]) +@pytest.mark.parametrize("ndim", [2, 3]) +@pytest.mark.parametrize( + "image_dtype", [cp.uint16, cp.uint8, cp.float16, cp.float32, cp.float64] +) +@pytest.mark.parametrize("op_name", ["intensity_min", "intensity_max"]) +@pytest.mark.parametrize("num_channels", [1, 3]) +def test_intensity_min_and_max( + precompute_max, ndim, image_dtype, op_name, num_channels +): + shape = (256, 512) if ndim == 2 else (15, 63, 37) + labels = get_labels_nd(shape) + intensity_image = get_intensity_image( + shape, dtype=image_dtype, num_channels=num_channels + ) + + max_label = int(cp.max(labels)) if precompute_max else None + + compute_min = op_name == "intensity_min" + compute_max = not compute_min + + values = regionprops_intensity_min_max( + labels, + intensity_image, + max_label=max_label, + compute_min=compute_min, + compute_max=compute_max, + )[op_name] + + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), + intensity_image=cp.asnumpy(intensity_image), + properties=[op_name], + ) + if num_channels == 1: + assert_allclose(values, expected[op_name]) + else: + for c in range(num_channels): + assert_allclose(values[..., c], expected[f"{op_name}-{c}"]) + + +@pytest.mark.parametrize("precompute_max", [False, True]) +@pytest.mark.parametrize("ndim", [2, 3]) +@pytest.mark.parametrize("image_dtype", [cp.uint16, cp.uint8, cp.float32]) +@pytest.mark.parametrize("std_dtype", [cp.float32, cp.float64]) +@pytest.mark.parametrize("num_channels", [1, 5]) +def test_intensity_std( + precompute_max, ndim, image_dtype, std_dtype, num_channels +): + shape = (1024, 2048) if ndim == 2 else (40, 64, 80) + labels = get_labels_nd(shape) + intensity_image = get_intensity_image( + shape, dtype=image_dtype, num_channels=num_channels + ) + + max_label = int(cp.max(labels)) if precompute_max else None + + # add some specifically sized regions + if ndim == 2 and precompute_max: + # clear small region + labels[50:54, 50:56] = 0 + # add a single pixel labeled region + labels[51, 51] = max_label + 1 + # add a two pixel labeled region + labels[53, 53:55] = max_label + 2 + max_label += 2 + + props_dict = regionprops_intensity_std( + labels, intensity_image, max_label=max_label, std_dtype=std_dtype + ) + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), + intensity_image=cp.asnumpy(intensity_image), + properties=["num_pixels", "intensity_mean", "intensity_std"], + ) + assert_array_equal(props_dict["num_pixels"], expected["num_pixels"]) + if num_channels == 1: + assert_allclose( + props_dict["intensity_mean"], expected["intensity_mean"], rtol=1e-3 + ) + assert_allclose( + props_dict["intensity_std"], expected["intensity_std"], rtol=1e-3 + ) + else: + for c in range(num_channels): + assert_allclose( + props_dict["intensity_mean"][..., c], + expected[f"intensity_mean-{c}"], + rtol=1e-3, + ) + assert_allclose( + props_dict["intensity_std"][..., c], + expected[f"intensity_std-{c}"], + rtol=1e-3, + ) + + @pytest.mark.parametrize("precompute_max", [False, True]) @pytest.mark.parametrize("dtype", [cp.uint32, cp.int64]) @pytest.mark.parametrize("return_slices", [False, True]) @@ -309,7 +450,9 @@ def test_coords(ndim, spacing): @pytest.mark.parametrize("ndim", [2, 3]) @pytest.mark.parametrize("spacing", [None, (1.5, 0.5, 0.76)]) -@pytest.mark.parametrize("property_name", list(basic_deps.keys())) +@pytest.mark.parametrize( + "property_name", list(basic_deps.keys()) + list(intensity_deps.keys()) +) def test_regionprops_dict_single_property(ndim, spacing, property_name): """Test to verify that any dependencies for a given property are automatically handled. @@ -318,7 +461,7 @@ def test_regionprops_dict_single_property(ndim, spacing, property_name): if spacing is not None: spacing = spacing[:ndim] labels = get_labels_nd(shape) - if property_name == "image_intensity": + if property_name in need_intensity_image: intensity_image = get_intensity_image( shape, dtype=cp.uint16, num_channels=1 ) @@ -352,7 +495,7 @@ def test_regionprops_image_and_coords_sequence(ndim, property_name): spacing = spacing[:ndim] labels = get_labels_nd(shape) max_label = int(labels.max()) - if property_name == "image_intensity": + if property_name in need_intensity_image: intensity_image = get_intensity_image( shape, dtype=cp.uint16, num_channels=1 ) From 15b57c1250669725ce338d2391c14427ca18a8de Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Sat, 1 Mar 2025 13:13:43 -0500 Subject: [PATCH 06/14] implement convex region properties These properties are computed based on the image_convex subimages: - area_convex - feret_diameter_max - solidity --- .../cucim/src/cucim/skimage/_shared/utils.py | 22 +++ .../cucim/skimage/measure/_regionprops_gpu.py | 23 +++ .../measure/_regionprops_gpu_basic_kernels.py | 1 - .../measure/_regionprops_gpu_convex.py | 147 ++++++++++++++++++ .../tests/test_regionprops_gpu_kernels.py | 130 +++++++++++++++- 5 files changed, 321 insertions(+), 2 deletions(-) create mode 100644 python/cucim/src/cucim/skimage/measure/_regionprops_gpu_convex.py diff --git a/python/cucim/src/cucim/skimage/_shared/utils.py b/python/cucim/src/cucim/skimage/_shared/utils.py index 7312eaf4..8ff3c1d7 100644 --- a/python/cucim/src/cucim/skimage/_shared/utils.py +++ b/python/cucim/src/cucim/skimage/_shared/utils.py @@ -27,6 +27,28 @@ ] +# For n nonzero elements cupy.nonzero returns a tuple of length ndim where +# each element is an array of size (n, ) corresponding to the coordinates on +# a specific axis. +# +# Often for regionprops purposes we would rather have a single array of +# size (n, ndim) instead of a the tuple of arrays. +# +# CuPy's `_ndarray_argwhere` (used internally by cupy.nonzero) already provides +# this but is not part of the public API. To guard against potential future +# change we provide a less efficient fallback implementation. +try: + from cupy._core._routines_indexing import _ndarray_argwhere +except ImportError: + + def _ndarray_argwhere(a): + """Stack the result of cupy.nonzero into a single array + + output shape will be (num_nonzero, ndim) + """ + return cp.stack(cp.nonzero(a), axis=-1) + + def _count_wrappers(func): """Count the number of wrappers around `func`.""" unwrapped = func diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py index 349f27a2..f4da492d 100644 --- a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py @@ -23,6 +23,11 @@ regionprops_num_perimeter_pixels, regionprops_num_pixels, ) +from ._regionprops_gpu_convex import ( + convex_deps, + regionprops_area_convex, + regionprops_feret_diameter_max, +) from ._regionprops_gpu_intensity_kernels import ( intensity_deps, regionprops_intensity_mean, @@ -35,10 +40,12 @@ "equivalent_diameter_area", "regionprops_area", "regionprops_area_bbox", + "regionprops_area_convex", "regionprops_bbox_coords", "regionprops_coords", "regionprops_dict", "regionprops_extent", + "regionprops_feret_diameter_max", "regionprops_image", "regionprops_intensity_mean", "regionprops_intensity_min_max", @@ -101,6 +108,7 @@ # get_property_dependencies below). property_deps = dict() property_deps.update(basic_deps) +property_deps.update(convex_deps) property_deps.update(intensity_deps) # set of properties that only supports 2D images @@ -376,6 +384,21 @@ def regionprops_dict( offset_coordinates=True, ) + if "area_convex" in required_props: + regionprops_area_convex( + out["image_convex"], max_label=max_label, props_dict=out + ) + + if "solidity" in required_props: + out["solidity"] = out["area"] / out["area_convex"] + + if "feret_diameter_max" in required_props: + regionprops_feret_diameter_max( + out["image_convex"], + spacing=spacing, + props_dict=out, + ) + compute_coords = "coords" in required_props compute_coords_scaled = "coords_scaled" in required_props if compute_coords or compute_coords_scaled: diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_basic_kernels.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_basic_kernels.py index b3ca745d..90b2d286 100644 --- a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_basic_kernels.py +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_basic_kernels.py @@ -49,7 +49,6 @@ basic_deps["equivalent_spherical_perimeter"] = ["equivalent_diameter_area"] basic_deps["extent"] = ["area", "area_bbox"] basic_deps["image"] = [] -basic_deps["image_convex"] = [] basic_deps["image_filled"] = ["label_filled"] basic_deps["image_intensity"] = [] basic_deps["label"] = [] diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_convex.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_convex.py new file mode 100644 index 00000000..aa95d7e8 --- /dev/null +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_convex.py @@ -0,0 +1,147 @@ +import math +import warnings +from collections.abc import Sequence + +import cupy as cp + +from cucim.skimage._shared.distance import pdist_max_blockwise +from cucim.skimage._shared.utils import _ndarray_argwhere +from cucim.skimage._vendored import ndimage as ndi + +# Store information on which other properties a given property depends on +# This information will be used by `regionprops_dict` to make sure that when +# a particular property is requested any dependent properties are computed +# first. +convex_deps = dict() +convex_deps["image_convex"] = ["image"] # computed by regionprops_image +convex_deps["area_convex"] = ["image_convex"] +convex_deps["feret_diameter_max"] = ["image_convex"] +convex_deps["solidity"] = ["area", "area_convex"] + + +def regionprops_area_convex( + images_convex, + max_label=None, + spacing=None, + area_dtype=cp.float64, + props_dict=None, +): + """Compute the area of each convex image. + + writes "area_convex" to props_dict + + Parameters + ---------- + images_convex : sequence of cupy.ndarray + Convex images for each region as produced by ``regionprops_image`` with + ``compute_convex=True``. + """ + if max_label is None: + max_label = len(images_convex) + if not isinstance(images_convex, Sequence): + raise ValueError("Expected `images_convex` to be a sequence of images") + area_convex = cp.zeros((max_label,), dtype=area_dtype) + for i in range(max_label): + area_convex[i] = images_convex[i].sum() + if spacing is not None: + if isinstance(spacing, cp.ndarray): + pixel_area = cp.product(spacing) + else: + pixel_area = math.prod(spacing) + area_convex *= pixel_area + if props_dict is not None: + props_dict["area_convex"] = area_convex + return area_convex + + +def _regionprops_coords_perimeter( + image, + connectivity=1, +): + """ + Takes an image of a single labeled region (e.g. one element of the tuple + resulting from regionprops_image) and returns the coordinates of the voxels + at the edge of that region. + """ + + # remove non-boundary pixels + binary_image = image > 0 + footprint = ndi.generate_binary_structure( + binary_image.ndim, connectivity=connectivity + ) + binary_image_eroded = ndi.binary_erosion(binary_image, footprint) + binary_edges = binary_image * ~binary_image_eroded + edge_coords = _ndarray_argwhere(binary_edges) + return edge_coords + + +def _feret_diameter_max(image_convex, spacing=None, return_argmax=False): + """Compute the maximum Feret diameter of a single convex image region.""" + if image_convex.size == 1: + warnings.warn( + "single element image, returning 0 for feret diameter", UserWarning + ) + return 0 + coords = _regionprops_coords_perimeter(image_convex, connectivity=1) + coords = coords.astype(cp.float32) + + if spacing is not None: + if all(s == 1.0 for s in spacing): + spacing = None + else: + spacing = cp.asarray(spacing, dtype=cp.float32).reshape(1, -1) + coords *= spacing + + out = pdist_max_blockwise( + coords, + metric="sqeuclidean", + compute_argmax=return_argmax, + coords_per_block=4000, + ) + if return_argmax: + return math.sqrt(out[0]), out[1] + return math.sqrt(out[0]) + + +def regionprops_feret_diameter_max( + images_convex, spacing=None, props_dict=None +): + """Compute the maximum Feret diameter of the convex hull of each image in + images_convex. + + writes "feret_diameter_max" to props_dict + + Parameters + ---------- + images_convex : sequence of cupy.ndarray + Convex images for each region as produced by ``regionprops_image`` with + ``compute_convex=True``. + spacing : tuple of float, optional + The pixel spacing of the image. + props_dict : dict, optional + A dictionary to store the computed properties. + + Notes + ----- + The maximum Feret diameter is the maximum distance between any two + points on the convex hull of the region. The implementation here is based + on pairwise distances of all boundary coordinates rather than using + marching squares or marching cubes as in scikit-image. The implementation + here is n-dimensional. + + The distance is between pixel centers and so may be approximately one pixel + width less than the one computed by scikit-image. + """ + if not isinstance(images_convex, Sequence): + raise ValueError("Expected `images_convex` to be a sequence of images") + diameters = cp.asarray( + tuple( + _feret_diameter_max( + image_convex, spacing=spacing, return_argmax=False + ) + for image_convex in images_convex + ) + ) + if props_dict is not None: + props_dict["feret_diameter_max"] = diameters + return diameters diff --git a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py index 26ad67ca..1f162692 100644 --- a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py +++ b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py @@ -3,6 +3,7 @@ import warnings import cupy as cp +import numpy as np import pytest from cupy.testing import ( assert_allclose, @@ -19,10 +20,12 @@ need_intensity_image, regionprops_area, regionprops_area_bbox, + regionprops_area_convex, regionprops_bbox_coords, regionprops_coords, regionprops_dict, regionprops_extent, + regionprops_feret_diameter_max, regionprops_image, regionprops_intensity_mean, regionprops_intensity_min_max, @@ -30,6 +33,7 @@ regionprops_num_pixels, ) from cucim.skimage.measure._regionprops_gpu_basic_kernels import basic_deps +from cucim.skimage.measure._regionprops_gpu_convex import convex_deps from cucim.skimage.measure._regionprops_gpu_intensity_kernels import ( intensity_deps, ) @@ -143,6 +147,72 @@ def test_area(precompute_max, ndim, area_dtype, spacing): ) +@pytest.mark.parametrize("ndim", [2, 3]) +@pytest.mark.parametrize("spacing", [None, (1, 1, 1), (0.5, 0.35, 0.75)]) +@pytest.mark.parametrize( + "blob_kwargs", [{}, dict(blob_size_fraction=0.12, volume_fraction=0.3)] +) +def test_area_convex_and_solidity(ndim, spacing, blob_kwargs): + shape = (256, 512) if ndim == 2 else (64, 64, 80) + labels = get_labels_nd(shape, **blob_kwargs) + # discard any extra dimensions from spacing + if spacing is not None: + spacing = spacing[:ndim] + + max_label = int(cp.max(labels)) + area = regionprops_area( + labels, + spacing=spacing, + max_label=max_label, + ) + _, _, images_convex = regionprops_image( + labels, + max_label=max_label, + compute_convex=True, + ) + area_convex = regionprops_area_convex( + images_convex, + max_label=max_label, + spacing=spacing, + ) + solidity = area / area_convex + + # suppress any QHull warnings coming from the scikit-image implementation + warnings.filterwarnings( + "ignore", + message="Failed to get convex hull image", + category=UserWarning, + ) + warnings.filterwarnings( + "ignore", + message="divide by zero", + category=RuntimeWarning, + ) + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), + spacing=spacing, + properties=["area", "area_convex", "solidity"], + ) + warnings.resetwarnings() + + assert_allclose(area, expected["area"]) + + # Note if 3d blobs are size 1 on one of the axes, it can cause QHull to + # fail and return a zeros convex image for that label. This has been + # resolved for cuCIM, but not yet for scikit-image. + # The test case with blob_kwargs != {} was chosen as a known good + # setting where such an edge case does NOT occur. + if blob_kwargs: + assert_allclose(area_convex, expected["area_convex"]) + assert_allclose(solidity, expected["solidity"]) + else: + # Can't compare to scikit-image in this case + # Just make sure the convex area is not smaller than the original + rtol = 1e-4 + assert cp.all(area_convex >= (area - rtol)) + assert not cp.any(cp.isnan(solidity)) + + @pytest.mark.parametrize("ndim", [2, 3]) @pytest.mark.parametrize("area_dtype", [cp.float32, cp.float64]) @pytest.mark.parametrize("spacing", [None, (0.5, 0.35, 0.75)]) @@ -448,10 +518,68 @@ def test_coords(ndim, spacing): ) +@pytest.mark.parametrize("ndim", [2, 3]) +@pytest.mark.parametrize("spacing", [None, (1, 1, 1), (0.5, 0.35, 0.75)]) +@pytest.mark.parametrize( + "blob_kwargs", [dict(blob_size_fraction=0.15, volume_fraction=0.1)] +) +def test_feret_diameter_max(ndim, spacing, blob_kwargs): + shape = (1024, 2048) if ndim == 2 else (64, 80, 48) + # use dilate blobs to avoid error from singleton dimension regions in + # scikit-image + labels = get_labels_nd(shape, dilate_blobs=ndim == 3, **blob_kwargs) + # discard any extra dimensions from spacing + if spacing is not None: + spacing = spacing[:ndim] + + max_label = int(cp.max(labels)) + _, _, images_convex = regionprops_image( + labels, + max_label=max_label, + compute_convex=True, + ) + feret_diameters = regionprops_feret_diameter_max( + images_convex, + spacing=spacing, + ) + + # suppress any QHull warnings coming from the scikit-image implementation + warnings.filterwarnings( + "ignore", + message="Failed to get convex hull image", + category=UserWarning, + ) + warnings.filterwarnings( + "ignore", + message="divide by zero", + category=RuntimeWarning, + ) + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), + spacing=spacing, + properties=["num_pixels", "feret_diameter_max"], + ) + warnings.resetwarnings() + + # print(f"{ndim=}, {spacing=}, {max_label=}") + # print(f"num_pixels={expected['num_pixels']}") + # print(f"diameters={expected['feret_diameter_max']}") + max_diff = np.max( + np.abs(feret_diameters.get() - expected["feret_diameter_max"]) + ) + # print(f"max_diff = {max_diff}") + assert max_diff < math.sqrt(ndim) + + @pytest.mark.parametrize("ndim", [2, 3]) @pytest.mark.parametrize("spacing", [None, (1.5, 0.5, 0.76)]) @pytest.mark.parametrize( - "property_name", list(basic_deps.keys()) + list(intensity_deps.keys()) + "property_name", + ( + list(basic_deps.keys()) + + list(convex_deps.keys()) + + list(intensity_deps.keys()) + ), ) def test_regionprops_dict_single_property(ndim, spacing, property_name): """Test to verify that any dependencies for a given property are From ae2d29ab15673efbf07f407c7a672b3376283782 Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Mon, 3 Mar 2025 10:36:03 -0500 Subject: [PATCH 07/14] update feret_diameter_max output dtype --- .../cucim/skimage/measure/_regionprops_gpu_convex.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_convex.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_convex.py index aa95d7e8..890f0a01 100644 --- a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_convex.py +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_convex.py @@ -92,15 +92,16 @@ def _feret_diameter_max(image_convex, spacing=None, return_argmax=False): spacing = cp.asarray(spacing, dtype=cp.float32).reshape(1, -1) coords *= spacing - out = pdist_max_blockwise( + squared_dist, index_argmax = pdist_max_blockwise( coords, metric="sqeuclidean", compute_argmax=return_argmax, coords_per_block=4000, ) + max_diameter = math.sqrt(float(squared_dist)) if return_argmax: - return math.sqrt(out[0]), out[1] - return math.sqrt(out[0]) + return max_diameter, index_argmax + return max_diameter def regionprops_feret_diameter_max( @@ -140,7 +141,8 @@ def regionprops_feret_diameter_max( image_convex, spacing=spacing, return_argmax=False ) for image_convex in images_convex - ) + ), + dtype=cp.float64, ) if props_dict is not None: props_dict["feret_diameter_max"] = diameters From f73f779fc8d810eeadc4f4bd3cc29cbd87a5642d Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Mon, 3 Mar 2025 11:13:47 -0500 Subject: [PATCH 08/14] dtype fix for temp output array in pdist_max_blockwise --- python/cucim/src/cucim/skimage/_shared/distance.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/cucim/src/cucim/skimage/_shared/distance.py b/python/cucim/src/cucim/skimage/_shared/distance.py index f051977c..127b7aeb 100644 --- a/python/cucim/src/cucim/skimage/_shared/distance.py +++ b/python/cucim/src/cucim/skimage/_shared/distance.py @@ -128,12 +128,14 @@ def pdist_max_blockwise( ) blocks_per_dim = math.ceil(num_coords / coords_per_block) + if coords.dtype not in [xp.float32, xp.float64]: + coords = coords.astype(xp.float32, copy=False) if blocks_per_dim > 1: # reuse the same temporary storage array for most blocks # (last block in row and column may be smaller) - temp = xp.zeros((coords_per_block, coords_per_block), dtype=xp.float32) - if coords.dtype not in [xp.float32, xp.float64]: - coords = coords.astype(xp.float32, copy=False) + temp = xp.zeros( + (coords_per_block, coords_per_block), dtype=coords.dtype + ) if not coords.flags.c_contiguous: coords = xp.ascontiguousarray(coords) max_dist = 0 From 9901ad4e9c0dec397577ab7d249cca76bb122687 Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Mon, 3 Mar 2025 11:53:43 -0500 Subject: [PATCH 09/14] always use float64 dtype within pdist_max_blockwise --- python/cucim/src/cucim/skimage/_shared/distance.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/cucim/src/cucim/skimage/_shared/distance.py b/python/cucim/src/cucim/skimage/_shared/distance.py index 127b7aeb..dbef09ca 100644 --- a/python/cucim/src/cucim/skimage/_shared/distance.py +++ b/python/cucim/src/cucim/skimage/_shared/distance.py @@ -128,8 +128,7 @@ def pdist_max_blockwise( ) blocks_per_dim = math.ceil(num_coords / coords_per_block) - if coords.dtype not in [xp.float32, xp.float64]: - coords = coords.astype(xp.float32, copy=False) + coords = coords.astype(xp.float64, copy=False) if blocks_per_dim > 1: # reuse the same temporary storage array for most blocks # (last block in row and column may be smaller) From dd68d0dd46a7cebff46475e84bd835a04ec247c3 Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Sat, 1 Mar 2025 15:07:07 -0500 Subject: [PATCH 10/14] add computation of many properties based on image moments --- .../src/cucim/skimage/measure/_regionprops.py | 49 +- .../cucim/skimage/measure/_regionprops_gpu.py | 205 +- .../_regionprops_gpu_moments_kernels.py | 1958 +++++++++++++++++ .../tests/test_regionprops_gpu_kernels.py | 584 +++++ 4 files changed, 2745 insertions(+), 51 deletions(-) create mode 100644 python/cucim/src/cucim/skimage/measure/_regionprops_gpu_moments_kernels.py diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops.py b/python/cucim/src/cucim/skimage/measure/_regionprops.py index 9678e3ac..4248ed3b 100644 --- a/python/cucim/src/cucim/skimage/measure/_regionprops.py +++ b/python/cucim/src/cucim/skimage/measure/_regionprops.py @@ -258,51 +258,6 @@ def func2d(self, *args, **kwargs): return func2d -def _inertia_eigvals_to_axes_lengths_3D(inertia_tensor_eigvals): - """Compute ellipsoid axis lengths from inertia tensor eigenvalues. - - Parameters - --------- - inertia_tensor_eigvals : sequence of float - A sequence of 3 floating point eigenvalues, sorted in descending order. - - Returns - ------- - axis_lengths : list of float - The ellipsoid axis lengths sorted in descending order. - - Notes - ----- - Let a >= b >= c be the ellipsoid semi-axes and s1 >= s2 >= s3 be the - inertia tensor eigenvalues. - - The inertia tensor eigenvalues are given for a solid ellipsoid in [1]_. - s1 = 1 / 5 * (a**2 + b**2) - s2 = 1 / 5 * (a**2 + c**2) - s3 = 1 / 5 * (b**2 + c**2) - - Rearranging to solve for a, b, c in terms of s1, s2, s3 gives - a = math.sqrt(5 / 2 * ( s1 + s2 - s3)) - b = math.sqrt(5 / 2 * ( s1 - s2 + s3)) - c = math.sqrt(5 / 2 * (-s1 + s2 + s3)) - - We can then simply replace sqrt(5/2) by sqrt(10) to get the full axes - lengths rather than the semi-axes lengths. - - References - ---------- - ..[1] https://en.wikipedia.org/wiki/List_of_moments_of_inertia#List_of_3D_inertia_tensors - """ # noqa: E501 - axis_lengths = [] - for ax in range(2, -1, -1): - w = sum( - v * -1 if i == ax else v - for i, v in enumerate(inertia_tensor_eigvals) - ) - axis_lengths.append(math.sqrt(10 * w)) - return axis_lengths - - class RegionProperties: """Please refer to `skimage.measure.regionprops` for more information on the available region properties. @@ -625,7 +580,9 @@ def axis_minor_length(self): elif self._ndim == 3: # equivalent to _inertia_eigvals_to_axes_lengths_3D(ev)[-1] ev = self.inertia_tensor_eigvals - return math.sqrt(10 * (-ev[0] + ev[1] + ev[2])) + # use max to avoid possibly very small negative value due to + # numeric error + return math.sqrt(max(10 * (-ev[0] + ev[1] + ev[2]), 0.0)) else: raise ValueError("axis_minor_length only available in 2D and 3D") diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py index f4da492d..5d6d81b6 100644 --- a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py @@ -34,6 +34,19 @@ regionprops_intensity_min_max, regionprops_intensity_std, ) +from ._regionprops_gpu_moments_kernels import ( + moment_deps, + regionprops_centroid, + regionprops_centroid_local, + regionprops_centroid_weighted, + regionprops_inertia_tensor, + regionprops_inertia_tensor_eigvals, + regionprops_moments, + regionprops_moments_central, + regionprops_moments_hu, + regionprops_moments_normalized, + required_order, +) from ._regionprops_gpu_utils import _get_min_integer_dtype __all__ = [ @@ -42,14 +55,24 @@ "regionprops_area_bbox", "regionprops_area_convex", "regionprops_bbox_coords", + "regionprops_centroid", + "regionprops_centroid_local", + "regionprops_centroid_weighted", "regionprops_coords", "regionprops_dict", "regionprops_extent", "regionprops_feret_diameter_max", "regionprops_image", + "regionprops_inertia_tensor", + "regionprops_inertia_tensor_eigvals", "regionprops_intensity_mean", "regionprops_intensity_min_max", "regionprops_intensity_std", + "regionprops_moments", + "regionprops_moments_central", + "regionprops_moments_hu", + "regionprops_moments_normalized", + "regionprops_num_pixels", # extra functions for cuCIM not currently in scikit-image "equivalent_spherical_perimeter", # as in ITK "regionprops_num_boundary_pixels", @@ -68,6 +91,8 @@ PROPS_GPU = copy(PROPS) # extra properties not currently in scikit-image PROPS_GPU_EXTRA = { + "axis_lengths": "axis_lengths", + "inertia_tensor_eigenvectors": "inertia_tensor_eigenvectors", "num_pixels_filled": "num_pixels_filled", # a few extra parameters as in ITK "num_perimeter_pixels": "num_perimeter_pixels", @@ -80,6 +105,8 @@ CURRENT_PROPS_GPU = set(PROPS_GPU.values()) COL_DTYPES_EXTRA = { + "axis_lengths": float, + "inertia_tensor_eigenvectors": float, "num_pixels_filled": int, "num_perimeter_pixels": int, "num_boundary_pixels": int, @@ -110,9 +137,7 @@ property_deps.update(basic_deps) property_deps.update(convex_deps) property_deps.update(intensity_deps) - -# set of properties that only supports 2D images -ndim_2_only = set() +property_deps.update(moment_deps) def get_property_dependencies(dependencies, node): @@ -140,7 +165,40 @@ def depth_first_search(n): } # set of properties that require an intensity_image also be provided -need_intensity_image = set(intensity_deps.keys()) | {"image_intensity"} +need_intensity_image = ( + set(intensity_deps.keys()) + | {"image_intensity"} + | set(p for p in CURRENT_PROPS_GPU if "weighted" in p) +) + +# set of properties that can only be computed for 2D regions +ndim_2_only = { + "eccentricity", + "moments_hu", + "moments_weighted_hu", + "orientation", + "perimeter", + "perimeter_crofton", # could be updated to nD as in ITK +} + + +def _check_moment_order(moment_order: int | None, requested_moment_props: set): + """Helper function for input validation in `regionprops_dict`. + + Determines the minimum order required across all requested moment + properties and validates the `moment_order` kwarg. + """ + min_order_required = max(required_order[p] for p in requested_moment_props) + if moment_order is not None: + if moment_order < min_order_required: + raise ValueError( + f"can't compute {requested_moment_props} with moment_order<" + f"{min_order_required}, but {moment_order=} was specified." + ) + order = moment_order + else: + order = min_order_required + return order def regionprops_dict( @@ -149,6 +207,7 @@ def regionprops_dict( properties=[], *, spacing=None, + moment_order=None, max_label=None, pixels_per_thread=16, ): @@ -177,6 +236,11 @@ def regionprops_dict( Extra Parameters ---------------- + moment_order : int or None + When computing moment properties, only moments up to this order are + computed. The default value of None results in the minimum order + required in order to compute the requested properties. For example, + properties based on the inertia_tensor require moment_order >= 2. max_label : int or None The maximum label value. If not provided it will be computed from `label_image`. @@ -233,7 +297,7 @@ def regionprops_dict( if any(invalid_names): raise ValueError( f"{label_image.ndim=}, but the following properties are for " - "2D label images only: {invalid_names}" + f"2D label images only: {invalid_names}" ) if intensity_image is None: has_intensity = False @@ -368,6 +432,137 @@ def regionprops_dict( out["num_boundary_pixels"] / out["num_perimeter_pixels"] ) + compute_unweighted_moments = "moments" in required_props + compute_weighted_moments = "moments_weighted" in required_props + compute_moments = compute_unweighted_moments or compute_weighted_moments + compute_inertia_tensor = "inertia_tensor" in required_props + + if compute_moments: + required_moment_props = set(moment_deps.keys()) & required_props + # determine minimum necessary order (or validate the user-provided one) + order = _check_moment_order(moment_order, required_moment_props) + + imgs = [] + if compute_unweighted_moments: + imgs.append(None) + if compute_weighted_moments: + imgs.append(intensity_image) + + # compute raw moments (weighted and/or unweighted) + for img in imgs: + regionprops_moments( + label_image, + intensity_image=img, + max_label=max_label, + order=order, + spacing=spacing, + **perf_kwargs, + props_dict=out, + ) + + compute_centroid_local = ( + "centroid_local" in required_moment_props + ) # noqa:E501 + compute_centroid = "centroid" in required_moment_props + if compute_centroid or compute_centroid_local: + regionprops_centroid_weighted( + moments_raw=out["moments"], + ndim=label_image.ndim, + bbox=out["bbox"], + compute_local=compute_centroid_local, + compute_global=compute_centroid, + weighted=False, + props_dict=out, + ) + + compute_centroid_weighted_local = ( + "centroid_weighted_local" in required_moment_props + ) # noqa: E501 + compute_centroid_weighted = ( + "centroid_weighted" in required_moment_props + ) # noqa: E501 + if compute_centroid_weighted or compute_centroid_weighted_local: + regionprops_centroid_weighted( + moments_raw=out["moments_weighted"], + ndim=label_image.ndim, + bbox=out["bbox"], + compute_local=compute_centroid_weighted_local, + compute_global=compute_centroid_weighted, + weighted=True, + props_dict=out, + ) + + if "moments_central" in required_moment_props: + regionprops_moments_central( + out["moments"], ndim=ndim, weighted=False, props_dict=out + ) + + if "moments_normalized" in required_moment_props: + regionprops_moments_normalized( + out["moments_central"], + ndim=ndim, + spacing=None, + pixel_correction=False, + weighted=False, + props_dict=out, + ) + if "moments_hu" in required_moment_props: + regionprops_moments_hu( + out["moments_normalized"], + weighted=False, + props_dict=out, + ) + + if "moments_weighted_central" in required_moment_props: + regionprops_moments_central( + out["moments_weighted"], ndim, weighted=True, props_dict=out + ) + + if "moments_weighted_normalized" in required_moment_props: + regionprops_moments_normalized( + out["moments_weighted_central"], + ndim=ndim, + spacing=None, + pixel_correction=False, + weighted=True, + props_dict=out, + ) + + if "moments_weighted_hu" in required_moment_props: + regionprops_moments_hu( + out["moments_weighted_normalized"], + weighted=True, + props_dict=out, + ) + + # inertia tensor computations come after moment computations + if compute_inertia_tensor: + regionprops_inertia_tensor( + out["moments_central"], + ndim=ndim, + compute_orientation=( + "orientation" in required_moment_props + ), # noqa: E501 + props_dict=out, + ) + + if "inertia_tensor_eigvals" in required_moment_props: + compute_axis_lengths = ( + "axis_minor_length" in required_moment_props + or "axis_major_length" in required_moment_props + ) + regionprops_inertia_tensor_eigvals( + out["inertia_tensor"], + compute_axis_lengths=compute_axis_lengths, + compute_eccentricity=( + "eccentricity" in required_moment_props + ), + compute_eigenvectors=( + "inertia_tensor_eigenvectors" in required_moment_props + ), + props_dict=out, + ) + compute_images = "image" in required_props compute_intensity_images = "image_intensity" in required_props compute_convex = "image_convex" in required_props diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_moments_kernels.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_moments_kernels.py new file mode 100644 index 00000000..e8a10602 --- /dev/null +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_moments_kernels.py @@ -0,0 +1,1958 @@ +import math + +import cupy as cp + +from ._regionprops_gpu_basic_kernels import get_bbox_coords_kernel +from ._regionprops_gpu_utils import ( + _check_intensity_image_shape, + _includes, + _unravel_loop_index, + _unravel_loop_index_declarations, +) + +__all__ = [ + "regionprops_centroid", + "regionprops_centroid_local", + "regionprops_centroid_weighted", + "regionprops_inertia_tensor", + "regionprops_inertia_tensor_eigvals", + "regionprops_moments", + "regionprops_moments_central", + "regionprops_moments_hu", + "regionprops_moments_normalized", +] + + +# Store information on which other properties a given property depends on +# This information will be used by `regionprops_dict` to make sure that when +# a particular property is requested any dependent properties are computed +# first. +moment_deps = dict() +moment_deps["moments"] = ["bbox"] +moment_deps["moments_weighted"] = ["bbox"] +moment_deps["eccentricity"] = ["inertia_tensor_eigvals"] +moment_deps["axis_major_length"] = ["inertia_tensor_eigvals"] +moment_deps["axis_minor_length"] = ["inertia_tensor_eigvals"] +moment_deps["inertia_tensor_eigenvectors"] = ["inertia_tensor_eigvals"] +moment_deps["inertia_tensor_eigvals"] = ["inertia_tensor"] +moment_deps["orientation"] = ["inertia_tensor"] +moment_deps["moments_hu"] = ["moments_normalized"] +moment_deps["moments_normalized"] = ["moments_central"] +moment_deps["inertia_tensor"] = ["moments_central"] +moment_deps["moments_central"] = ["moments"] +moment_deps["centroid"] = ["moments"] +moment_deps["centroid_local"] = ["moments"] +moment_deps["moments_weighted_central"] = ["moments_weighted"] +moment_deps["moments_weighted_normalized"] = ["moments_weighted_central"] +moment_deps["moments_weighted_hu"] = ["moments_weighted_normalized"] +moment_deps["centroid_weighted"] = ["moments_weighted"] +moment_deps["centroid_weighted_local"] = ["moments_weighted"] + + +# The minimum moment "order" required to compute each property +required_order = { + "centroid": 1, + "centroid_local": 1, + "centroid_weighted": 1, + "centroid_weighted_local": 1, + "axis_major_length": 2, + "axis_minor_length": 2, + "eccentricity": 2, + "inertia_tensor": 2, + "inertia_tensor_eigvals": 2, + "inertia_tensor_eigenvectors": 2, + "moments": 2, + "moments_central": 2, + "moments_normalized": 2, + "moments_weighted": 2, + "moments_weighted_central": 2, + "moments_weighted_normalized": 2, + "orientation": 2, + "moments_hu": 3, + "moments_weighted_hu": 3, +} + + +def regionprops_centroid( + label_image, + max_label=None, + pixels_per_thread=16, + props_dict=None, +): + """Compute the centroid of each labeled region in the image. + + reuses "num_pixels" from previously computed properties if present + + writes "centroid" to `props_dict` + + Returns + ------- + centroid : cp.ndarray + The centroid of each region. + """ + if max_label is None: + max_label = int(label_image.max()) + ndim = label_image.ndim + + int32_coords = max(label_image.shape) < 2**32 + if props_dict is not None and "num_pixels" in props_dict: + centroid_counts = props_dict["num_pixels"] + if centroid_counts.dtype != cp.uint32: + centroid_counts = centroid_counts.astype(cp.uint32) + compute_num_pixels = False + else: + centroid_counts = cp.zeros((max_label,), dtype=cp.uint32) + compute_num_pixels = True + + bbox_coords_kernel = get_bbox_coords_kernel( + ndim=label_image.ndim, + int32_coords=int32_coords, + compute_bbox=False, + compute_num_pixels=compute_num_pixels, + compute_coordinate_sums=True, + pixels_per_thread=pixels_per_thread, + ) + + centroid_sums = cp.zeros((max_label, ndim), dtype=cp.uint64) + + # make a copy if the inputs are not already C-contiguous + if not label_image.flags.c_contiguous: + label_image = cp.ascontiguousarray(label_image) + + if compute_num_pixels: + outputs = (centroid_counts, centroid_sums) + else: + outputs = centroid_sums + bbox_coords_kernel( + label_image, + label_image.size, + *outputs, + size=math.ceil(label_image.size / pixels_per_thread), + ) + + centroid = centroid_sums / centroid_counts[:, cp.newaxis] + if props_dict is not None: + props_dict["centroid"] = centroid + if "num_pixels" not in props_dict: + props_dict["num_pixels"] = centroid_counts + return centroid + + +@cp.memoize(for_each_device=True) +def get_centroid_local_kernel(coord_dtype, ndim): + """Keep this kernel for n-dimensional support as the raw_moments kernels + currently only support 2D and 3D data. + """ + coord_dtype = cp.dtype(coord_dtype) + sum_dtype = cp.dtype(cp.uint64) + count_dtype = cp.dtype(cp.uint32) + uint_t = ( + "unsigned int" if coord_dtype.itemsize <= 4 else "unsigned long long" + ) + + source = """ + auto L = label[i]; + if (L != 0) {""" + source += _unravel_loop_index("label", ndim, uint_t=uint_t) + for d in range(ndim): + source += f""" + atomicAdd(¢roid_sums[(L - 1) * {ndim} + {d}], + in_coord[{d}] - bbox[(L - 1) * {2 * ndim} + {d}]); + """ + source += """ + atomicAdd(¢roid_counts[L - 1], 1); + }\n""" + inputs = f"raw X label, raw {coord_dtype.name} bbox" + outputs = f"raw {count_dtype.name} centroid_counts, " + outputs += f"raw {sum_dtype.name} centroid_sums" + name = f"cucim_centroid_local_{ndim}d_{coord_dtype.name}" + return cp.ElementwiseKernel( + inputs, outputs, source, preamble=_includes, name=name + ) + + +def regionprops_centroid_local( + label_image, + max_label=None, + coord_dtype=cp.uint32, + pixels_per_thread=16, + props_dict=None, +): + """Compute the central moments of the labeled regions. + + dimensions supported: nD + + reuses "moments" from previously computed properties if present + reuses "bbox" from previously computed properties if present + + writes "centroid_local" to `props_dict` + writes "bbox" to `props_dict` if it was not already present + writes "num_pixels" to `props_dict` if it was not already present + + Parameters + ---------- + label_image : cp.ndarray + Image containing labels where 0 is the background and sequential + values > 0 are the labels. + max_label : int or None + The maximum label value present in label_image. Will be computed if not + provided. + coord_dtype : dtype, optional + The data type to use for coordinate calculations. Should be + ``cp.uint32`` or ``cp.uint64``. + + Returns + ------- + counts : cp.ndarray + The number of samples in each region. + centroid_local : cp.ndarray + The local centroids + + Notes + ----- + The centroid could also be extracted from the raw moments + computed via `regionprops_moments`. That will be more efficient than + running this separate function if additional moment-based properties + are also needed. + + This function is also useful for data with more than 3 dimensions as + regionprops_moments currently only supports 2D and 3D data. + """ + if props_dict is None: + props_dict = {} + if max_label is None: + max_label = int(label_image.max()) + + int32_coords = max(label_image.shape) < 2**32 + coord_dtype = cp.dtype(cp.uint32 if int32_coords else cp.uint64) + + ndim = label_image.ndim + + if "moments" in props_dict and ndim in [2, 3]: + # already have the moments needed in previously computed properties + moments = props_dict["moments"] + # can't compute if only zeroth moment is present + if moments.shape[-1] > 1: + centroid_local = cp.empty((max_label, ndim), dtype=moments.dtype) + if ndim == 2: + m0 = moments[:, 0, 0] + centroid_local[:, 0] = moments[:, 1, 0] / m0 + centroid_local[:, 1] = moments[:, 0, 1] / m0 + else: + m0 = moments[:, 0, 0, 0] + centroid_local[:, 0] = moments[:, 1, 0, 0] / m0 + centroid_local[:, 1] = moments[:, 0, 1, 0] / m0 + centroid_local[:, 2] = moments[:, 0, 0, 1] / m0 + props_dict["centroid_local"] = centroid_local + return centroid_local + + if "bbox" in props_dict: + # reuse previously computed bounding box coordinates + bbox_coords = props_dict["bbox"] + if bbox_coords.dtype != coord_dtype: + bbox_coords = bbox_coords.astype(coord_dtype) + + else: + bbox_coords_kernel = get_bbox_coords_kernel( + ndim=label_image.ndim, + int32_coords=int32_coords, + compute_bbox=True, + compute_num_pixels=False, + compute_coordinate_sums=False, + pixels_per_thread=pixels_per_thread, + ) + + bbox_coords = cp.zeros((max_label, 2 * ndim), dtype=coord_dtype) + + # Initialize value for atomicMin on first ndim coordinates + # The value for atomicMax columns is already 0 as desired. + bbox_coords[:, :ndim] = cp.iinfo(coord_dtype).max + + # make a copy if the inputs are not already C-contiguous + if not label_image.flags.c_contiguous: + label_image = cp.ascontiguousarray(label_image) + + bbox_coords_kernel( + label_image, + label_image.size, + bbox_coords, + size=math.ceil(label_image.size / pixels_per_thread), + ) + if "bbox" not in props_dict: + props_dict["bbox"] = bbox_coords + + counts = cp.zeros((max_label,), dtype=cp.uint32) + centroids_sums = cp.zeros((max_label, ndim), dtype=cp.uint64) + centroid_local_kernel = get_centroid_local_kernel( + coord_dtype, label_image.ndim + ) + centroid_local_kernel( + label_image, bbox_coords, counts, centroids_sums, size=label_image.size + ) + + centroid_local = centroids_sums / counts[:, cp.newaxis] + props_dict["centroid_local"] = centroid_local + if "num_pixels" not in props_dict: + props_dict["num_pixels"] = counts + return centroid_local + + +def _get_raw_moments_code( + coord_c_type, + moments_c_type, + ndim, + order, + array_size, + num_channels=1, + has_spacing=False, + has_weights=False, +): + """ + Notes + ----- + Local variables created: + + - local_moments : shape (array_size, num_channels, num_moments) + local set of moments up to the specified order (1-3 supported) + + Output variables written to: + + - moments : shape (max_label, num_channels, num_moments) + """ + + # number is for a densely populated moments matrix of size (order + 1) per + # side (values at locations where order is greater than specified will be 0) + num_moments = (order + 1) ** ndim + + if order > 3: + raise ValueError("Only moments of orders 0-3 are supported") + + use_floating_point = moments_c_type in ["float", "double"] + + source_pre = f""" + {moments_c_type} local_moments[{array_size*num_channels*num_moments}] = {{0}}; + {coord_c_type} m_offset = 0; + {coord_c_type} local_off = 0;\n""" # noqa: E501 + if has_weights: + source_pre += f""" + {moments_c_type} w = 0.0;\n""" + + # op uses external coordinate array variables: + # bbox : bounding box coordinates, shape (max_label, 2*ndim) + # in_coord[0]...in_coord[ndim - 1] : coordinates + # coordinates in the labeled image at the current index + # ii : index into labels array + # current_label : value of the label image at location ii + # spacing (optional) : pixel spacings + # img (optional) : intensity image + source_operation = "" + # using bounding box to transform the global coordinates to local ones + # (c0 = local coordinate on axis 0, etc.) + for d in range(ndim): + source_operation += f""" + {moments_c_type} c{d} = in_coord[{d}] + - bbox[(current_label - 1) * {2 * ndim} + {d}];""" + if has_spacing: + source_operation += f""" + c{d} *= spacing[{d}];""" + + # need additional multiplication by the intensity value for weighted case + w = "w * " if has_weights else "" + for c in range(num_channels): + source_operation += f""" + local_off = {num_moments*num_channels}*offset + {c * num_moments};\n""" # noqa: E501 + + # zeroth moment + if has_weights: + source_operation += f""" + w = static_cast<{moments_c_type}>(img[{num_channels} * ii + {c}]); + local_moments[local_off] += w;\n""" # noqa: E501 + elif use_floating_point: + source_operation += """ + local_moments[local_off] += 1.0;\n""" + else: + source_operation += """ + local_moments[local_off] += 1;\n""" + + # moments for order 1-3 + if ndim == 2: + if order == 1: + source_operation += f""" + local_moments[local_off + 1] += {w}c1; + local_moments[local_off + 2] += {w}c0;\n""" + elif order == 2: + source_operation += f""" + local_moments[local_off + 1] += {w}c1; + local_moments[local_off + 2] += {w}c1 * c1; + local_moments[local_off + 3] += {w}c0; + local_moments[local_off + 4] += {w}c0 * c1; + local_moments[local_off + 6] += {w}c0 * c0;\n""" + elif order == 3: + source_operation += f""" + local_moments[local_off + 1] += {w}c1; + local_moments[local_off + 2] += {w}c1 * c1; + local_moments[local_off + 3] += {w}c1 * c1 * c1; + local_moments[local_off + 4] += {w}c0; + local_moments[local_off + 5] += {w}c0 * c1; + local_moments[local_off + 6] += {w}c0 * c1 * c1; + local_moments[local_off + 8] += {w}c0 * c0; + local_moments[local_off + 9] += {w}c0 * c0 * c1; + local_moments[local_off + 12] += {w}c0 * c0 * c0;\n""" + elif ndim == 3: + if order == 1: + source_operation += f""" + local_moments[local_off + 1] += {w}c2; + local_moments[local_off + 2] += {w}c1; + local_moments[local_off + 4] += {w}c0;\n""" + elif order == 2: + source_operation += f""" + local_moments[local_off + 1] += {w}c2; + local_moments[local_off + 2] += {w}c2 * c2; + local_moments[local_off + 3] += {w}c1; + local_moments[local_off + 4] += {w}c1 * c2; + local_moments[local_off + 6] += {w}c1 * c1; + local_moments[local_off + 9] += {w}c0; + local_moments[local_off + 10] += {w}c0 * c2; + local_moments[local_off + 12] += {w}c0 * c1; + local_moments[local_off + 18] += {w}c0 * c0;\n""" + elif order == 3: + source_operation += f""" + local_moments[local_off + 1] += {w}c2; + local_moments[local_off + 2] += {w}c2 * c2; + local_moments[local_off + 3] += {w}c2 * c2 * c2; + local_moments[local_off + 4] += {w}c1; + local_moments[local_off + 5] += {w}c1 * c2; + local_moments[local_off + 6] += {w}c1 * c2 * c2; + local_moments[local_off + 8] += {w}c1 * c1; + local_moments[local_off + 9] += {w}c1 * c1 * c2; + local_moments[local_off + 12] += {w}c1 * c1 * c1; + local_moments[local_off + 16] += {w}c0; + local_moments[local_off + 17] += {w}c0 * c2; + local_moments[local_off + 18] += {w}c0 * c2 * c2; + local_moments[local_off + 20] += {w}c0 * c1; + local_moments[local_off + 21] += {w}c0 * c1 * c2; + local_moments[local_off + 24] += {w}c0 * c1 * c1; + local_moments[local_off + 32] += {w}c0 * c0; + local_moments[local_off + 33] += {w}c0 * c0 * c2; + local_moments[local_off + 36] += {w}c0 * c0 * c1; + local_moments[local_off + 48] += {w}c0 * c0 * c0;\n""" + else: + raise ValueError("only ndim = 2 or 3 is supported") + + # post_operation uses external variables: + # ii : index into num_pixels array + # lab : label value that corresponds to location ii + # coord_sums : output with shape (max_label, ndim) + source_post = "" + for c in range(0, num_channels): + source_post += f""" + // moments outputs + m_offset = {num_moments*num_channels}*(lab - 1) + {c * num_moments}; + local_off = {num_moments*num_channels}*ii + {c * num_moments}; + atomicAdd(&moments[m_offset], local_moments[local_off]);\n""" # noqa: E501 + + if ndim == 2: + if order == 1: + source_post += """ + atomicAdd(&moments[m_offset + 1], local_moments[local_off + 1]); + atomicAdd(&moments[m_offset + 2], local_moments[local_off + 2]);\n""" # noqa: E501 + elif order == 2: + source_post += """ + atomicAdd(&moments[m_offset + 1], local_moments[local_off + 1]); + atomicAdd(&moments[m_offset + 2], local_moments[local_off + 2]); + atomicAdd(&moments[m_offset + 3], local_moments[local_off + 3]); + atomicAdd(&moments[m_offset + 4], local_moments[local_off + 4]); + atomicAdd(&moments[m_offset + 6], local_moments[local_off + 6]);\n""" # noqa: E501 + elif order == 3: + source_post += """ + atomicAdd(&moments[m_offset + 1], local_moments[local_off + 1]); + atomicAdd(&moments[m_offset + 2], local_moments[local_off + 2]); + atomicAdd(&moments[m_offset + 3], local_moments[local_off + 3]); + atomicAdd(&moments[m_offset + 4], local_moments[local_off + 4]); + atomicAdd(&moments[m_offset + 5], local_moments[local_off + 5]); + atomicAdd(&moments[m_offset + 6], local_moments[local_off + 6]); + atomicAdd(&moments[m_offset + 8], local_moments[local_off + 8]); + atomicAdd(&moments[m_offset + 9], local_moments[local_off + 9]); + atomicAdd(&moments[m_offset + 12], local_moments[local_off + 12]);\n""" # noqa: E501 + elif ndim == 3: + if order == 1: + source_post += """ + atomicAdd(&moments[m_offset + 1], local_moments[local_off + 1]); + atomicAdd(&moments[m_offset + 2], local_moments[local_off + 2]); + atomicAdd(&moments[m_offset + 4], local_moments[local_off + 4]);\n""" # noqa: E501 + elif order == 2: + source_post += """ + atomicAdd(&moments[m_offset + 1], local_moments[local_off + 1]); + atomicAdd(&moments[m_offset + 2], local_moments[local_off + 2]); + atomicAdd(&moments[m_offset + 3], local_moments[local_off + 3]); + atomicAdd(&moments[m_offset + 4], local_moments[local_off + 4]); + atomicAdd(&moments[m_offset + 6], local_moments[local_off + 6]); + atomicAdd(&moments[m_offset + 9], local_moments[local_off + 9]); + atomicAdd(&moments[m_offset + 10], local_moments[local_off + 10]); + atomicAdd(&moments[m_offset + 12], local_moments[local_off + 12]); + atomicAdd(&moments[m_offset + 18], local_moments[local_off + 18]);\n""" # noqa: E501 + elif order == 3: + source_post += """ + atomicAdd(&moments[m_offset + 1], local_moments[local_off + 1]); + atomicAdd(&moments[m_offset + 2], local_moments[local_off + 2]); + atomicAdd(&moments[m_offset + 3], local_moments[local_off + 3]); + atomicAdd(&moments[m_offset + 4], local_moments[local_off + 4]); + atomicAdd(&moments[m_offset + 5], local_moments[local_off + 5]); + atomicAdd(&moments[m_offset + 6], local_moments[local_off + 6]); + atomicAdd(&moments[m_offset + 8], local_moments[local_off + 8]); + atomicAdd(&moments[m_offset + 9], local_moments[local_off + 9]); + atomicAdd(&moments[m_offset + 12], local_moments[local_off + 12]); + atomicAdd(&moments[m_offset + 16], local_moments[local_off + 16]); + atomicAdd(&moments[m_offset + 17], local_moments[local_off + 17]); + atomicAdd(&moments[m_offset + 18], local_moments[local_off + 18]); + atomicAdd(&moments[m_offset + 20], local_moments[local_off + 20]); + atomicAdd(&moments[m_offset + 21], local_moments[local_off + 21]); + atomicAdd(&moments[m_offset + 24], local_moments[local_off + 24]); + atomicAdd(&moments[m_offset + 32], local_moments[local_off + 32]); + atomicAdd(&moments[m_offset + 33], local_moments[local_off + 33]); + atomicAdd(&moments[m_offset + 36], local_moments[local_off + 36]); + atomicAdd(&moments[m_offset + 48], local_moments[local_off + 48]);\n""" # noqa: E501 + return source_pre, source_operation, source_post + + +@cp.memoize(for_each_device=True) +def get_raw_moments_kernel( + ndim, + order, + moments_dtype=cp.float64, + int32_coords=True, + spacing=None, + weighted=False, + num_channels=1, + pixels_per_thread=8, +): + moments_dtype = cp.dtype(moments_dtype) + + array_size = pixels_per_thread + + coord_dtype = cp.dtype(cp.uint32 if int32_coords else cp.uint64) + if coord_dtype.itemsize <= 4: + coord_c_type = "unsigned int" + else: + coord_c_type = "unsigned long long" + + use_floating_point = moments_dtype.kind == "f" + has_spacing = spacing is not None + if (weighted or has_spacing) and not use_floating_point: + raise ValueError( + "`moments_dtype` must be a floating point type for weighted " + "moments calculations or moment calculations using spacing." + ) + moments_c_type = "double" if use_floating_point else "unsigned long long" + if spacing is not None: + if len(spacing) != ndim: + raise ValueError("len(spacing) must equal len(shape)") + if moments_dtype.kind != "f": + raise ValueError("moments must have a floating point data type") + + moments_pre, moments_op, moments_post = _get_raw_moments_code( + coord_c_type=coord_c_type, + moments_c_type=moments_c_type, + ndim=ndim, + order=order, + array_size=array_size, + has_weights=weighted, + has_spacing=spacing is not None, + num_channels=num_channels, + ) + + # store only counts for label > 0 (label = 0 is the background) + source = f""" + uint64_t start_index = {pixels_per_thread}*i; + """ + source += moments_pre + + inner_op = "" + + source += _unravel_loop_index_declarations( + "labels", ndim, uint_t=coord_c_type + ) + + inner_op += _unravel_loop_index( + "labels", + ndim=ndim, + uint_t=coord_c_type, + raveled_index="ii", + omit_declarations=True, + ) + inner_op += moments_op + + source += f""" + X encountered_labels[{array_size}] = {{0}}; + X current_label; + X prev_label = labels[start_index]; + int offset = 0; + encountered_labels[0] = prev_label; + uint64_t ii_max = min(start_index + {pixels_per_thread}, labels_size); + for (uint64_t ii = start_index; ii < ii_max; ii++) {{ + current_label = labels[ii]; + if (current_label == 0) {{ continue; }} + if (current_label != prev_label) {{ + offset += 1; + prev_label = current_label; + encountered_labels[offset] = current_label; + }} + {inner_op} + }}""" + source += """ + for (size_t ii = 0; ii <= offset; ii++) { + X lab = encountered_labels[ii]; + if (lab != 0) {""" + + source += moments_post + source += """ + } + }\n""" + + # print(source) + inputs = ( + f"raw X labels, raw uint64 labels_size, raw {coord_dtype.name} bbox" + ) + if spacing: + inputs += ", raw float64 spacing" + if weighted: + inputs += ", raw Y img" + outputs = f"raw {moments_dtype.name} moments" + weighted_str = "_weighted" if weighted else "" + spacing_str = "_sp" if spacing else "" + name = f"cucim_moments{weighted_str}{spacing_str}_order{order}_{ndim}d" + name += f"_{coord_dtype.char}_{moments_dtype.char}_batch{pixels_per_thread}" + return cp.ElementwiseKernel( + inputs, outputs, source, preamble=_includes, name=name + ) + + +def regionprops_moments( + label_image, + intensity_image=None, + max_label=None, + order=2, + spacing=None, + pixels_per_thread=10, + props_dict=None, +): + """Compute the raw moments of the labeled regions. + + reuses "bbox" from previously computed properties if present + + writes "moments" to `props_dict` if `intensity_image` is not provided + writes "moments_weighted" to `props_dict` if `intensity_image` is provided + + Parameters + ---------- + label_image : cp.ndarray + Image containing labels where 0 is the background and sequential + values > 0 are the labels. + intensity_image : cp.ndarray, optional + Image of intensities. If provided, weighted moments are computed. If + this is a multi-channel image, moments are computed independently for + each channel. + max_label : int or None, optional + The maximum label value present in label_image. Will be computed if not + provided. + + Returns + ------- + moments : cp.ndarray + The moments up to the specified order. Will be stored in an + ``(order + 1, ) * ndim`` matrix where any elements corresponding to + order greater than that specified will be set to 0. For example, for + the 2D case, the last two axes represent the 2D moments matrix, ``M`` + where each matrix would have the following sizes and non-zero entries: + + ```py + # for a 2D image with order = 1 + M = [ + [m00, m01], + [m10, 0], + ] + + # for a 2D image with order = 2 + M = [ + [m00, m01, m02], + [m10, m11, 0], + [m20, 0, 0], + ] + + # for a 2D image with order = 3 + M = [ + [m00, m01, m02, m03], + [m10, m11, m12, 0], + [m20, m21, 0, 0], + [m30, 0, 0, 0], + ] + ``` + + When there is no `intensity_image` or the `intensity_image` is single + channel, the shape of the moments output is + ``shape = (max_label, ) + (order + 1, ) * ndim``. + When the ``intensity_image`` is multichannel a channel axis will be + present in the `moments` output at position 1 to give + ``shape = (max_label, ) + (num_channels, ) + (order + 1,) * ndim``. + """ + + if props_dict is None: + props_dict = {} + + if max_label is None: + max_label = int(label_image.max()) + + # make a copy if the inputs are not already C-contiguous + if not label_image.flags.c_contiguous: + label_image = cp.ascontiguousarray(label_image) + ndim = label_image.ndim + + int32_coords = max(label_image.shape) < 2**32 + coord_dtype = cp.dtype(cp.uint32 if int32_coords else cp.uint64) + if "bbox" in props_dict: + bbox_coords = props_dict["bbox"] + if bbox_coords.dtype != coord_dtype: + bbox_coords = bbox_coords.astype(coord_dtype) + else: + bbox_kernel = get_bbox_coords_kernel( + ndim=ndim, + int32_coords=int32_coords, + compute_bbox=True, + compute_num_pixels=False, + compute_coordinate_sums=False, + pixels_per_thread=pixels_per_thread, + ) + + bbox_coords = cp.zeros((max_label, 2 * ndim), dtype=coord_dtype) + + # Initialize value for atomicMin on first ndim coordinates + # The value for atomicMax columns is already 0 as desired. + bbox_coords[:, :ndim] = cp.iinfo(coord_dtype).max + + bbox_kernel( + label_image, + label_image.size, + bbox_coords, + size=math.ceil(label_image.size / pixels_per_thread), + ) + if props_dict is not None: + props_dict["bbox"] = bbox_coords + + moments_shape = (max_label,) + (order + 1,) * ndim + if intensity_image is not None: + if not intensity_image.flags.c_contiguous: + intensity_image = cp.ascontiguousarray(intensity_image) + + num_channels = _check_intensity_image_shape( + label_image, intensity_image + ) + if num_channels > 1: + moments_shape = (max_label,) + (num_channels,) + (order + 1,) * ndim + weighted = True + else: + num_channels = 1 + weighted = False + + # total number of elements in the moments matrix + moments = cp.zeros(moments_shape, dtype=cp.float64) + moments_kernel = get_raw_moments_kernel( + ndim=label_image.ndim, + order=order, + moments_dtype=moments.dtype, + int32_coords=int32_coords, + spacing=spacing, + weighted=weighted, + num_channels=num_channels, + pixels_per_thread=pixels_per_thread, + ) + input_args = ( + label_image, + label_image.size, + bbox_coords, + ) + if spacing: + input_args = input_args + (cp.asarray(spacing, dtype=cp.float64),) + if weighted: + input_args = input_args + (intensity_image,) + size = math.ceil(label_image.size / pixels_per_thread) + moments_kernel(*input_args, moments, size=size) + if weighted: + props_dict["moments_weighted"] = moments + else: + props_dict["moments"] = moments + return moments + + +@cp.memoize(for_each_device=True) +def get_moments_central_kernel( + moments_dtype, + ndim, + order, +): + """Applies analytical formulas to convert raw moments to central moments + + These are as in `_moments_raw_to_central_fast` from + `_moments_analytical.py` but that kernel is scalar, while this one will be + applied to all labeled regions (and any channels dimension) at once. + """ + moments_dtype = cp.dtype(moments_dtype) + + uint_t = "unsigned int" + + # number is for a densely populated moments matrix of size (order + 1) per + # side (values at locations where order is greater than specified will be 0) + num_moments = (order + 1) ** ndim + + if moments_dtype.kind != "f": + raise ValueError( + "`moments_dtype` must be a floating point type for central moments " + "calculations." + ) + + # floating point type used for the intermediate computations + float_type = "double" + + source = f""" + {uint_t} offset = i * {num_moments};\n""" + if ndim == 2: + if order <= 1: + # only zeroth moment is non-zero for central moments + source += """ + out[offset] = moments_raw[offset];\n""" + elif order == 2: + source += f""" + // retrieve the 2nd order raw moments + {float_type} m00 = moments_raw[offset]; + {float_type} m01 = moments_raw[offset + 1]; + {float_type} m02 = moments_raw[offset + 2]; + {float_type} m10 = moments_raw[offset + 3]; + {float_type} m11 = moments_raw[offset + 4]; + {float_type} m20 = moments_raw[offset + 6]; + + // compute centroids + // (TODO: add option to output the centroids as well?) + {float_type} cx = m10 / m00; + {float_type} cy = m01 / m00; + + // analytical expressions for the central moments + out[offset] = m00; // out[0, 0] + // 2nd order central moments + out[offset + 2] = m02 - cy * m01; // out[0, 2] + out[offset + 4] = m11 - cx * m01; // out[1, 1] + out[offset + 6] = m20 - cx * m10; // out[2, 0]\n""" + elif order == 3: + source += f""" + // retrieve the 2nd order raw moments + {float_type} m00 = moments_raw[offset]; + {float_type} m01 = moments_raw[offset + 1]; + {float_type} m02 = moments_raw[offset + 2]; + {float_type} m03 = moments_raw[offset + 3]; + {float_type} m10 = moments_raw[offset + 4]; + {float_type} m11 = moments_raw[offset + 5]; + {float_type} m12 = moments_raw[offset + 6]; + {float_type} m20 = moments_raw[offset + 8]; + {float_type} m21 = moments_raw[offset + 9]; + {float_type} m30 = moments_raw[offset + 12]; + + // compute centroids + {float_type} cx = m10 / m00; + {float_type} cy = m01 / m00; + + // zeroth moment + out[offset] = m00; // out[0, 0] + // 2nd order central moments + out[offset + 2] = m02 - cy * m01; // out[0, 2] + out[offset + 5] = m11 - cx * m01; // out[1, 1] + out[offset + 8] = m20 - cx * m10; // out[2, 0] + // 3rd order central moments + out[offset + 3] = m03 - 3*cy*m02 + 2*cy*cy*m01; // out[0, 3] + out[offset + 6] = m12 - 2*cy*m11 - cx*m02 + 2*cy*cx*m01; // out[1, 2] + out[offset + 9] = m21 - 2*cx*m11 - cy*m20 + cx*cx*m01 + cy*cx*m10; // out[2, 1] + out[offset + 12] = m30 - 3*cx*m20 + 2*cx*cx*m10; // out[3, 0]\n""" # noqa: E501 + else: + raise ValueError("only order <= 3 is supported") + elif ndim == 3: + if order <= 1: + # only zeroth moment is non-zero for central moments + source += """ + out[offset] = moments_raw[offset];\n""" + elif order == 2: + source += f""" + // retrieve the 2nd order raw moments + {float_type} m000 = moments_raw[offset]; + {float_type} m001 = moments_raw[offset + 1]; + {float_type} m002 = moments_raw[offset + 2]; + {float_type} m010 = moments_raw[offset + 3]; + {float_type} m011 = moments_raw[offset + 4]; + {float_type} m020 = moments_raw[offset + 6]; + {float_type} m100 = moments_raw[offset + 9]; + {float_type} m101 = moments_raw[offset + 10]; + {float_type} m110 = moments_raw[offset + 12]; + {float_type} m200 = moments_raw[offset + 18]; + + // compute centroids + {float_type} cx = m100 / m000; + {float_type} cy = m010 / m000; + {float_type} cz = m001 / m000; + + // zeroth moment + out[offset] = m000; // out[0, 0, 0] + // 2nd order central moments + out[offset + 2] = -cz*m001 + m002; // out[0, 0, 2] + out[offset + 4] = -cy*m001 + m011; // out[0, 1, 1] + out[offset + 6] = -cy*m010 + m020; // out[0, 2, 0] + out[offset + 10] = -cx*m001 + m101; // out[1, 0, 1] + out[offset + 12] = -cx*m010 + m110; // out[1, 1, 0] + out[offset + 18] = -cx*m100 + m200; // out[2, 0, 0]\n""" + elif order == 3: + source += f""" + // retrieve the 3rd order raw moments + {float_type} m000 = moments_raw[offset]; + {float_type} m001 = moments_raw[offset + 1]; + {float_type} m002 = moments_raw[offset + 2]; + {float_type} m003 = moments_raw[offset + 3]; + {float_type} m010 = moments_raw[offset + 4]; + {float_type} m011 = moments_raw[offset + 5]; + {float_type} m012 = moments_raw[offset + 6]; + {float_type} m020 = moments_raw[offset + 8]; + {float_type} m021 = moments_raw[offset + 9]; + {float_type} m030 = moments_raw[offset + 12]; + {float_type} m100 = moments_raw[offset + 16]; + {float_type} m101 = moments_raw[offset + 17]; + {float_type} m102 = moments_raw[offset + 18]; + {float_type} m110 = moments_raw[offset + 20]; + {float_type} m111 = moments_raw[offset + 21]; + {float_type} m120 = moments_raw[offset + 24]; + {float_type} m200 = moments_raw[offset + 32]; + {float_type} m201 = moments_raw[offset + 33]; + {float_type} m210 = moments_raw[offset + 36]; + {float_type} m300 = moments_raw[offset + 48]; + + // compute centroids + {float_type} cx = m100 / m000; + {float_type} cy = m010 / m000; + {float_type} cz = m001 / m000; + + // zeroth moment + out[offset] = m000; + // 2nd order central moments + out[offset + 2] = -cz*m001 + m002; // out[0, 0, 2] + out[offset + 5] = -cy*m001 + m011; // out[0, 1, 1] + out[offset + 8] = -cy*m010 + m020; // out[0, 2, 0] + out[offset + 17] = -cx*m001 + m101; // out[1, 0, 1] + out[offset + 20] = -cx*m010 + m110; // out[1, 1, 0] + out[offset + 32] = -cx*m100 + m200; // out[2, 0, 0] + // 3rd order central moments + out[offset + 3] = 2*cz*cz*m001 - 3*cz*m002 + m003; // out[0, 0, 3] + out[offset + 6] = -cy*m002 + 2*cz*(cy*m001 - m011) + m012; // out[0, 1, 2] + out[offset + 9] = cy*cy*m001 - 2*cy*m011 + cz*(cy*m010 - m020) + m021; // out[0, 2, 1] + out[offset + 12] = 2*cy*cy*m010 - 3*cy*m020 + m030; // out[0, 3, 0] + out[offset + 18] = -cx*m002 + 2*cz*(cx*m001 - m101) + m102; // out[1, 0, 2] + out[offset + 21] = -cx*m011 + cy*(cx*m001 - m101) + cz*(cx*m010 - m110) + m111; // out[1, 1, 1] + out[offset + 24] = -cx*m020 - 2*cy*(-cx*m010 + m110) + m120; // out[1, 2, 0] + out[offset + 33] = cx*cx*m001 - 2*cx*m101 + cz*(cx*m100 - m200) + m201; // out[2, 0, 1] + out[offset + 36] = cx*cx*m010 - 2*cx*m110 + cy*(cx*m100 - m200) + m210; // out[2, 1, 0] + out[offset + 48] = 2*cx*cx*m100 - 3*cx*m200 + m300; // out[3, 0, 0]\n""" # noqa: E501 + else: + raise ValueError("only order <= 3 is supported") + else: + # note: ndim here is the number of spatial image dimensions + raise ValueError("only ndim = 2 or 3 is supported") + inputs = "raw X moments_raw" + outputs = "raw X out" + name = f"cucim_moments_central_order{order}_{ndim}d" + return cp.ElementwiseKernel( + inputs, outputs, source, preamble=_includes, name=name + ) + + +def regionprops_moments_central( + moments_raw, ndim, weighted=False, props_dict=None +): + """Compute the central moments of the labeled regions. + + Computes the central moments from raw moments. + + Writes "moments_central" to `props_dict` if `weighted` is ``False``. + + Writes "moments_weighted_central" to `props_dict` if `weighted` is ``True``. + """ + if props_dict is None: + props_dict = {} + + if moments_raw.ndim == 2 + ndim: + num_channels = moments_raw.shape[1] + elif moments_raw.ndim == 1 + ndim: + num_channels = 1 + else: + raise ValueError( + f"{moments_raw.shape=} does not have expected length of `ndim + 1`" + " (or `ndim + 2` for the multi-channel weighted moments case)." + ) + order = moments_raw.shape[-1] - 1 + max_label = moments_raw.shape[0] + + if moments_raw.dtype.kind != "f": + float_dtype = cp.promote_types(cp.float32, moments_raw.dtype) + moments_raw = moments_raw.astype(float_dtype) + + # make a copy if the inputs are not already C-contiguous + if not moments_raw.flags.c_contiguous: + moments_raw = cp.ascontiguousarray(moments_raw) + + moments_kernel = get_moments_central_kernel(moments_raw.dtype, ndim, order) + moments_central = cp.zeros_like(moments_raw) + # kernel loops over moments so size is max_label * num_channels + moments_kernel(moments_raw, moments_central, size=max_label * num_channels) + if props_dict is not None: + if weighted: + props_dict["moments_weighted_central"] = moments_central + else: + props_dict["moments_central"] = moments_central + return moments_central + + +@cp.memoize(for_each_device=True) +def get_moments_normalize_kernel( + moments_dtype, ndim, order, unit_scale=False, pixel_correction=False +): + """Normalizes central moments of order >=2""" + moments_dtype = cp.dtype(moments_dtype) + + uint_t = "unsigned int" + + # number is for a densely populated moments matrix of size (order + 1) per + # side (values at locations where order is greater than specified will be 0) + num_moments = (order + 1) ** ndim + + if moments_dtype.kind != "f": + raise ValueError( + "`moments_dtype` must be a floating point type for central moments " + "calculations." + ) + + # floating point type used for the intermediate computations + float_type = "double" + source = f""" + {uint_t} offset = i * {num_moments};\n""" + if ndim == 2: + if order == 2: + source += f""" + // retrieve zeroth moment + {float_type} m00 = moments_central[offset];\n""" + + # compute normalization factor + source += f""" + {float_type} norm_order2 = pow(m00, 2.0 / {ndim} + 1.0);""" + if not unit_scale: + source += """ + norm_order2 *= scale * scale;\n""" + + # normalize + source += """ + // normalize the 2nd order central moments + out[offset + 2] = moments_central[offset + 2] / norm_order2; // out[0, 2] + out[offset + 4] = moments_central[offset + 4] / norm_order2; // out[1, 1] + out[offset + 6] = moments_central[offset + 6] / norm_order2; // out[2, 0]\n""" # noqa: E501 + elif order == 3: + source += f""" + // retrieve zeroth moment + {float_type} m00 = moments_central[offset];\n""" + + # compute normalization factor + source += f""" + {float_type} norm_order2 = pow(m00, 2.0 / {ndim} + 1.0); + {float_type} norm_order3 = pow(m00, 3.0 / {ndim} + 1.0);""" + if not unit_scale: + source += """ + norm_order2 *= scale * scale; + norm_order3 *= scale * scale * scale;\n""" + + # normalize + source += """ + // normalize the 2nd order central moments + out[offset + 2] = moments_central[offset + 2] / norm_order2; // out[0, 2] + out[offset + 5] = moments_central[offset + 5] / norm_order2; // out[1, 1] + out[offset + 8] = moments_central[offset + 8] / norm_order2; // out[2, 0] + // normalize the 3rd order central moments + out[offset + 3] = moments_central[offset + 3] / norm_order3; // out[0, 3] + out[offset + 6] = moments_central[offset + 6] / norm_order3; // out[1, 2] + out[offset + 9] = moments_central[offset + 9] / norm_order3; // out[2, 1] + out[offset + 12] = moments_central[offset + 12] / norm_order3; // out[3, 0]\n""" # noqa: E501 + else: + raise ValueError("only order = 2 or 3 is supported") + elif ndim == 3: + if order == 2: + source += f""" + // retrieve the zeroth moment + {float_type} m000 = moments_central[offset];\n""" + + # compute normalization factor + source += f""" + {float_type} norm_order2 = pow(m000, 2.0 / {ndim} + 1.0);""" + if not unit_scale: + source += """ + norm_order2 *= scale * scale;\n""" + + # normalize + source += """ + // normalize the 2nd order central moments + out[offset + 2] = moments_central[offset + 2] / norm_order2; // out[0, 0, 2] + out[offset + 4] = moments_central[offset + 4] / norm_order2; // out[0, 1, 1] + out[offset + 6] = moments_central[offset + 6] / norm_order2; // out[0, 2, 0] + out[offset + 10] = moments_central[offset + 10] / norm_order2; // out[1, 0, 1] + out[offset + 12] = moments_central[offset + 12] / norm_order2; // out[1, 1, 0] + out[offset + 18] = moments_central[offset + 18] / norm_order2; // out[2, 0, 0]\n""" # noqa: E501 + elif order == 3: + source += f""" + // retrieve the zeroth moment + {float_type} m000 = moments_central[offset];\n""" + + # compute normalization factor + source += f""" + {float_type} norm_order2 = pow(m000, 2.0 / {ndim} + 1.0); + {float_type} norm_order3 = pow(m000, 3.0 / {ndim} + 1.0);""" + if not unit_scale: + source += """ + norm_order2 *= scale * scale; + norm_order3 *= scale * scale * scale;\n""" + + # normalize + source += """ + // normalize the 2nd order central moments + out[offset + 2] = moments_central[offset + 2] / norm_order2; // out[0, 0, 2] + out[offset + 5] = moments_central[offset + 5] / norm_order2; // out[0, 1, 1] + out[offset + 8] = moments_central[offset + 8] / norm_order2; // out[0, 2, 0] + out[offset + 17] = moments_central[offset + 17] / norm_order2; // out[1, 0, 1] + out[offset + 20] = moments_central[offset + 20] / norm_order2; // out[1, 1, 0] + out[offset + 32] = moments_central[offset + 32] / norm_order2; // out[2, 0, 0] + // normalize the 3rd order central moments + out[offset + 3] = moments_central[offset + 3] / norm_order3; // out[0, 0, 3] + out[offset + 6] = moments_central[offset + 6] / norm_order3; // out[0, 1, 2] + out[offset + 9] = moments_central[offset + 9] / norm_order3; // out[0, 2, 1] + out[offset + 12] = moments_central[offset + 12] / norm_order3; // out[0, 3, 0] + out[offset + 18] = moments_central[offset + 18] / norm_order3; // out[1, 0, 2] + out[offset + 21] = moments_central[offset + 21] / norm_order3; // out[1, 1, 1] + out[offset + 24] = moments_central[offset + 24] / norm_order3; // out[1, 2, 0] + out[offset + 33] = moments_central[offset + 33] / norm_order3; // out[2, 0, 1] + out[offset + 36] = moments_central[offset + 36] / norm_order3; // out[2, 1, 0] + out[offset + 48] = moments_central[offset + 48] / norm_order3; // out[3, 0, 0]\n""" # noqa: E501 + else: + raise ValueError("only order = 2 or 3 is supported") + else: + # note: ndim here is the number of spatial image dimensions + raise ValueError("only ndim = 2 or 3 is supported") + inputs = "raw X moments_central" + if not unit_scale: + inputs += ", float64 scale" + outputs = "raw X out" + name = f"cucim_moments_normalized_order{order}_{ndim}d" + return cp.ElementwiseKernel( + inputs, outputs, source, preamble=_includes, name=name + ) + + +def regionprops_moments_normalized( + moments_central, + ndim, + spacing=None, + pixel_correction=False, + weighted=False, + props_dict=None, +): + """Compute the normalizedcentral moments of the labeled regions. + + Computes normalized central moments from central moments. + + Writes "moments_normalized" to `props_dict` if `weighted` is ``False``. + + Writes "moments_weighted_normalized" to `props_dict` if `weighted` is + ``True``. + + Notes + ----- + Default setting of `pixel_correction=False` matches the scikit-image + behavior (as of v0.25). + + The `pixel_correction` is to account for pixel/voxel size and is only + implemented for 2nd order moments currently based on the derivation in: + + The correction should need to be updated to take 'spacing' into account as + it currently assumes unit size. + + Padfield D., Miller J. "A Label Geometry Image Filter for Multiple Object + Measurement". The Insight Journal. 2013 Mar. + https://doi.org/10.54294/saa3nn + """ + if moments_central.ndim == 2 + ndim: + num_channels = moments_central.shape[1] + elif moments_central.ndim == 1 + ndim: + num_channels = 1 + else: + raise ValueError( + f"{moments_central.shape=} does not have expected length of " + " `ndim + 1` (or `ndim + 2` for the multi-channel weighted moments " + "case)." + ) + order = moments_central.shape[-1] - 1 + if order < 2 or order > 3: + raise ValueError( + "normalized moment calculations only implemented for order=2 " + "and order=3" + ) + if ndim < 2 or ndim > 3: + raise ValueError( + "moment normalization only implemented for 2D and 3D images" + ) + max_label = moments_central.shape[0] + + if moments_central.dtype.kind != "f": + raise ValueError("moments_central must have a floating point dtype") + + # make a copy if the inputs are not already C-contiguous + if not moments_central.flags.c_contiguous: + moments_central = cp.ascontiguousarray(moments_central) + + if spacing is None: + unit_scale = True + inputs = (moments_central,) + else: + if spacing: + if isinstance(spacing, cp.ndarray): + scale = spacing.min() + else: + scale = float(min(spacing)) + unit_scale = False + inputs = (moments_central, scale) + + moments_norm_kernel = get_moments_normalize_kernel( + moments_central.dtype, + ndim, + order, + unit_scale=unit_scale, + pixel_correction=pixel_correction, + ) + # output is NaN except for locations with orders in range [2, order] + moments_norm = cp.full( + moments_central.shape, cp.nan, dtype=moments_central.dtype + ) + + # kernel loops over moments so size is max_label * num_channels + moments_norm_kernel(*inputs, moments_norm, size=max_label * num_channels) + if props_dict is not None: + if weighted: + props_dict["moments_weighted_normalized"] = moments_norm + else: + props_dict["moments_normalized"] = moments_norm + return moments_norm + + +@cp.memoize(for_each_device=True) +def get_moments_hu_kernel(moments_dtype): + """Normalizes central moments of order >=2""" + moments_dtype = cp.dtype(moments_dtype) + + uint_t = "unsigned int" + + # number is for a densely populated moments matrix of size (order + 1) per + # side (values at locations where order is greater than specified will be 0) + num_moments = 16 + + if moments_dtype.kind != "f": + raise ValueError( + "`moments_dtype` must be a floating point type for central moments " + "calculations." + ) + + # floating point type used for the intermediate computations + float_type = "double" + + # compute offset to the current moment matrix and hu moment vector + source = f""" + {uint_t} offset_normalized = i * {num_moments}; + {uint_t} offset_hu = i * 7;\n""" + + source += f""" + // retrieve 2nd and 3rd order normalized moments + {float_type} m02 = moments_normalized[offset_normalized + 2]; + {float_type} m03 = moments_normalized[offset_normalized + 3]; + {float_type} m12 = moments_normalized[offset_normalized + 6]; + {float_type} m11 = moments_normalized[offset_normalized + 5]; + {float_type} m20 = moments_normalized[offset_normalized + 8]; + {float_type} m21 = moments_normalized[offset_normalized + 9]; + {float_type} m30 = moments_normalized[offset_normalized + 12]; + + {float_type} t0 = m30 + m12; + {float_type} t1 = m21 + m03; + {float_type} q0 = t0 * t0; + {float_type} q1 = t1 * t1; + {float_type} n4 = 4 * m11; + {float_type} s = m20 + m02; + {float_type} d = m20 - m02; + hu[offset_hu] = s; + hu[offset_hu + 1] = d * d + n4 * m11; + hu[offset_hu + 3] = q0 + q1; + hu[offset_hu + 5] = d * (q0 - q1) + n4 * t0 * t1; + t0 *= q0 - 3 * q1; + t1 *= 3 * q0 - q1; + q0 = m30- 3 * m12; + q1 = 3 * m21 - m03; + hu[offset_hu + 2] = q0 * q0 + q1 * q1; + hu[offset_hu + 4] = q0 * t0 + q1 * t1; + hu[offset_hu + 6] = q1 * t0 - q0 * t1;\n""" + + inputs = f"raw {moments_dtype.name} moments_normalized" + outputs = f"raw {moments_dtype.name} hu" + name = f"cucim_moments_hu_order_{moments_dtype.name}" + return cp.ElementwiseKernel( + inputs, outputs, source, preamble=_includes, name=name + ) + + +def regionprops_moments_hu(moments_normalized, weighted=False, props_dict=None): + """Compute the 2D Hu invariant moments from 3rd ordernormalized central + moments. + + Writes "moments_hu" to `props_dict` if `weighted` is ``False`` + + Writes "moments_weighted_hu" to `props_dict` if `weighted` is ``True``. + """ + if props_dict is None: + props_dict = {} + + if moments_normalized.ndim == 4: + num_channels = moments_normalized.shape[1] + elif moments_normalized.ndim == 3: + num_channels = 1 + else: + raise ValueError( + "Hu's moments are only defined for 2D images. Expected " + "`moments_normalized to have 3 dimensions (or 4 for the " + "multi-channel `intensity_image` case)." + ) + order = moments_normalized.shape[-1] - 1 + if order < 3: + raise ValueError( + "Calculating Hu's moments requires normalized moments of " + "order >= 3 to be provided as input" + ) + elif order > 3: + # truncate any unused moments + moments_normalized = cp.ascontiguousarray( + moments_normalized[..., :4, :4] + ) + max_label = moments_normalized.shape[0] + + if moments_normalized.dtype.kind != "f": + raise ValueError("moments_normalized must have a floating point dtype") + + # make a copy if the inputs are not already C-contiguous + if not moments_normalized.flags.c_contiguous: + moments_normalized = cp.ascontiguousarray(moments_normalized) + + moments_hu_kernel = get_moments_hu_kernel(moments_normalized.dtype) + # Hu's moments are a set of 7 moments stored instead of a moment matrix + hu_shape = moments_normalized.shape[:-2] + (7,) + moments_hu = cp.full(hu_shape, cp.nan, dtype=moments_normalized.dtype) + + # kernel loops over moments so size is max_label * num_channels + moments_hu_kernel( + moments_normalized, moments_hu, size=max_label * num_channels + ) + if props_dict is not None: + if weighted: + props_dict["moments_weighted_hu"] = moments_hu + else: + props_dict["moments_hu"] = moments_hu + return moments_hu + + +@cp.memoize(for_each_device=True) +def get_inertia_tensor_kernel(moments_dtype, ndim, compute_orientation): + """Normalizes central moments of order >=2""" + moments_dtype = cp.dtype(moments_dtype) + + # assume moments input was truncated to only hold order<=2 moments + num_moments = 3**ndim + + # size of the inertia_tensor matrix + num_out = ndim * ndim + + if moments_dtype.kind != "f": + raise ValueError( + "`moments_dtype` must be a floating point type for central moments " + "calculations." + ) + + source = f""" + unsigned int offset = i * {num_moments}; + unsigned int offset_out = i * {num_out};\n""" + if ndim == 2: + source += """ + F mu0 = moments_central[offset]; + F mxx = moments_central[offset + 6]; + F myy = moments_central[offset + 2]; + F mxy = moments_central[offset + 4]; + + F a = myy / mu0; + F b = -mxy / mu0; + F c = mxx / mu0; + out[offset_out + 0] = a; + out[offset_out + 1] = b; + out[offset_out + 2] = b; + out[offset_out + 3] = c; + """ + if compute_orientation: + source += """ + if (a - c == 0) { + // had to use <= 0 to get same result as Python's atan2 with < 0 + if (b < 0) { + orientation[i] = -M_PI / 4.0; + } else { + orientation[i] = M_PI / 4.0; + } + } else { + orientation[i] = 0.5 * atan2(-2 * b, c - a); + }\n""" + elif ndim == 3: + if compute_orientation: + raise ValueError("orientation can only be computed for 2d images") + source += """ + F mu0 = moments_central[offset]; // [0, 0, 0] + F mxx = moments_central[offset + 18]; // [2, 0, 0] + F myy = moments_central[offset + 6]; // [0, 2, 0] + F mzz = moments_central[offset + 2]; // [0, 0, 2] + + F mxy = moments_central[offset + 12]; // [1, 1, 0] + F mxz = moments_central[offset + 10]; // [1, 0, 1] + F myz = moments_central[offset + 4]; // [0, 1, 1] + + out[offset_out + 0] = (myy + mzz) / mu0; + out[offset_out + 4] = (mxx + mzz) / mu0; + out[offset_out + 8] = (mxx + myy) / mu0; + out[offset_out + 1] = -mxy / mu0; + out[offset_out + 3] = -mxy / mu0; + out[offset_out + 2] = -mxz / mu0; + out[offset_out + 6] = -mxz / mu0; + out[offset_out + 5] = -myz / mu0; + out[offset_out + 7] = -myz / mu0;\n""" + else: + # note: ndim here is the number of spatial image dimensions + raise ValueError("only ndim = 2 or 3 is supported") + inputs = "raw F moments_central" + outputs = "raw F out" + if compute_orientation: + outputs += ", raw F orientation" + name = f"cucim_inertia_tensor_{ndim}d" + return cp.ElementwiseKernel( + inputs, outputs, source, preamble=_includes, name=name + ) + + +def regionprops_inertia_tensor( + moments_central, ndim, compute_orientation=False, props_dict=None +): + """ "Compute the inertia tensor from the central moments. + + The input to this function is the output of `regionprops_moments_central`. + + Writes "inertia_tensor" to `props_dict`. + Writes "orientation" to `props_dict` if `compute_orientation` is ``True``. + """ + if ndim < 2 or ndim > 3: + raise ValueError("inertia tensor only implemented for 2D and 3D images") + if compute_orientation and ndim != 2: + raise ValueError("orientation can only be computed for ndim=2") + + nbatch = math.prod(moments_central.shape[:-ndim]) + + if moments_central.dtype.kind != "f": + raise ValueError("moments_central must have a floating point dtype") + + # make a copy if the inputs are not already C-contiguous + if not moments_central.flags.c_contiguous: + moments_central = cp.ascontiguousarray(moments_central) + + order = moments_central.shape[-1] - 1 + if order < 2: + raise ValueError( + f"inertia tensor calculation requires order>=2, found {order}" + ) + if order > 2: + # truncate to only the 2nd order moments + slice_kept = (Ellipsis,) + (slice(0, 3),) * ndim + moments_central = cp.ascontiguousarray(moments_central[slice_kept]) + + kernel = get_inertia_tensor_kernel( + moments_central.dtype, ndim, compute_orientation=compute_orientation + ) + itensor_shape = moments_central.shape[:-ndim] + (ndim, ndim) + itensor = cp.zeros(itensor_shape, dtype=moments_central.dtype) + if compute_orientation: + orientation = cp.zeros( + moments_central.shape[:-ndim], dtype=moments_central.dtype + ) + kernel(moments_central, itensor, orientation, size=nbatch) + if props_dict is not None: + props_dict["inertia_tensor"] = itensor + props_dict["orientation"] = orientation + return itensor, orientation + + kernel(moments_central, itensor, size=nbatch) + if props_dict is not None: + props_dict["inertia_tensor"] = itensor + return itensor + + +@cp.memoize(for_each_device=True) +def get_spd_matrix_eigvals_kernel( + rank, + compute_axis_lengths=False, + compute_eccentricity=False, +): + """Compute symmetric positive definite (SPD) matrix eigenvalues + + Implements closed-form analytical solutions for 2x2 and 3x3 matrices. + + C. Deledalle, L. Denis, S. Tabti, F. Tupin. Closed-form expressions + of the eigen decomposition of 2 x 2 and 3 x 3 Hermitian matrices. + + [Research Report] Université de Lyon. 2017. + https://hal.archives-ouvertes.fr/hal-01501221/file/matrix_exp_and_log_formula.pdf + """ # noqa: E501 + + # assume moments input was truncated to only hold order<=2 moments + num_elements = rank * rank + + # size of the inertia_tensor matrix + source = f""" + unsigned int offset = i * {num_elements}; + unsigned int offset_evals = i * {rank};\n""" + if rank == 2: + source += """ + F tmp1, tmp2; + double m00 = static_cast(spd_matrix[offset]); + double m01 = static_cast(spd_matrix[offset + 1]); + double m11 = static_cast(spd_matrix[offset + 3]); + tmp1 = m01 * m01; + tmp1 *= 4; + + tmp2 = m00 - m11; + tmp2 *= tmp2; + tmp2 += tmp1; + tmp2 = sqrt(tmp2); + tmp2 /= 2; + + tmp1 = m00 + m11; + tmp1 /= 2; + + // store in "descending" order and clip to positive values + // (matrix is spd, so negatives values can only be due to + // numerical errors) + F lam1 = max(tmp1 + tmp2, 0.0); + F lam2 = max(tmp1 - tmp2, 0.0); + evals[offset_evals] = lam1; + evals[offset_evals + 1] = lam2;\n""" + if compute_axis_lengths: + source += """ + axis_lengths[offset_evals] = 4.0 * sqrt(lam1); + axis_lengths[offset_evals + 1] = 4.0 * sqrt(lam2);\n""" + if compute_eccentricity: + source += """ + eccentricity[i] = sqrt(1.0 - lam2 / lam1);\n""" + elif rank == 3: + if compute_eccentricity: + raise ValueError("eccentricity only supported for 2D images") + + source += """ + double x1, x2, phi; + // extract triangle of (spd) inertia tensor values + // [a, d, f] + // [-, b, e] + // [-, -, c] + double a = static_cast(spd_matrix[offset]); + double b = static_cast(spd_matrix[offset + 4]); + double c = static_cast(spd_matrix[offset + 8]); + double d = static_cast(spd_matrix[offset + 1]); + double e = static_cast(spd_matrix[offset + 5]); + double f = static_cast(spd_matrix[offset + 2]); + double d_sq = d * d; + double e_sq = e * e; + double f_sq = f * f; + double tmpa = (2*a - b - c); + double tmpb = (2*b - a - c); + double tmpc = (2*c - a - b); + x2 = - tmpa * tmpb * tmpc; + x2 += 9 * (tmpc*d_sq + tmpb*f_sq + tmpa*e_sq); + x2 -= 54 * (d * e * f); + x1 = a*a + b*b + c*c - a*b - a*c - b*c + 3 * (d_sq + e_sq + f_sq); + + // grlee77: added max() here for numerical stability + // (avoid NaN values in ridge filter test cases) + x1 = max(x1, 0.0); + + if (x2 == 0.0) { + phi = M_PI / 2.0; + } else { + // grlee77: added max() here for numerical stability + // (avoid NaN values in test_hessian_matrix_eigvals_3d) + double arg = max(4*x1*x1*x1 - x2*x2, 0.0); + phi = atan(sqrt(arg)/x2); + if (x2 < 0) { + phi += M_PI; + } + } + double x1_term = (2.0 / 3.0) * sqrt(x1); + double abc = (a + b + c) / 3.0; + F lam1 = abc - x1_term * cos(phi / 3.0); + F lam2 = abc + x1_term * cos((phi - M_PI) / 3.0); + F lam3 = abc + x1_term * cos((phi + M_PI) / 3.0); + + // abc = 141.94321771 + // x1_term = 1279.25821493 + // M_PI = 3.14159265 + // phi = 1.91643394 + // cos(phi/3.0) = 0.80280507 + // cos((phi - M_PI) / 3.0) = 0.91776289 + + F stmp; + if (lam3 > lam2) { + stmp = lam2; + lam2 = lam3; + lam3 = stmp; + } + if (lam3 > lam1) { + stmp = lam1; + lam1 = lam3; + lam3 = stmp; + } + if (lam2 > lam1) { + stmp = lam1; + lam1 = lam2; + lam2 = stmp; + } + // clip to positive values + // (matrix is spd, so negatives values can only be due to + // numerical errors) + lam1 = max(lam1, 0.0); + lam2 = max(lam2, 0.0); + lam3 = max(lam3, 0.0); + evals[offset_evals] = lam1; + evals[offset_evals + 1] = lam2; + evals[offset_evals + 2] = lam3;\n""" + if compute_axis_lengths: + """ + Notes + ----- + Let a >= b >= c be the ellipsoid semi-axes and s1 >= s2 >= s3 be the + inertia tensor eigenvalues. + + The inertia tensor eigenvalues are given for a solid ellipsoid in [1]_. + s1 = 1 / 5 * (a**2 + b**2) + s2 = 1 / 5 * (a**2 + c**2) + s3 = 1 / 5 * (b**2 + c**2) + + Rearranging to solve for a, b, c in terms of s1, s2, s3 gives + a = math.sqrt(5 / 2 * ( s1 + s2 - s3)) + b = math.sqrt(5 / 2 * ( s1 - s2 + s3)) + c = math.sqrt(5 / 2 * (-s1 + s2 + s3)) + + We can then simply replace sqrt(5/2) by sqrt(10) to get the full axes + lengths rather than the semi-axes lengths. + + References + ---------- + ..[1] https://en.wikipedia.org/wiki/List_of_moments_of_inertia#List_of_3D_inertia_tensors + """ # noqa: E501 + source += """ + // formula reference: + // https://github.com/scikit-image/scikit-image/blob/v0.25.0/skimage/measure/_regionprops.py#L275-L295 + // note: added max to clip possible small (e.g. 1e-7) negative value due to numerical error + axis_lengths[offset_evals] = sqrt(10.0 * (lam1 + lam2 - lam3)); + axis_lengths[offset_evals + 1] = sqrt(10.0 * (lam1 - lam2 + lam3)); + axis_lengths[offset_evals + 2] = sqrt(10.0 * max(-lam1 + lam2 + lam3, 0.0));\n""" # noqa: E501 + else: + # note: ndim here is the number of spatial image dimensions + raise ValueError("only rank = 2 or 3 is supported") + inputs = "raw F spd_matrix" + outputs = ["raw F evals"] + name = f"cucim_spd_matrix_eigvals_{rank}d" + if compute_axis_lengths: + outputs.append("raw F axis_lengths") + name += "_with_axis" + if compute_eccentricity: + outputs.append("raw F eccentricity") + name += "_eccen" + outputs = ", ".join(outputs) + return cp.ElementwiseKernel( + inputs, outputs, source, preamble=_includes, name=name + ) + + +def regionprops_inertia_tensor_eigvals( + inertia_tensor, + compute_eigenvectors=False, + compute_axis_lengths=False, + compute_eccentricity=False, + props_dict=None, +): + """ "Compute the inertia tensor eigenvalues (and eigenvectors) from the + inertia tensor of each labeled region. + + The input to this function is the output of `regionprops_inertia_tensor`. + + writes "inertia_tensor_eigvals" to `props_dict` + if compute_eigenvectors: + - writes "inertia_tensor_eigenvectors" to `props_dict` + if compute_axis_lengths: + - writes "axis_major_length" to `props_dict` + - writes "axis_minor_length" to `props_dict` + - writes "axis_lengths" to `props_dict` + if compute_eccentricity: + - writes "eccentricity" to `props_dict` + """ + # inertia tensor should have shape (ndim, ndim) on last two axes + ndim = inertia_tensor.shape[-1] + if ndim < 2 or ndim > 3: + raise ValueError("inertia tensor only implemented for 2D and 3D images") + nbatch = math.prod(inertia_tensor.shape[:-2]) + + if compute_eccentricity and ndim != 2: + raise ValueError("eccentricity is only supported for 2D images") + + if inertia_tensor.dtype.kind != "f": + raise ValueError("moments_central must have a floating point dtype") + + if not inertia_tensor.flags.c_contiguous: + inertia_tensor = cp.ascontiguousarray(inertia_tensor) + + # don't use this kernel for eigenvectors as it is not robust to 0 entries + kernel = get_spd_matrix_eigvals_kernel( + rank=ndim, + compute_axis_lengths=compute_axis_lengths, + compute_eccentricity=compute_eccentricity, + ) + eigvals_shape = inertia_tensor.shape[:-2] + (ndim,) + eigvals = cp.empty(eigvals_shape, dtype=inertia_tensor.dtype) + outputs = [eigvals] + if compute_axis_lengths: + axis_lengths = cp.empty(eigvals_shape, dtype=inertia_tensor.dtype) + outputs.append(axis_lengths) + if compute_eccentricity: + eccentricity = cp.empty( + inertia_tensor.shape[:-2], dtype=inertia_tensor.dtype + ) + outputs.append(eccentricity) + kernel(inertia_tensor, *outputs, size=nbatch) + if compute_eigenvectors: + # eigenvectors computed by the kernel are not robust to 0 entries, so + # use slightly slow cp.linalg.eigh instead + eigvals, eigvecs = cp.linalg.eigh(inertia_tensor) + # swap from ascending to descending order + eigvals = eigvals[:, ::-1] + eigvecs = eigvecs[:, ::-1] + if props_dict is None: + props_dict = {} + props_dict["inertia_tensor_eigvals"] = eigvals + if compute_eccentricity: + props_dict["eccentricity"] = eccentricity + if compute_axis_lengths: + props_dict["axis_lengths"] = axis_lengths + props_dict["axis_major_length"] = axis_lengths[..., 0] + props_dict["axis_minor_length"] = axis_lengths[..., -1] + if compute_eigenvectors: + props_dict["inertia_tensor_eigenvectors"] = eigvecs + return props_dict + + +@cp.memoize(for_each_device=True) +def get_centroid_weighted_kernel( + moments_dtype, + ndim, + compute_local=True, + compute_global=False, + unit_spacing=True, + num_channels=1, +): + """Centroid (in global or local coordinates) from 1st order moment matrix""" + moments_dtype = cp.dtype(moments_dtype) + + # assume moments input was truncated to only hold order<=2 moments + num_moments = 2**ndim + if moments_dtype.kind != "f": + raise ValueError( + "`moments_dtype` must be a floating point type for central moments " + "calculations." + ) + source = "" + if compute_global: + source += f""" + unsigned int offset_coords = i * {2 * ndim};\n""" + + if num_channels > 1: + source += f""" + uint32_t num_channels = moments_raw.shape()[1]; + for (int c = 0; c < num_channels; c++) {{ + unsigned int offset = i * {num_moments} * num_channels + c * {num_moments}; + unsigned int offset_out = i * {ndim} * num_channels + c * {ndim}; + F m0 = moments_raw[offset];\n""" # noqa: E501 + else: + source += f""" + unsigned int offset = i * {num_moments}; + unsigned int offset_out = i * {ndim}; + F m0 = moments_raw[offset];\n""" + + # general formula for the n-dimensional case + # + # in 2D it gives: + # out[offset_out + 1] = moments_raw[offset + 1] / m0; // m[0, 1] + # out[offset_out] = moments_raw[offset + 2] / m0; // m[1, 0] + # + # in 3D it gives: + # out[offset_out + 2] = moments_raw[offset + 1] / m0; // m[0, 0, 1] + # out[offset_out + 1] = moments_raw[offset + 2] / m0; // m[0, 1, 0] + # out[offset_out] = moments_raw[offset + 4] / m0; // m[1, 0, 0] + axis_offset = 1 + for d in range(ndim - 1, -1, -1): + if compute_local: + source += f""" + out_local[offset_out + {d}] = moments_raw[offset + {axis_offset}] / m0;""" # noqa: E501 + if compute_global: + spc = "" if unit_spacing else f" * spacing[{d}]" + source += f""" + out_global[offset_out + {d}] = moments_raw[offset + {axis_offset}] / m0 + bbox[offset_coords + {d}]{spc};""" # noqa: E501 + axis_offset *= 2 + if num_channels > 1: + source += """ + } // channels loop\n""" + name = f"cucim_centroid_weighted_{ndim}d" + inputs = ["raw F moments_raw"] + outputs = [] + if compute_global: + name += "_global" + outputs.append("raw F out_global") + # bounding box coordinates + inputs.append("raw Y bbox") + if not unit_spacing: + inputs.append("raw float64 spacing") + name += "_spacing" + if compute_local: + name += "_local" + outputs.append("raw F out_local") + inputs = ", ".join(inputs) + outputs = ", ".join(outputs) + return cp.ElementwiseKernel( + inputs, outputs, source, preamble=_includes, name=name + ) + + +def regionprops_centroid_weighted( + moments_raw, + ndim, + bbox=None, + compute_local=True, + compute_global=False, + weighted=True, + spacing=None, + props_dict=None, +): + """Centroid (in global or local coordinates) from 1st order moment matrix + + If `compute_local` the centroid is in local coordinates, otherwise it is in + global coordinates. + + `bbox` property must be provided either via kwarg or within `props_dict` if + `compute_global` is ``True``. + + if weighted: + if compute_global: + writes "centroid_weighted" to `props_dict` + if compute_local: + writes "centroid_weighted_local" to `props_dict` + else: + if compute_global: + writes "centroid" to `props_dict` + if compute_local: + writes "centroid_local" to `props_dict` + """ + max_label = moments_raw.shape[0] + if moments_raw.ndim == ndim + 2: + num_channels = moments_raw.shape[1] + elif moments_raw.ndim == ndim + 1: + num_channels = 1 + else: + raise ValueError("moments_raw has unexpected shape") + + if compute_global and bbox is None: + if "bbox" in props_dict: + bbox = props_dict["bbox"] + else: + raise ValueError( + "bbox coordinates must be provided to get the non-local " + "centroid" + ) + + if not (compute_local or compute_global): + raise ValueError( + "nothing to compute: either compute_global and/or compute_local " + "must be true" + ) + if moments_raw.dtype.kind != "f": + raise ValueError("moments_raw must have a floating point dtype") + order = moments_raw.shape[-1] - 1 + if order < 1: + raise ValueError( + f"inertia tensor calculation requires order>=1, found {order}" + ) + if order >= 1: + # truncate to only the 1st order moments + slice_kept = (Ellipsis,) + (slice(0, 2),) * ndim + moments_raw = cp.ascontiguousarray(moments_raw[slice_kept]) + + # make a copy if the inputs are not already C-contiguous + if not moments_raw.flags.c_contiguous: + moments_raw = cp.ascontiguousarray(moments_raw) + + unit_spacing = spacing is None + + if compute_local and not compute_global: + inputs = (moments_raw,) + else: + if not bbox.flags.c_contiguous: + bbox = cp.ascontiguousarray(bbox) + inputs = (moments_raw, bbox) + if not unit_spacing: + inputs = inputs + (cp.asarray(spacing),) + kernel = get_centroid_weighted_kernel( + moments_raw.dtype, + ndim, + compute_local=compute_local, + compute_global=compute_global, + unit_spacing=unit_spacing, + num_channels=num_channels, + ) + centroid_shape = moments_raw.shape[:-ndim] + (ndim,) + outputs = [] + if compute_global: + centroid_global = cp.zeros(centroid_shape, dtype=moments_raw.dtype) + outputs.append(centroid_global) + if compute_local: + centroid_local = cp.zeros(centroid_shape, dtype=moments_raw.dtype) + outputs.append(centroid_local) + # Note: order of inputs and outputs here must match + # get_centroid_weighted_kernel + kernel(*inputs, *outputs, size=max_label) + if props_dict is None: + props_dict = {} + if compute_local: + if weighted: + props_dict["centroid_weighted_local"] = centroid_local + else: + props_dict["centroid_local"] = centroid_local + if compute_global: + if weighted: + props_dict["centroid_weighted"] = centroid_global + else: + props_dict["centroid"] = centroid_global + + return props_dict diff --git a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py index 1f162692..7dfa16a1 100644 --- a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py +++ b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py @@ -1,6 +1,7 @@ import functools import math import warnings +from copy import deepcopy import cupy as cp import numpy as np @@ -17,19 +18,29 @@ from cucim.skimage.measure._regionprops import PROPS from cucim.skimage.measure._regionprops_gpu import ( equivalent_diameter_area, + ndim_2_only, need_intensity_image, regionprops_area, regionprops_area_bbox, regionprops_area_convex, regionprops_bbox_coords, + regionprops_centroid, + regionprops_centroid_local, + regionprops_centroid_weighted, regionprops_coords, regionprops_dict, regionprops_extent, regionprops_feret_diameter_max, regionprops_image, + regionprops_inertia_tensor, + regionprops_inertia_tensor_eigvals, regionprops_intensity_mean, regionprops_intensity_min_max, regionprops_intensity_std, + regionprops_moments, + regionprops_moments_central, + regionprops_moments_hu, + regionprops_moments_normalized, regionprops_num_pixels, ) from cucim.skimage.measure._regionprops_gpu_basic_kernels import basic_deps @@ -37,6 +48,9 @@ from cucim.skimage.measure._regionprops_gpu_intensity_kernels import ( intensity_deps, ) +from cucim.skimage.measure._regionprops_gpu_moments_kernels import ( + moment_deps, +) def get_labels_nd( @@ -429,6 +443,572 @@ def test_bbox_coords_and_area(precompute_max, ndim, dtype, return_slices): assert_allclose(area_bbox, expected_bbox["area_bbox"], rtol=1e-5) +@pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("ndim", [2, 3]) +@pytest.mark.parametrize("via_moments", [False, True]) +def test_centroid(via_moments, local, ndim): + shape = (1024, 512) if ndim == 2 else (80, 64, 48) + labels = get_labels_nd(shape) + max_label = int(cp.max(labels)) + if via_moments: + props = {} + moments = regionprops_moments( + labels, max_label=max_label, order=1, props_dict=props + ) + assert "bbox" in props + assert "moments" in props + if local: + name = "centroid_local" + if via_moments: + centroid = regionprops_centroid_weighted( + moments_raw=moments, + ndim=labels.ndim, + bbox=props["bbox"], + compute_local=True, + compute_global=False, + weighted=False, + props_dict=props, + )[name] + assert name in props + else: + centroid = regionprops_centroid_local(labels, max_label=max_label) + else: + name = "centroid" + if via_moments: + centroid = regionprops_centroid_weighted( + moments_raw=moments, + ndim=labels.ndim, + bbox=props["bbox"], + compute_local=False, + compute_global=True, + weighted=False, + props_dict=props, + )[name] + assert name in props + else: + centroid = regionprops_centroid(labels, max_label=max_label) + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), properties=[name] + ) + assert_allclose(centroid[:, 0], expected[name + "-0"]) + if ndim > 1: + assert_allclose(centroid[:, 1], expected[name + "-1"]) + if ndim > 2: + assert_allclose(centroid[:, 2], expected[name + "-2"]) + + +@pytest.mark.parametrize("spacing", [None, (0.8, 0.5), (0.2, 1.3)]) +@pytest.mark.parametrize("order", [0, 1, 2, 3]) +@pytest.mark.parametrize( + "weighted, intensity_dtype, num_channels", + [ + (False, None, 1), + (True, cp.float32, 1), + (True, cp.uint8, 3), + ], +) +@pytest.mark.parametrize("norm_type", ["raw", "central", "normalized", "hu"]) +@pytest.mark.parametrize("blob_size_fraction", [0.03, 0.1, 0.3]) +def test_moments_2d( + spacing, + order, + weighted, + intensity_dtype, + num_channels, + norm_type, + blob_size_fraction, +): + shape = (800, 600) + labels = get_labels_nd(shape, blob_size_fraction=blob_size_fraction) + max_label = int(cp.max(labels)) + kwargs = {"spacing": spacing} + prop = "moments" + if norm_type == "hu": + if order != 3: + pytest.skip("Hu moments require order = 3") + elif spacing and spacing != (1.0, 1.0): + pytest.skip("Hu moments only support spacing = (1.0, 1.0)") + if norm_type == "normalized" and order < 2: + pytest.skip("normalized case only supports order >=2") + if weighted: + intensity_image = get_intensity_image( + shape, dtype=intensity_dtype, num_channels=num_channels + ) + kwargs["intensity_image"] = intensity_image + prop += "_weighted" + if norm_type == "central": + prop += "_central" + elif norm_type == "normalized": + prop += "_normalized" + elif norm_type == "hu": + prop += "_hu" + kwargs_cpu = deepcopy(kwargs) + if "intensity_image" in kwargs_cpu: + kwargs_cpu["intensity_image"] = cp.asnumpy( + kwargs_cpu["intensity_image"] + ) + + # ignore possible warning from skimage implementation + warnings.filterwarnings( + "ignore", + message="invalid value encountered in scalar power", + category=RuntimeWarning, + ) + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), properties=[prop], **kwargs_cpu + ) + warnings.resetwarnings() + + moments = regionprops_moments( + labels, max_label=max_label, order=order, **kwargs + ) + if norm_type in ["central", "normalized", "hu"]: + ndim = len(shape) + moments = regionprops_moments_central(moments, ndim=ndim) + if norm_type in ["normalized", "hu"]: + moments = regionprops_moments_normalized( + moments, ndim=ndim, spacing=spacing + ) + if norm_type == "normalized": + # assert that np.nan values were set for non-computed orders + orders = cp.arange(order + 1)[:, cp.newaxis] + orders = orders + orders.T + mask = cp.logical_and(orders < 1, orders > order) + # prepend labels (and channels) axes + if num_channels > 1: + mask = mask[cp.newaxis, cp.newaxis, ...] + mask = cp.tile(mask, moments.shape[:2] + (1, 1)) + else: + mask = mask[cp.newaxis, ...] + mask = cp.tile(mask, moments.shape[:1] + (1, 1)) + assert cp.all(cp.isnan(moments[mask])) + + if norm_type == "hu": + moments = regionprops_moments_hu(moments) + assert moments.shape[-1] == 7 + + # regionprops does not use the more accurate analytical expressions for the + # central moments, so need to relax tolerance in the "central" moments case + rtol = 1e-4 if norm_type != "raw" else 1e-5 + atol = 1e-4 if norm_type != "raw" else 1e-7 + allclose = functools.partial(assert_allclose, rtol=rtol, atol=atol) + if norm_type == "hu": + # hu moments are stored as a 7-element vector + if num_channels == 1: + for d in range(7): + allclose(moments[:, d], expected[prop + f"-{d}"]) + else: + for c in range(num_channels): + for d in range(7): + allclose(moments[:, c, d], expected[prop + f"-{d}-{c}"]) + else: + # All other moment types produce a (order + 1, order + 1) matrix + if num_channels == 1: + # zeroth moment + allclose(moments[:, 0, 0], expected[prop + "-0-0"]) + + if order > 0 and norm_type != "normalized": + # first-order moments + if norm_type == "central": + assert_array_equal(moments[:, 0, 1], 0.0) + assert_array_equal(moments[:, 1, 0], 0.0) + else: + allclose(moments[:, 0, 1], expected[prop + "-0-1"]) + allclose(moments[:, 1, 0], expected[prop + "-1-0"]) + if order > 1: + # second-order moments + allclose(moments[:, 0, 2], expected[prop + "-0-2"]) + allclose(moments[:, 1, 1], expected[prop + "-1-1"]) + allclose(moments[:, 2, 0], expected[prop + "-2-0"]) + if order > 3: + # third-order moments + allclose(moments[:, 0, 3], expected[prop + "-0-3"]) + allclose(moments[:, 1, 2], expected[prop + "-1-2"]) + allclose(moments[:, 2, 1], expected[prop + "-2-1"]) + allclose(moments[:, 3, 0], expected[prop + "-3-0"]) + else: + for c in range(num_channels): + # zeroth moment + allclose(moments[:, c, 0, 0], expected[prop + f"-0-0-{c}"]) + + if order > 0 and norm_type != "normalized": + # first-order moments + if norm_type == "central": + assert_array_equal(moments[:, c, 0, 1], 0.0) + assert_array_equal(moments[:, c, 1, 0], 0.0) + else: + allclose(moments[:, c, 0, 1], expected[prop + f"-0-1-{c}"]) + allclose(moments[:, c, 1, 0], expected[prop + f"-1-0-{c}"]) + if order > 1: + # second-order moments + allclose(moments[:, c, 0, 2], expected[prop + f"-0-2-{c}"]) + allclose(moments[:, c, 1, 1], expected[prop + f"-1-1-{c}"]) + allclose(moments[:, c, 2, 0], expected[prop + f"-2-0-{c}"]) + if order > 3: + # third-order moments + allclose(moments[:, c, 0, 3], expected[prop + f"-0-3-{c}"]) + allclose(moments[:, c, 1, 2], expected[prop + f"-1-2-{c}"]) + allclose(moments[:, c, 2, 1], expected[prop + f"-2-1-{c}"]) + allclose(moments[:, c, 3, 0], expected[prop + f"-3-0-{c}"]) + + +@pytest.mark.parametrize("spacing", [None, (0.8, 0.5, 0.75)]) +@pytest.mark.parametrize("order", [0, 1, 2, 3]) +@pytest.mark.parametrize( + "weighted, intensity_dtype, num_channels", + [ + (False, None, 1), + (True, cp.float32, 1), + (True, cp.uint8, 3), + ], +) +@pytest.mark.parametrize("norm_type", ["raw", "central", "normalized"]) +def test_moments_3d( + spacing, order, weighted, intensity_dtype, num_channels, norm_type +): + shape = (80, 64, 48) + labels = get_labels_nd(shape) + max_label = int(cp.max(labels)) + kwargs = {"spacing": spacing} + prop = "moments" + if norm_type == "normalized" and order < 2: + pytest.skip("normalized case only supports order >=2") + if weighted: + intensity_image = get_intensity_image( + shape, dtype=intensity_dtype, num_channels=num_channels + ) + kwargs["intensity_image"] = intensity_image + prop += "_weighted" + if norm_type == "central": + prop += "_central" + elif norm_type == "normalized": + prop += "_normalized" + kwargs_cpu = deepcopy(kwargs) + if "intensity_image" in kwargs_cpu: + kwargs_cpu["intensity_image"] = cp.asnumpy( + kwargs_cpu["intensity_image"] + ) + + # ignore possible warning from skimage implementation + warnings.filterwarnings( + "ignore", + message="invalid value encountered in scalar power", + category=RuntimeWarning, + ) + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), properties=[prop], **kwargs_cpu + ) + warnings.resetwarnings() + + moments = regionprops_moments( + labels, max_label=max_label, order=order, **kwargs + ) + if norm_type in ["central", "normalized"]: + ndim = len(shape) + moments = regionprops_moments_central(moments, ndim=ndim) + if norm_type == "normalized": + moments = regionprops_moments_normalized( + moments, ndim=ndim, spacing=spacing + ) + + # assert that np.nan values were set for non-computed orders + orders = cp.arange(order + 1) + orders = ( + orders[:, cp.newaxis, cp.newaxis] + + orders[cp.newaxis, :, cp.newaxis] + + orders[cp.newaxis, cp.newaxis, :] + ) + mask = cp.logical_and(orders < 1, orders > order) + # prepend labels (and channels) axes and replicate mask to match + # the moments shape + if num_channels > 1: + mask = mask[cp.newaxis, cp.newaxis, ...] + mask = cp.tile(mask, moments.shape[:2] + (1, 1, 1)) + else: + mask = mask[cp.newaxis, ...] + mask = cp.tile(mask, moments.shape[:1] + (1, 1, 1)) + assert cp.all(cp.isnan(moments[mask])) + + # regionprops does not use the more accurate analytical expressions for the + # central moments, so need to relax tolerance in the "central" moments case + rtol = 1e-3 if norm_type != "raw" else 1e-5 + atol = 1e-4 if norm_type != "raw" else 1e-7 + allclose = functools.partial(assert_allclose, rtol=rtol, atol=atol) + if num_channels == 1: + # zeroth moment + allclose(moments[:, 0, 0, 0], expected[prop + "-0-0-0"]) + if order > 0 and norm_type != "normalized": + # first-order moments + if norm_type == "central": + assert_array_equal(moments[:, 0, 0, 1], 0.0) + assert_array_equal(moments[:, 0, 1, 0], 0.0) + assert_array_equal(moments[:, 1, 0, 0], 0.0) + else: + allclose(moments[:, 0, 0, 1], expected[prop + "-0-0-1"]) + allclose(moments[:, 0, 1, 0], expected[prop + "-0-1-0"]) + allclose(moments[:, 1, 0, 0], expected[prop + "-1-0-0"]) + if order > 1: + # second-order moments + allclose(moments[:, 0, 0, 2], expected[prop + "-0-0-2"]) + allclose(moments[:, 0, 2, 0], expected[prop + "-0-2-0"]) + allclose(moments[:, 2, 0, 0], expected[prop + "-2-0-0"]) + allclose(moments[:, 1, 1, 0], expected[prop + "-1-1-0"]) + allclose(moments[:, 1, 0, 1], expected[prop + "-1-0-1"]) + allclose(moments[:, 0, 1, 1], expected[prop + "-0-1-1"]) + if order > 2: + # third-order moments + allclose(moments[:, 0, 0, 3], expected[prop + "-0-0-3"]) + allclose(moments[:, 0, 3, 0], expected[prop + "-0-3-0"]) + allclose(moments[:, 3, 0, 0], expected[prop + "-3-0-0"]) + allclose(moments[:, 1, 2, 0], expected[prop + "-1-2-0"]) + allclose(moments[:, 2, 1, 0], expected[prop + "-2-1-0"]) + allclose(moments[:, 1, 0, 2], expected[prop + "-1-0-2"]) + allclose(moments[:, 2, 0, 1], expected[prop + "-2-0-1"]) + allclose(moments[:, 0, 1, 2], expected[prop + "-0-1-2"]) + allclose(moments[:, 0, 2, 1], expected[prop + "-0-2-1"]) + allclose(moments[:, 1, 1, 1], expected[prop + "-1-1-1"]) + else: + for c in range(num_channels): + # zeroth moment + allclose(moments[:, c, 0, 0, 0], expected[prop + f"-0-0-0-{c}"]) + if order > 0 and norm_type != "normalized": + # first-order moments + if norm_type == "central": + assert_array_equal(moments[:, c, 0, 0, 1], 0.0) + assert_array_equal(moments[:, c, 0, 1, 0], 0.0) + assert_array_equal(moments[:, c, 1, 0, 0], 0.0) + else: + allclose( + moments[:, c, 0, 0, 1], expected[prop + f"-0-0-1-{c}"] + ) + allclose( + moments[:, c, 0, 1, 0], expected[prop + f"-0-1-0-{c}"] + ) + allclose( + moments[:, c, 1, 0, 0], expected[prop + f"-1-0-0-{c}"] + ) + if order > 1: + # second-order moments + allclose(moments[:, c, 0, 0, 2], expected[prop + f"-0-0-2-{c}"]) + allclose(moments[:, c, 0, 2, 0], expected[prop + f"-0-2-0-{c}"]) + allclose(moments[:, c, 2, 0, 0], expected[prop + f"-2-0-0-{c}"]) + allclose(moments[:, c, 1, 1, 0], expected[prop + f"-1-1-0-{c}"]) + allclose(moments[:, c, 1, 0, 1], expected[prop + f"-1-0-1-{c}"]) + allclose(moments[:, c, 0, 1, 1], expected[prop + f"-0-1-1-{c}"]) + if order > 2: + # third-order moments + allclose(moments[:, c, 0, 0, 3], expected[prop + f"-0-0-3-{c}"]) + allclose(moments[:, c, 0, 3, 0], expected[prop + f"-0-3-0-{c}"]) + allclose(moments[:, c, 3, 0, 0], expected[prop + f"-3-0-0-{c}"]) + allclose(moments[:, c, 1, 2, 0], expected[prop + f"-1-2-0-{c}"]) + allclose(moments[:, c, 2, 1, 0], expected[prop + f"-2-1-0-{c}"]) + allclose(moments[:, c, 1, 0, 2], expected[prop + f"-1-0-2-{c}"]) + allclose(moments[:, c, 2, 0, 1], expected[prop + f"-2-0-1-{c}"]) + allclose(moments[:, c, 0, 1, 2], expected[prop + f"-0-1-2-{c}"]) + allclose(moments[:, c, 0, 2, 1], expected[prop + f"-0-2-1-{c}"]) + allclose(moments[:, c, 1, 1, 1], expected[prop + f"-1-1-1-{c}"]) + + +@pytest.mark.parametrize("spacing", [None, (0.8, 0.5, 1.2)]) +@pytest.mark.parametrize("order", [1, 2, 3]) +@pytest.mark.parametrize("shape", [(500, 400), (64, 96, 32)]) +@pytest.mark.parametrize("compute_orientation", [False, True]) +@pytest.mark.parametrize("compute_axis_lengths", [False, True]) +@pytest.mark.parametrize("blob_size_fraction", [0.03, 0.1, 0.3]) +def test_inertia_tensor( + shape, + spacing, + order, + compute_orientation, + compute_axis_lengths, + blob_size_fraction, +): + ndim = len(shape) + labels = get_labels_nd(shape, blob_size_fraction=blob_size_fraction) + max_label = int(cp.max(labels)) + if spacing is not None: + # omit 3rd element for 2d images + spacing = spacing[:ndim] + props = ["inertia_tensor", "inertia_tensor_eigvals"] + compute_eccentricity = True if ndim == 2 else False + if compute_eccentricity: + props += ["eccentricity"] + if compute_orientation: + props += ["orientation"] + if compute_axis_lengths: + props += ["axis_major_length", "axis_minor_length"] + moments_raw = regionprops_moments( + labels, max_label=max_label, order=order, spacing=spacing + ) + moments_central = regionprops_moments_central(moments_raw, ndim=ndim) + + itensor_kwargs = dict(ndim=ndim, compute_orientation=compute_orientation) + if order < 2: + # can't compute inertia tensor without 2nd order moments + with pytest.raises(ValueError): + regionprops_inertia_tensor(moments_central, **itensor_kwargs) + return + + if compute_orientation: + if ndim != 2: + with pytest.raises(ValueError): + regionprops_inertia_tensor(moments_central, **itensor_kwargs) + return + itensor, orientation = regionprops_inertia_tensor( + moments_central, **itensor_kwargs + ) + assert orientation.shape == itensor.shape[:-2] + else: + itensor = regionprops_inertia_tensor(moments_central, **itensor_kwargs) + + assert itensor.shape[-2:] == (ndim, ndim) + + props_dict = regionprops_inertia_tensor_eigvals( + itensor, + compute_axis_lengths=compute_axis_lengths, + compute_eccentricity=compute_eccentricity, + ) + eigvals = props_dict["inertia_tensor_eigvals"] + assert eigvals.shape == (max_label, ndim) + if compute_eccentricity: + eccentricity = props_dict["eccentricity"] + assert eccentricity.shape == (max_label,) + if compute_axis_lengths: + axis_lengths = props_dict["axis_lengths"] + assert axis_lengths.shape == (max_label, ndim) + + # Do not compare to scikit-image via measure_cpu due to unhandled + # ValueError: math domain error in scikit-image. (Floating point numeric + # error can cause square root of very slightly negative value) + expected = measure.regionprops_table( + labels, + properties=props, + spacing=spacing, + ) + + # regionprops does not use the more accurate analytical expressions for the + # central moments, so need to relax tolerance in the "central" moments case + rtol = 1e-4 + atol = 1e-5 + allclose = functools.partial(assert_allclose, rtol=rtol, atol=atol) + if ndim == 2: + # valida inertia tensor + allclose(itensor[:, 0, 0], expected["inertia_tensor-0-0"]) + allclose(itensor[:, 0, 1], expected["inertia_tensor-0-1"]) + allclose(itensor[:, 1, 0], expected["inertia_tensor-1-0"]) + allclose(itensor[:, 1, 1], expected["inertia_tensor-1-1"]) + + # validate eigenvalues + allclose(eigvals[:, 0], expected["inertia_tensor_eigvals-0"]) + allclose(eigvals[:, 1], expected["inertia_tensor_eigvals-1"]) + + if compute_orientation: + pass + # Disabled orientation comparison as it is currently not robust + # (fails for the spacing != None case) + # + # # validate orientation + # # use sin/cos to avoid PI and -PI from being considered different + # tol_kw = dict(rtol=1e-3, atol=1e-3) + # assert_allclose( + # cp.cos(orientation), cp.cos(expected["orientation"]), **tol_kw + # ) + # assert_allclose( + # cp.sin(orientation), cp.sin(expected["orientation"]), **tol_kw + # ) + if compute_eccentricity: + allclose(eccentricity, expected["eccentricity"]) + + elif ndim == 3: + # valida inertia tensor + allclose(itensor[:, 0, 0], expected["inertia_tensor-0-0"]) + allclose(itensor[:, 0, 1], expected["inertia_tensor-0-1"]) + allclose(itensor[:, 0, 2], expected["inertia_tensor-0-2"]) + allclose(itensor[:, 1, 0], expected["inertia_tensor-1-0"]) + allclose(itensor[:, 1, 1], expected["inertia_tensor-1-1"]) + allclose(itensor[:, 1, 2], expected["inertia_tensor-1-2"]) + allclose(itensor[:, 2, 0], expected["inertia_tensor-2-0"]) + allclose(itensor[:, 2, 1], expected["inertia_tensor-2-1"]) + allclose(itensor[:, 2, 2], expected["inertia_tensor-2-2"]) + + # validate eigenvalues + allclose(eigvals[:, 0], expected["inertia_tensor_eigvals-0"]) + allclose(eigvals[:, 1], expected["inertia_tensor_eigvals-1"]) + allclose(eigvals[:, 2], expected["inertia_tensor_eigvals-2"]) + rtol = 1e-5 + # seems to be a larger fractional pixel error in length in 3D case + atol = 5e-3 if ndim == 3 else 1e-5 + allclose = functools.partial(assert_allclose, rtol=rtol, atol=atol) + if compute_axis_lengths: + allclose(axis_lengths[..., 0], expected["axis_major_length"]) + allclose(axis_lengths[..., -1], expected["axis_minor_length"]) + + +@pytest.mark.parametrize("spacing", [None, (0.8, 0.5, 1.2)]) +@pytest.mark.parametrize( + "intensity_dtype, num_channels", + [(cp.float32, 1), (cp.uint8, 3)], +) +@pytest.mark.parametrize("shape", [(800, 600), (80, 60, 40)]) +@pytest.mark.parametrize("local", [False, True]) +def test_centroid_weighted( + shape, spacing, intensity_dtype, num_channels, local +): + ndim = len(shape) + labels = get_labels_nd(shape) + + max_label = int(cp.max(labels)) + if spacing is not None: + # omit 3rd element for 2d images + spacing = spacing[:ndim] + intensity_image = get_intensity_image( + shape, dtype=intensity_dtype, num_channels=num_channels + ) + kwargs = {"spacing": spacing, "intensity_image": intensity_image} + prop = "centroid_weighted" + if local: + prop += "_local" + + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), + properties=[prop], + spacing=spacing, + intensity_image=cp.asnumpy(intensity_image), + ) + moments_raw = regionprops_moments( + labels, max_label=max_label, order=1, **kwargs + ) + + if local: + bbox = None + else: + bbox, _ = regionprops_bbox_coords( + labels, max_label=max_label, return_slices=False + ) + + centroids = regionprops_centroid_weighted( + moments_raw, + ndim=ndim, + bbox=bbox, + compute_local=local, + compute_global=not local, + spacing=spacing, + )[prop] + + assert centroids.shape[-1] == ndim + + rtol = 1e-7 + atol = 0 + allclose = functools.partial(assert_allclose, rtol=rtol, atol=atol) + if num_channels == 1: + for d in range(ndim): + allclose(centroids[:, d], expected[prop + f"-{d}"]) + else: + for c in range(num_channels): + for d in range(ndim): + allclose(centroids[:, c, d], expected[prop + f"-{d}-{c}"]) + + @pytest.mark.parametrize("ndim", [2, 3]) @pytest.mark.parametrize("num_channels", [1, 3]) @pytest.mark.parametrize( @@ -579,12 +1159,16 @@ def test_feret_diameter_max(ndim, spacing, blob_kwargs): list(basic_deps.keys()) + list(convex_deps.keys()) + list(intensity_deps.keys()) + + list(moment_deps.keys()) ), ) def test_regionprops_dict_single_property(ndim, spacing, property_name): """Test to verify that any dependencies for a given property are automatically handled. """ + if ndim != 2 and property_name in ndim_2_only: + pytest.skip(f"{property_name} is for 2d images only.") + return shape = (768, 512) if ndim == 2 else (64, 64, 64) if spacing is not None: spacing = spacing[:ndim] From df640e3a205aa0cfa7fa8c8982d3565633ce6e52 Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Sat, 1 Mar 2025 15:37:31 -0500 Subject: [PATCH 11/14] reduce code duplication by reusing regionprops_moments_central from _moments_analytical.py --- .../skimage/measure/_moments_analytical.py | 194 +----------------- 1 file changed, 9 insertions(+), 185 deletions(-) diff --git a/python/cucim/src/cucim/skimage/measure/_moments_analytical.py b/python/cucim/src/cucim/skimage/measure/_moments_analytical.py index 2d110056..0c91f501 100644 --- a/python/cucim/src/cucim/skimage/measure/_moments_analytical.py +++ b/python/cucim/src/cucim/skimage/measure/_moments_analytical.py @@ -4,190 +4,9 @@ import cupy as cp import numpy as np -_order0_or_1 = """ - mc[0] = m[0]; -""" - -_order2_2d = """ - /* Implementation of the commented code below with C-order raveled - * indices into 3 x 3 matrices, m and mc. - * - * mc[0, 0] = m[0, 0]; - * cx = m[1, 0] / m[0, 0]; - * cy = m[0, 1] / m[0, 0]; - * mc[1, 1] = m[1, 1] - cx*m[0, 1]; - * mc[2, 0] = m[2, 0] - cx*m[1, 0]; - * mc[0, 2] = m[0, 2] - cy*m[0, 1]; - */ - mc[0] = m[0]; - F cx = m[3] / m[0]; - F cy = m[1] / m[0]; - mc[4] = m[4] - cx*m[1]; - mc[6] = m[6] - cx*m[3]; - mc[2] = m[2] - cy*m[1]; -""" - -_order3_2d = """ - /* Implementation of the commented code below with C-order raveled - * indices into 4 x 4 matrices, m and mc. - * - * mc[0, 0] = m[0, 0]; - * cx = m[1, 0] / m[0, 0]; - * cy = m[0, 1] / m[0, 0]; - * mc[1, 1] = m[1, 1] - cx*m[0, 1]; - * mc[2, 0] = m[2, 0] - cx*m[1, 0]; - * mc[0, 2] = m[0, 2] - cy*m[0, 1]; - * mc[2, 1] = (m[2, 1] - 2*cx*m[1, 1] - cy*m[2, 0] + cx*cx*m[0, 1] + cy*cx*m[1, 0]); - * mc[1, 2] = (m[1, 2] - 2*cy*m[1, 1] - cx*m[0, 2] + 2*cy*cx*m[0, 1]); - * mc[3, 0] = m[3, 0] - 3*cx*m[2, 0] + 2*cx*cx*m[1, 0]; - * mc[0, 3] = m[0, 3] - 3*cy*m[0, 2] + 2*cy*cx*m[0, 1]; - */ - - mc[0] = m[0]; - F cx = m[4] / m[0]; - F cy = m[1] / m[0]; - // 2nd order moments - mc[5] = m[5] - cx*m[1]; - mc[8] = m[8] - cx*m[4]; - mc[2] = m[2] - cy*m[1]; - // 3rd order moments - mc[9] = (m[9] - 2*cx*m[5] - cy*m[8] + cx*cx*m[1] + cy*cx*m[4]); - mc[6] = (m[6] - 2*cy*m[5] - cx*m[2] + 2*cy*cx*m[1]); - mc[12] = m[12] - 3*cx*m[8] + 2*cx*cx*m[4]; - mc[3] = m[3] - 3*cy*m[2] + 2*cy*cy*m[1]; -""" # noqa - - -# Note for 2D kernels using C-order raveled indices -_order2_3d = """ - /* Implementation of the commented code below with C-order raveled - * indices into shape (3, 3, 3) matrices, m and mc. - * - * mc[0, 0, 0] = m[0, 0, 0]; - * cx = m[1, 0, 0] / m[0, 0, 0]; - * cy = m[0, 1, 0] / m[0, 0, 0]; - * cz = m[0, 0, 1] / m[0, 0, 0]; - * mc[0, 0, 2] = -cz*m[0, 0, 1] + m[0, 0, 2]; - * mc[0, 1, 1] = -cy*m[0, 0, 1] + m[0, 1, 1]; - * mc[0, 2, 0] = -cy*m[0, 1, 0] + m[0, 2, 0]; - * mc[1, 0, 1] = -cx*m[0, 0, 1] + m[1, 0, 1]; - * mc[1, 1, 0] = -cx*m[0, 1, 0] + m[1, 1, 0]; - * mc[2, 0, 0] = -cx*m[1, 0, 0] + m[2, 0, 0]; - */ - mc[0] = m[0]; - F cx = m[9] / m[0]; - F cy = m[3] / m[0]; - F cz = m[1] / m[0]; - // 2nd order moments - mc[2] = -cz*m[1] + m[2]; - mc[4] = -cy*m[1] + m[4]; - mc[6] = -cy*m[3] + m[6]; - mc[10] = -cx*m[1] + m[10]; - mc[12] = -cx*m[3] + m[12]; - mc[18] = -cx*m[9] + m[18]; -""" - -_order3_3d = """ - /* Implementation of the commented code below with C-order raveled - * indices into shape (4, 4, 4) matrices, m and mc. - * - * mc[0, 0, 0] = m[0, 0, 0]; - * cx = m[1, 0, 0] / m[0, 0, 0]; - * cy = m[0, 1, 0] / m[0, 0, 0]; - * cz = m[0, 0, 1] / m[0, 0, 0]; - * // 2nd order moments - * mc[0, 0, 2] = -cz*m[0, 0, 1] + m[0, 0, 2]; - * mc[0, 1, 1] = -cy*m[0, 0, 1] + m[0, 1, 1]; - * mc[0, 2, 0] = -cy*m[0, 1, 0] + m[0, 2, 0]; - * mc[1, 0, 1] = -cx*m[0, 0, 1] + m[1, 0, 1]; - * mc[1, 1, 0] = -cx*m[0, 1, 0] + m[1, 1, 0]; - * mc[2, 0, 0] = -cx*m[1, 0, 0] + m[2, 0, 0]; - * // 3rd order moments - * mc[0, 0, 3] = (2*cz*cz*m[0, 0, 1] - 3*cz*m[0, 0, 2] + m[0, 0, 3]); - * mc[0, 1, 2] = (-cy*m[0, 0, 2] + 2*cz*(cy*m[0, 0, 1] - m[0, 1, 1]) + m[0, 1, 2]); - * mc[0, 2, 1] = (cy*cy*m[0, 0, 1] - 2*cy*m[0, 1, 1] + cz*(cy*m[0, 1, 0] - m[0, 2, 0]) + m[0, 2, 1]); - * mc[0, 3, 0] = (2*cy*cy*m[0, 1, 0] - 3*cy*m[0, 2, 0] + m[0, 3, 0]); - * mc[1, 0, 2] = (-cx*m[0, 0, 2] + 2*cz*(cx*m[0, 0, 1] - m[1, 0, 1]) + m[1, 0, 2]); - * mc[1, 1, 1] = (-cx*m[0, 1, 1] + cy*(cx*m[0, 0, 1] - m[1, 0, 1]) + cz*(cx*m[0, 1, 0] - m[1, 1, 0]) + m[1, 1, 1]); - * mc[1, 2, 0] = (-cx*m[0, 2, 0] - 2*cy*(-cx*m[0, 1, 0] + m[1, 1, 0]) + m[1, 2, 0]); - * mc[2, 0, 1] = (cx*cx*m[0, 0, 1] - 2*cx*m[1, 0, 1] + cz*(cx*m[1, 0, 0] - m[2, 0, 0]) + m[2, 0, 1]); - * mc[2, 1, 0] = (cx*cx*m[0, 1, 0] - 2*cx*m[1, 1, 0] + cy*(cx*m[1, 0, 0] - m[2, 0, 0]) + m[2, 1, 0]); - * mc[3, 0, 0] = (2*cx*cx*m[1, 0, 0] - 3*cx*m[2, 0, 0] + m[3, 0, 0]); - */ - mc[0] = m[0]; - F cx = m[16] / m[0]; - F cy = m[4] / m[0]; - F cz = m[1] / m[0]; - // 2nd order moments - mc[2] = -cz*m[1] + m[2]; - mc[5] = -cy*m[1] + m[5]; - mc[8] = -cy*m[4] + m[8]; - mc[17] = -cx*m[1] + m[17]; - mc[20] = -cx*m[4] + m[20]; - mc[32] = -cx*m[16] + m[32]; - // 3rd order moments - mc[3] = (2*cz*cz*m[1] - 3*cz*m[2] + m[3]); - mc[6] = (-cy*m[2] + 2*cz*(cy*m[1] - m[5]) + m[6]); - mc[9] = (cy*cy*m[1] - 2*cy*m[5] + cz*(cy*m[4] - m[8]) + m[9]); - mc[12] = (2*cy*cy*m[4] - 3*cy*m[8] + m[12]); - mc[18] = (-cx*m[2] + 2*cz*(cx*m[1] - m[17]) + m[18]); - mc[21] = (-cx*m[5] + cy*(cx*m[1] - m[17]) + cz*(cx*m[4] - m[20]) + m[21]); - mc[24] = (-cx*m[8] - 2*cy*(-cx*m[4] + m[20]) + m[24]); - mc[33] = (cx*cx*m[1] - 2*cx*m[17] + cz*(cx*m[16] - m[32]) + m[33]); - mc[36] = (cx*cx*m[4] - 2*cx*m[20] + cy*(cx*m[16] - m[32]) + m[36]); - mc[48] = (2*cx*cx*m[16] - 3*cx*m[32] + m[48]); -""" # noqa - - -def _moments_raw_to_central_fast(moments_raw): - """Analytical formulae for 2D and 3D central moments of order < 4. - - `moments_raw_to_central` will automatically call this function when - ndim < 4 and order < 4. - - Parameters - ---------- - moments_raw : ndarray - The raw moments. - - Returns - ------- - moments_central : ndarray - The central moments. - """ - ndim = moments_raw.ndim - order = moments_raw.shape[0] - 1 - # convert to float64 during the computation for better accuracy - moments_raw = moments_raw.astype(cp.float64, copy=False) - moments_central = cp.zeros_like(moments_raw) - if order >= 4 or ndim not in [2, 3]: - raise ValueError( - "This function only supports 2D or 3D moments of order < 4." - ) - if ndim == 2: - if order < 2: - operation = _order0_or_1 - elif order == 2: - operation = _order2_2d - elif order == 3: - operation = _order3_2d - elif ndim == 3: - if order < 2: - operation = _order0_or_1 - elif order == 2: - operation = _order2_3d - elif order == 3: - operation = _order3_3d - - kernel = cp.ElementwiseKernel( - "raw F m", - "raw F mc", - operation=operation, - name=f"order{order}_{ndim}d_kernel", - ) - # run a single-threaded kernel, so we can avoid device->host->device copy - kernel(moments_raw, moments_central, size=1) - return moments_central +from ._regionprops_gpu_moments_kernels import ( + regionprops_moments_central, +) def moments_raw_to_central(moments_raw): @@ -196,7 +15,12 @@ def moments_raw_to_central(moments_raw): if ndim in [2, 3] and order < 4: # fast path with analytical GPU kernels # (avoids any host/device transfers) - moments_central = _moments_raw_to_central_fast(moments_raw) + + # have to temporarily prepend a "labels" dimension to reuse + # regionprops_moments_central + moments_central = regionprops_moments_central( + moments_raw[cp.newaxis, ...], ndim + )[0] return moments_central.astype(moments_raw.dtype, copy=False) # Fallback to general formula applied on the host From 597784ad0402306e5ac201d7caaa5f83897b0289 Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Mon, 3 Mar 2025 20:39:15 -0500 Subject: [PATCH 12/14] restore _inertia_eigvals_to_axes_lengths_3D for use in test case --- .../src/cucim/skimage/measure/_regionprops.py | 4 +- .../skimage/measure/tests/test_regionprops.py | 46 ++++++++++++++++++- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops.py b/python/cucim/src/cucim/skimage/measure/_regionprops.py index 4248ed3b..cd9b649d 100644 --- a/python/cucim/src/cucim/skimage/measure/_regionprops.py +++ b/python/cucim/src/cucim/skimage/measure/_regionprops.py @@ -566,7 +566,6 @@ def axis_major_length(self): l1 = self.inertia_tensor_eigvals[0] return 4 * math.sqrt(l1) elif self._ndim == 3: - # equivalent to _inertia_eigvals_to_axes_lengths_3D(ev)[0] ev = self.inertia_tensor_eigvals return math.sqrt(10 * (ev[0] + ev[1] - ev[2])) else: @@ -578,11 +577,10 @@ def axis_minor_length(self): l2 = self.inertia_tensor_eigvals[-1] return 4 * math.sqrt(l2) elif self._ndim == 3: - # equivalent to _inertia_eigvals_to_axes_lengths_3D(ev)[-1] ev = self.inertia_tensor_eigvals # use max to avoid possibly very small negative value due to # numeric error - return math.sqrt(max(10 * (-ev[0] + ev[1] + ev[2]), 0.0)) + return math.sqrt(10 * max(-ev[0] + ev[1] + ev[2], 0.0)) else: raise ValueError("axis_minor_length only available in 2D and 3D") diff --git a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops.py b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops.py index 4db42b9b..21c89961 100644 --- a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops.py +++ b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops.py @@ -24,7 +24,6 @@ COL_DTYPES, OBJECT_COLUMNS, PROPS, - _inertia_eigvals_to_axes_lengths_3D, _parse_docs, _props_to_dict, _require_intensity_image, @@ -1402,6 +1401,51 @@ def test_multichannel(): assert_array_equal(p, p_multi[..., 1]) +def _inertia_eigvals_to_axes_lengths_3D(inertia_tensor_eigvals): + """Compute ellipsoid axis lengths from inertia tensor eigenvalues. + + Parameters + --------- + inertia_tensor_eigvals : sequence of float + A sequence of 3 floating point eigenvalues, sorted in descending order. + + Returns + ------- + axis_lengths : list of float + The ellipsoid axis lengths sorted in descending order. + + Notes + ----- + Let a >= b >= c be the ellipsoid semi-axes and s1 >= s2 >= s3 be the + inertia tensor eigenvalues. + + The inertia tensor eigenvalues are given for a solid ellipsoid in [1]_. + s1 = 1 / 5 * (a**2 + b**2) + s2 = 1 / 5 * (a**2 + c**2) + s3 = 1 / 5 * (b**2 + c**2) + + Rearranging to solve for a, b, c in terms of s1, s2, s3 gives + a = math.sqrt(5 / 2 * ( s1 + s2 - s3)) + b = math.sqrt(5 / 2 * ( s1 - s2 + s3)) + c = math.sqrt(5 / 2 * (-s1 + s2 + s3)) + + We can then simply replace sqrt(5/2) by sqrt(10) to get the full axes + lengths rather than the semi-axes lengths. + + References + ---------- + ..[1] https://en.wikipedia.org/wiki/List_of_moments_of_inertia#List_of_3D_inertia_tensors + """ # noqa: E501 + axis_lengths = [] + for ax in range(2, -1, -1): + w = sum( + v * -1 if i == ax else v + for i, v in enumerate(inertia_tensor_eigvals) + ) + axis_lengths.append(math.sqrt(10 * w)) + return axis_lengths + + def test_3d_ellipsoid_axis_lengths(): """Verify that estimated axis lengths are correct. From d21edcb23d24504424e064752467fbb60d5f08d6 Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Sat, 1 Mar 2025 15:17:30 -0500 Subject: [PATCH 13/14] add implementations for perimeter, perimeter_crofton and euler_number --- .../cucim/skimage/measure/_regionprops_gpu.py | 56 +- .../measure/_regionprops_gpu_misc_kernels.py | 721 ++++++++++++++++++ .../skimage/measure/_regionprops_gpu_utils.py | 25 + .../tests/test_regionprops_gpu_kernels.py | 126 +++ 4 files changed, 927 insertions(+), 1 deletion(-) create mode 100644 python/cucim/src/cucim/skimage/measure/_regionprops_gpu_misc_kernels.py diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py index 5d6d81b6..e8555610 100644 --- a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py @@ -34,6 +34,12 @@ regionprops_intensity_min_max, regionprops_intensity_std, ) +from ._regionprops_gpu_misc_kernels import ( + misc_deps, + regionprops_euler, + regionprops_perimeter, + regionprops_perimeter_crofton, +) from ._regionprops_gpu_moments_kernels import ( moment_deps, regionprops_centroid, @@ -47,7 +53,7 @@ regionprops_moments_normalized, required_order, ) -from ._regionprops_gpu_utils import _get_min_integer_dtype +from ._regionprops_gpu_utils import _find_close_labels, _get_min_integer_dtype __all__ = [ "equivalent_diameter_area", @@ -60,6 +66,7 @@ "regionprops_centroid_weighted", "regionprops_coords", "regionprops_dict", + "regionprops_euler", "regionprops_extent", "regionprops_feret_diameter_max", "regionprops_image", @@ -73,6 +80,8 @@ "regionprops_moments_hu", "regionprops_moments_normalized", "regionprops_num_pixels", + "regionprops_perimeter", + "regionprops_perimeter_crofton", # extra functions for cuCIM not currently in scikit-image "equivalent_spherical_perimeter", # as in ITK "regionprops_num_boundary_pixels", @@ -137,6 +146,7 @@ property_deps.update(basic_deps) property_deps.update(convex_deps) property_deps.update(intensity_deps) +property_deps.update(misc_deps) property_deps.update(moment_deps) @@ -563,6 +573,50 @@ def regionprops_dict( props_dict=out, ) + compute_perimeter = "perimeter" in required_props + compute_perimeter_crofton = "perimeter_crofton" in required_props + compute_euler = "euler_number" in required_props + + if compute_euler or compute_perimeter or compute_perimeter_crofton: + # precompute list of labels with <2 pixels space between them + if label_image.dtype == cp.uint8: + labels_mask = label_image.view("bool") + else: + labels_mask = label_image > 0 + labels_close = _find_close_labels( + label_image, binary_image=labels_mask, max_label=max_label + ) + + if compute_perimeter: + regionprops_perimeter( + label_image, + neighborhood=4, + max_label=max_label, + robust=True, + labels_close=labels_close, + props_dict=out, + ) + if compute_perimeter_crofton: + regionprops_perimeter_crofton( + label_image, + directions=4, + max_label=max_label, + robust=True, + omit_image_edges=False, + labels_close=labels_close, + props_dict=out, + ) + + if compute_euler: + regionprops_euler( + label_image, + connectivity=None, + max_label=max_label, + robust=True, + labels_close=labels_close, + props_dict=out, + ) + compute_images = "image" in required_props compute_intensity_images = "image_intensity" in required_props compute_convex = "image_convex" in required_props diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_misc_kernels.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_misc_kernels.py new file mode 100644 index 00000000..1b5edd8c --- /dev/null +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_misc_kernels.py @@ -0,0 +1,721 @@ +import math + +import cupy as cp +import numpy as np + +from cucim.skimage._vendored import ndimage as ndi, pad + +from ._regionprops_gpu_basic_kernels import regionprops_bbox_coords +from ._regionprops_gpu_intensity_kernels import ( + _get_intensity_img_kernel_dtypes, + get_intensity_measure_kernel, +) +from ._regionprops_gpu_utils import _find_close_labels, _get_min_integer_dtype + +__all__ = [ + "regionprops_euler", + "regionprops_perimeter", + "regionprops_perimeter_crofton", +] + +misc_deps = dict() +misc_deps["perimeter"] = ["slice"] +misc_deps["perimeter_crofton"] = ["slice"] +misc_deps["euler_number"] = ["slice"] + + +def _weighted_sum_of_filtered_image( + label_image, max_label, image_filtered, coefs, pixels_per_thread=16 +): + """Compute weighted sums of pixels for each label. + + 1. Apply the coefs LUT to the filtered image to get a coefficient image + 2. Sum the values in the coefficient image for each labeled region + + This function is used during computation of the Euler characteristic and + perimeter properties. + + Parameters + ---------- + label_image : cupy.ndarray + Label image. + max_label : int + Maximum label value. + image_filtered : (M, N) ndarray + Filtered image (must have integer values that can be used to index into + the coefs LUT) + coefs : cupy.ndarray + Coefficients look-up table (LUT). + pixels_per_thread : int, optional + Number of pixels per thread. + + Returns + ------- + output : cupy.ndarray + Weighted sum of pixels for each label. + """ + coefs_image = coefs[image_filtered] + + # generate kernel for per-label weighted sum + coefs_sum_kernel = get_intensity_measure_kernel( + coefs_image.dtype, + num_channels=1, + compute_num_pixels=False, + compute_sum=True, + compute_sum_sq=False, + compute_min=False, + compute_max=False, + pixels_per_thread=pixels_per_thread, + ) + + # prepare output array + sum_dtype, _, _, _ = _get_intensity_img_kernel_dtypes(coefs_image.dtype) + output = cp.zeros((max_label,), dtype=sum_dtype) + + coefs_sum_kernel( + label_image, + label_image.size, + coefs_image, + output, + size=math.ceil(label_image.size / pixels_per_thread), + ) + return output + + +@cp.memoize(for_each_device=True) +def _get_perimeter_weights_and_coefs(coefs_dtype=cp.float32): + # convolution weights + weights = cp.array( + [[10, 2, 10], [2, 1, 2], [10, 2, 10]], + ) + + # LUT for weighted sum + coefs = np.zeros(50, dtype=coefs_dtype) + coefs[[5, 7, 15, 17, 25, 27]] = 1 + coefs[[21, 33]] = math.sqrt(2) + coefs[[13, 23]] = (1 + math.sqrt(2)) / 2 + coefs = cp.asarray(coefs) + return weights, coefs + + +def regionprops_perimeter( + labels, + neighborhood=4, + *, + max_label=None, + robust=True, + labels_close=None, + props_dict=None, + pixels_per_thread=10, +): + """Calculate total perimeter of all objects in binary image. + + when `robust` is ``True``, reuses "slice" from `props_dict` + + writes "perimeter" to `props_dict` + + Parameters + ---------- + labels : (M, N) ndarray + Binary input image. + neighborhood : 4 or 8, optional + Neighborhood connectivity for border pixel determination. It is used to + compute the contour. A higher neighborhood widens the border on which + the perimeter is computed. + + Extra Parameters + ---------------- + max_label : int or None, optional + The maximum label in labels can be provided to avoid recomputing it if + it was already known. + robust : bool, optional + If True, extra computation will be done to detect if any labeled + regions are <=2 pixel spacing from another. Any regions that meet that + criteria will have their perimeter recomputed in isolation to avoid + possible error that would otherwise occur in this case. Turning this + on will make the run time substantially longer, so it should only be + used when labeled regions may have a non-negligible portion of their + boundary within a <2 pixel gap from another label. + labels_close : numpy.ndarray or sequence of int + List of labeled regions that are less than 2 pixel gap from another + label. Used when robust=True. If not provided and robust=True, it + will be computed internally. + props_dict : dict or None, optional + Dictionary of pre-computed properties (e.g. "slice"). The output of this + function will be stored under key "perimeter" within this dictionary. + pixels_per_thread : int, optional + Number of pixels processed per thread on the GPU during the final + weighted summation. + + Returns + ------- + perimeter : float + Total perimeter of all objects in binary image. + + Notes + ----- + The `perimeter` method does not consider the boundary along the image edge + as image as part of the perimeter, while the `perimeter_crofton` method + does. In any case, an object touching the image edge likely extends outside + of the field of view, so an accurate perimeter cannot be measured for such + objects. + + If the labeled regions have holes, the hole edges will be included in this + measurement. If this is not desired, use regionprops_label_filled to fill + the holes and then pass the filled labels image to this function. + + References + ---------- + .. [1] K. Benkrid, D. Crookes. Design and FPGA Implementation of + a Perimeter Estimator. The Queen's University of Belfast. + http://www.cs.qub.ac.uk/~d.crookes/webpubs/papers/perimeter.doc + + See Also + -------- + perimeter_crofton + + Examples + -------- + >>> import cupy as cp + >>> from skimage import data + >>> from cucim.skimage import util + >>> from cucim.skimage.measure import label + >>> # coins image (binary) + >>> img_coins = cp.array(data.coins() > 110) + >>> # total perimeter of all objects in the image + >>> perimeter(img_coins, neighborhood=4) # doctest: +ELLIPSIS + array(7796.86799644) + >>> perimeter(img_coins, neighborhood=8) # doctest: +ELLIPSIS + array(8806.26807333) + """ + if max_label is None: + max_label = int(labels.max()) + + binary_image = labels > 0 + if robust and labels_close is None: + labels_close = _find_close_labels(labels, binary_image, max_label) + if neighborhood == 4: + footprint = cp.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=cp.uint8) + else: + footprint = 3 + + eroded_image = ndi.binary_erosion(binary_image, footprint, border_value=0) + border_image = binary_image.view(cp.uint8) - eroded_image + + perimeter_weights, perimeter_coefs = _get_perimeter_weights_and_coefs( + cp.float32 + ) + + perimeter_image = ndi.convolve( + border_image, + perimeter_weights, + mode="constant", + cval=0, + output=cp.uint8, + ) + + min_integer_type = _get_min_integer_dtype(max_label, signed=False) + # if labels.dtype != min_integer_type: + # labels = labels_dilated.astype(min_integer_type) + + # dilate labels by 1 pixel so we can sum with values in XF to give + # unique histogram bins for each labeled regions (as long as no labeled + # regions are within < 2 pixels from another labeled region) + labels_dilated = ndi.grey_dilation( + labels, 3, mode="constant", output=min_integer_type + ) + + if robust and labels_close.size > 0: + if props_dict is not None and "slice" in props_dict: + slices = props_dict["slice"] + else: + _, slices = regionprops_bbox_coords(labels, return_slices=True) + + # sum the coefficients for each label to compute the perimeter + perimeters = _weighted_sum_of_filtered_image( + label_image=labels_dilated, + max_label=max_label, + image_filtered=perimeter_image, + coefs=perimeter_coefs, + pixels_per_thread=pixels_per_thread, + ) + if robust: + # recompute perimeter in isolation for each region that may be too + # close to another one + shape = binary_image.shape + for lab in labels_close: + sl = slices[lab - 1] + + # keep boundary of 1 so object is not at 'edge' of cropped + # region (unless it is at a true image edge) + ld = labels[ + max(sl[0].start - 1, 0) : min(sl[0].stop + 1, shape[0]), + max(sl[1].start - 1, 0) : min(sl[1].stop + 1, shape[1]), + ] + + # print(f"{lab=}, {sl=}") + # import matplotlib.pyplot as plt + # plt.figure(); plt.imshow(ld.get()); plt.show() + + p = regionprops_perimeter( + ld == lab, max_label=1, neighborhood=neighborhood, robust=False + ) + perimeters[lab - 1] = p[0] + if props_dict is not None: + props_dict["perimeter"] = perimeters + return perimeters + + +@cp.memoize(for_each_device=True) +def _get_perimeter_crofton_weights_and_coefs( + directions, coefs_dtype=cp.float32 +): + # determine convolution weights + filter_weights = cp.array( + [[0, 0, 0], [0, 1, 4], [0, 2, 8]], dtype=cp.float32 + ) + + if directions == 2: + coefs = [ + 0, + np.pi / 2, + 0, + 0, + 0, + np.pi / 2, + 0, + 0, + np.pi / 2, + np.pi, + 0, + 0, + np.pi / 2, + np.pi, + 0, + 0, + 0, + ] + else: + sq2 = math.sqrt(2) + coefs = [ + 0, + np.pi / 4 * (1 + 1 / sq2), + np.pi / (4 * sq2), + np.pi / (2 * sq2), + 0, + np.pi / 4 * (1 + 1 / sq2), + 0, + np.pi / (4 * sq2), + np.pi / 4, + np.pi / 2, + np.pi / (4 * sq2), + np.pi / (4 * sq2), + np.pi / 4, + np.pi / 2, + 0, + 0, + 0, + 0, + 0, + ] + coefs = cp.asarray(coefs, dtype=coefs_dtype) + return filter_weights, coefs + + +def regionprops_perimeter_crofton( + labels, + directions=4, + *, + max_label=None, + robust=True, + omit_image_edges=False, + labels_close=None, + props_dict=None, + pixels_per_thread=10, +): + """Calculate total Crofton perimeter of all objects in binary image. + + when `robust` is ``True``, reuses "slice" from `props_dict` + + writes "perimeter_crofton" to `props_dict` + + Parameters + ---------- + labels : (M, N) ndarray + Input image. If image is not binary, all values greater than zero + are considered as the object. + directions : 2 or 4, optional + Number of directions used to approximate the Crofton perimeter. By + default, 4 is used: it should be more accurate than 2. + Computation time is the same in both cases. + + Extra Parameters + ---------------- + max_label : int or None, optional + The maximum label in labels can be provided to avoid recomputing it if + it was already known. + robust : bool, optional + If True, extra computation will be done to detect if any labeled + regions are <=2 pixel spacing from another. Any regions that meet that + criteria will have their perimeter recomputed in isolation to avoid + possible error that would otherwise occur in this case. Turning this + on will make the run time substantially longer, so it should only be + used when labeled regions may have a non-negligible portion of their + boundary within a <2 pixel gap from another label. + omit_image_edges : bool, optional + This can be set to avoid an additional padding step that includes the + edges of objects that correspond to the image edge as part of the + perimeter. We cannot accurately estimate the perimeter of objects + falling partly outside of `image`, so it seems acceptable to just set + this to True. The default remains False for consistency with upstream + scikit-image. + labels_close : numpy.ndarray or sequence of int + List of labeled regions that are less than 2 pixel gap from another + label. Used when robust=True. If not provided and robust=True, it + will be computed internally. + props_dict : dict or None, optional + Dictionary of pre-computed properties (e.g. "slice"). The output of this + function will be stored under key "perimeter_crofton" within this + dictionary. + pixels_per_thread : int, optional + Number of pixels processed per thread on the GPU during the final + weighted summation. + + Returns + ------- + perimeter : float + Total perimeter of all objects in binary image. + + Notes + ----- + This measure is based on Crofton formula [1], which is a measure from + integral geometry. It is defined for general curve length evaluation via + a double integral along all directions. In a discrete + space, 2 or 4 directions give a quite good approximation, 4 being more + accurate than 2 for more complex shapes. + + Similar to :func:`~.measure.perimeter`, this function returns an + approximation of the perimeter in continuous space. + + The `perimeter` method does not consider the boundary along the image edge + as image as part of the perimeter, while the `perimeter_crofton` method + does. In any case, an object touching the image edge likely extends outside + of the field of view, so an accurate perimeter cannot be measured for such + objects. + + If the labeled regions have holes, the hole edges will be included in this + measurement. If this is not desired, use regionprops_label_filled to fill + the holes and then pass the filled labels image to this function. + + See Also + -------- + perimeter + + References + ---------- + .. [1] https://en.wikipedia.org/wiki/Crofton_formula + .. [2] S. Rivollier. Analyse d’image geometrique et morphometrique par + diagrammes de forme et voisinages adaptatifs generaux. PhD thesis, + 2010. + Ecole Nationale Superieure des Mines de Saint-Etienne. + https://tel.archives-ouvertes.fr/tel-00560838 + """ + if max_label is None: + max_label = int(labels.max()) + + ndim = labels.ndim + if ndim not in [2, 3]: + raise ValueError("labels must be 2D or 3D") + + binary_image = labels > 0 + if robust and labels_close is None: + labels_close = _find_close_labels(labels, binary_image, max_label) + + footprint = 3 # scalar 3 -> (3, ) * ndim array of ones + + if not omit_image_edges: + # Dilate labels by 1 pixel so we can sum with values in image_filtered + # to give unique histogram bins for each labeled regions (As long as no + # labeled regions are within < 2 pixels from another labeled region) + labels_pad = cp.pad(labels, pad_width=1, mode="constant") + labels_dilated = ndi.grey_dilation(labels_pad, 3, mode="constant") + binary_image = pad(binary_image, pad_width=1, mode="constant") + # need dilated mask for later use for indexing into + # `image_filtered_labeled` for bincount + binary_image_mask = ndi.binary_dilation(binary_image, footprint) + binary_image_mask = cp.logical_xor( + binary_image_mask, ndi.binary_erosion(binary_image, footprint) + ) + else: + labels_dilated = ndi.grey_dilation(labels, footprint, mode="constant") + binary_image_mask = binary_image + + # determine convolution weights and LUT for weighted sum + filter_weights, coefs = _get_perimeter_crofton_weights_and_coefs( + directions, cp.float32 + ) + + image_filtered = ndi.convolve( + binary_image.view(cp.uint8), + filter_weights, + mode="constant", + cval=0, + output=cp.uint8, + ) + + if robust and labels_close.size > 0: + if props_dict is not None and "slice" in props_dict: + slices = props_dict["slice"] + else: + _, slices = regionprops_bbox_coords(labels, return_slices=True) + + # sum the coefficients for each label to compute the perimeter + perimeters = _weighted_sum_of_filtered_image( + label_image=labels_dilated, + max_label=max_label, + image_filtered=image_filtered, + coefs=coefs, + pixels_per_thread=pixels_per_thread, + ) + if robust: + # recompute perimeter in isolation for each region that may be too + # close to another one + shape = labels_dilated.shape + for lab in labels_close: + sl = slices[lab - 1] + ld = labels[ + max(sl[0].start, 0) : min(sl[0].stop, shape[0]), + max(sl[1].start, 0) : min(sl[1].stop, shape[1]), + ] + p = regionprops_perimeter_crofton( + ld == lab, + max_label=1, + directions=directions, + omit_image_edges=False, + robust=False, + ) + perimeters[lab - 1] = p[0] + if props_dict is not None: + props_dict["perimeter_crofton"] = perimeters + return perimeters + + +@cp.memoize(for_each_device=True) +def _get_euler_weights_and_coefs(ndim, connectivity, coefs_dtype=cp.float32): + from cucim.skimage.measure._regionprops_utils import ( + EULER_COEFS2D_4, + EULER_COEFS2D_8, + EULER_COEFS3D_26, + ) + + if ndim not in [2, 3]: + raise ValueError("only 2D and 3D images are supported") + + if ndim == 2: + filter_weights = cp.array([[0, 0, 0], [0, 1, 4], [0, 2, 8]]) + if connectivity == 1: + coefs = EULER_COEFS2D_4 + else: + coefs = EULER_COEFS2D_8 + coefs = cp.asarray(coefs, dtype=coefs_dtype) + else: # 3D images + if connectivity == 2: + raise NotImplementedError( + "For 3D images, Euler number is implemented " + "for connectivities 1 and 3 only" + ) + + filter_weights = cp.array( + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 1, 4], [0, 2, 8]], + [[0, 0, 0], [0, 16, 64], [0, 32, 128]], + ] + ) + if connectivity == 1: + coefs = EULER_COEFS3D_26[::-1] + else: + coefs = EULER_COEFS3D_26 + coefs = cp.asarray(0.125 * coefs, dtype=coefs_dtype) + + return filter_weights, coefs + + +def regionprops_euler( + labels, + connectivity=None, + *, + max_label=None, + robust=True, + labels_close=None, + props_dict=None, + pixels_per_thread=10, +): + """Calculate the Euler characteristic in binary image. + + For 2D objects, the Euler number is the number of objects minus the number + of holes. For 3D objects, the Euler number is obtained as the number of + objects plus the number of holes, minus the number of tunnels, or loops. + + when `robust` is ``True``, reuses "slice" from `props_dict` + + writes "euler_number" to `props_dict` + + Parameters + ---------- + labels: (M, N[, P]) cupy.ndarray + Input image. If image is not binary, all values greater than zero + are considered as the object. + connectivity : int, optional + Maximum number of orthogonal hops to consider a pixel/voxel + as a neighbor. + Accepted values are ranging from 1 to input.ndim. If ``None``, a full + connectivity of ``input.ndim`` is used. + 4 or 8 neighborhoods are defined for 2D images (connectivity 1 and 2, + respectively). + 6 or 26 neighborhoods are defined for 3D images, (connectivity 1 and 3, + respectively). Connectivity 2 is not defined. + + Extra Parameters + ---------------- + max_label : int or None, optional + The maximum label in labels can be provided to avoid recomputing it if + it was already known. + robust : bool, optional + If True, extra computation will be done to detect if any labeled + regions are <=2 pixel spacing from another. Any regions that meet that + criteria will have their perimeter recomputed in isolation to avoid + possible error that would otherwise occur in this case. Turning this + on will make the run time substantially longer, so it should only be + used when labeled regions may have a non-negligible portion of their + boundary within a <2 pixel gap from another label. + labels_close : numpy.ndarray or sequence of int + List of labeled regions that are less than 2 pixel gap from another + label. Used when robust=True. If not provided and robust=True, it + will be computed internally. + props_dict : dict or None, optional + Dictionary of pre-computed properties (e.g. "slice"). The output of this + function will be stored under key "euler_number" within this dictionary. + pixels_per_thread : int, optional + Number of pixels processed per thread on the GPU during the final + weighted summation. + + Returns + ------- + euler_number : cp.ndarray of int + Euler characteristic of the set of all objects in the image. + + Notes + ----- + The Euler characteristic is an integer number that describes the + topology of the set of all objects in the input image. If object is + 4-connected, then background is 8-connected, and conversely. + + The computation of the Euler characteristic is based on an integral + geometry formula in discretized space. In practice, a neighborhood + configuration is constructed, and a LUT is applied for each + configuration. The coefficients used are the ones of Ohser et al. + + It can be useful to compute the Euler characteristic for several + connectivities. A large relative difference between results + for different connectivities suggests that the image resolution + (with respect to the size of objects and holes) is too low. + + References + ---------- + .. [1] S. Rivollier. Analyse d’image geometrique et morphometrique par + diagrammes de forme et voisinages adaptatifs generaux. PhD thesis, + 2010. Ecole Nationale Superieure des Mines de Saint-Etienne. + https://tel.archives-ouvertes.fr/tel-00560838 + .. [2] Ohser J., Nagel W., Schladitz K. (2002) The Euler Number of + Discretized Sets - On the Choice of Adjacency in Homogeneous + Lattices. In: Mecke K., Stoyan D. (eds) Morphology of Condensed + Matter. Lecture Notes in Physics, vol 600. Springer, Berlin, + Heidelberg. + -------- + perimeter_crofton + + Examples + -------- + >>> import cupy as cp + >>> from skimage import data + >>> from cucim.skimage import util + >>> from cucim.skimage.measure import label + >>> # coins image (binary) + >>> img_coins = cp.array(data.coins() > 110) + >>> # total perimeter of all objects in the image + >>> perimeter(img_coins, neighborhood=4) # doctest: +ELLIPSIS + array(7796.86799644) + >>> perimeter(img_coins, neighborhood=8) # doctest: +ELLIPSIS + array(8806.26807333) + """ + + if max_label is None: + max_label = int(labels.max()) + + # check connectivity + if connectivity is None: + connectivity = labels.ndim + + binary_image = labels > 0 + + if robust and labels_close is None: + labels_close = _find_close_labels(labels, binary_image, max_label) + + filter_weights, coefs = _get_euler_weights_and_coefs( + labels.ndim, connectivity, cp.float32 + ) + binary_image = pad(binary_image, pad_width=1, mode="constant") + image_filtered = ndi.convolve( + binary_image.view(cp.uint8), + filter_weights, + mode="constant", + cval=0, + output=cp.uint8, + ) + + if robust and labels_close.size > 0: + if props_dict is not None and "slice" in props_dict: + slices = props_dict["slice"] + else: + _, slices = regionprops_bbox_coords(labels, return_slices=True) + + min_integer_type = _get_min_integer_dtype(max_label, signed=False) + if labels.dtype != min_integer_type: + labels = labels.astype(min_integer_type) + # dilate labels by 1 pixel so we can sum with values in XF to give + # unique histogram bins for each labeled regions (as long as no labeled + # regions are within < 2 pixels from another labeled region) + labels_pad = pad(labels, pad_width=1, mode="constant") + labels_dilated = ndi.grey_dilation(labels_pad, 3, mode="constant") + + # sum the coefficients for each label to compute the Euler number + euler_number = _weighted_sum_of_filtered_image( + label_image=labels_dilated, + max_label=max_label, + image_filtered=image_filtered, + coefs=coefs, + pixels_per_thread=pixels_per_thread, + ) + euler_number = euler_number.astype(cp.int64, copy=False) + + if robust: + # recompute perimeter in isolation for each region that may be too + # close to another one + shape = labels_dilated.shape + for lab in labels_close: + sl = slices[lab - 1] + # keep boundary of 1 so object is not at 'edge' of cropped + # region (unless it is at a true image edge) + # + 2 is because labels_pad is padded, but labels was not + ld = labels_pad[ + max(sl[0].start, 0) : min(sl[0].stop + 2, shape[0]), + max(sl[1].start, 0) : min(sl[1].stop + 2, shape[1]), + ] + euler_num = regionprops_euler( + ld == lab, connectivity=connectivity, max_label=1, robust=False + ) + euler_number[lab - 1] = euler_num[0] + if props_dict is not None: + props_dict["euler_number"] = euler_number + return euler_number diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_utils.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_utils.py index 3e331efc..fad8b2d9 100644 --- a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_utils.py +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_utils.py @@ -3,6 +3,9 @@ import cupy as cp from packaging.version import parse +from cucim.skimage._vendored import ndimage as ndi +from cucim.skimage.util import map_array + CUPY_GTE_13_3_0 = parse(cp.__version__) >= parse("13.3.0") # Need some default includes so uint32_t, uint64_t, etc. are defined @@ -96,3 +99,25 @@ def _unravel_loop_index( code += """ in_coord[0] = temp_idx;""" return code + + +def _reverse_label_values(label_image, max_label): + """reverses the value of all labels (keeping background value=0 the same)""" + dtype = label_image.dtype + labs = cp.asarray(tuple(range(max_label + 1)), dtype=dtype) + rev_labs = cp.asarray((0,) + tuple(range(max_label, 0, -1)), dtype=dtype) + return map_array(label_image, labs, rev_labs) + + +def _find_close_labels(labels, binary_image, max_label): + # check possibly too-close regions for which we may need to + # manually recompute the regions perimeter in isolation + labels_dilated2 = ndi.grey_dilation(labels, 5, mode="constant") + labels2 = labels_dilated2 * binary_image + rev_labels = _reverse_label_values(labels, max_label=max_label) + rev_labels = ndi.grey_dilation(rev_labels, 5, mode="constant") + rev_labels = rev_labels * binary_image + labels3 = _reverse_label_values(rev_labels, max_label=max_label) + diffs = cp.logical_or(labels != labels2, labels != labels3) + labels_close = cp.asnumpy(cp.unique(labels[diffs])) + return labels_close diff --git a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py index 7dfa16a1..6edd499b 100644 --- a/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py +++ b/python/cucim/src/cucim/skimage/measure/tests/test_regionprops_gpu_kernels.py @@ -17,6 +17,7 @@ from cucim.skimage._vendored import ndimage as ndi from cucim.skimage.measure._regionprops import PROPS from cucim.skimage.measure._regionprops_gpu import ( + _find_close_labels, equivalent_diameter_area, ndim_2_only, need_intensity_image, @@ -29,6 +30,7 @@ regionprops_centroid_weighted, regionprops_coords, regionprops_dict, + regionprops_euler, regionprops_extent, regionprops_feret_diameter_max, regionprops_image, @@ -42,12 +44,15 @@ regionprops_moments_hu, regionprops_moments_normalized, regionprops_num_pixels, + regionprops_perimeter, + regionprops_perimeter_crofton, ) from cucim.skimage.measure._regionprops_gpu_basic_kernels import basic_deps from cucim.skimage.measure._regionprops_gpu_convex import convex_deps from cucim.skimage.measure._regionprops_gpu_intensity_kernels import ( intensity_deps, ) +from cucim.skimage.measure._regionprops_gpu_misc_kernels import misc_deps from cucim.skimage.measure._regionprops_gpu_moments_kernels import ( moment_deps, ) @@ -1009,6 +1014,126 @@ def test_centroid_weighted( allclose(centroids[:, c, d], expected[prop + f"-{d}-{c}"]) +@pytest.mark.parametrize("shape", [(256, 512), (4096, 1024)]) +@pytest.mark.parametrize("volume_fraction", [0.1, 0.25, 0.5]) +@pytest.mark.parametrize("blob_size_fraction", [0.025, 0.05, 0.1]) +@pytest.mark.parametrize("robust", [False, True]) +def test_perimeter(shape, robust, volume_fraction, blob_size_fraction): + labels = get_labels_nd( + shape, + blob_size_fraction=blob_size_fraction, + volume_fraction=volume_fraction, + ) + + max_label = int(cp.max(labels)) + if not robust: + # remove any regions that are to close for non-robust algorithm + labels_close = _find_close_labels(labels, labels > 0, max_label) + for value in labels_close: + labels[labels == value] = 0 + + # relabel to ensure sequential + labels = measure.label(labels > 0) + max_label = int(cp.max(labels)) + + max_label = int(cp.max(labels)) + values = regionprops_perimeter(labels, max_label=max_label, robust=robust) + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), properties=["perimeter"] + ) + rtol = 1e-5 + atol = 1e-5 + allclose = functools.partial(assert_allclose, rtol=rtol, atol=atol) + allclose(values, expected["perimeter"]) + + +@pytest.mark.parametrize("shape", [(256, 512), (4096, 1024)]) +@pytest.mark.parametrize("volume_fraction", [0.1, 0.25, 0.5]) +@pytest.mark.parametrize("blob_size_fraction", [0.025, 0.05, 0.1]) +@pytest.mark.parametrize("robust", [False, True]) +def test_perimeter_crofton(shape, robust, volume_fraction, blob_size_fraction): + labels = get_labels_nd( + shape, + blob_size_fraction=blob_size_fraction, + volume_fraction=volume_fraction, + ) + + max_label = int(cp.max(labels)) + if not robust: + # remove any regions that are to close for non-robust algorithm + labels_close = _find_close_labels(labels, labels > 0, max_label) + for value in labels_close: + labels[labels == value] = 0 + + # relabel to ensure sequential + labels = measure.label(labels > 0) + max_label = int(cp.max(labels)) + + values = regionprops_perimeter_crofton( + labels, max_label=max_label, robust=robust + ) + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), properties=["perimeter_crofton"] + ) + rtol = 1e-5 + atol = 1e-5 + allclose = functools.partial(assert_allclose, rtol=rtol, atol=atol) + allclose(values, expected["perimeter_crofton"]) + + +@pytest.mark.parametrize( + "shape, connectivity", + [ + ((256, 512), None), + ((256, 512), 1), + ((256, 512), 2), + ((96, 64, 48), None), + ((96, 64, 48), 1), + ((96, 64, 48), 3), + ], +) +@pytest.mark.parametrize("volume_fraction", [0.25]) +@pytest.mark.parametrize("blob_size_fraction", [0.04]) +@pytest.mark.parametrize("insert_holes", [False, True]) +@pytest.mark.parametrize("robust", [False, True]) +def test_euler( + shape, + connectivity, + robust, + volume_fraction, + blob_size_fraction, + insert_holes, +): + labels = get_labels_nd( + shape, + blob_size_fraction=blob_size_fraction, + volume_fraction=volume_fraction, + insert_holes=insert_holes, + ) + + max_label = int(cp.max(labels)) + if not robust: + # remove any regions that are to close for non-robust algorithm + labels_close = _find_close_labels(labels, labels > 0, max_label) + for value in labels_close: + labels[labels == value] = 0 + + # relabel to ensure sequential + labels = measure.label(labels > 0) + max_label = int(cp.max(labels)) + + values = regionprops_euler( + labels, connectivity, max_label=max_label, robust=robust + ) + if connectivity is None or connectivity == labels.ndim: + # regionprops_table has hard-coded connectivity, can't use it to verify + # other values + expected = measure_cpu.regionprops_table( + cp.asnumpy(labels), properties=["euler_number"] + ) + assert_array_equal(values, expected["euler_number"]) + + @pytest.mark.parametrize("ndim", [2, 3]) @pytest.mark.parametrize("num_channels", [1, 3]) @pytest.mark.parametrize( @@ -1159,6 +1284,7 @@ def test_feret_diameter_max(ndim, spacing, blob_kwargs): list(basic_deps.keys()) + list(convex_deps.keys()) + list(intensity_deps.keys()) + + list(misc_deps.keys()) + list(moment_deps.keys()) ), ) From 80199fd178323c1bd9b82ffbbbfacf93518aa3b6 Mon Sep 17 00:00:00 2001 From: Gregory Lee Date: Sun, 2 Mar 2025 20:33:30 -0500 Subject: [PATCH 14/14] add robust_perimeter flag --- .../cucim/skimage/measure/_regionprops_gpu.py | 33 +++++++++++++++---- .../_regionprops_gpu_moments_kernels.py | 1 - 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py index e8555610..398f9754 100644 --- a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu.py @@ -220,6 +220,7 @@ def regionprops_dict( moment_order=None, max_label=None, pixels_per_thread=16, + robust_perimeter=True, ): """Compute image properties and return them as a pandas-compatible table. @@ -259,6 +260,15 @@ def regionprops_dict( from each GPU thread. The number of adjacent pixels processed corresponds to `pixels_per_thread` and can be used as a performance tuning parameter. + robust_perimeter : bool, optional + Batch computation of perimeter and euler characteristics can give + incorrect results for perimeter pixels that are not more than 1 pixel + spacing from another label. If True, a check for this condition is + performed and any labels close to another label have their perimeter + recomputed in isolation. Doing this check results in performance + overhead so can optionally be disabled. This parameter effects the + following regionprops: {"perimeter", "perimeter_crofton", + "euler_number"}. """ supported_properties = CURRENT_PROPS_GPU | GLOBAL_PROPS properties = set(properties) @@ -583,16 +593,27 @@ def regionprops_dict( labels_mask = label_image.view("bool") else: labels_mask = label_image > 0 - labels_close = _find_close_labels( - label_image, binary_image=labels_mask, max_label=max_label - ) + if robust_perimeter: + # avoid repeatedly computing "labels_close" for + # perimeter, perimeter_crofton and euler_number regionprops + labels_close = _find_close_labels( + label_image, binary_image=labels_mask, max_label=max_label + ) + if labels_close.size > 0: + print( + f"Found {labels_close.size} regions with <=1 background " + "pixel spacing from another region. Using slower robust " + "perimeter/euler measurements for these regions." + ) + else: + labels_close = None if compute_perimeter: regionprops_perimeter( label_image, neighborhood=4, max_label=max_label, - robust=True, + robust=robust_perimeter, labels_close=labels_close, props_dict=out, ) @@ -601,7 +622,7 @@ def regionprops_dict( label_image, directions=4, max_label=max_label, - robust=True, + robust=robust_perimeter, omit_image_edges=False, labels_close=labels_close, props_dict=out, @@ -612,7 +633,7 @@ def regionprops_dict( label_image, connectivity=None, max_label=max_label, - robust=True, + robust=robust_perimeter, labels_close=labels_close, props_dict=out, ) diff --git a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_moments_kernels.py b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_moments_kernels.py index e8a10602..a3b81a7a 100644 --- a/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_moments_kernels.py +++ b/python/cucim/src/cucim/skimage/measure/_regionprops_gpu_moments_kernels.py @@ -22,7 +22,6 @@ "regionprops_moments_normalized", ] - # Store information on which other properties a given property depends on # This information will be used by `regionprops_dict` to make sure that when # a particular property is requested any dependent properties are computed