diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index c15a1fcf..eefb07b8 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -10,6 +10,7 @@ import zarr import numpy as np +from napari.layers import Image from qtpy.QtWidgets import QWidget import torch.nn as nn @@ -19,7 +20,6 @@ from micro_sam.instance_segmentation import AMGBase, get_decoder from micro_sam.precompute_state import cache_amg_state, cache_is_state -from napari.layers import Image from segment_anything import SamPredictor try: @@ -102,6 +102,7 @@ def initialize_predictor( pbar_update=None, skip_load=True, use_cli=False, + is_sam2=False, # By default, we use SAM1. ): assert ndim in (2, 3) @@ -132,8 +133,13 @@ def progress_bar_factory(model_type): self.image_embeddings = save_path self.embedding_path = None # setting this to 'None' as we do not have embeddings cached. - else: # otherwise, compute the image embeddings. - self.image_embeddings = util.precompute_image_embeddings( + else: # Otherwise, compute the image embeddings. + if is_sam2: + from micro_sam.v2.util import precompute_image_embeddings as _comp_embed_fn + else: + _comp_embed_fn = util.precompute_image_embeddings + + self.image_embeddings = _comp_embed_fn( predictor=self.predictor, input_=image_data, save_path=save_path, diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 9ec2d542..fb9e1b76 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -31,6 +31,8 @@ # from napari.qt.threading import thread_worker from napari.utils import progress +from segment_anything import SamPredictor + from . import util as vutil from ._tooltips import get_tooltip from ._state import AnnotatorState @@ -237,8 +239,11 @@ def _get_model_size_options(self): # We store the actual model names mapped to UI labels. self.model_size_mapping = {} if self.model_family == "Natural Images (SAM)": - self.model_size_options = list(self._model_size_map .values()) + self.model_size_options = list(self._model_size_map.values()) self.model_size_mapping = {self._model_size_map[k]: f"vit_{k}" for k in self._model_size_map.keys()} + elif self.model_family == "Natural Images (SAM2)": + self.model_size_options = list(self._model_size_map.values()) + self.model_size_mapping = {self._model_size_map[k]: f"hvit_{k}" for k in self._model_size_map.keys()} else: model_suffix = self.supported_dropdown_maps[self.model_family] self.model_size_options = [] @@ -278,7 +283,10 @@ def _update_model_type(self): size_key = next( (k for k, v in self._model_size_map.items() if v == self.model_size), "b" ) - self.model_type = f"vit_{size_key}" + self.supported_dropdown_maps[self.model_family] + if "SAM2" in self.model_family: + self.model_type = f"hvit_{size_key}" + else: + self.model_type = f"vit_{size_key}" + self.supported_dropdown_maps[self.model_family] self.model_size_dropdown.setCurrentText(self.model_size) # Apply the selected text to the dropdown @@ -293,6 +301,7 @@ def _create_model_section(self, default_model: str = util._DEFAULT_MODEL, create # Create a list of support dropdown values and correspond them to suffixes. self.supported_dropdown_maps = { "Natural Images (SAM)": "", + "Natural Images (SAM2)": "_sam2", "Light Microscopy": "_lm", "Electron Microscopy": "_em_organelles", "Medical Imaging": "_medical_imaging", @@ -343,7 +352,10 @@ def _create_model_size_section(self): def _validate_model_type_and_custom_weights(self): # Let's get all model combination stuff into the desired `model_type` structure. - self.model_type = "vit_" + self.model_size[0] + self.supported_dropdown_maps[self.model_family] + if "SAM2" in self.model_family: + self.model_type = "hvit_" + self.model_size[0] + else: + self.model_type = "vit_" + self.model_size[0] + self.supported_dropdown_maps[self.model_family] # For 'custom_weights', we remove the displayed text on top of the drop-down menu. if self.custom_weights: @@ -1012,10 +1024,21 @@ def segment(viewer: "napari.viewer.Viewer", batched: bool = False) -> None: predictor = AnnotatorState().predictor image_embeddings = AnnotatorState().image_embeddings - seg = vutil.prompt_segmentation( - predictor, points, labels, boxes, masks, shape, image_embeddings=image_embeddings, - multiple_box_prompts=True, batched=batched, previous_segmentation=viewer.layers["current_object"].data, - ) + + if isinstance(predictor, SamPredictor): # This is SAM1 predictor. + seg = vutil.prompt_segmentation( + predictor, points, labels, boxes, masks, shape, image_embeddings=image_embeddings, + multiple_box_prompts=True, batched=batched, previous_segmentation=viewer.layers["current_object"].data, + ) + else: # This would be SAM2 predictors. + from micro_sam.v2.prompt_based_segmentation import promptable_segmentation_2d + seg = promptable_segmentation_2d( + predictor=predictor, + points=points, + labels=labels, + boxes=boxes, + masks=masks, + ) # no prompts were given or prompts were invalid, skip segmentation if seg is None: @@ -1053,10 +1076,37 @@ def segment_slice(viewer: "napari.viewer.Viewer") -> None: points, labels = point_prompts state = AnnotatorState() - seg = vutil.prompt_segmentation( - state.predictor, points, labels, boxes, masks, shape, multiple_box_prompts=False, - image_embeddings=state.image_embeddings, i=z, - ) + + # Check if using SAM1 or SAM2 + if isinstance(state.predictor, SamPredictor): + # SAM1 path (existing code) + seg = vutil.prompt_segmentation( + state.predictor, points, labels, boxes, masks, shape, multiple_box_prompts=False, + image_embeddings=state.image_embeddings, i=z, + ) + else: + # SAM2 path - use PromptableSegmentation3D class + from micro_sam.v2.prompt_based_segmentation import PromptableSegmentation3D + + # Get the volume from the viewer + image_name = state.get_image_name(viewer) + volume = viewer.layers[image_name].data + + # Create a PromptableSegmentation3D instance with existing inference state + seg_handler = PromptableSegmentation3D( + predictor=state.predictor, + volume=volume, + ) + + # Use the segment_slice method + boxes = [box[[1, 0, 3, 2]] for box in boxes] + seg = seg_handler.segment_slice( + frame_idx=z, + points=points[:, ::-1].copy(), + labels=labels, + boxes=boxes, + masks=masks + ) # no prompts were given or prompts were invalid, skip segmentation if seg is None: @@ -1449,11 +1499,35 @@ def pbar_init(total, description): if self.automatic_segmentation_mode == "amg": prefer_decoder = False + # Define a predictor for SAM2 models. + predictor = None + if self.model_type.startswith("h"): # i.e. SAM2 models. + from micro_sam.v2.util import get_sam2_model + + if ndim == 2: # Get the SAM2 model and prepare the image predictor. + model = get_sam2_model(model_type=self.model_type, input_type="images") + # Prepare the SAM2 predictor. + from sam2.sam2_image_predictor import SAM2ImagePredictor + predictor = SAM2ImagePredictor(model) + elif ndim == 3: # Get SAM2 video predictor + predictor = get_sam2_model(model_type=self.model_type, input_type="videos") + else: + raise ValueError + state.initialize_predictor( - image_data, model_type=self.model_type, save_path=save_path, ndim=ndim, - device=self.device, checkpoint_path=self.custom_weights, tile_shape=tile_shape, halo=halo, - prefer_decoder=prefer_decoder, pbar_init=pbar_init, + image_data, + model_type=self.model_type, + save_path=save_path, + ndim=ndim, + device=self.device, + checkpoint_path=self.custom_weights, + predictor=predictor, + tile_shape=tile_shape, + halo=halo, + prefer_decoder=prefer_decoder, + pbar_init=pbar_init, pbar_update=lambda update: pbar_signals.pbar_update.emit(update), + is_sam2=self.model_type.startswith("h"), ) pbar_signals.pbar_stop.emit() @@ -1536,6 +1610,14 @@ def _create_settings(self): ) setting_values.layout().addLayout(layout) + # Create the UI element in form of a checkbox for multi-object segmentation. + self.batched = False + setting_values.layout().addWidget( + self._add_boolean_param( + "batched", self.batched, title="batched", tooltip=get_tooltip("segmentnd", "batched") + ) + ) + # Create the UI element for the motion smoothing (if we have the tracking widget). if self.tracking: self.motion_smoothing = 0.5 @@ -1611,24 +1693,76 @@ def volumetric_segmentation_impl(): pbar_signals.pbar_total.emit(shape[0]) pbar_signals.pbar_description.emit("Segment object") - # Step 1: Segment all slices with prompts. - seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts( - state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"], - state.image_embeddings, shape, - update_progress=lambda update: pbar_signals.pbar_update.emit(update), - ) + if isinstance(state.predictor, SamPredictor): # This is SAM2 predictor. + # Step 1: Segment all slices with prompts. + seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts( + state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"], + state.image_embeddings, shape, + update_progress=lambda update: pbar_signals.pbar_update.emit(update), + ) + + # Step 2: Segment the rest of the volume based on projecting prompts. + seg, (z_min, z_max) = segment_mask_in_volume( + seg, state.predictor, state.image_embeddings, slices, + stop_lower, stop_upper, + iou_threshold=self.iou_threshold, projection=self.projection, + box_extension=self.box_extension, + update_progress=lambda update: pbar_signals.pbar_update.emit(update), + ) + + state.z_range = (z_min, z_max) + + else: # This would be SAM2 predictors. + # Prepare the prompts + point_prompts = self._viewer.layers["point_prompts"] + box_prompts = self._viewer.layers["prompts"] + z_values_points = np.round(point_prompts.data[:, 0]) + z_values_boxes = np.concatenate( + [box[:1, 0] for box in box_prompts.data] + ) if box_prompts.data else np.zeros(0, dtype="int") + + # Get the volumetric data. + # TODO: We need to switch later to volume embeddings. + volume = self._viewer.layers[0].data # Assumption is image is in the first index. + + # NOTE: Prototype for new design of prompting in volumetric data. + # CP: this looks redundant. We redo the initialization each time a prompt is added. + from micro_sam.v2.prompt_based_segmentation import PromptableSegmentation3D + segmenter = PromptableSegmentation3D(predictor=state.predictor, volume=volume) + + # Whether the user decide to provide batched prompts for multi-object segmentation. + is_batched = bool(self.batched) + + # Let's do points first. + for curr_z_values_point in z_values_points: + # Extract the point prompts from the points layer first. + points, labels = vutil.point_layer_to_prompts(layer=point_prompts, i=curr_z_values_point) + + # Add prompts one after the other. + [ + segmenter.add_point_prompts( + frame_ids=curr_z_values_point, + points=np.array([curr_point]), + point_labels=np.array([curr_label]), + object_id=i if is_batched else None, + ) for i, (curr_point, curr_label) in enumerate(zip(points, labels), start=1) + ] + + # Next, we add box prompts. + for curr_z_values_box in z_values_boxes: + # Extract the box prompts from the shapes layer first. + boxes, _ = vutil.shape_layer_to_prompts( + layer=box_prompts, shape=state.image_shape, i=curr_z_values_box, + ) + + # Add prompts one after the other. + segmenter.add_box_prompts(frame_ids=curr_z_values_box, boxes=boxes) + + # Propagate the prompts throughout the volume and combine the propagated segmentations. + seg = segmenter.predict() - # Step 2: Segment the rest of the volume based on projecting prompts. - seg, (z_min, z_max) = segment_mask_in_volume( - seg, state.predictor, state.image_embeddings, slices, - stop_lower, stop_upper, - iou_threshold=self.iou_threshold, projection=self.projection, - box_extension=self.box_extension, - update_progress=lambda update: pbar_signals.pbar_update.emit(update), - ) pbar_signals.pbar_stop.emit() - state.z_range = (z_min, z_max) return seg def update_segmentation(seg): diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index 85096fb1..7c1bea4c 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -693,7 +693,8 @@ def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, devic # Update the index for model size, eg. 'base', 'tiny', etc. size_map = {"t": "tiny", "b": "base", "l": "large", "h": "huge"} - model_size = size_map[model_type[4]] + size_idx = 5 if model_type.startswith("h") else 4 + model_size = size_map[model_type[size_idx]] index = widget.model_size_dropdown.findText(model_size) if index > 0: diff --git a/micro_sam/v2/__init__.py b/micro_sam/v2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/micro_sam/v2/models/__init__.py b/micro_sam/v2/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/micro_sam/v2/models/_video_predictor.py b/micro_sam/v2/models/_video_predictor.py new file mode 100644 index 00000000..fa9969a6 --- /dev/null +++ b/micro_sam/v2/models/_video_predictor.py @@ -0,0 +1,232 @@ +import os +from tqdm import tqdm +from typing import Optional +from collections import OrderedDict + +import numpy as np +from PIL import Image +from skimage.transform import resize + +import torch + +from sam2.build_sam import _load_checkpoint +from sam2.sam2_video_predictor import SAM2VideoPredictor +from sam2.utils.misc import AsyncVideoFrameLoader + + +def _load_img_as_tensor(img_path, image_size): + if isinstance(img_path, str): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + video_width, video_height = img_pil.size # the original video size + else: + img_np = img_path + img_np = np.stack([img_np] * 3, axis=-1) if img_np.ndim == 2 else img_np # Make it in RGB style. + img_np = resize( + img_np, + output_shape=(image_size, image_size, 3), + order=0, + anti_aliasing=False, + preserve_range=True, + ).astype(img_np.dtype) + video_height, video_width = img_path.shape + + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + + img = torch.from_numpy(img_np).permute(2, 0, 1) + return img, video_height, video_width + + +def _load_video_frames_from_images( + video_path, + volume, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), + verbosity=True, +): + """Based on 'load_video_frames_from_jpg_images'. + + Load the video frames from a directory of image files (eg. ".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if video_path is None: + assert isinstance(volume, np.ndarray) and volume.ndim == 3, "Something is off with the 'volume'." + # Iterate over each slice. + images = [] + for i, curr_slice in enumerate(volume): + curr_image, video_height, video_width = _load_img_as_tensor(curr_slice, image_size) + images.append(curr_image) + images = torch.stack(images) # Stack the inputs in expected format. + else: + if isinstance(video_path, str) and os.path.isdir(video_path): + frames_folder = video_path + else: + raise AssertionError("The video predictor expects the user to provide the folder where frames are stored.") + + frame_names = [p for p in os.listdir(frames_folder)] # NOTE: This part has changed to support multiple ffs. + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"No images found in '{frames_folder}'.") + + img_paths = [os.path.join(frames_folder, frame_name) for frame_name in frame_names] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading", disable=not verbosity)): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + + # Normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +class CustomVideoPredictor(SAM2VideoPredictor): + """The video predictor class inherited from the original predictor class to update 'init_state'. + """ + + @torch.inference_mode() + def init_state( + self, + video_path: Optional[str], + volume: Optional[np.ndarray] = None, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + verbosity=True, + ): + """Initialize an inference state.""" + compute_device = self.device # device of the model + + # Either video_path or volume is None. Both cannot be None or passed at the same time. + if (video_path is None) == (volume is None): + raise ValueError("Only one of 'video_path' or 'volume' must be provided (not both or neither).") + + # CP: this looks redundant. We load the video data each time this is called. + images, video_height, video_width = _load_video_frames_from_images( + video_path=video_path, + volume=volume, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + verbosity=verbosity, + ) + + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = compute_device + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = compute_device + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["frames_tracked_per_obj"] = {} + + # CP: this looks redundant. We compute the embedding each time this is called despite it already being there. + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + +def _build_sam2_video_predictor( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + from hydra import compose + from hydra.utils import instantiate + from omegaconf import OmegaConf + + hydra_overrides = [ + "++model._target_=micro_sam2.models._video_predictor.CustomVideoPredictor", + ] + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks + # are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` + # (before resizing them to the original video resolution) + "++model.fill_hole_area=8", + ] + hydra_overrides.extend(hydra_overrides_extra) + + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model diff --git a/micro_sam/v2/prompt_based_segmentation.py b/micro_sam/v2/prompt_based_segmentation.py new file mode 100644 index 00000000..95ab0ad6 --- /dev/null +++ b/micro_sam/v2/prompt_based_segmentation.py @@ -0,0 +1,491 @@ +from typing import Optional, Union, List + +import numpy as np + +from micro_sam.prompt_based_segmentation import _process_box + + +def promptable_segmentation_2d( + predictor, + image: Optional[np.ndarray] = None, + points: Optional[np.ndarray] = None, + labels: Optional[np.ndarray] = None, + boxes: Optional[np.ndarray] = None, + masks: Optional[np.ndarray] = None, + batched: Optional[bool] = None, +): + """@private""" + + if image is not None: + if image.ndim == 2: + image = np.stack([image] * 3, axis=-1) + + assert image.ndim == 3 + + if image.shape[0] == 3: # Make channels last, as expected in RGB images. + image = image.transpose(1, 2, 0) + + # Set the predictor state. + predictor.set_image(image.astype("uint8")) + + assert len(points) == len(labels) + have_points = points is not None and len(points) > 0 + have_boxes = boxes is not None and len(boxes) > 0 + + # If no prompts are provided, return 'None'. + if not have_points and not have_boxes: + return + + kwargs = {} + if have_points: + kwargs["point_coords"] = points[:, ::-1].copy() # Ensure contiguous array convention so that PyTorch likes it. + kwargs["point_labels"] = labels + if have_boxes: + shape = predictor._orig_hw[0] + kwargs["box"] = np.array([_process_box(b, shape) for b in boxes]) + + # Run interactive segmentation. + masks, scores, logits = predictor.predict( + # mask_input=masks, + multimask_output=False, # NOTE: Hard-coded to 'False' atm. + **kwargs + ) + + # Get the count of points / boxes. + n_points = len(points) if have_points else 0 + n_boxes = len(boxes) if have_boxes else 0 + + if n_points > 1 or n_boxes > 1: # Has more than one object, expected instance segmentation. + out = np.zeros(masks.shape[-2:]) + for i, curr_mask in enumerate(masks, start=1): + out[curr_mask.squeeze() > 0] = i + else: + out = masks.squeeze() + + # HACK: Hard-code the expected data type for labels for napari labels layer: uint8 + out = out.astype("uint8") + + return out + + +def promptable_segmentation_3d( + predictor, + volume: np.ndarray, + frame_id: int, + volume_embeddings: Optional[...] = None, + points: Optional[np.ndarray] = None, + labels: Optional[np.ndarray] = None, + boxes: Optional[np.ndarray] = None, + masks: Optional[np.ndarray] = None, +): + """@private""" + + assert volume.ndim == 3 + + # Initialize the inference state + inference_state = predictor.init_state(video_path=None, volume=volume) + + assert len(points) == len(labels) + have_points = points is not None and len(points) > 0 + have_boxes = boxes is not None and len(boxes) > 0 + + # If no prompts are provided, return 'None'. + if not have_points and not have_boxes: + return + + kwargs = {} + if have_points: + kwargs["points"] = points[:, ::-1].copy() # Ensure contiguous array convention so that PyTorch likes it. + kwargs["labels"] = labels + if have_boxes: + shape = volume.shape[-2:] + kwargs["box"] = np.array([_process_box(b, shape) for b in boxes]) + + # Add point/box prompts in a single frame. + _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( + inference_state=inference_state, + frame_idx=int(frame_id), + obj_id=1, # NOTE: Setting a fixed object id, assuming only one object is being segmented. + clear_old_points=True, # Whether to make use of old points in memory. + **kwargs + ) + + # TODO: Figure out how to integrate mask prompts in 3d. + + # Next, propagate the masklets throughout the frames using the input prompts in selected frames. + forward_video_segments = {} + for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): + forward_video_segments[out_frame_idx] = { + out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) + } + + # Let's do the propagation reverse in time now. + reverse_video_segments = {} + if len(forward_video_segments) < volume.shape[0]: # Perform reverse propagation only if necessary + for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( + inference_state, reverse=True, + ): + reverse_video_segments[out_frame_idx] = { + out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) + } + # NOTE: The order is reversed to stitch the reverse propagation with forward. + reverse_video_segments = dict(reversed(list(reverse_video_segments.items()))) + + # We stitch the segmented slices together. + video_segments = {**reverse_video_segments, **forward_video_segments} + + # Now, let's merge the segmented objects per frame back together as instances per slice. + segmentation = [] + for slice_idx in video_segments.keys(): + per_slice_seg = np.zeros(volume.shape[-2:]) + for _instance_idx, _instance_mask in video_segments[slice_idx].items(): + per_slice_seg[_instance_mask.squeeze()] = _instance_idx + segmentation.append(per_slice_seg) + + segmentation = (np.stack(segmentation) > 0).astype("uint64") + + # Reset the state after finishing the segmentation round. + predictor.reset_state(inference_state) + + return segmentation + + +class PromptableSegmentation3D: + """Promptable segmentation class for volumetric data. + """ + def __init__(self, predictor, volume): + self.predictor = predictor + self.volume = volume + + if self.volume.ndim != 3: + raise AssertionError(f"The dimensionality of the volume should be 3, got '{self.volume.ndim}'") + + self.init_predictor() + + # Store prompts per instance. + self.running_point_frame_ids: Optional[Union[List[int]]] = None + self.running_points: Optional[np.ndarray] = None + self.running_point_labels: Optional[np.ndarray] = None + + self.running_box_frame_ids: Optional[Union[int, List[int]]] = None + self.running_boxes: Optional[np.ndarray] = None + + self.running_mask_frame_ids: Optional[Union[int, List[int]]] = None + self.running_masks: Optional[np.ndarray] = None + + def init_predictor(self): + # Initialize the inference state. + self.inference_state = self.predictor.init_state(video_path=None, volume=self.volume) + + def reset_predictor(self): + # Reset the state after finishing the segmentation round. + self.predictor.reset_state(self.inference_state) + + def _as_array(self, x): + return None if x is None else np.asarray(x) + + def _is_array_equal(self, a, b): + if a is None and b is None: + return True + if (a is None) != (b is None): + return False + a = np.asarray(a) + b = np.asarray(b) + return a.shape == b.shape and np.array_equal(a, b) + + def _is_prefix(self, old, new) -> bool: + """Checks whether the new object is a prefix element of the older object.""" + if old is None: + return True + if new is None: + return False + + old = np.asarray(old) + new = np.asarray(new) + + if old.ndim == 0 or new.ndim == 0: + return False + if old.shape[1:] != new.shape[1:]: + return False + if old.shape[0] > new.shape[0]: + return False + return np.array_equal(old, new[: old.shape[0]]) + + def _tail(self, old, new): + """Returns the trailing tail by eliminating the prefix from new object.""" + if old is None: + return new + old = np.asarray(old) + new = np.asarray(new) + if old.shape[0] == new.shape[0]: + return None + return new[old.shape[0]:] + + def get_valid_prompts(self, frame_ids, points=None, labels=None, boxes=None, masks=None): + """Returns the valid prompts to add for promptable segmentation. + + This workflow manages and returns prompts to add to make sure of the following: + 1. Either use new (unused) prompts, or. + 2. Reprompt all prompts in case an old prompt is deleted. + """ + have_points = (points is not None) or (labels is not None) + if have_points and (points is None or labels is None): + raise ValueError("For using point prompts, both 'points' and 'labels' must be provided.") + have_boxes = boxes is not None + have_masks = masks is not None + + valid_prompt_combinations = sum([have_points, have_boxes, have_masks]) + if valid_prompt_combinations == 0: + raise ValueError("You must provide a valid prompt combination.") + elif valid_prompt_combinations > 1: + raise ValueError("Please choose only one of the prompt combinations.") + + # The core manager for maintaining prompts in memory and returning valid prompts. + if have_points: + points = self._as_array(points) + labels = self._as_array(labels) + + # Let's perform a quick point prompt sanity check. + if points.ndim == 0: + raise ValueError("'points' must be array-like, not a scalar.") + if labels.ndim == 0: + raise ValueError("'labels' must be array-like, not a scalar.") + if points.shape[0] == 0: + raise ValueError("'points' must contain at least one point.") + if labels.shape[0] != points.shape[0]: + raise ValueError("'labels' must have the same length as 'points'.") + + # If the prompt arrive here for the first time, remember me :) + if self.running_point_frame_ids is None: + self.running_point_frame_ids = frame_ids + self.running_points = points + self.running_point_labels = labels + return {"mode": "all", "frame_ids": frame_ids, "points": points, "labels": labels} + + # If the 'frame_ids' change, the safest would be to reprompt all and overwrite. + if frame_ids != self.running_point_frame_ids: + self.running_point_frame_ids = frame_ids + self.running_points = points + self.running_point_labels = labels + return {"mode": "all", "frame_ids": frame_ids, "points": points, "labels": labels} + + # If the prompt arrive and exactly match the stored prompts, return no prompts. + if ( + self._is_array_equal(points, self.running_points) and + self._is_array_equal(labels, self.running_point_labels) + ): + return {} + + # If the prompts arrive and have some new prompts compared to stored ones, only return the new ones. + if self._is_prefix(self.running_points, points) and self._is_prefix(self.running_point_labels, labels): + new_points = self._tail(self.running_points, points) + new_labels = self._tail(self.running_point_labels, labels) + + # Let's update the prompt storage to the full incoming prompts. + self.running_points = points + self.running_point_labels = labels + + if new_points is None: + return {} + return {"mode": "tail", "frame_ids": frame_ids, "points": new_points, "labels": new_labels} + + # If the prompts arrive and have some old stored prompts deleted, return all arrived prompts as is. + # NOTE: It could be deletion / modification / reordering, we simply reprompt all prompts again. + self.running_points = points + self.running_point_labels = labels + return {"mode": "all", "frame_ids": frame_ids, "points": points, "labels": labels} + + def add_point_prompts( + self, + frame_ids: Union[int, List[int]], + points: np.ndarray, + point_labels: np.ndarray, + object_id: Optional[Union[List[int], int]] = None, + multiple_objects: bool = False, # Enables multi-object segmentation. + ): + """ + """ + # Support multi-object segmentation. + if multiple_objects and object_id is not None: + raise ValueError("Well you can't segment multiple objects and provide a specific id, duh!") + + # In case there is no multi-object segmentation happening and the user forgot to specify object, pin obj_id=1. + if object_id is None: + object_id = 1 + + # If no point prompts are provided, return 'None'. + if points is None or len(points) == 0: + return + + # Check what's been provided by the user. + if not isinstance(frame_ids, list): + frame_ids = [frame_ids] + + if len(points) != len(point_labels): + raise AssertionError("The number of points and corresponding labels for it are mismatching.") + + # Prepare the point prompts. + expected_prompts = self.get_valid_prompts(frame_ids=frame_ids, points=points, labels=point_labels) + if not expected_prompts: # If there are no new prompts, we should not add them. + return + + mode = expected_prompts["mode"] + frame_ids = expected_prompts["frame_ids"] + points = expected_prompts["points"] + point_labels = expected_prompts["labels"] + + clear_old_points = (mode == "all") # TODO: Make use of this in a smarter way! + points = points[:, ::-1].copy() # Ensure contiguous array convention so that PyTorch likes it. + + # Make object ids consistent to our per-prompt addition strategy + if not isinstance(object_id, list): + object_id = [object_id] + + # Now that we have lists, they should match the total number of prompts (hint: going towards multiple objects) + if len(object_id) != len(point_labels) and len(object_id == 1): + object_id = object_id * len(point_labels) + + # At this stage, the length of points, point_labels and object_id should match. + assert len(object_id) == len(point_labels) == len(points), "Number of object ids should match total prompts." + + # Add point prompts in a particular frame. + for i, (curr_frame_id, curr_point, curr_point_label, curr_obj_id) in enumerate( + zip(frame_ids, points, point_labels, object_id) + ): + self.predictor.add_new_points_or_box( + inference_state=self.inference_state, + frame_idx=int(curr_frame_id), + obj_id=curr_obj_id, # NOTE: Setting a fixed object id, assuming only one object is being segmented. + clear_old_points=False, # HACK: Hard-coded atm # Whether to make use of old points in memory. + points=np.array([curr_point]), + labels=np.array([curr_point_label]), + ) + + def add_box_prompts(self, frame_ids: Union[int, List[int]], boxes: Optional[np.ndarray] = None): + # Check what's been provided by the user. + have_boxes = boxes is not None and len(boxes) > 0 + + # If no boxes prompts are provided, return 'None'. + if not have_boxes: + return + + if not isinstance(frame_ids, List): + frame_ids = [frame_ids] + + # Prepare the box prompts. + # TODO: Validate based on running prompts. + clear_old_points = True # TODO: Must depend on the running prompt logic. + boxes = np.array([_process_box(b, self.volume.shape[-2:]) for b in boxes]) + + # Add box prompts in a particular frame. + for curr_frame_id, curr_box in zip(frame_ids, boxes): + self.predictor.add_new_points_or_box( + inference_state=self.inference_state, + frame_idx=int(curr_frame_id), + obj_id=1, # NOTE: Setting a fixed object id, assuming only one object is being segmented. + clear_old_points=clear_old_points, # Whether to make use of old points in memory. + box=np.array([curr_box]), + ) + + def add_mask_prompts( + self, frame_ids: Union[int, List[int]], masks: Optional[np.ndarray] = None, + ): + raise NotImplementedError + + def propagate_prompts(self): + # First, we propagate the masklets throughout the frames using the input prompts in selected frames. + forward_video_segments = {} + for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(self.inference_state): + forward_video_segments[out_frame_idx] = { + out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) + } + + # Next, we do the propagation reverse in time. + reverse_video_segments = {} + if len(forward_video_segments) < self.volume.shape[0]: # Perform reverse propagation only if necessary + for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video( + self.inference_state, reverse=True, + ): + reverse_video_segments[out_frame_idx] = { + out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) + } + # NOTE: The order is reversed to stitch the reverse propagation with forward. + reverse_video_segments = dict(reversed(list(reverse_video_segments.items()))) + + # Now, we should stitch the segmented slices together. + video_segments = {**reverse_video_segments, **forward_video_segments} + return video_segments + + def segment_slice( + self, + frame_idx: int, + points: Optional[np.ndarray] = None, + labels: Optional[np.ndarray] = None, + boxes: Optional[List] = None, + masks: Optional[List] = None, + object_id: int = 1, + ): + """Segment a single slice using SAM2 video predictor. + + Args: + frame_idx: Slice index to segment. + points: Point prompts (N, 2) array. + labels: Point labels (N,) array. + boxes: List of box prompts. + masks: List of mask prompts (can be None). + object_id: Object ID to use for the segmentation (default: 1). + + Returns: + Segmentation mask for the slice (2D array), or None if no valid prompts provided. + """ + # Validate prompts + have_points = points is not None and len(points) > 0 + have_boxes = boxes is not None and len(boxes) > 0 + + if not have_points and not have_boxes: + return None + + try: + # Prepare prompts + box = boxes[0] if have_boxes else None + + # Add prompts to the specific frame + _, out_obj_ids, out_mask_logits = self.predictor.add_new_points_or_box( + inference_state=self.inference_state, + frame_idx=frame_idx, + obj_id=object_id, + points=points if have_points else None, + labels=labels if have_points else None, + box=box, + ) + + # Extract the mask from logits + # out_mask_logits shape: (num_objects, 1, H, W) + mask_logits = out_mask_logits[0] # Get first object + seg = (mask_logits.squeeze() > 0.0).cpu().numpy() + + # Ensure correct output type + seg = seg.astype("uint32") + + finally: + # Reset the state to clear this object's prompts + # This ensures the next segmentation starts fresh + self.predictor.reset_state(self.inference_state) + + return seg + + def predict(self): + # First, we propagate prompts. + video_segments = self.propagate_prompts() + + # Next, let's merge the segmented objects per frame back together as instances per slice. + segmentation = [] + for slice_idx in video_segments.keys(): + per_slice_seg = np.zeros(self.volume.shape[-2:]) + for _instance_idx, _instance_mask in video_segments[slice_idx].items(): + per_slice_seg[_instance_mask.squeeze()] = _instance_idx + segmentation.append(per_slice_seg) + + segmentation = np.stack(segmentation).astype("uint64") + + return segmentation diff --git a/micro_sam/v2/util.py b/micro_sam/v2/util.py new file mode 100644 index 00000000..a55957f9 --- /dev/null +++ b/micro_sam/v2/util.py @@ -0,0 +1,391 @@ +import os +import sys +import pooch +from pathlib import Path +from typing import Union, Literal, Optional, Tuple + +import zarr +import numpy as np + +import torch + +from micro_sam.util import get_device + +import sam2 +from sam2.build_sam import build_sam2 + +from .models._video_predictor import _build_sam2_video_predictor + + +# NOTE: The model config is expected to be fetched from the module's relative path location. +sys.path.append(Path(sam2.__file__).parents[0]) + + +_DEFAULT_MODEL = "hvit_t" + +BACKBONE = "sam2.1" + +CFG_PATHS = { + "sam2.0": { + "hvit_t": "configs/sam2/sam2_hiera_t.yaml", + "hvit_s": "configs/sam2/sam2_hiera_s.yaml", + "hvit_b": "configs/sam2/sam2_hiera_b+.yaml", + "hvit_l": "configs/sam2/sam2_hiera_l.yaml", + }, + "sam2.1": { + "hvit_t": "configs/sam2.1/sam2.1_hiera_t.yaml", + "hvit_s": "configs/sam2.1/sam2.1_hiera_s.yaml", + "hvit_b": "configs/sam2.1/sam2.1_hiera_b+.yaml", + "hvit_l": "configs/sam2.1/sam2.1_hiera_l.yaml", + } +} + +SUPPORTED_MODELS = ["hvit_t", "hvit_s", "hvit_b", "hvit_l"] + +URLS = { + "sam2.0": { + "hvit_t": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt", + "hvit_s": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt", + "hvit_b": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt", + "hvit_l": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt", + }, + "sam2.1": { + "hvit_t": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt", + "hvit_s": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt", + "hvit_b": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt", + "hvit_l": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt", + }, +} + +HASHES = { + "sam2.0": { + "hvit_t": "65b50056e05bcb13694174f51bb6da89c894b57b75ccdf0ba6352c597c5d1125", + "hvit_s": "95949964d4e548409021d47b22712d5f1abf2564cc0c3c765ba599a24ac7dce3", + "hvit_b": "d0bb7f236400a49669ffdd1be617959a8b1d1065081789d7bbff88eded3a8071", + "hvit_l": "7442e4e9b732a508f80e141e7c2913437a3610ee0c77381a66658c3a445df87b", + }, + "sam2.1": { + "hvit_t": "7402e0d864fa82708a20fbd15bc84245c2f26dff0eb43a4b5b93452deb34be69", + "hvit_s": "6d1aa6f30de5c92224f8172114de081d104bbd23dd9dc5c58996f0cad5dc4d38", + "hvit_b": "a2345aede8715ab1d5d31b4a509fb160c5a4af1970f199d9054ccfb746c004c5", + "hvit_l": "2647878d5dfa5098f2f8649825738a9345572bae2d4350a2468587ece47dd318", + }, +} + + +def _get_device(device=None): + if device is None: + device = get_device() + + if device == "cuda": + # NOTE: adapt global variables to work with flash attentions. + sam2.modeling.sam.transformer.OLD_GPU = True + sam2.modeling.sam.transformer.USE_FLASH_ATTN = True + sam2.modeling.sam.transformer.MATH_KERNEL_ON = True + elif device == "mps": + raise ValueError("The scripts have not been tested on MPS device.") + + return device + + +def _get_checkpoint(model_type=_DEFAULT_MODEL, backbone=BACKBONE): + # Let's first create a cache directory. + save_directory = os.path.expanduser(pooch.os_cache("micro_sam2/models")) + + # Download the checkpoint paths if the user does not provide them. + fname = f"{model_type}_{backbone}" + pooch.retrieve( + url=URLS[backbone][model_type], + known_hash=HASHES[backbone][model_type], + fname=fname, + path=save_directory, + progressbar=True + ) + + # Finally, get the filepath to the cached checkpoint. + checkpoint_path = os.path.join(save_directory, fname) + + return checkpoint_path + + +def get_sam2_model( + model_type: str = _DEFAULT_MODEL, + device: Optional[Union[torch.device, str]] = None, + checkpoint_path: Optional[Union[os.PathLike, str]] = None, + input_type: Literal["images", "videos"] = "images", + backbone: Literal["sam2.0", "sam2.1"] = BACKBONE, +): + """Get the Segment Anything 2 (SAM2) model for interactive segmentation of images and videos. + + Args: + model_type: The choice of size for the vision transformer, eg. `hvit_t`. The default is `hvit_t` model. + device: The pytorch device. + checkpoint_path: Filepath to the pretrained model weights. + input_type: Whether the inputs are images or videos. + backbone: Whether the SAM2 backbone is initialized with `sam2.0` or `sam2.1` model configuration. + The default is `sam2.1`. + + Returns: + The SAM2 model. + """ + model_cfg = CFG_PATHS[backbone][model_type[:6]] + + device = _get_device(device) + + if input_type == "images": + _build_segment_anything_2 = build_sam2 + elif input_type == "videos": + _build_segment_anything_2 = _build_sam2_video_predictor + else: + raise ValueError(f"'{input_type}' is not a valid input type.") + + if checkpoint_path is None: + checkpoint_path = _get_checkpoint(model_type=model_type, backbone=backbone) + + model = _build_segment_anything_2( + config_file=model_cfg, + ckpt_path=checkpoint_path, + device=device, + mode="eval", + apply_postprocessing=False, + ) + + return model + + +def _check_saved_embeddings(): + raise NotImplementedError + + +def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update): + # Check if the embeddings are already cached. + if save_path is not None and "original_size" in f.attrs: + # In this case we load the embeddings. + features = f["features"][:] + original_size = f.attrs["original_size"] + image_embeddings = {"features": features, "original_size": original_size} + # Also set the embeddings. + set_precomputed(predictor, image_embeddings) + return image_embeddings + + pbar_init(1, "Compute Image Embeddings 2D") + # Otherwise we have to compute the embeddings. + predictor.reset_predictor() + + from micro_sam.util import _to_image + predictor.set_image(_to_image(input_)) + features = predictor.get_image_embedding().cpu().numpy() + high_res_features = predictor._features.get("high_res_feats") + original_size = predictor._orig_hw + pbar_update(1) + + # Save the embeddings if we have a save_path. + if save_path is not None: + from micro_sam.util import _create_dataset_with_data + _create_dataset_with_data(f, "features", data=features) + # TODO: Write the embedding signature. + + image_embeddings = {"features": features, "high_res_feats": high_res_features, "original_size": original_size} + return image_embeddings + + +@torch.no_grad +def _compute_embeddings_batched_3d(inference_state, predictor, batched_z, batched_images): + batched_features, original_sizes = [], [] + + for image, z_id in zip(batched_images, batched_z): + _, _, curr_features, _, feat_sizes = predictor._get_image_feature(inference_state, frame_idx=z_id, batch_size=1) + + # Convert features to expected shape + curr_feat, curr_feat_size = curr_features[-1], feat_sizes[-1] + curr_feat = curr_feat.permute(1, 2, 0).view(1, -1, *curr_feat_size) + + # TODO: We probably need 'backbone_out' here too. + + batched_features.append(curr_feat) + original_sizes.append(image.shape[:2]) + + tensors = torch.cat(batched_features) + return tensors, inference_state, original_sizes + + +def _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size): + # Check if the embeddings are already fully cached. + if save_path is not None and "original_size" in f.attrs: + # In this case we load the embeddings. + features = f["features"] if lazy_loading else f["features"][:] + original_size = f.attrs["original_size"] + image_embeddings = {"features": features, "original_size": original_size} + return image_embeddings + + # Otherwise we have to compute the embeddings. + + # First check if we have a save path or not and set things up accordingly. + if save_path is None: + features = [] + save_features = False + partial_features = False + else: + save_features = True + embed_shape = (1, 256, 64, 64) # TODO: Check this? + shape = (input_.shape[0],) + embed_shape + chunks = (1,) + embed_shape + if "features" in f: + partial_features = True + features = f["features"] + if features.shape != shape or features.chunks != chunks: + raise RuntimeError("Invalid partial features") + else: + partial_features = False + from micro_sam.util import _create_dataset_without_data + features = _create_dataset_without_data(f, "features", shape=shape, chunks=chunks, dtype="float32") + + # We create the 'inference_state' object which keeps all important components in memory. + inference_state = predictor.init_state(video_path=None, volume=input_) + + # Initialize the pbar and batches. + n_slices = input_.shape[0] + pbar_init(n_slices, "Compute Image Embeddings 3D") + n_batches = int(np.ceil(n_slices / batch_size)) + + for batch_id in range(n_batches): + z_start = batch_id * batch_size + z_stop = min(z_start + batch_size, n_slices) + + batched_images, batched_z = [], [] + for z in range(z_start, z_stop): + # Skip feature computation in case of partial features in non-zero slice. + if partial_features and np.count_nonzero(features[z]) != 0: + continue + + from micro_sam.util import _to_image + tile_input = _to_image(input_[z]) + batched_images.append(tile_input) + batched_z.append(z) + + batched_embeddings, inference_state, original_sizes = _compute_embeddings_batched_3d( + inference_state, predictor, batched_z, batched_images + ) + + for z, embedding in zip(batched_z, batched_embeddings): + embedding = embedding.unsqueeze(0) + if save_features: + features[z] = embedding.cpu().numpy() + else: + features.append(embedding.unsqueeze(0)) + pbar_update(1) + + if save_features: + pass # TODO: Write the embedding signature? + else: + # Concatenate across the z axis. + features = torch.cat(features).cpu().numpy() + + image_embeddings = {"features": features, "original_size": original_sizes[-1]} + return image_embeddings + + +def precompute_image_embeddings( + predictor, + input_: np.ndarray, + save_path: Optional[Union[str, os.PathLike]] = None, + lazy_loading: bool = False, + ndim: Optional[int] = None, + tile_shape: Optional[Tuple[int, int]] = None, + halo: Optional[Tuple[int, int]] = None, + verbose: bool = True, + batch_size: int = 1, + pbar_init: Optional[callable] = None, + pbar_update: Optional[callable] = None, +): + """Compute the image embeddings (output of the encoder) for the input. + + If 'save_path' is given the embeddings will be loaded/saved in a zarr container. + + Args: + ... + + Returns: + The image embeddings. + """ + ndim = input_.ndim if ndim is None else ndim + + # Handle the embedding save_path. + # We don't have a save path, open in memory zarr file to hold tiled embeddings. + if save_path is None: + f = zarr.group() + + # We have a save path and it already exists. Embeddings will be loaded from it, + # check tha tthe saved embeedidng in there match the parameters of the function call.abs + elif os.path.exists(save_path): + f = zarr.open(save_path, mode="a") + _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) + + # We have a save path and it does not exist yet. Create the zarr file to which the + # embeddings will then be saved. + else: + f = zarr.open(save_path, mode="a") + + from micro_sam.util import handle_pbar + _, pbar_init, pbar_update, pbar_close = handle_pbar(verbose, pbar_init, pbar_update) + + if ndim == 2 and tile_shape is None: + embeddings = _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update) + elif ndim == 2 and tile_shape is not None: + raise NotImplementedError + elif ndim == 3 and tile_shape is None: + embeddings = _compute_3d(input_, predictor, f, save_path, lazy_loading, pbar_init, pbar_update, batch_size) + elif ndim == 3 and tile_shape is not None: + raise NotImplementedError + else: + raise ValueError(f"Invalid dimensionality {input_.ndim}, expect 2 or 3 dim data.") + + pbar_close() + return embeddings + + +def set_precomputed( + predictor, + image_embeddings, + i=None, + tile_id=None, +): + """Set the precomputed image embeddings for a predictor. + + Args: + ... + + Returns: + ... + """ + if tile_id is not None: + raise NotImplementedError + + try: + device = predictor.device() # Works for video predictor. + except TypeError: + device = predictor.device # Otherwise, for image predictor. + + features = image_embeddings["features"] + assert features.ndim in (4, 5), f"{features.ndim}" + if features.ndim == 5: + if i is None: + raise ValueError("The data is 3D so an index i is needed.") + + # NOTE: The assumption is that 'predictor' is a tuple of the + # predictor object and the pre-initialized 'inference_state'. + _predictor, inference_state = predictor + + # TODO: I need to puzzle this together. I can't find an elegant way atm to initialize stuff. + # We need to figure out the "backbone_out' from 'prepare_features'. + + return _predictor, inference_state + + elif features.ndim == 4: + if i is not None: + raise ValueError("The data is 2D so an index is not needed.") + + predictor._features = {"image_embed": features, "high_res_feats": image_embeddings["high_res_feats"]} + predictor._is_image_set = True + predictor._orig_hw = image_embeddings["original_size"] + return predictor