diff --git a/monai/deploy/operators/dicom_seg_writer_operator.py b/monai/deploy/operators/dicom_seg_writer_operator.py index f5bc5ae4..94e1a6e5 100644 --- a/monai/deploy/operators/dicom_seg_writer_operator.py +++ b/monai/deploy/operators/dicom_seg_writer_operator.py @@ -185,6 +185,8 @@ def __init__( output_folder: Path, custom_tags: Optional[Dict[str, str]] = None, omit_empty_frames: bool = True, + force_contiguous_labels: bool = False, + label_mapping_dict: Optional[Dict[str, int]] = None, **kwargs, ): """Instantiates the DICOM Seg Writer instance with optional list of segment label strings. @@ -211,12 +213,22 @@ def __init__( omit_empty_frames: bool, optional Whether to omit frames that contain no segmented pixels from the output segmentation. Defaults to True, same as the underlying lib API. + force_contiguous_labels: bool + Based on the input image or algorithm results some structures may be undetected. Similarly, + internal algorithm labeling may be non-contiguous. This flag ensures that before writing the DICOM SEG + SegmentDescription contains contiguous SegmentNumbers and SegmentLabels + label_mapping_dict: Optional[Dict[str, int]] + If force_contiguous_labels is True and label_mapping_dict provides full set of possible lables and their + respective numerical values in segmentation, if any label is missing, SegmentDescriptions and numerical + values will be consolidated to contiguous set """ self._seg_descs = [sd.to_segment_description(n) for n, sd in enumerate(segment_descriptions, 1)] self._custom_tags = custom_tags self._omit_empty_frames = omit_empty_frames self.output_folder = output_folder if output_folder else DICOMSegmentationWriterOperator.DEFAULT_OUTPUT_FOLDER + self.force_contiguous_labels = force_contiguous_labels + self.label_mapping_dict = label_mapping_dict self.input_name_seg = "seg_image" self.input_name_series = "study_selected_series_list" @@ -291,6 +303,12 @@ def process_images( elif not isinstance(image, np.ndarray): raise ValueError("'image' is not a numpy array, Image object, or supported image file.") + if self.force_contiguous_labels: + if self.label_mapping_dict is not None: + seg_image_numpy, self._seg_descs = self._consolidate_contiguous_labels(seg_image_numpy, self._seg_descs) + else: + raise ValueError("force_contiguous_labels is True, but label_mapping_dict was not provided.") + # Pick DICOM Series that was used as input for getting the seg image. # For now, first one in the list. for study_selected_series in study_selected_series_list: @@ -301,6 +319,41 @@ def process_images( self.create_dicom_seg(seg_image_numpy, dicom_series, output_dir) break + def _consolidate_contiguous_labels(self, image: np.ndarray, segment_descriptions: List[SegmentDescription]): + # Get all inferred classes + inferred_classes = np.unique(image) + renumber_segment_descriptions = False + updated_segment_descriptions = [] + + # Go through all segment descriptions, look if any segment is not in inferred classes and if yes remove that + # segment from segment descriptions + for segment in segment_descriptions: + label_int = self.label_mapping_dict[segment.segment_label] + if label_int in inferred_classes: + updated_segment_descriptions.append(segment) + else: + renumber_segment_descriptions = True + + segment_descriptions = updated_segment_descriptions + + # Renumber the segment_numbers if needed + if renumber_segment_descriptions: + for n, segment in enumerate(segment_descriptions, 1): + segment.SegmentNumber = n + + # In case segment mask numbers are no contiguous, ensure their numpy values match the segment numbers + if image.max() != (len(inferred_classes) - 1): # do not count background - 0 - as class + # work rather on new array to avoid accidentally overwriting some values + remapped_image = np.zeros_like(image) + for segment in segment_descriptions: + label_int = self.label_mapping_dict[segment.segment_label] + mask = (image == label_int) + remapped_image[mask] = segment.SegmentNumber + + image = remapped_image + + return image, segment_descriptions + def create_dicom_seg(self, image: np.ndarray, dicom_series: DICOMSeries, output_dir: Path): # Generate SOP instance UID, and use it as dcm file name too seg_sop_instance_uid = hd.UID() # generate_uid() can be used too.