Skip to content
12 changes: 9 additions & 3 deletions micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import zarr
import numpy as np

from napari.layers import Image
from qtpy.QtWidgets import QWidget

import torch.nn as nn
Expand All @@ -19,7 +20,6 @@
from micro_sam.instance_segmentation import AMGBase, get_decoder
from micro_sam.precompute_state import cache_amg_state, cache_is_state

from napari.layers import Image
from segment_anything import SamPredictor

try:
Expand Down Expand Up @@ -102,6 +102,7 @@ def initialize_predictor(
pbar_update=None,
skip_load=True,
use_cli=False,
is_sam2=False, # By default, we use SAM1.
):
assert ndim in (2, 3)

Expand Down Expand Up @@ -132,8 +133,13 @@ def progress_bar_factory(model_type):
self.image_embeddings = save_path
self.embedding_path = None # setting this to 'None' as we do not have embeddings cached.

else: # otherwise, compute the image embeddings.
self.image_embeddings = util.precompute_image_embeddings(
else: # Otherwise, compute the image embeddings.
if is_sam2:
from micro_sam.v2.util import precompute_image_embeddings as _comp_embed_fn
else:
_comp_embed_fn = util.precompute_image_embeddings

self.image_embeddings = _comp_embed_fn(
predictor=self.predictor,
input_=image_data,
save_path=save_path,
Expand Down
192 changes: 163 additions & 29 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
# from napari.qt.threading import thread_worker
from napari.utils import progress

from segment_anything import SamPredictor

from . import util as vutil
from ._tooltips import get_tooltip
from ._state import AnnotatorState
Expand Down Expand Up @@ -237,8 +239,11 @@ def _get_model_size_options(self):
# We store the actual model names mapped to UI labels.
self.model_size_mapping = {}
if self.model_family == "Natural Images (SAM)":
self.model_size_options = list(self._model_size_map .values())
self.model_size_options = list(self._model_size_map.values())
self.model_size_mapping = {self._model_size_map[k]: f"vit_{k}" for k in self._model_size_map.keys()}
elif self.model_family == "Natural Images (SAM2)":
self.model_size_options = list(self._model_size_map.values())
self.model_size_mapping = {self._model_size_map[k]: f"hvit_{k}" for k in self._model_size_map.keys()}
else:
model_suffix = self.supported_dropdown_maps[self.model_family]
self.model_size_options = []
Expand Down Expand Up @@ -278,7 +283,10 @@ def _update_model_type(self):
size_key = next(
(k for k, v in self._model_size_map.items() if v == self.model_size), "b"
)
self.model_type = f"vit_{size_key}" + self.supported_dropdown_maps[self.model_family]
if "SAM2" in self.model_family:
self.model_type = f"hvit_{size_key}"
else:
self.model_type = f"vit_{size_key}" + self.supported_dropdown_maps[self.model_family]

self.model_size_dropdown.setCurrentText(self.model_size) # Apply the selected text to the dropdown

Expand All @@ -293,6 +301,7 @@ def _create_model_section(self, default_model: str = util._DEFAULT_MODEL, create
# Create a list of support dropdown values and correspond them to suffixes.
self.supported_dropdown_maps = {
"Natural Images (SAM)": "",
"Natural Images (SAM2)": "_sam2",
"Light Microscopy": "_lm",
"Electron Microscopy": "_em_organelles",
"Medical Imaging": "_medical_imaging",
Expand Down Expand Up @@ -343,7 +352,10 @@ def _create_model_size_section(self):

def _validate_model_type_and_custom_weights(self):
# Let's get all model combination stuff into the desired `model_type` structure.
self.model_type = "vit_" + self.model_size[0] + self.supported_dropdown_maps[self.model_family]
if "SAM2" in self.model_family:
self.model_type = "hvit_" + self.model_size[0]
else:
self.model_type = "vit_" + self.model_size[0] + self.supported_dropdown_maps[self.model_family]

# For 'custom_weights', we remove the displayed text on top of the drop-down menu.
if self.custom_weights:
Expand Down Expand Up @@ -1012,10 +1024,21 @@ def segment(viewer: "napari.viewer.Viewer", batched: bool = False) -> None:

predictor = AnnotatorState().predictor
image_embeddings = AnnotatorState().image_embeddings
seg = vutil.prompt_segmentation(
predictor, points, labels, boxes, masks, shape, image_embeddings=image_embeddings,
multiple_box_prompts=True, batched=batched, previous_segmentation=viewer.layers["current_object"].data,
)

if isinstance(predictor, SamPredictor): # This is SAM1 predictor.
seg = vutil.prompt_segmentation(
predictor, points, labels, boxes, masks, shape, image_embeddings=image_embeddings,
multiple_box_prompts=True, batched=batched, previous_segmentation=viewer.layers["current_object"].data,
)
else: # This would be SAM2 predictors.
from micro_sam.v2.prompt_based_segmentation import promptable_segmentation_2d
seg = promptable_segmentation_2d(
predictor=predictor,
points=points,
labels=labels,
boxes=boxes,
masks=masks,
)

# no prompts were given or prompts were invalid, skip segmentation
if seg is None:
Expand Down Expand Up @@ -1053,10 +1076,37 @@ def segment_slice(viewer: "napari.viewer.Viewer") -> None:
points, labels = point_prompts

state = AnnotatorState()
seg = vutil.prompt_segmentation(
state.predictor, points, labels, boxes, masks, shape, multiple_box_prompts=False,
image_embeddings=state.image_embeddings, i=z,
)

# Check if using SAM1 or SAM2
if isinstance(state.predictor, SamPredictor):
# SAM1 path (existing code)
seg = vutil.prompt_segmentation(
state.predictor, points, labels, boxes, masks, shape, multiple_box_prompts=False,
image_embeddings=state.image_embeddings, i=z,
)
else:
# SAM2 path - use PromptableSegmentation3D class
from micro_sam.v2.prompt_based_segmentation import PromptableSegmentation3D

# Get the volume from the viewer
image_name = state.get_image_name(viewer)
volume = viewer.layers[image_name].data

# Create a PromptableSegmentation3D instance with existing inference state
seg_handler = PromptableSegmentation3D(
predictor=state.predictor,
volume=volume,
)

# Use the segment_slice method
boxes = [box[[1, 0, 3, 2]] for box in boxes]
seg = seg_handler.segment_slice(
frame_idx=z,
points=points[:, ::-1].copy(),
labels=labels,
boxes=boxes,
masks=masks
)

# no prompts were given or prompts were invalid, skip segmentation
if seg is None:
Expand Down Expand Up @@ -1449,11 +1499,35 @@ def pbar_init(total, description):
if self.automatic_segmentation_mode == "amg":
prefer_decoder = False

# Define a predictor for SAM2 models.
predictor = None
if self.model_type.startswith("h"): # i.e. SAM2 models.
from micro_sam.v2.util import get_sam2_model

if ndim == 2: # Get the SAM2 model and prepare the image predictor.
model = get_sam2_model(model_type=self.model_type, input_type="images")
# Prepare the SAM2 predictor.
from sam2.sam2_image_predictor import SAM2ImagePredictor
predictor = SAM2ImagePredictor(model)
elif ndim == 3: # Get SAM2 video predictor
predictor = get_sam2_model(model_type=self.model_type, input_type="videos")
else:
raise ValueError

state.initialize_predictor(
image_data, model_type=self.model_type, save_path=save_path, ndim=ndim,
device=self.device, checkpoint_path=self.custom_weights, tile_shape=tile_shape, halo=halo,
prefer_decoder=prefer_decoder, pbar_init=pbar_init,
image_data,
model_type=self.model_type,
save_path=save_path,
ndim=ndim,
device=self.device,
checkpoint_path=self.custom_weights,
predictor=predictor,
tile_shape=tile_shape,
halo=halo,
prefer_decoder=prefer_decoder,
pbar_init=pbar_init,
pbar_update=lambda update: pbar_signals.pbar_update.emit(update),
is_sam2=self.model_type.startswith("h"),
)
pbar_signals.pbar_stop.emit()

Expand Down Expand Up @@ -1536,6 +1610,14 @@ def _create_settings(self):
)
setting_values.layout().addLayout(layout)

# Create the UI element in form of a checkbox for multi-object segmentation.
self.batched = False
setting_values.layout().addWidget(
self._add_boolean_param(
"batched", self.batched, title="batched", tooltip=get_tooltip("segmentnd", "batched")
)
)

# Create the UI element for the motion smoothing (if we have the tracking widget).
if self.tracking:
self.motion_smoothing = 0.5
Expand Down Expand Up @@ -1611,24 +1693,76 @@ def volumetric_segmentation_impl():
pbar_signals.pbar_total.emit(shape[0])
pbar_signals.pbar_description.emit("Segment object")

# Step 1: Segment all slices with prompts.
seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts(
state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"],
state.image_embeddings, shape,
update_progress=lambda update: pbar_signals.pbar_update.emit(update),
)
if isinstance(state.predictor, SamPredictor): # This is SAM2 predictor.
# Step 1: Segment all slices with prompts.
seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts(
state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"],
state.image_embeddings, shape,
update_progress=lambda update: pbar_signals.pbar_update.emit(update),
)

# Step 2: Segment the rest of the volume based on projecting prompts.
seg, (z_min, z_max) = segment_mask_in_volume(
seg, state.predictor, state.image_embeddings, slices,
stop_lower, stop_upper,
iou_threshold=self.iou_threshold, projection=self.projection,
box_extension=self.box_extension,
update_progress=lambda update: pbar_signals.pbar_update.emit(update),
)

state.z_range = (z_min, z_max)

else: # This would be SAM2 predictors.
# Prepare the prompts
point_prompts = self._viewer.layers["point_prompts"]
box_prompts = self._viewer.layers["prompts"]
z_values_points = np.round(point_prompts.data[:, 0])
z_values_boxes = np.concatenate(
[box[:1, 0] for box in box_prompts.data]
) if box_prompts.data else np.zeros(0, dtype="int")

# Get the volumetric data.
# TODO: We need to switch later to volume embeddings.
volume = self._viewer.layers[0].data # Assumption is image is in the first index.

# NOTE: Prototype for new design of prompting in volumetric data.
# CP: this looks redundant. We redo the initialization each time a prompt is added.
from micro_sam.v2.prompt_based_segmentation import PromptableSegmentation3D
segmenter = PromptableSegmentation3D(predictor=state.predictor, volume=volume)

# Whether the user decide to provide batched prompts for multi-object segmentation.
is_batched = bool(self.batched)

# Let's do points first.
for curr_z_values_point in z_values_points:
# Extract the point prompts from the points layer first.
points, labels = vutil.point_layer_to_prompts(layer=point_prompts, i=curr_z_values_point)

# Add prompts one after the other.
[
segmenter.add_point_prompts(
frame_ids=curr_z_values_point,
points=np.array([curr_point]),
point_labels=np.array([curr_label]),
object_id=i if is_batched else None,
) for i, (curr_point, curr_label) in enumerate(zip(points, labels), start=1)
]

# Next, we add box prompts.
for curr_z_values_box in z_values_boxes:
# Extract the box prompts from the shapes layer first.
boxes, _ = vutil.shape_layer_to_prompts(
layer=box_prompts, shape=state.image_shape, i=curr_z_values_box,
)

# Add prompts one after the other.
segmenter.add_box_prompts(frame_ids=curr_z_values_box, boxes=boxes)

# Propagate the prompts throughout the volume and combine the propagated segmentations.
seg = segmenter.predict()

# Step 2: Segment the rest of the volume based on projecting prompts.
seg, (z_min, z_max) = segment_mask_in_volume(
seg, state.predictor, state.image_embeddings, slices,
stop_lower, stop_upper,
iou_threshold=self.iou_threshold, projection=self.projection,
box_extension=self.box_extension,
update_progress=lambda update: pbar_signals.pbar_update.emit(update),
)
pbar_signals.pbar_stop.emit()

state.z_range = (z_min, z_max)
return seg

def update_segmentation(seg):
Expand Down
3 changes: 2 additions & 1 deletion micro_sam/sam_annotator/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,8 @@ def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, devic

# Update the index for model size, eg. 'base', 'tiny', etc.
size_map = {"t": "tiny", "b": "base", "l": "large", "h": "huge"}
model_size = size_map[model_type[4]]
size_idx = 5 if model_type.startswith("h") else 4
model_size = size_map[model_type[size_idx]]

index = widget.model_size_dropdown.findText(model_size)
if index > 0:
Expand Down
Empty file added micro_sam/v2/__init__.py
Empty file.
Empty file added micro_sam/v2/models/__init__.py
Empty file.
Loading
Loading