diff --git a/micro_sam/inference.py b/micro_sam/inference.py index 7ff77e4c..dba925c8 100644 --- a/micro_sam/inference.py +++ b/micro_sam/inference.py @@ -214,8 +214,9 @@ def batched_inference( # Call get image embeddings, this will throw an error if they have not yet been computed. predictor.get_image_embedding() else: + input_ = image if i is None else image[i] image_embeddings = util.precompute_image_embeddings( - predictor, image, embedding_path, verbose=verbose_embeddings, i=i, + predictor, input_, embedding_path, verbose=verbose_embeddings ) util.set_precomputed(predictor, image_embeddings) diff --git a/scripts/apg_experiments/README.md b/scripts/apg_experiments/README.md index dd880239..84f1d9c5 100644 --- a/scripts/apg_experiments/README.md +++ b/scripts/apg_experiments/README.md @@ -1,3 +1,23 @@ ## Experiments for Automatic Prompt Generation (APG) -More explanation coming soon! +This folder contains evaluaton code for applying the new APG method, built on `micro-sam` to microscopy data using the `micro_sam` library. This code was used for our experiments in preparation of the [manuscript](https://openreview.net/forum?id=xFO3DFZN45). + +Please note that this folder may become outdated due to changes in function signatures, etc., and often does not use the functionality that we recommend to users. We also will not actively maintain the code here. Please refer to the [example notebooks](https://github.com/computational-cell-analytics/micro-sam/tree/master/notebooks) and [example scripts](https://github.com/computational-cell-analytics/micro-sam/tree/master/examples) for well maintained and documented `micro-sam`'s APG examples. + +### Evaluation Scripts: + +The top-level folder contains scripts to evaluate other models with `micro-sam`, and the `plotting` subfolder contains scripts for visualizations and plots for the manuscript. + +- `analyze_posthoc.py`: Experiments to visually understand and debug the APG method. +- `perform_posthoc.py`: Experiments to store results to work with `analyze_posthoc.py`. +- `prepare_baselines.py`: Experiments to run all methods presented in the manuscript with default parameters on all microscopy imaging data. +- `run_evaluation.py`: Experiments related to APG for understanding the hyperparameters using grid-search. +- `submit_evaluation.py`: Scripts for submitting jobs to slurm. +- `util,py`: Scripts containing data processing scripts and other miscellanous stuff. +- `plotting`/ + - `calculate_mean.py`: Scripts to calculate mean performance metric over per modality per method. + - `plot_ablation.py`: Scripts to show results for comparing connected components- vs. boundary distance-based prompt extraction. + - `plot_average.py`: Same as `calculate_mean.py`. The scripts plot the mean values in a barplot and displays absolute rank per method over all datasets per modality. + - `plot_qualitative.py`: Scripts to display qualitative results over all datasets. + - `plot_quantitative.py`: Scripts to display quantitative results over all datasets. + - `plot_util.py`: Stores related information helpful for plotting. diff --git a/scripts/apg_experiments/plotting/calculate_mean.py b/scripts/apg_experiments/plotting/calculate_mean.py new file mode 100644 index 00000000..70a24e26 --- /dev/null +++ b/scripts/apg_experiments/plotting/calculate_mean.py @@ -0,0 +1,75 @@ +import numpy as np + +from plot_util import ( + msa_results, + msa_results_fluorescence, + msa_results_label_free, + msa_results_histopathology, +) + + +AVG_METHODS = [ + "AMG (vit_b) - without grid search", + "AIS - without grid search", + "SAM3", + "CellPose3", + "CellPoseSAM", + "CellSAM", + "APG - without grid search (cc)", +] + +AVG_DISPLAY_NAME_MAP = { + "AMG (vit_b) - without grid search": "SAM", + "AIS - without grid search": "AIS (µSAM)", + "SAM3": "SAM3", + "CellPose3": "CellPose 3", + "CellPoseSAM": "CellPoseSAM", + "CellSAM": "CellSAM", + "APG - without grid search (cc)": "APG (µSAM)", +} + + +def compute_method_means(msa_results_dict, methods_filter=None): + vals_per_method = {} + + for _, entries in msa_results_dict.items(): + for e in entries: + m = e["method"] + v = e["mSA"] + if methods_filter is not None and m not in methods_filter: + continue + if v is None: + continue + vals_per_method.setdefault(m, []).append(float(v)) + + mean_msa = {m: float(np.mean(vs)) for m, vs in vals_per_method.items()} + return mean_msa + + +def format_float(v): + return f"{v:.3f}" + + +if __name__ == "__main__": + datasets = { + "Label-Free (Cell)": msa_results_label_free, + "Fluorescence (Cell)": msa_results_fluorescence, + "Fluorescence (Nucleus)": msa_results, + "Histopathology (Nucleus)": msa_results_histopathology, + } + + means_by_modality = {} + for mod_name, data in datasets.items(): + means_by_modality[mod_name] = compute_method_means(data, methods_filter=AVG_METHODS) + + # Print a compact validation table + col_order = list(datasets.keys()) + print("Averages (mSA) per modality:") + print("Method".ljust(22), *(c[:18].rjust(18) for c in col_order)) + for m in AVG_METHODS: + disp = AVG_DISPLAY_NAME_MAP.get(m, m) + row = [disp.ljust(22)] + for mod in col_order: + v = means_by_modality[mod].get(m, np.nan) + row.append((format_float(v) if np.isfinite(v) else "--").rjust(18)) + print("".join(row)) diff --git a/scripts/apg_experiments/plotting/plot_ablation.py b/scripts/apg_experiments/plotting/plot_ablation.py index 3ce5060f..360eff4c 100644 --- a/scripts/apg_experiments/plotting/plot_ablation.py +++ b/scripts/apg_experiments/plotting/plot_ablation.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt from matplotlib.patches import Patch -from util import ( +from plot_util import ( msa_results, msa_results_fluorescence, msa_results_label_free, @@ -18,10 +18,10 @@ APG_GS_CC = "APG - with grid search (cc)" plt.rcParams.update({ - "axes.titlesize": 10, + "axes.titlesize": 11, "axes.labelsize": 9, - "xtick.labelsize": 8, - "ytick.labelsize": 8, + "xtick.labelsize": 11, + "ytick.labelsize": 10, }) @@ -142,34 +142,10 @@ def max_improvement(d): if max_val > 0: ax.axhspan(0, max_val, color="#e0f3db", alpha=0.8, zorder=0) - bars_wo_bd = ax.bar( - x - 1.5 * width, - rel_wo_bd, - width, - color=color_wo_bd, - zorder=1, - ) - bars_wo_cc = ax.bar( - x - 0.5 * width, - rel_wo_cc, - width, - color=color_wo_cc, - zorder=1, - ) - bars_gs_bd = ax.bar( - x + 0.5 * width, - rel_gs_bd, - width, - color=color_gs_bd, - zorder=1, - ) - bars_gs_cc = ax.bar( - x + 1.5 * width, - rel_gs_cc, - width, - color=color_gs_cc, - zorder=1, - ) + bars_wo_bd = ax.bar(x - 1.5 * width, rel_wo_bd, width, color=color_wo_bd, zorder=1) + bars_wo_cc = ax.bar(x - 0.5 * width, rel_wo_cc, width, color=color_wo_cc, zorder=1) + bars_gs_bd = ax.bar(x + 0.5 * width, rel_gs_bd, width, color=color_gs_bd, zorder=1) + bars_gs_cc = ax.bar(x + 1.5 * width, rel_gs_cc, width, color=color_gs_cc, zorder=1) best_mask_wo_bd = [] best_mask_wo_cc = [] @@ -177,12 +153,7 @@ def max_improvement(d): best_mask_gs_cc = [] for i in range(len(datasets)): - vals_i = [ - rel_wo_bd[i], - rel_wo_cc[i], - rel_gs_bd[i], - rel_gs_cc[i], - ] + vals_i = [rel_wo_bd[i], rel_wo_cc[i], rel_gs_bd[i], rel_gs_cc[i]] max_i = max(vals_i) best_mask_wo_bd.append(rel_wo_bd[i] == max_i) @@ -216,7 +187,7 @@ def annotate_bars(bars, vals, best_mask): f"{v:+.1f}%", ha="center", va=va, - fontsize=7, + fontsize=9, rotation=90, fontweight=fontweight, ) @@ -234,31 +205,27 @@ def annotate_bars(bars, vals, best_mask): fig.tight_layout(rect=[0.06, 0.18, 1, 0.97]) fig.text( - 0.05, 0.55, + 0.06, 0.575, "Relative Mean Segmentation Accuracy (compared to AIS)", va="center", ha="center", rotation="vertical", - fontsize=11, + fontsize=10, fontweight="bold", ) legend_patches = [ - Patch(facecolor=color_wo_bd, edgecolor="black", - label="APG (Boundary Distance) - Default"), - Patch(facecolor=color_wo_cc, edgecolor="black", - label="APG (Components) - Default"), - Patch(facecolor=color_gs_bd, edgecolor="black", - label="APG (Boundary) - GS"), - Patch(facecolor=color_gs_cc, edgecolor="black", - label="APG (Components) - GS"), + Patch(facecolor=color_wo_bd, label="APG (Boundary Distance) - Default"), + Patch(facecolor=color_wo_cc, label="APG (Components) - Default"), + Patch(facecolor=color_gs_bd, label="APG (Boundary Distance) - Grid Search"), + Patch(facecolor=color_gs_cc, label="APG (Components) - Grid Search"), ] fig.legend( handles=legend_patches, loc="lower center", - bbox_to_anchor=(0.5, 0.125), - ncol=4, + bbox_to_anchor=(0.5, 0.1), + ncol=2, fontsize=8, frameon=True, ) diff --git a/scripts/apg_experiments/plotting/plot_average.py b/scripts/apg_experiments/plotting/plot_average.py index 150e054a..53c3c169 100644 --- a/scripts/apg_experiments/plotting/plot_average.py +++ b/scripts/apg_experiments/plotting/plot_average.py @@ -3,7 +3,7 @@ import numpy as np import matplotlib.pyplot as plt -from util import ( +from plot_util import ( msa_results, msa_results_fluorescence, msa_results_label_free, @@ -23,23 +23,24 @@ APG_METHOD = "APG - without grid search (cc)" AVG_DISPLAY_NAME_MAP = { - "AMG (vit_b) - without grid search": "SAM", - "AIS - without grid search": r"$\mu$SAM", + "AMG (vit_b) - without grid search": "AMG (SAM)", + "AIS - without grid search": "AIS (µSAM)", "SAM3": "SAM3", "CellPose3": "CellPose 3", "CellPoseSAM": "CellPoseSAM", "CellSAM": "CellSAM", - "APG - without grid search (cc)": "APG", + "APG - without grid search (cc)": "APG (µSAM)", } AVG_DISPLAY_NAME_MAP_HISTO = AVG_DISPLAY_NAME_MAP.copy() -AVG_DISPLAY_NAME_MAP_HISTO["AIS - without grid search"] = "PathoSAM" +AVG_DISPLAY_NAME_MAP_HISTO["AIS - without grid search"] = "AIS (PathoSAM)" +AVG_DISPLAY_NAME_MAP_HISTO["APG - without grid search (cc)"] = "APG \n (PathoSAM)" plt.rcParams.update({ - "axes.titlesize": 10, - "axes.labelsize": 9, - "xtick.labelsize": 8, - "ytick.labelsize": 8, + "axes.titlesize": 11, + "axes.labelsize": 10, + "xtick.labelsize": 10, + "ytick.labelsize": 10, }) @@ -217,7 +218,7 @@ def filtered_methods_for_modality(mean_dict): disp_names = [disp_map[m] for m in methods] ax.set_xticks(x) - ax.set_xticklabels(disp_names, rotation=60, ha="right") + ax.set_xticklabels(disp_names, rotation=45, ha="right") xticklabels = ax.get_xticklabels() for idx_lbl, lbl in enumerate(xticklabels): @@ -231,7 +232,7 @@ def filtered_methods_for_modality(mean_dict): ax.set_ylim(0.0, 1.0) fig.text( - 0.06, 0.5, + 0.05, 0.575, "Mean Segmentation Accuracy (mSA)", va="center", ha="center", diff --git a/scripts/apg_experiments/plotting/plot_qualitative.py b/scripts/apg_experiments/plotting/plot_qualitative.py new file mode 100644 index 00000000..cb9b5ba7 --- /dev/null +++ b/scripts/apg_experiments/plotting/plot_qualitative.py @@ -0,0 +1,131 @@ +import os +import sys +from glob import glob +from tqdm import tqdm +from natsort import natsorted + +import imageio.v3 as imageio +import matplotlib.pyplot as plt + +from torch_em.util.util import get_random_colors + +from elf.evaluation import mean_segmentation_accuracy + +from micro_sam.evaluation.model_comparison import _overlay_outline +from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation + + +sys.path.append("..") + + +ROOT = "/mnt/vast-nhr/projects/cidas/cca" + + +def plot_quali(dataset_name, model_type): + # Get APG result paths. + pred_paths = natsorted( + glob( + os.path.join( + ROOT, "experiments", "micro_sam", + "apg_baselines", "cc_without_box", "inference", f"{dataset_name}_apg_*", + "*" + ) + ) + ) + + # Get the image and label paths + from util import get_image_label_paths + image_paths, label_paths = get_image_label_paths(dataset_name=dataset_name, split="test") + + # HACK: Sometimes there's just too many images. Just compare over first 500. + image_paths, label_paths, pred_paths = image_paths[:500], label_paths[:500], pred_paths[:500] + + # Get the cellpose model to run CPSAM on the fly. + from cellpose import models + model = models.CellposeModel(gpu=True, model_type="cpsam") + + # Let's iterate over each image, label and predictions. + results = [] + for image_path, label_path, pred_path in tqdm( + zip(image_paths, label_paths, pred_paths), desc="Qualitative analysis", total=len(image_paths) + ): + # Load everything first + image = imageio.imread(image_path) + label = imageio.imread(label_path) + pred = imageio.imread(pred_path) + + assert label.shape == pred.shape + + # Let's run CellPoseSAM to get the best relative results for APG. + seg, _, _ = model.eval(image) + + # Compare both + apg_msa = mean_segmentation_accuracy(pred, label) + cpsam_msa = mean_segmentation_accuracy(seg, label) + + # Next, find the relative score and store it for the dataset + diff = apg_msa - cpsam_msa + results.append((diff, image_path, label_path, pred_path, apg_msa, cpsam_msa)) + + # Let's fetch top-k where APG wins + k = 10 + results.sort(key=lambda x: abs(x[0]), reverse=True) + top_k = results[:k] + + # Prepare other model stuff + # SAM-related models + predictor_amg, segmenter_amg = get_predictor_and_segmenter(model_type="vit_b", segmentation_mode="amg") + predictor_ais, segmenter_ais = get_predictor_and_segmenter(model_type=model_type, segmentation_mode="ais") + + # Plot each of the top-k images as a separate horizontal triplet + os.makedirs(f"./quali_figures/{dataset_name}", exist_ok=True) + for rank, (diff, image_path, label_path, pred_path, apg_msa, cpsam_msa) in enumerate(top_k, 1): + from micro_sam.util import _to_image + image = _to_image(imageio.imread(image_path)) + gt = imageio.imread(label_path) + + # APG prediction is read from disk, CPSAM is recomputed for plotting + amg_pred = automatic_instance_segmentation( + input_path=image, verbose=False, ndim=2, predictor=predictor_amg, segmenter=segmenter_amg, + ) + ais_pred = automatic_instance_segmentation( + input_path=image, verbose=False, ndim=2, predictor=predictor_ais, segmenter=segmenter_ais, + ) + apg_pred = imageio.imread(pred_path) + cpsam_pred, _, _ = model.eval(image) + + from cellSAM import cellsam_pipeline + cellsam_pred = cellsam_pipeline(image, use_wsi=False) + if cellsam_pred.ndim == 3: # NOTE: For images with no objects found, a weird segmentation is returned. + cellsam_pred = cellsam_pred[0] + + # Prepare image as expected (normalize and overlay labels) + curr_image = _overlay_outline(image, gt, outline_dilation=1) + + fig, axes = plt.subplots(1, 6, figsize=(15, 5), constrained_layout=True) + for ax in axes: + ax.set_axis_off() + + axes[0].imshow(curr_image[:256, :256]) + axes[1].imshow(amg_pred[:256, :256], cmap=get_random_colors(amg_pred)) + axes[2].imshow(ais_pred[:256, :256], cmap=get_random_colors(ais_pred)) + axes[3].imshow(cpsam_pred[:256, :256], cmap=get_random_colors(cpsam_pred)) + axes[4].imshow(cellsam_pred[:256, :256], cmap=get_random_colors(cellsam_pred)) + axes[5].imshow(apg_pred[:256, :256], cmap=get_random_colors(apg_pred)) + + plt.savefig(f"./quali_figures/{dataset_name}/{rank}.png", dpi=400, bbox_inches="tight") + plt.savefig(f"./quali_figures/{dataset_name}/{rank}.svg", dpi=400, bbox_inches="tight") + plt.close() + + +def main(args): + plot_quali(args.dataset, args.model_type) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--dataset", type=str, required=True) + parser.add_argument("-m", "--model_type", type=str, default="vit_b_lm") # NOTE: For AIS + args = parser.parse_args() + main(args) diff --git a/scripts/apg_experiments/plotting/plot_quantitative.py b/scripts/apg_experiments/plotting/plot_quantitative.py index 014ea083..934b5f32 100644 --- a/scripts/apg_experiments/plotting/plot_quantitative.py +++ b/scripts/apg_experiments/plotting/plot_quantitative.py @@ -3,7 +3,7 @@ import numpy as np import matplotlib.pyplot as plt -from util import ( +from plot_util import ( msa_results, msa_results_fluorescence, msa_results_label_free, @@ -16,12 +16,14 @@ DISPLAY_NAME_MAP_HISTO["AMG - without grid search"] = "AMG (PathoSAM)" if "AIS - without grid search" in DISPLAY_NAME_MAP_HISTO: DISPLAY_NAME_MAP_HISTO["AIS - without grid search"] = "AIS (PathoSAM)" +if "APG - without grid search (cc)" in DISPLAY_NAME_MAP_HISTO: + DISPLAY_NAME_MAP_HISTO["APG - without grid search (cc)"] = r"$\mathbf{APG}$ " + "\n" + r"$\mathbf{(PathoSAM)}$ " plt.rcParams.update({ - "axes.titlesize": 10, - "axes.labelsize": 9, - "xtick.labelsize": 8, - "ytick.labelsize": 8, + "axes.titlesize": 15, + "axes.labelsize": 12, + "xtick.labelsize": 11, + "ytick.labelsize": 11, }) @@ -40,13 +42,12 @@ def plot_msa_grid( fig, axes = plt.subplots( nrows=nrows, ncols=ncols, - figsize=(figsize_per_subplot[0] * ncols, - figsize_per_subplot[1] * nrows), + figsize=(figsize_per_subplot[0] * ncols, figsize_per_subplot[1] * nrows), sharey=True, ) axes = np.array(axes).reshape(-1) - color_top1 = "#1f77b4" # darkest + color_top1 = "#1f77b4" color_top2 = "#6baed6" color_top3 = "#c6dbef" color_rest = "#d9d9d9" @@ -96,8 +97,7 @@ def plot_msa_grid( if IN_DOMAIN.get((dataset, raw_name), False): bar.set_hatch("////") - apg_indices = [i for i, name in enumerate(raw_methods) - if name.startswith("APG")] + apg_indices = [i for i, name in enumerate(raw_methods) if name.startswith("APG")] ais_idx = None for i, name in enumerate(raw_methods): @@ -119,14 +119,7 @@ def plot_msa_grid( bar_height = v y_text = min(bar_height + 0.01, 0.98) - ax.text( - x[i], - y_text, - f"{v:.3f}", - ha="center", - va="bottom", - fontsize=8, - ) + ax.text(x[i], y_text, f"{v:.3f}", ha="center", va="bottom", fontsize=9) if ais_value is not None: for i in apg_indices: @@ -145,44 +138,30 @@ def plot_msa_grid( y_text = min(bar_height + 0.02, 0.99) ax.text( - x[i], - y_text, - f"{diff_pct:+.1f}%", - ha="center", - va="bottom", - fontsize=8, - fontweight="bold", - color=color, + x[i], y_text, f"{diff_pct:+.1f}%", ha="center", + va="bottom", fontsize=10, fontweight="bold", color=color, ) ax.set_title(dataset, fontweight="bold") ax.set_xticks(x) - ax.set_xticklabels(methods, rotation=60, ha="right") + ax.set_xticklabels(methods, rotation=45, ha="right") ax.set_ylim(0.0, 1.0) for j in range(len(datasets), len(axes)): fig.delaxes(axes[j]) - if suptitle is not None: - fig.tight_layout(rect=[0.08, 0, 1, 0.95]) - else: - fig.tight_layout(rect=[0.08, 0, 1, 1]) - fig.text( - 0.075, 0.5, - "Mean Segmentation Accuracy", - va="center", - ha="center", - rotation="vertical", - fontsize=11, - fontweight="bold", + 0.02, 0.525, "Mean Segmentation Accuracy", va="center", ha="center", + rotation="vertical", fontsize=14, fontweight="bold", ) + fig.subplots_adjust(left=0.075, right=0.98, bottom=0.125, top=0.95, wspace=0.1, hspace=0.9) + if save_path is not None: - fig.savefig(save_path, bbox_inches="tight", dpi=300) + fig.savefig(save_path, dpi=300) root, _ = os.path.splitext(save_path) svg_path = root + ".svg" - fig.savefig(svg_path, bbox_inches="tight") + fig.savefig(svg_path) return fig, axes diff --git a/scripts/apg_experiments/plotting/util.py b/scripts/apg_experiments/plotting/plot_util.py similarity index 94% rename from scripts/apg_experiments/plotting/util.py rename to scripts/apg_experiments/plotting/plot_util.py index 7df7d0e4..1bf71c51 100644 --- a/scripts/apg_experiments/plotting/util.py +++ b/scripts/apg_experiments/plotting/plot_util.py @@ -1,3 +1,6 @@ +# NOTE: All SAM3 prompts are the same: "cell" (matching with the training routine) + + msa_results = { "DSB": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.331}, @@ -10,7 +13,7 @@ {"method": "CellPose3", "mSA": 0.484}, {"method": "CellPoseSAM", "mSA": 0.656}, {"method": "CellSAM", "mSA": 0.634}, - {"method": "SAM3", "mSA": 0.383}, + {"method": "SAM3", "mSA": 0.367}, ], "U20S": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.258}, @@ -23,7 +26,7 @@ {"method": "CellPose3", "mSA": 0.787}, {"method": "CellPoseSAM", "mSA": 0.787}, {"method": "CellSAM", "mSA": 0.673}, - {"method": "SAM3", "mSA": 0.636}, + {"method": "SAM3", "mSA": 0.674}, ], "Arvidsson": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.416}, @@ -36,7 +39,7 @@ {"method": "CellPose3", "mSA": 0.611}, {"method": "CellPoseSAM", "mSA": 0.484}, {"method": "CellSAM", "mSA": 0.434}, - {"method": "SAM3", "mSA": 0.488}, + {"method": "SAM3", "mSA": 0.297}, ], "BitDepth NucSeg": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.224}, @@ -49,7 +52,7 @@ {"method": "CellPose3", "mSA": 0.302}, {"method": "CellPoseSAM", "mSA": 0.377}, {"method": "CellSAM", "mSA": 0.168}, - {"method": "SAM3", "mSA": 0.201}, + {"method": "SAM3", "mSA": 0.182}, ], "IFNuclei": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.293}, @@ -62,7 +65,7 @@ {"method": "CellPose3", "mSA": 0.404}, {"method": "CellPoseSAM", "mSA": 0.728}, {"method": "CellSAM", "mSA": 0.589}, - {"method": "SAM3", "mSA": 0.405}, + {"method": "SAM3", "mSA": 0.301}, ], "DynamicNuclearNet": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.298}, @@ -75,7 +78,7 @@ {"method": "CellPose3", "mSA": 0.512}, {"method": "CellPoseSAM", "mSA": 0.379}, {"method": "CellSAM", "mSA": 0.455}, - {"method": "SAM3", "mSA": 0.376}, + {"method": "SAM3", "mSA": 0.346}, ], "GoNuclear": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.339}, @@ -88,7 +91,7 @@ {"method": "CellPose3", "mSA": 0.447}, {"method": "CellPoseSAM", "mSA": 0.415}, {"method": "CellSAM", "mSA": 0.112}, - {"method": "SAM3", "mSA": 0.132}, + {"method": "SAM3", "mSA": 0.034}, ], "NIS3D": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.216}, @@ -101,7 +104,7 @@ {"method": "CellPose3", "mSA": 0.255}, {"method": "CellPoseSAM", "mSA": 0.246}, {"method": "CellSAM", "mSA": 0.264}, - {"method": "SAM3", "mSA": 0.008}, + {"method": "SAM3", "mSA": 0.031}, ], "Parhyale Regen": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.173}, @@ -114,7 +117,7 @@ {"method": "CellPose3", "mSA": 0.138}, {"method": "CellPoseSAM", "mSA": 0.272}, {"method": "CellSAM", "mSA": 0.144}, - {"method": "SAM3", "mSA": 0.039}, + {"method": "SAM3", "mSA": 0.063}, ], } @@ -130,7 +133,7 @@ {"method": "CellPose3", "mSA": 0.154}, {"method": "CellPoseSAM", "mSA": 0.475}, {"method": "CellSAM", "mSA": 0.345}, - {"method": "SAM3", "mSA": 0.089}, + {"method": "SAM3", "mSA": 0.121}, ], "CellPose": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.147}, @@ -143,7 +146,7 @@ {"method": "CellPose3", "mSA": 0.431}, {"method": "CellPoseSAM", "mSA": 0.566}, {"method": "CellSAM", "mSA": 0.413}, - {"method": "SAM3", "mSA": 0.229}, + {"method": "SAM3", "mSA": 0.299}, ], "PlantSeg (Root)": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.091}, @@ -156,7 +159,7 @@ {"method": "CellPose3", "mSA": 0.076}, {"method": "CellPoseSAM", "mSA": 0.161}, {"method": "CellSAM", "mSA": 0.096}, - {"method": "SAM3", "mSA": 0.023}, + {"method": "SAM3", "mSA": 0.067}, ], "PlantSeg (Ovules)": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.135}, @@ -169,7 +172,7 @@ {"method": "CellPose3", "mSA": 0.266}, {"method": "CellPoseSAM", "mSA": 0.331}, {"method": "CellSAM", "mSA": 0.333}, - {"method": "SAM3", "mSA": 0.071}, + {"method": "SAM3", "mSA": 0.184}, ], "PNAS Arabidopsis": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.145}, @@ -182,7 +185,7 @@ {"method": "CellPose3", "mSA": 0.411}, {"method": "CellPoseSAM", "mSA": 0.471}, {"method": "CellSAM", "mSA": 0.459}, - {"method": "SAM3", "mSA": 0.073}, + {"method": "SAM3", "mSA": 0.241}, ], "Covid-IF": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.007}, @@ -195,7 +198,7 @@ {"method": "CellPose3", "mSA": 0.161}, {"method": "CellPoseSAM", "mSA": 0.333}, {"method": "CellSAM", "mSA": 0.154}, - {"method": "SAM3", "mSA": 0.004}, + {"method": "SAM3", "mSA": 0.005}, ], "HPA": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.043}, @@ -208,7 +211,7 @@ {"method": "CellPose3", "mSA": 0.078}, {"method": "CellPoseSAM", "mSA": 0.431}, {"method": "CellSAM", "mSA": 0.301}, - {"method": "SAM3", "mSA": 0.152}, + {"method": "SAM3", "mSA": 0.155}, ], "CellBinDB": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.177}, @@ -221,7 +224,7 @@ {"method": "CellPose3", "mSA": 0.279}, {"method": "CellPoseSAM", "mSA": 0.342}, {"method": "CellSAM", "mSA": 0.264}, - {"method": "SAM3", "mSA": 0.136}, + {"method": "SAM3", "mSA": 0.137}, ], "Mouse Embryo": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.003}, @@ -234,7 +237,7 @@ {"method": "CellPose3", "mSA": 0.109}, {"method": "CellPoseSAM", "mSA": 0.155}, {"method": "CellSAM", "mSA": 0.083}, - {"method": "SAM3", "mSA": 0.073}, + {"method": "SAM3", "mSA": 0.081}, ], } @@ -250,7 +253,7 @@ {"method": "CellPose3", "mSA": 0.414}, {"method": "CellPoseSAM", "mSA": 0.444}, {"method": "CellSAM", "mSA": 0.098}, - {"method": "SAM3", "mSA": 0.313}, + {"method": "SAM3", "mSA": 0.331}, ], "OmniPose": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.137}, @@ -263,7 +266,7 @@ {"method": "CellPose3", "mSA": 0.468}, {"method": "CellPoseSAM", "mSA": 0.644}, {"method": "CellSAM", "mSA": 0.531}, - {"method": "SAM3", "mSA": 0.131}, + {"method": "SAM3", "mSA": 0.356}, ], "DeepBacs": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.057}, @@ -276,7 +279,7 @@ {"method": "CellPose3", "mSA": 0.455}, {"method": "CellPoseSAM", "mSA": 0.612}, {"method": "CellSAM", "mSA": 0.441}, - {"method": "SAM3", "mSA": 0.083}, + {"method": "SAM3", "mSA": 0.157}, ], "Usiigaci": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.051}, @@ -289,7 +292,7 @@ {"method": "CellPose3", "mSA": 0.291}, {"method": "CellPoseSAM", "mSA": 0.445}, {"method": "CellSAM", "mSA": 0.167}, - {"method": "SAM3", "mSA": 0.377}, + {"method": "SAM3", "mSA": 0.362}, ], "Vicar": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.115}, @@ -302,7 +305,7 @@ {"method": "CellPose3", "mSA": 0.338}, {"method": "CellPoseSAM", "mSA": 0.458}, {"method": "CellSAM", "mSA": 0.426}, - {"method": "SAM3", "mSA": 0.122}, + {"method": "SAM3", "mSA": 0.086}, ], "TOIAM": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.009}, @@ -315,7 +318,7 @@ {"method": "CellPose3", "mSA": 0.837}, {"method": "CellPoseSAM", "mSA": 0.898}, {"method": "CellSAM", "mSA": 0.631}, - {"method": "SAM3", "mSA": 0.0}, + {"method": "SAM3", "mSA": 0.027}, ], "DeepSeas": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.098}, @@ -328,7 +331,7 @@ {"method": "CellPose3", "mSA": 0.191}, {"method": "CellPoseSAM", "mSA": 0.345}, {"method": "CellSAM", "mSA": 0.203}, - {"method": "SAM3", "mSA": 0.227}, + {"method": "SAM3", "mSA": 0.277}, ], "YeaZ": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.382}, @@ -341,7 +344,7 @@ {"method": "CellPose3", "mSA": 0.817}, {"method": "CellPoseSAM", "mSA": 0.873}, {"method": "CellSAM", "mSA": 0.853}, - {"method": "SAM3", "mSA": 0.764}, + {"method": "SAM3", "mSA": 0.723}, ], "SegPC": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.027}, @@ -354,7 +357,7 @@ {"method": "CellPose3", "mSA": 0.001}, {"method": "CellPoseSAM", "mSA": 0.178}, {"method": "CellSAM", "mSA": 0.069}, - {"method": "SAM3", "mSA": 0.152}, + {"method": "SAM3", "mSA": 0.106}, ], } @@ -370,7 +373,7 @@ {"method": "CellPose3", "mSA": 0.152}, {"method": "CellPoseSAM", "mSA": 0.342}, {"method": "CellSAM", "mSA": 0.244}, - {"method": "SAM3", "mSA": 0.159}, + {"method": "SAM3", "mSA": 0.341}, ], "IHC TMA": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.236}, @@ -383,7 +386,7 @@ {"method": "CellPose3", "mSA": 0.297}, {"method": "CellPoseSAM", "mSA": 0.452}, {"method": "CellSAM", "mSA": 0.333}, - {"method": "SAM3", "mSA": 0.353}, + {"method": "SAM3", "mSA": 0.435}, ], "LynSec": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.233}, @@ -396,7 +399,7 @@ {"method": "CellPose3", "mSA": 0.163}, {"method": "CellPoseSAM", "mSA": 0.561}, {"method": "CellSAM", "mSA": 0.213}, - {"method": "SAM3", "mSA": 0.003}, + {"method": "SAM3", "mSA": 0.157}, ], "MoNuSeg": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.182}, @@ -409,7 +412,7 @@ {"method": "CellPose3", "mSA": 0.125}, {"method": "CellPoseSAM", "mSA": 0.373}, {"method": "CellSAM", "mSA": 0.302}, - {"method": "SAM3", "mSA": 0.207}, + {"method": "SAM3", "mSA": 0.345}, ], "NuInsSeg": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.161}, @@ -422,7 +425,7 @@ {"method": "CellPose3", "mSA": 0.144}, {"method": "CellPoseSAM", "mSA": 0.349}, {"method": "CellSAM", "mSA": 0.229}, - {"method": "SAM3", "mSA": 0.287}, + {"method": "SAM3", "mSA": 0.312}, ], "PUMA": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.232}, @@ -435,7 +438,7 @@ {"method": "CellPose3", "mSA": 0.101}, {"method": "CellPoseSAM", "mSA": 0.501}, {"method": "CellSAM", "mSA": 0.294}, - {"method": "SAM3", "mSA": 0.324}, + {"method": "SAM3", "mSA": 0.415}, ], "TNBC": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.209}, @@ -448,7 +451,7 @@ {"method": "CellPose3", "mSA": 0.075}, {"method": "CellPoseSAM", "mSA": 0.451}, {"method": "CellSAM", "mSA": 0.383}, - {"method": "SAM3", "mSA": 0.359}, + {"method": "SAM3", "mSA": 0.443}, ], "CryoNuSeg": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.165}, @@ -461,7 +464,7 @@ {"method": "CellPose3", "mSA": 0.113}, {"method": "CellPoseSAM", "mSA": 0.295}, {"method": "CellSAM", "mSA": 0.177}, - {"method": "SAM3", "mSA": 0.079}, + {"method": "SAM3", "mSA": 0.118}, ], "CytoDark0": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.182}, @@ -474,19 +477,19 @@ {"method": "CellPose3", "mSA": 0.222}, {"method": "CellPoseSAM", "mSA": 0.441}, {"method": "CellSAM", "mSA": 0.315}, - {"method": "SAM3", "mSA": 0.369}, + {"method": "SAM3", "mSA": 0.414}, ], } DISPLAY_NAME_MAP = { "AMG (vit_b) - without grid search": "AMG (SAM)", # "AMG - without grid search": r"AMG ($\mu$SAM)", - "AIS - without grid search": r"AIS ($\mu$SAM)", + "AIS - without grid search": "AIS (µSAM)", "CellPose3": "CellPose 3", "CellPoseSAM": "CellPoseSAM", "CellSAM": "CellSAM", "SAM3": "SAM3", - "APG - without grid search (cc)": r"$\mathbf{APG}$", + "APG - without grid search (cc)": r"$\mathbf{APG}$ " + r"$\mathbf{(µSAM)}$ ", # "APG - with grid search (cc)": r"$\mathbf{APG^{*}}$", } diff --git a/scripts/apg_experiments/prepare_baselines.py b/scripts/apg_experiments/prepare_baselines.py index 90271539..efeb2d92 100644 --- a/scripts/apg_experiments/prepare_baselines.py +++ b/scripts/apg_experiments/prepare_baselines.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd +from PIL import Image import imageio.v3 as imageio from elf.evaluation import mean_segmentation_accuracy, matching @@ -30,18 +31,15 @@ def run_baseline_engine(image, method, **kwargs): # Newer SAM methods. elif method == "sam3": # TODO: Wrap this out in a modular function too? - from PIL import Image - from sam3.model_builder import build_sam3_image_model - from sam3.model.sam3_image_processor import Sam3Processor - model = build_sam3_image_model() - processor = Sam3Processor(model) + processor = kwargs["processor"] + # Set the image to the processor inference_state = processor.set_image(Image.fromarray(_to_image(image))) # Prompt the model with text processor.reset_all_prompts(inference_state) segmentation = processor.set_text_prompt(state=inference_state, prompt=kwargs["prompt"]) segmentation = segmentation["masks"] # Get the masks only. - if len(segmentation) == 0: + if len(segmentation) == 0: # Handles when no objects are segmented. segmentation = np.zeros(image.shape[:2], dtype="uint32") else: # HACK: Let's get a cheap merging strategy segmentation = segmentation.squeeze(1).detach().cpu().numpy() @@ -72,9 +70,11 @@ def run_default_baselines(dataset_name, method, model_type, experiment_folder, t os.makedirs(res_folder, exist_ok=True) os.makedirs(inference_folder, exist_ok=True) - csv_path = os.path.join(experiment_folder, "results", f"{dataset_name}_{method}_{model_type}.csv") + fnext = (target if model_type == "sam3" else model_type) + csv_path = os.path.join(res_folder, f"{dataset_name}_{method}_{fnext}.csv") if os.path.exists(csv_path): - print(f"The results are computed and stored at {csv_path}") + print(pd.read_csv(csv_path)) + print(f"The results are computed and stored at '{csv_path}'.") return # Get the image and label paths. @@ -83,7 +83,7 @@ def run_default_baselines(dataset_name, method, model_type, experiment_folder, t assert isinstance(method, str) kwargs = {} if method in ["ais", "amg"]: - predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, amg=(method == "amg")) + predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, segmentation_mode="amg") kwargs["predictor"] = predictor kwargs["segmenter"] = segmenter elif method == "apg": @@ -99,6 +99,10 @@ def run_default_baselines(dataset_name, method, model_type, experiment_folder, t elif method == "sam2": kwargs["model_type"] = model_type elif method == "sam3": + from micro_sam3.util import get_sam3_model + from sam3.model.sam3_image_processor import Sam3Processor + model = get_sam3_model(input_type="image") + kwargs["processor"] = Sam3Processor(model) kwargs["prompt"] = target msas, sa50s, precisions, recalls, f1s = [], [], [], [], [] @@ -135,7 +139,8 @@ def run_default_baselines(dataset_name, method, model_type, experiment_folder, t } results = pd.DataFrame.from_dict([results]) results.to_csv(csv_path) - print(f"The results are stored at {csv_path}") + print(results) + print(f"The results above are stored at '{csv_path}'.") def main(args):