diff --git a/doc/apg.md b/doc/apg.md new file mode 100644 index 000000000..7e863f65c --- /dev/null +++ b/doc/apg.md @@ -0,0 +1,11 @@ +# APG + +`micro_sam` supports three different modes for instance segmentation: +- Automatic Mask Generation (AMG) covers the image with a grid of points. These points are used as prompts and the resulting masks are merged via non-maximum suppression (NMS) to obtain the instance segmentation. This method has been introduced by the original SAM publication. +- Automatic Instance Segmentation (AIS) uses an additional segmentation decoder, which we introduced in the `micro_sam` publication. This decoder predicts foreground probabilities as well as the normalized distances to cell centroids and boundaries. These predictions are used as input to a waterhsed to obtain the instances. +- Autmatic Prompt Generation (APG) is an instance segmentation approach that we introduced in [a new paper](https://openreview.net/forum?id=xFO3DFZN45). It derives point prompts from the segmentation decoder (see AIS) and merges the resulting masks via NMS. + +In our experiments, APG yields the best overall instance segmentation results (compared to AMG and AIS) and is competitive with CellPose-SAM, the state-of-the-art model for cell instance segmentation. + +The segmentation mode can be selected with the argument `mode` or `segmentation_mode` in the [CLI](#using-the-command-line-interface-cli) and [python functionality](https://computational-cell-analytics.github.io/micro-sam/micro_sam/automatic_segmentation.html). For details on how to use the different automatic segmentation modes check out the [automatic segmentation +notebook](https://github.com/computational-cell-analytics/micro-sam/blob/master/notebooks/automatic_segmentation.ipynb). The code for the experiments comparing the different segmentation modes (from [the new paper](https://openreview.net/forum?id=xFO3DFZN45)) can be found [here](https://github.com/computational-cell-analytics/micro-sam/tree/master/scripts/apg_experiments). diff --git a/doc/cli_tools.md b/doc/cli_tools.md index 277830d36..bf57afb2c 100644 --- a/doc/cli_tools.md +++ b/doc/cli_tools.md @@ -11,14 +11,15 @@ The supported CLIs can be used by - Running `$ micro_sam.image_series_annotator` for starting the image series annotator. - Running `$ micro_sam.train` for finetuning Segment Anything models on your data. - Running `$ micro_sam.automatic_segmentation` for automatic instance segmentation. - - We support all post-processing parameters for automatic instance segmentation (for both AMG and AIS). - - The automatic segmentation mode can be controlled by: `--mode `, where the available choice for `MODE_NAME` is `amg` / `ais`. + - We support all post-processing parameters for automatic instance segmentation (for AMG, AIS and APG). + - The automatic segmentation mode can be controlled by: `--mode `, where the available choice for `MODE_NAME` is `amg` / `ais` / `apg`. - AMG is supported by both default Segment Anything models and `micro-sam` models / finetuned models. - AIS is supported by `micro-sam` models (or finetuned models; subjected to they are trained with the additional instance segmentation decoder) + - APG is supported by `micro-sam` models (or finetuned models; subjected to they are trained with the additional instance segmentation decoder) - If these parameters are not provided by the user, `micro-sam` makes use of the best post-processing parameters (depending on the choice of model). - The post-processing parameters can be changed by parsing the parameters via the CLI using `-- .` For example, one can update the parameter values (eg. `pred_iou_thresh`, `stability_iou_thresh`, etc. - supported by AMG) using `$ micro_sam.automatic_segmentation ... --pred_iou_thresh 0.6 --stability_iou_thresh 0.6 ...` - Remember to specify the automatic segmentation mode using `--mode ` when using additional post-processing parameters. - - You can check details for supported parameters and their respective default values at `micro_sam/instance_segmentation.py` under the `generate` method for `AutomaticMaskGenerator` and `InstanceSegmentationWithDecoder` class. + - You can check details for supported parameters and their respective default values at `micro_sam/instance_segmentation.py` under the `generate` method for `AutomaticMaskGenerator`, `InstanceSegmentationWithDecoder` and `AutomaticPromptGenerator` class. - A good practice is to set `--ndim `, where `` corresponds to the number of dimensions of input images. - Running `$ micro_sam.evaluate` for evaluating instance segmentation. diff --git a/doc/faq.md b/doc/faq.md index 76eabc5f3..3b13889a0 100644 --- a/doc/faq.md +++ b/doc/faq.md @@ -94,7 +94,7 @@ We recommend transferring the model checkpoints to the system-level cache direct -### 1. I have some micropscopy images. Can I use the annotator tool for segmenting them? +### 1. I have some microscopy images. Can I use the annotator tool for segmenting them? Yes, you can use the annotator tool for: - Segmenting objects in 2d images (using automatic and/or interactive segmentation). - Segmenting objects in 3d volumes (using automatic and/or interactive segmentation for the entire object(s)). @@ -214,30 +214,33 @@ You can load your finetuned model by entering the path to its checkpoint in the If you are using the python library or CLI you can specify this path with the `checkpoint_path` parameter. -### 5. What is the background of the new AIS (Automatic Instance Segmentation) feature in `micro_sam`? -`micro_sam` introduces a new segmentation decoder to the Segment Anything backbone, for enabling faster and accurate automatic instance segmentation, by predicting the [distances to the object center and boundary](https://github.com/constantinpape/torch-em/blob/main/torch_em/transform/label.py#L284) as well as predicting foregrund, and performing [seeded watershed-based postprocessing](https://github.com/constantinpape/torch-em/blob/main/torch_em/util/segmentation.py#L122) to obtain the instances. +### 5. What is the background of the AIS (Automatic Instance Segmentation) feature in `micro_sam`? +`micro_sam` introduces a new segmentation decoder to the Segment Anything backbone, for enabling faster and accurate automatic instance segmentation, by predicting the [distances to the object center and boundary](https://github.com/constantinpape/torch-em/blob/main/torch_em/transform/label.py#L284) as well as predicting foreground, and performing [seeded watershed-based postprocessing](https://github.com/constantinpape/torch-em/blob/main/torch_em/util/segmentation.py#L122) to obtain the instances. +### 6. What is the background of the new APG (Automatic Prompt Generation) feature in `micro_sam`? -### 6. I want to finetune only the Segment Anything model without the additional instance decoder. -The instance segmentation decoder is optional. So you can only finetune SAM or SAM and the additional decoder. Finetuning with the decoder will increase training times, but will enable you to use AIS. See [this example](https://github.com/computational-cell-analytics/micro-sam/tree/master/examples/finetuning#example-for-model-finetuning) for finetuning with both the objectives. +With the latest version 1.7.0 onwards, `micro_sam` introduces a new automatic instance segmentation method, called APG (automatic prompt generation). It builds on `micro_sam` by extracting prompts from the boundary and center distances predicted by the pretrained segmentation decoder. Once the prompts have been derived, it provides them to the prompt encoder and mask decoder (and additional postprocessing to the outputs) to obtain the instances. The method is compatible with the `micro_sam.automatic_segmentation` CLI (by selecting the `segmentation_mode="apg"`) and the python interface. See [APG](#apg) for details. + +### 7. I want to finetune only the Segment Anything model without the additional instance decoder. +The instance segmentation decoder is optional. So you can only finetune SAM or SAM and the additional decoder. Finetuning with the decoder will increase training times, but will enable you to use AIS and APG. See [this example](https://github.com/computational-cell-analytics/micro-sam/tree/master/examples/finetuning#example-for-model-finetuning) for finetuning with both the objectives. > NOTE: To try out the other way round (i.e. the automatic instance segmentation framework without the interactive capability, i.e. a UNETR: a vision transformer encoder and a convolutional decoder), you can take inspiration from this [example on LIVECell](https://github.com/constantinpape/torch-em/blob/main/experiments/vision-transformer/unetr/for_vimunet_benchmarking/run_livecell.py). -### 7. I have a NVIDIA RTX 4090Ti GPU with 24GB VRAM. Can I finetune Segment Anything? +### 8. I have a NVIDIA RTX 4090Ti GPU with 24GB VRAM. Can I finetune Segment Anything? Finetuning Segment Anything is possible in most consumer-grade GPU and CPU resources (but training being a lot slower on the CPU). For the mentioned resource, it should be possible to finetune a ViT Base (also abbreviated as `vit_b`) by reducing the number of objects per image to 15. This parameter has the biggest impact on the VRAM consumption and quality of the finetuned model. You can find an overview of the resources we have tested for finetuning [here](#training-your-own-model). We also provide a the convenience function `micro_sam.training.train_sam_for_configuration` that selects the best training settings for these configuration. This function is also used by the finetuning UI. -### 8. I want to create a dataloader for my data, to finetune Segment Anything. +### 9. I want to create a dataloader for my data, to finetune Segment Anything. Thanks to `torch-em`, a) Creating PyTorch datasets and dataloaders using the python library is convenient and supported for various data formats and data structures. See the [tutorial notebook](https://github.com/constantinpape/torch-em/blob/main/notebooks/tutorial_create_dataloaders.ipynb) on how to create dataloaders using `torch-em` and the [documentation](https://github.com/constantinpape/torch-em/blob/main/doc/datasets_and_dataloaders.md) for details on creating your own datasets and dataloaders; and b) finetuning using the `napari` tool eases the aforementioned process, by allowing you to add the input parameters (path to the directory for inputs and labels etc.) directly in the tool. > NOTE: If you have images with large input shapes with a sparse density of instance segmentations, we recommend using [`sampler`](https://github.com/constantinpape/torch-em/blob/main/torch_em/data/sampler.py) for choosing the patches with valid segmentation for the finetuning purpose (see the [example](https://github.com/computational-cell-analytics/micro-sam/blob/master/finetuning/specialists/training/light_microscopy/plantseg_root_finetuning.py#L29) for PlantSeg (Root) specialist model in `micro_sam`). -### 9. How can I evaluate a model I have finetuned? +### 10. How can I evaluate a model I have finetuned? To validate a Segment Anything model for your data, you have different options, depending on the task you want to solve and whether you have segmentation annotations for your data. - If you don't have any annotations you will have to validate the model visually. We suggest doing this with the `micro_sam` GUI tools. You can learn how to use them in the `micro_sam` documentation. diff --git a/doc/start_page.md b/doc/start_page.md index b158134e7..6651bd323 100644 --- a/doc/start_page.md +++ b/doc/start_page.md @@ -14,7 +14,7 @@ Based on these components `micro_sam` enables fast interactive and automatic ann We are still working on improving and extending its functionality. The current roadmap includes: - Releasing more and better finetuned models for the biomedical imaging domain. - Integrating parameter efficient training and compressed models for efficient fine-tuning and faster inference. -- Support for [SAM2](https://ai.meta.com/sam2/). +- Support for [SAM2](https://ai.meta.com/sam2/) and [SAM3](https://ai.meta.com/sam3/). If you run into any problems or have questions please [open an issue](https://github.com/computational-cell-analytics/micro-sam/issues/new) or reach out via [image.sc](https://forum.image.sc/) using the tag `micro-sam`. You can follow recent updates on `micro_sam` in our [news feed](https://forum.image.sc/t/microsam-news-feed). diff --git a/examples/automatic_segmentation.py b/examples/automatic_segmentation.py index 449c3f37c..70fd77d62 100644 --- a/examples/automatic_segmentation.py +++ b/examples/automatic_segmentation.py @@ -10,7 +10,7 @@ DATA_CACHE = os.path.join(get_cache_directory(), "sample_data") -def livecell_automatic_segmentation(model_type, use_amg, generate_kwargs): +def livecell_automatic_segmentation(model_type, segmentation_mode, generate_kwargs): """Run the automatic segmentation for an example image from the LIVECell dataset. See https://doi.org/10.1038/s41592-021-01249-6 for details on the data. @@ -21,7 +21,7 @@ def livecell_automatic_segmentation(model_type, use_amg, generate_kwargs): predictor, segmenter = get_predictor_and_segmenter( model_type=model_type, checkpoint=None, # Replace this with your custom checkpoint. - amg=use_amg, + segmentation_mode=segmentation_mode, is_tiled=False, # Switch to 'True' in case you would like to perform tiling-window based prediction. ) @@ -42,7 +42,7 @@ def livecell_automatic_segmentation(model_type, use_amg, generate_kwargs): napari.run() -def hela_automatic_segmentation(model_type, use_amg, generate_kwargs): +def hela_automatic_segmentation(model_type, segmentation_mode, generate_kwargs): """Run the automatic segmentation for an example image from the Cell Tracking Challenge (HeLa 2d) dataset. """ example_data = fetch_hela_2d_example_data(DATA_CACHE) @@ -51,7 +51,7 @@ def hela_automatic_segmentation(model_type, use_amg, generate_kwargs): predictor, segmenter = get_predictor_and_segmenter( model_type=model_type, checkpoint=None, # Replace this with your custom checkpoint. - amg=use_amg, + segmentation_mode=segmentation_mode, is_tiled=False, # Switch to 'True' in case you would like to perform tiling-window based prediction. ) @@ -72,7 +72,7 @@ def hela_automatic_segmentation(model_type, use_amg, generate_kwargs): napari.run() -def wholeslide_automatic_segmentation(model_type, use_amg, generate_kwargs): +def wholeslide_automatic_segmentation(model_type, segmentation_mode, generate_kwargs): """Run the automatic segmentation with tiling for an example whole-slide image from the NeurIPS Cell Segmentation challenge. """ @@ -82,7 +82,7 @@ def wholeslide_automatic_segmentation(model_type, use_amg, generate_kwargs): predictor, segmenter = get_predictor_and_segmenter( model_type=model_type, checkpoint=None, # Replace this with your custom checkpoint. - amg=use_amg, + segmentation_mode=segmentation_mode, is_tiled=True, ) @@ -110,19 +110,19 @@ def main(): # Whether to use: # the automatic mask generation (AMG): supported by all our models. # the automatic instance segmentation (AIS): supported by 'micro-sam' models. - use_amg = False # 'False' chooses AIS as the automatic segmentation mode. + # the automatic prompt generation (APG): supported by 'micro-sam' models. + segmentation_mode = "apg" # available choices for automatic segmentation modes are 'amg' / 'ais' / 'apg'. # Post-processing parameters for automatic segmentation. - if use_amg: # AMG parameters + if segmentation_mode == "amg": # AMG parameters generate_kwargs = { "pred_iou_thresh": 0.88, "stability_score_thresh": 0.95, "box_nms_thresh": 0.7, "crop_nms_thresh": 0.7, "min_mask_region_area": 0, - "output_mode": "binary_mask", } - else: # AIS parameters + elif segmentation_mode == "ais": # AIS parameters generate_kwargs = { "center_distance_threshold": 0.5, "boundary_distance_threshold": 0.5, @@ -130,17 +130,25 @@ def main(): "foreground_smoothing": 1.0, "distance_smoothing": 1.6, "min_size": 0, - "output_mode": "binary_mask", } + elif segmentation_mode == "apg": # APG parameters + generate_kwargs = { + "center_distance_threshold": 0.5, + "boundary_distance_threshold": 0.5, + "foreground_threshold": 0.5, + "nms_threshold": 0.9, + } + else: + raise ValueError("The selected 'segmentation_mode' is not a supported segmentation method.") # Automatic segmentation for livecell data. - livecell_automatic_segmentation(model_type, use_amg, generate_kwargs) + livecell_automatic_segmentation(model_type, segmentation_mode, generate_kwargs) # Automatic segmentation for cell tracking challenge hela data. - # hela_automatic_segmentation(model_type, use_amg, generate_kwargs) + # hela_automatic_segmentation(model_type, segmentation_mode, generate_kwargs) # Automatic segmentation for a whole slide image. - # wholeslide_automatic_segmentation(model_type, use_amg, generate_kwargs) + # wholeslide_automatic_segmentation(model_type, segmentation_mode, generate_kwargs) # The corresponding CLI call for hela_automatic_segmentation: diff --git a/examples/automatic_tracking.py b/examples/automatic_tracking.py index 0b5f49df7..18fb99775 100644 --- a/examples/automatic_tracking.py +++ b/examples/automatic_tracking.py @@ -29,7 +29,7 @@ def example_automatic_tracking(use_finetuned_model): embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc.zarr") model_type = "vit_h" - predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, amg=False) + predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, segmentation_mode="ais") masks_tracked, _ = automatic_tracking( predictor=predictor, diff --git a/micro_sam/__init__.py b/micro_sam/__init__.py index 7934d4dc9..805ca80d4 100644 --- a/micro_sam/__init__.py +++ b/micro_sam/__init__.py @@ -5,6 +5,7 @@ .. include:: ../doc/cli_tools.md .. include:: ../doc/python_library.md .. include:: ../doc/finetuned_models.md +.. include:: ../doc/apg.md .. include:: ../doc/data_submission.md .. include:: ../doc/faq.md .. include:: ../doc/contributing.md diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index 72703f4eb..81f6f1cdb 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -4,7 +4,7 @@ from tqdm import tqdm from pathlib import Path from functools import partial -from typing import Dict, List, Optional, Union, Tuple +from typing import Dict, List, Optional, Union, Tuple, Literal import numpy as np import imageio.v3 as imageio @@ -26,7 +26,7 @@ def get_predictor_and_segmenter( model_type: str, checkpoint: Optional[Union[os.PathLike, str]] = None, device: str = None, - segmentation_mode: Optional[str] = None, + segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None, is_tiled: bool = False, predictor=None, state=None, diff --git a/micro_sam/evaluation/benchmark_datasets.py b/micro_sam/evaluation/benchmark_datasets.py index a1679c9a5..1a590e88a 100644 --- a/micro_sam/evaluation/benchmark_datasets.py +++ b/micro_sam/evaluation/benchmark_datasets.py @@ -23,7 +23,9 @@ from .inference import run_inference_with_iterative_prompting from .evaluation import run_evaluation_for_iterative_prompting from .multi_dimensional_segmentation import segment_slices_from_ground_truth -from ..automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter +from ..automatic_segmentation import ( + automatic_instance_segmentation, get_predictor_and_segmenter, DEFAULT_SEGMENTATION_MODE_WITH_DECODER, +) LM_2D_DATASETS = [ @@ -513,7 +515,7 @@ def _run_automatic_segmentation_per_dataset( ndim: Optional[int] = None, device: Optional[Union[torch.device, str]] = None, checkpoint_path: Optional[Union[os.PathLike, str]] = None, - run_amg: Optional[bool] = None, + segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = "ais", **auto_seg_kwargs ): """Functionality to run automatic segmentation for multiple input files at once. @@ -527,19 +529,16 @@ def _run_automatic_segmentation_per_dataset( ndim: The number of input dimensions. device: The torch device. checkpoint_path: The filepath where the model checkpoints are stored. - run_amg: Whether to run automatic segmentation in AMG mode. + segmentation_mode: The mode for automatic segmentation. auto_seg_kwargs: Additional arguments for automatic segmentation parameters. """ - # First, we check if 'run_amg' is done, whether decoder is available or not. - # Depending on that, we can set 'run_amg' to the default best automatic segmentation (i.e. AIS > AMG). - if run_amg is None or (not run_amg): # The 2nd condition checks if you want AIS and if decoder state exists or not. + if segmentation_mode is None: # The 2nd condition checks if you want AIS and if decoder state exists or not. _, state = util.get_sam_model( model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True ) - run_amg = ("decoder_state" not in state) + segmentation_mode = DEFAULT_SEGMENTATION_MODE_WITH_DECODER if "decoder_state" in state else "amg" - experiment_name = "AMG" if run_amg else "AIS" - fname = f"{experiment_name.lower()}_{ndim}d" + fname = f"{segmentation_mode}_{ndim}d" result_path = os.path.join(output_folder, "results", f"{fname}.csv") if os.path.exists(result_path): @@ -550,10 +549,11 @@ def _run_automatic_segmentation_per_dataset( # Get the predictor (and the additional instance segmentation decoder, if available). predictor, segmenter = get_predictor_and_segmenter( - model_type=model_type, checkpoint=checkpoint_path, device=device, amg=run_amg, is_tiled=False, + model_type=model_type, checkpoint=checkpoint_path, device=device, + segmentation_mode=segmentation_mode, is_tiled=False, ) - for image_path in tqdm(image_paths, desc=f"Run {experiment_name} in {ndim}d"): + for image_path in tqdm(image_paths, desc=f"Run {segmentation_mode} in {ndim}d"): output_path = os.path.join(prediction_dir, os.path.basename(image_path)) if os.path.exists(output_path): continue @@ -667,7 +667,8 @@ def _run_interactive_segmentation_per_dataset( def _run_benchmark_evaluation_series( - image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg, evaluation_methods, + image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, + segmentation_mode, evaluation_methods, ): seg_kwargs = { "image_paths": image_paths, @@ -687,11 +688,11 @@ def _run_benchmark_evaluation_series( if evaluation_methods != "interactive": # Avoid auto. seg. evaluation for 'interactive'-only run choice. # i. Run automatic segmentation method supported with the SAM model (AMG or AIS). - _run_automatic_segmentation_per_dataset(run_amg=None, **seg_kwargs) + _run_automatic_segmentation_per_dataset(segmentation_mode=None, **seg_kwargs) # ii. Run automatic mask generation (AMG). # NOTE: This would only run if the user wants to. Else by default, it is set to 'False'. - _run_automatic_segmentation_per_dataset(run_amg=run_amg, **seg_kwargs) + _run_automatic_segmentation_per_dataset(segmentation_mode=segmentation_mode, **seg_kwargs) if evaluation_methods != "automatic": # Avoid int. seg. evaluation for 'automatic'-only run choice. # b. Run interactive segmentation (supported in both 2d and 3d, wherever relevant) @@ -746,7 +747,7 @@ def run_benchmark_evaluations( model_type: str = util._DEFAULT_MODEL, output_folder: Optional[Union[str, os.PathLike]] = None, checkpoint_path: Optional[Union[str, os.PathLike]] = None, - run_amg: bool = False, + segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None, retain: Optional[List[str]] = None, evaluation_methods: Literal["all", "automatic", "interactive"] = "all", ignore_warnings: bool = False, @@ -759,7 +760,7 @@ def run_benchmark_evaluations( model_type: The model choice for SAM. output_folder: The path to directory where all outputs will be stored. checkpoint_path: The checkpoint path - run_amg: Whether to run automatic segmentation in AMG mode. + segmentation_mode: The segmentation mode. One of 'amg', 'ais', or 'apg'. retain: Whether to retain certain parts of the benchmark runs. By default, removes everything besides quantitative results. There is the choice to retain 'data', 'crops', 'automatic', or 'interactive'. @@ -799,7 +800,7 @@ def run_benchmark_evaluations( ndim=ndim, device=device, checkpoint_path=checkpoint_path, - run_amg=run_amg, + segmentation_mode=segmentation_mode, evaluation_methods=evaluation_methods, ) @@ -814,7 +815,7 @@ def run_benchmark_evaluations( ndim=2, device=device, checkpoint_path=checkpoint_path, - run_amg=run_amg, + segmentation_mode=segmentation_mode, evaluation_methods=evaluation_methods, ) @@ -885,7 +886,7 @@ def main(): model_type=args.model_type, output_folder=args.output_folder, checkpoint_path=args.checkpoint_path, - run_amg=args.amg, + segmentation_mode=args.segmentation_mode, retain=args.retain, evaluation_methods=args.evaluate, ignore_warnings=True, diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index 1e444a04c..33aea7653 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -8,7 +8,7 @@ from abc import ABC from copy import deepcopy from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Literal, List, Optional, Tuple, Union import vigra import numpy as np @@ -1375,7 +1375,7 @@ def generate( - 'binary_mask': Return a list of dictionaries with masks encoded as binary masks. - 'instance_segmentation': Return masks merged into an instance segmentation in a single array. By default, set to 'instance_segmentation'. - mask_threshold: The threshold for turining logits into masks in `micro_sam.inference.batched_inference`.` + mask_threshold: The threshold for turning logits into masks in `micro_sam.inference.batched_inference`.` refine_with_box_prompts: Whether to refine the mask outputs with another round of box promtps derived from the segmentations after point prompts. prompt_function: A custom function for deriving prompts from the segmentation decoder predictions. @@ -1572,7 +1572,7 @@ def get_instance_segmentation_generator( predictor: SamPredictor, is_tiled: bool, decoder: Optional[torch.nn.Module] = None, - segmentation_mode: Optional[str] = None, + segmentation_mode: Optional[Literal["amg", "ais", "apg"]] = None, **kwargs, ) -> Union[AMGBase, InstanceSegmentationWithDecoder]: f"""Get the automatic mask generator. diff --git a/notebooks/automatic_segmentation.ipynb b/notebooks/automatic_segmentation.ipynb index e2549f061..3f97cc71f 100644 --- a/notebooks/automatic_segmentation.ipynb +++ b/notebooks/automatic_segmentation.ipynb @@ -6,12 +6,13 @@ "source": [ "# Automatic Instance Segmentation with Segment Anything for Microscopy\n", "\n", - "This notebook shows how to use Segment Anything (SAM) for automatic instance segmentation, using the corresponding functionality from `µsam` (Segment Anything for Microscopy). We use immunoflourescence microscopy images, abbreviated as `Covid IF` (from [Pape et al](https://doi.org/10.1002/bies.202000257)), in this notebook. The functionalities shown here should work for your (microscopy) images too!\n", + "This notebook shows how to use Segment Anything (SAM) for automatic instance segmentation, using the corresponding functionality from `µsam` ([Segment Anything for Microscopy](https://www.nature.com/articles/s41592-024-02580-4)). We use immunoflourescence microscopy images, abbreviated as `Covid IF` (from [Pape et al](https://doi.org/10.1002/bies.202000257)), in this notebook. The functionalities shown here should work for your (microscopy) images too!\n", "\n", "We demonstrate using two different functionalities:\n", "\n", "1. Automatic Mask Generation (AMG): The \"Segment Anything\" feature where positive point prompts are sampled in a grid over the entire image to perform instance segmentation on 2d images.\n", - "2. Automatic Instance Segmentation (AIS): A new feature introduced in `µsam` where we train an additional decoder to perform automatic instance segmentation. This method is much faster in AMG and yields better results if it is applied to data that is similar to the finetuning dataset on 2d images and volumetric (3d) data." + "2. Automatic Instance Segmentation (AIS): A new feature introduced in `µsam` where we train an additional decoder to perform automatic instance segmentation. This method is much faster than AMG and yields better results if it is applied to data that is similar to the finetuning dataset on 2d images and volumetric (3d) data.\n", + "3. Automatic Prompt Generation (APG): A new feature introduced in [`µsam++`](https://openreview.net/forum?id=xFO3DFZN45) where we derive prompts from the predicted distance maps using `µsam`'s pretrained additional decoder to perform automatic promptable instance segmentation. This method is much faster than AMG and yields better results if it is applied to data that is similar to the finetuning dataset on 2d images and volumetric (3d) data." ] }, { @@ -293,7 +294,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "execution": { "iopub.execute_input": "2025-04-16T15:11:55.246016Z", @@ -305,6 +306,53 @@ }, "outputs": [], "source": [ + "def run_automatic_mask_generation(\n", + " image: np.ndarray,\n", + " ndim: int,\n", + " checkpoint_path: Optional[Union[os.PathLike, str]] = None,\n", + " model_type: str = \"vit_b\",\n", + " device: Optional[Union[str, torch.device]] = None,\n", + " tile_shape: Optional[Tuple[int, int]] = None,\n", + " halo: Optional[Tuple[int, int]] = None,\n", + "):\n", + " \"\"\"Automatic Mask Generation (AMG) is the automatic segmentation method offered by SAM.\n", + "\n", + " NOTE: AMG is supported for both Segment Anything models and `µsam` models.\n", + "\n", + " Args:\n", + " image: The input image.\n", + " ndim: The number of dimensions for the input data.\n", + " checkpoint_path: The path to stored checkpoints.\n", + " model_type: The choice of the SAM / `µsam` model.\n", + " device: The device to run the model inference.\n", + " tile_shape: The tile shape for tiling-based segmentation.\n", + " halo: The overlap shape on each side per tile for stitching the segmented tiles.\n", + "\n", + " Returns:\n", + " The instance segmentation.\n", + " \"\"\"\n", + " # Step 1: Get the 'predictor' and 'segmenter' to perform automatic instance segmentation.\n", + " predictor, segmenter = get_predictor_and_segmenter(\n", + " model_type=model_type, # choice of the Segment Anything model\n", + " checkpoint=checkpoint_path, # overwrite to pass your own finetuned model.\n", + " device=device, # the device to run the model inference.\n", + " segmentation_mode=\"amg\", # set the automatic segmentation mode to AMG.\n", + " is_tiled=(tile_shape is not None), # whether to run automatic segmentation with tiling.\n", + " )\n", + "\n", + " # Step 2: Get the instance segmentation for the given image.\n", + " prediction = automatic_instance_segmentation(\n", + " predictor=predictor, # the predictor for the Segment Anything model.\n", + " segmenter=segmenter, # the segmenter class responsible for generating predictions.\n", + " input_path=image, # the filepath to image or the input array for automatic segmentation.\n", + " ndim=ndim, # the number of input dimensions.\n", + " tile_shape=tile_shape, # the tile shape for tiling-based prediction.\n", + " halo=halo, # the overlap shape for tiling-based prediction.\n", + " )\n", + "\n", + " return prediction\n", + "\n", + "\n", "def run_automatic_instance_segmentation(\n", " image: np.ndarray,\n", " ndim: int,\n", @@ -335,7 +383,7 @@ " model_type=model_type, # choice of the Segment Anything model\n", " checkpoint=checkpoint_path, # overwrite to pass your own finetuned model.\n", " device=device, # the device to run the model inference.\n", - " amg=False, # set the automatic segmentation mode to AIS.\n", + " segmentation_mode=\"ais\", # set the automatic segmentation mode to AIS.\n", " is_tiled=(tile_shape is not None), # whether to run automatic segmentation with tiling.\n", " )\n", "\n", @@ -352,24 +400,24 @@ " return prediction\n", "\n", "\n", - "def run_automatic_mask_generation(\n", + "def run_automatic_prompt_generation(\n", " image: np.ndarray,\n", " ndim: int,\n", " checkpoint_path: Optional[Union[os.PathLike, str]] = None,\n", - " model_type: str = \"vit_b\",\n", + " model_type: str = \"vit_b_lm\",\n", " device: Optional[Union[str, torch.device]] = None,\n", " tile_shape: Optional[Tuple[int, int]] = None,\n", " halo: Optional[Tuple[int, int]] = None,\n", "):\n", - " \"\"\"Automatic Mask Generation (AMG) is the automatic segmentation method offered by SAM.\n", + " \"\"\"Automatic Prompt Generation (APG) by deriving prompts from the additional instance decoder in µSAM.\n", "\n", - " NOTE: AMG is supported for both Segment Anything models and `µsam` models.\n", + " NOTE: APG is supported only for `µsam` models.\n", "\n", " Args:\n", " image: The input image.\n", " ndim: The number of dimensions for the input data.\n", " checkpoint_path: The path to stored checkpoints.\n", - " model_type: The choice of the SAM / `µsam` model.\n", + " model_type: The choice of the `µsam` model.\n", " device: The device to run the model inference.\n", " tile_shape: The tile shape for tiling-based segmentation.\n", " halo: The overlap shape on each side per tile for stitching the segmented tiles.\n", @@ -382,7 +430,7 @@ " model_type=model_type, # choice of the Segment Anything model\n", " checkpoint=checkpoint_path, # overwrite to pass your own finetuned model.\n", " device=device, # the device to run the model inference.\n", - " amg=True, # set the automatic segmentation mode to AMG.\n", + " segmentation_mode=\"apg\", # set the automatic segmentation mode to APG.\n", " is_tiled=(tile_shape is not None), # whether to run automatic segmentation with tiling.\n", " )\n", "\n", @@ -468,7 +516,46 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The segmentation quality looks okay. How does it compare with the original automatic segmentation mode (called AMG, automatic mask generation) offered by the original SAM?" + "The segmentation quality looks okay. How does it compare with the new automatic segmentation mode (called [APG](https://openreview.net/forum?id=xFO3DFZN45), automatic prompt generation), recently introduced by `µsam`?" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Building on `µsam`, APG is a new method for automatic instance segmentation, which derives meaningful prompts using the outputs of the additional `µsam` decoder, to improve the segmentation quality by finding relevant objects with a better heuristic.\n", + "\n", + "Since the method builds on `µsam`, the offered models are directly compatible with APG. We test the smallest light microscopy (LM) `µsam` model on the microscopy data, ViT Base (LM) (abbreviated as `vit_b_lm`)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_choice = \"vit_b_lm\"\n", + "\n", + "for sample_path in sample_paths:\n", + " with h5py.File(sample_path, \"r\") as f:\n", + " raw = f[\"raw/serum_IgG/s0\"][:]\n", + " labels = f[\"labels/cells/s0\"][:]\n", + "\n", + " # NOTE: If you have large images, we recommend using tiling for automatic segmentation.\n", + " # eg. for a training patch size of (512, 512), you can provide the following example combination:\n", + " # 'tile_shape=(384, 384). halo=(64, 64)' for running automatic segmentation over tiles.\n", + " prediction = run_automatic_prompt_generation(raw, ndim=2, model_type=model_choice)\n", + "\n", + " plot_samples(image=raw, gt=labels, segmentation=prediction)\n", + "\n", + " break # comment this out in case you want to visualize all images" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The segmentation quality looks better. How does it compare with the original automatic segmentation mode (called AMG, automatic mask generation) offered by the original SAM?" ] }, { @@ -764,7 +851,7 @@ "sourceType": "notebook" }, "kernelspec": { - "display_name": "Python 3", + "display_name": "super", "language": "python", "name": "python3" }, @@ -778,7 +865,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.11.12" } }, "nbformat": 4, diff --git a/test/test_training.py b/test/test_training.py index 07aa2ff49..e78bc04a6 100644 --- a/test/test_training.py +++ b/test/test_training.py @@ -258,7 +258,7 @@ def test_train_instance_segmentation(self): self.assertTrue(os.path.exists(export_path)) # Check that this model works for AIS. - predictor, segmenter = get_predictor_and_segmenter(model_type, export_path, amg=False) + predictor, segmenter = get_predictor_and_segmenter(model_type, export_path, segmentation_mode="ais") image_path = os.path.join(self.tmp_folder, "synthetic-data", "images", "test", "data-0.tif") segmentation = automatic_instance_segmentation(predictor, segmenter, image_path) expected_shape = imageio.imread(image_path).shape