diff --git a/scripts/apg_experiments/plotting/calculate_mean.py b/scripts/apg_experiments/plotting/calculate_mean.py new file mode 100644 index 00000000..70a24e26 --- /dev/null +++ b/scripts/apg_experiments/plotting/calculate_mean.py @@ -0,0 +1,75 @@ +import numpy as np + +from plot_util import ( + msa_results, + msa_results_fluorescence, + msa_results_label_free, + msa_results_histopathology, +) + + +AVG_METHODS = [ + "AMG (vit_b) - without grid search", + "AIS - without grid search", + "SAM3", + "CellPose3", + "CellPoseSAM", + "CellSAM", + "APG - without grid search (cc)", +] + +AVG_DISPLAY_NAME_MAP = { + "AMG (vit_b) - without grid search": "SAM", + "AIS - without grid search": "AIS (µSAM)", + "SAM3": "SAM3", + "CellPose3": "CellPose 3", + "CellPoseSAM": "CellPoseSAM", + "CellSAM": "CellSAM", + "APG - without grid search (cc)": "APG (µSAM)", +} + + +def compute_method_means(msa_results_dict, methods_filter=None): + vals_per_method = {} + + for _, entries in msa_results_dict.items(): + for e in entries: + m = e["method"] + v = e["mSA"] + if methods_filter is not None and m not in methods_filter: + continue + if v is None: + continue + vals_per_method.setdefault(m, []).append(float(v)) + + mean_msa = {m: float(np.mean(vs)) for m, vs in vals_per_method.items()} + return mean_msa + + +def format_float(v): + return f"{v:.3f}" + + +if __name__ == "__main__": + datasets = { + "Label-Free (Cell)": msa_results_label_free, + "Fluorescence (Cell)": msa_results_fluorescence, + "Fluorescence (Nucleus)": msa_results, + "Histopathology (Nucleus)": msa_results_histopathology, + } + + means_by_modality = {} + for mod_name, data in datasets.items(): + means_by_modality[mod_name] = compute_method_means(data, methods_filter=AVG_METHODS) + + # Print a compact validation table + col_order = list(datasets.keys()) + print("Averages (mSA) per modality:") + print("Method".ljust(22), *(c[:18].rjust(18) for c in col_order)) + for m in AVG_METHODS: + disp = AVG_DISPLAY_NAME_MAP.get(m, m) + row = [disp.ljust(22)] + for mod in col_order: + v = means_by_modality[mod].get(m, np.nan) + row.append((format_float(v) if np.isfinite(v) else "--").rjust(18)) + print("".join(row)) diff --git a/scripts/apg_experiments/plotting/plot_ablation.py b/scripts/apg_experiments/plotting/plot_ablation.py index 3ce5060f..360eff4c 100644 --- a/scripts/apg_experiments/plotting/plot_ablation.py +++ b/scripts/apg_experiments/plotting/plot_ablation.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt from matplotlib.patches import Patch -from util import ( +from plot_util import ( msa_results, msa_results_fluorescence, msa_results_label_free, @@ -18,10 +18,10 @@ APG_GS_CC = "APG - with grid search (cc)" plt.rcParams.update({ - "axes.titlesize": 10, + "axes.titlesize": 11, "axes.labelsize": 9, - "xtick.labelsize": 8, - "ytick.labelsize": 8, + "xtick.labelsize": 11, + "ytick.labelsize": 10, }) @@ -142,34 +142,10 @@ def max_improvement(d): if max_val > 0: ax.axhspan(0, max_val, color="#e0f3db", alpha=0.8, zorder=0) - bars_wo_bd = ax.bar( - x - 1.5 * width, - rel_wo_bd, - width, - color=color_wo_bd, - zorder=1, - ) - bars_wo_cc = ax.bar( - x - 0.5 * width, - rel_wo_cc, - width, - color=color_wo_cc, - zorder=1, - ) - bars_gs_bd = ax.bar( - x + 0.5 * width, - rel_gs_bd, - width, - color=color_gs_bd, - zorder=1, - ) - bars_gs_cc = ax.bar( - x + 1.5 * width, - rel_gs_cc, - width, - color=color_gs_cc, - zorder=1, - ) + bars_wo_bd = ax.bar(x - 1.5 * width, rel_wo_bd, width, color=color_wo_bd, zorder=1) + bars_wo_cc = ax.bar(x - 0.5 * width, rel_wo_cc, width, color=color_wo_cc, zorder=1) + bars_gs_bd = ax.bar(x + 0.5 * width, rel_gs_bd, width, color=color_gs_bd, zorder=1) + bars_gs_cc = ax.bar(x + 1.5 * width, rel_gs_cc, width, color=color_gs_cc, zorder=1) best_mask_wo_bd = [] best_mask_wo_cc = [] @@ -177,12 +153,7 @@ def max_improvement(d): best_mask_gs_cc = [] for i in range(len(datasets)): - vals_i = [ - rel_wo_bd[i], - rel_wo_cc[i], - rel_gs_bd[i], - rel_gs_cc[i], - ] + vals_i = [rel_wo_bd[i], rel_wo_cc[i], rel_gs_bd[i], rel_gs_cc[i]] max_i = max(vals_i) best_mask_wo_bd.append(rel_wo_bd[i] == max_i) @@ -216,7 +187,7 @@ def annotate_bars(bars, vals, best_mask): f"{v:+.1f}%", ha="center", va=va, - fontsize=7, + fontsize=9, rotation=90, fontweight=fontweight, ) @@ -234,31 +205,27 @@ def annotate_bars(bars, vals, best_mask): fig.tight_layout(rect=[0.06, 0.18, 1, 0.97]) fig.text( - 0.05, 0.55, + 0.06, 0.575, "Relative Mean Segmentation Accuracy (compared to AIS)", va="center", ha="center", rotation="vertical", - fontsize=11, + fontsize=10, fontweight="bold", ) legend_patches = [ - Patch(facecolor=color_wo_bd, edgecolor="black", - label="APG (Boundary Distance) - Default"), - Patch(facecolor=color_wo_cc, edgecolor="black", - label="APG (Components) - Default"), - Patch(facecolor=color_gs_bd, edgecolor="black", - label="APG (Boundary) - GS"), - Patch(facecolor=color_gs_cc, edgecolor="black", - label="APG (Components) - GS"), + Patch(facecolor=color_wo_bd, label="APG (Boundary Distance) - Default"), + Patch(facecolor=color_wo_cc, label="APG (Components) - Default"), + Patch(facecolor=color_gs_bd, label="APG (Boundary Distance) - Grid Search"), + Patch(facecolor=color_gs_cc, label="APG (Components) - Grid Search"), ] fig.legend( handles=legend_patches, loc="lower center", - bbox_to_anchor=(0.5, 0.125), - ncol=4, + bbox_to_anchor=(0.5, 0.1), + ncol=2, fontsize=8, frameon=True, ) diff --git a/scripts/apg_experiments/plotting/plot_average.py b/scripts/apg_experiments/plotting/plot_average.py index 150e054a..53c3c169 100644 --- a/scripts/apg_experiments/plotting/plot_average.py +++ b/scripts/apg_experiments/plotting/plot_average.py @@ -3,7 +3,7 @@ import numpy as np import matplotlib.pyplot as plt -from util import ( +from plot_util import ( msa_results, msa_results_fluorescence, msa_results_label_free, @@ -23,23 +23,24 @@ APG_METHOD = "APG - without grid search (cc)" AVG_DISPLAY_NAME_MAP = { - "AMG (vit_b) - without grid search": "SAM", - "AIS - without grid search": r"$\mu$SAM", + "AMG (vit_b) - without grid search": "AMG (SAM)", + "AIS - without grid search": "AIS (µSAM)", "SAM3": "SAM3", "CellPose3": "CellPose 3", "CellPoseSAM": "CellPoseSAM", "CellSAM": "CellSAM", - "APG - without grid search (cc)": "APG", + "APG - without grid search (cc)": "APG (µSAM)", } AVG_DISPLAY_NAME_MAP_HISTO = AVG_DISPLAY_NAME_MAP.copy() -AVG_DISPLAY_NAME_MAP_HISTO["AIS - without grid search"] = "PathoSAM" +AVG_DISPLAY_NAME_MAP_HISTO["AIS - without grid search"] = "AIS (PathoSAM)" +AVG_DISPLAY_NAME_MAP_HISTO["APG - without grid search (cc)"] = "APG \n (PathoSAM)" plt.rcParams.update({ - "axes.titlesize": 10, - "axes.labelsize": 9, - "xtick.labelsize": 8, - "ytick.labelsize": 8, + "axes.titlesize": 11, + "axes.labelsize": 10, + "xtick.labelsize": 10, + "ytick.labelsize": 10, }) @@ -217,7 +218,7 @@ def filtered_methods_for_modality(mean_dict): disp_names = [disp_map[m] for m in methods] ax.set_xticks(x) - ax.set_xticklabels(disp_names, rotation=60, ha="right") + ax.set_xticklabels(disp_names, rotation=45, ha="right") xticklabels = ax.get_xticklabels() for idx_lbl, lbl in enumerate(xticklabels): @@ -231,7 +232,7 @@ def filtered_methods_for_modality(mean_dict): ax.set_ylim(0.0, 1.0) fig.text( - 0.06, 0.5, + 0.05, 0.575, "Mean Segmentation Accuracy (mSA)", va="center", ha="center", diff --git a/scripts/apg_experiments/plotting/plot_qualitative.py b/scripts/apg_experiments/plotting/plot_qualitative.py new file mode 100644 index 00000000..cb9b5ba7 --- /dev/null +++ b/scripts/apg_experiments/plotting/plot_qualitative.py @@ -0,0 +1,131 @@ +import os +import sys +from glob import glob +from tqdm import tqdm +from natsort import natsorted + +import imageio.v3 as imageio +import matplotlib.pyplot as plt + +from torch_em.util.util import get_random_colors + +from elf.evaluation import mean_segmentation_accuracy + +from micro_sam.evaluation.model_comparison import _overlay_outline +from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation + + +sys.path.append("..") + + +ROOT = "/mnt/vast-nhr/projects/cidas/cca" + + +def plot_quali(dataset_name, model_type): + # Get APG result paths. + pred_paths = natsorted( + glob( + os.path.join( + ROOT, "experiments", "micro_sam", + "apg_baselines", "cc_without_box", "inference", f"{dataset_name}_apg_*", + "*" + ) + ) + ) + + # Get the image and label paths + from util import get_image_label_paths + image_paths, label_paths = get_image_label_paths(dataset_name=dataset_name, split="test") + + # HACK: Sometimes there's just too many images. Just compare over first 500. + image_paths, label_paths, pred_paths = image_paths[:500], label_paths[:500], pred_paths[:500] + + # Get the cellpose model to run CPSAM on the fly. + from cellpose import models + model = models.CellposeModel(gpu=True, model_type="cpsam") + + # Let's iterate over each image, label and predictions. + results = [] + for image_path, label_path, pred_path in tqdm( + zip(image_paths, label_paths, pred_paths), desc="Qualitative analysis", total=len(image_paths) + ): + # Load everything first + image = imageio.imread(image_path) + label = imageio.imread(label_path) + pred = imageio.imread(pred_path) + + assert label.shape == pred.shape + + # Let's run CellPoseSAM to get the best relative results for APG. + seg, _, _ = model.eval(image) + + # Compare both + apg_msa = mean_segmentation_accuracy(pred, label) + cpsam_msa = mean_segmentation_accuracy(seg, label) + + # Next, find the relative score and store it for the dataset + diff = apg_msa - cpsam_msa + results.append((diff, image_path, label_path, pred_path, apg_msa, cpsam_msa)) + + # Let's fetch top-k where APG wins + k = 10 + results.sort(key=lambda x: abs(x[0]), reverse=True) + top_k = results[:k] + + # Prepare other model stuff + # SAM-related models + predictor_amg, segmenter_amg = get_predictor_and_segmenter(model_type="vit_b", segmentation_mode="amg") + predictor_ais, segmenter_ais = get_predictor_and_segmenter(model_type=model_type, segmentation_mode="ais") + + # Plot each of the top-k images as a separate horizontal triplet + os.makedirs(f"./quali_figures/{dataset_name}", exist_ok=True) + for rank, (diff, image_path, label_path, pred_path, apg_msa, cpsam_msa) in enumerate(top_k, 1): + from micro_sam.util import _to_image + image = _to_image(imageio.imread(image_path)) + gt = imageio.imread(label_path) + + # APG prediction is read from disk, CPSAM is recomputed for plotting + amg_pred = automatic_instance_segmentation( + input_path=image, verbose=False, ndim=2, predictor=predictor_amg, segmenter=segmenter_amg, + ) + ais_pred = automatic_instance_segmentation( + input_path=image, verbose=False, ndim=2, predictor=predictor_ais, segmenter=segmenter_ais, + ) + apg_pred = imageio.imread(pred_path) + cpsam_pred, _, _ = model.eval(image) + + from cellSAM import cellsam_pipeline + cellsam_pred = cellsam_pipeline(image, use_wsi=False) + if cellsam_pred.ndim == 3: # NOTE: For images with no objects found, a weird segmentation is returned. + cellsam_pred = cellsam_pred[0] + + # Prepare image as expected (normalize and overlay labels) + curr_image = _overlay_outline(image, gt, outline_dilation=1) + + fig, axes = plt.subplots(1, 6, figsize=(15, 5), constrained_layout=True) + for ax in axes: + ax.set_axis_off() + + axes[0].imshow(curr_image[:256, :256]) + axes[1].imshow(amg_pred[:256, :256], cmap=get_random_colors(amg_pred)) + axes[2].imshow(ais_pred[:256, :256], cmap=get_random_colors(ais_pred)) + axes[3].imshow(cpsam_pred[:256, :256], cmap=get_random_colors(cpsam_pred)) + axes[4].imshow(cellsam_pred[:256, :256], cmap=get_random_colors(cellsam_pred)) + axes[5].imshow(apg_pred[:256, :256], cmap=get_random_colors(apg_pred)) + + plt.savefig(f"./quali_figures/{dataset_name}/{rank}.png", dpi=400, bbox_inches="tight") + plt.savefig(f"./quali_figures/{dataset_name}/{rank}.svg", dpi=400, bbox_inches="tight") + plt.close() + + +def main(args): + plot_quali(args.dataset, args.model_type) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--dataset", type=str, required=True) + parser.add_argument("-m", "--model_type", type=str, default="vit_b_lm") # NOTE: For AIS + args = parser.parse_args() + main(args) diff --git a/scripts/apg_experiments/plotting/plot_quantitative.py b/scripts/apg_experiments/plotting/plot_quantitative.py index 014ea083..934b5f32 100644 --- a/scripts/apg_experiments/plotting/plot_quantitative.py +++ b/scripts/apg_experiments/plotting/plot_quantitative.py @@ -3,7 +3,7 @@ import numpy as np import matplotlib.pyplot as plt -from util import ( +from plot_util import ( msa_results, msa_results_fluorescence, msa_results_label_free, @@ -16,12 +16,14 @@ DISPLAY_NAME_MAP_HISTO["AMG - without grid search"] = "AMG (PathoSAM)" if "AIS - without grid search" in DISPLAY_NAME_MAP_HISTO: DISPLAY_NAME_MAP_HISTO["AIS - without grid search"] = "AIS (PathoSAM)" +if "APG - without grid search (cc)" in DISPLAY_NAME_MAP_HISTO: + DISPLAY_NAME_MAP_HISTO["APG - without grid search (cc)"] = r"$\mathbf{APG}$ " + "\n" + r"$\mathbf{(PathoSAM)}$ " plt.rcParams.update({ - "axes.titlesize": 10, - "axes.labelsize": 9, - "xtick.labelsize": 8, - "ytick.labelsize": 8, + "axes.titlesize": 15, + "axes.labelsize": 12, + "xtick.labelsize": 11, + "ytick.labelsize": 11, }) @@ -40,13 +42,12 @@ def plot_msa_grid( fig, axes = plt.subplots( nrows=nrows, ncols=ncols, - figsize=(figsize_per_subplot[0] * ncols, - figsize_per_subplot[1] * nrows), + figsize=(figsize_per_subplot[0] * ncols, figsize_per_subplot[1] * nrows), sharey=True, ) axes = np.array(axes).reshape(-1) - color_top1 = "#1f77b4" # darkest + color_top1 = "#1f77b4" color_top2 = "#6baed6" color_top3 = "#c6dbef" color_rest = "#d9d9d9" @@ -96,8 +97,7 @@ def plot_msa_grid( if IN_DOMAIN.get((dataset, raw_name), False): bar.set_hatch("////") - apg_indices = [i for i, name in enumerate(raw_methods) - if name.startswith("APG")] + apg_indices = [i for i, name in enumerate(raw_methods) if name.startswith("APG")] ais_idx = None for i, name in enumerate(raw_methods): @@ -119,14 +119,7 @@ def plot_msa_grid( bar_height = v y_text = min(bar_height + 0.01, 0.98) - ax.text( - x[i], - y_text, - f"{v:.3f}", - ha="center", - va="bottom", - fontsize=8, - ) + ax.text(x[i], y_text, f"{v:.3f}", ha="center", va="bottom", fontsize=9) if ais_value is not None: for i in apg_indices: @@ -145,44 +138,30 @@ def plot_msa_grid( y_text = min(bar_height + 0.02, 0.99) ax.text( - x[i], - y_text, - f"{diff_pct:+.1f}%", - ha="center", - va="bottom", - fontsize=8, - fontweight="bold", - color=color, + x[i], y_text, f"{diff_pct:+.1f}%", ha="center", + va="bottom", fontsize=10, fontweight="bold", color=color, ) ax.set_title(dataset, fontweight="bold") ax.set_xticks(x) - ax.set_xticklabels(methods, rotation=60, ha="right") + ax.set_xticklabels(methods, rotation=45, ha="right") ax.set_ylim(0.0, 1.0) for j in range(len(datasets), len(axes)): fig.delaxes(axes[j]) - if suptitle is not None: - fig.tight_layout(rect=[0.08, 0, 1, 0.95]) - else: - fig.tight_layout(rect=[0.08, 0, 1, 1]) - fig.text( - 0.075, 0.5, - "Mean Segmentation Accuracy", - va="center", - ha="center", - rotation="vertical", - fontsize=11, - fontweight="bold", + 0.02, 0.525, "Mean Segmentation Accuracy", va="center", ha="center", + rotation="vertical", fontsize=14, fontweight="bold", ) + fig.subplots_adjust(left=0.075, right=0.98, bottom=0.125, top=0.95, wspace=0.1, hspace=0.9) + if save_path is not None: - fig.savefig(save_path, bbox_inches="tight", dpi=300) + fig.savefig(save_path, dpi=300) root, _ = os.path.splitext(save_path) svg_path = root + ".svg" - fig.savefig(svg_path, bbox_inches="tight") + fig.savefig(svg_path) return fig, axes diff --git a/scripts/apg_experiments/plotting/util.py b/scripts/apg_experiments/plotting/plot_util.py similarity index 94% rename from scripts/apg_experiments/plotting/util.py rename to scripts/apg_experiments/plotting/plot_util.py index 7df7d0e4..1bf71c51 100644 --- a/scripts/apg_experiments/plotting/util.py +++ b/scripts/apg_experiments/plotting/plot_util.py @@ -1,3 +1,6 @@ +# NOTE: All SAM3 prompts are the same: "cell" (matching with the training routine) + + msa_results = { "DSB": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.331}, @@ -10,7 +13,7 @@ {"method": "CellPose3", "mSA": 0.484}, {"method": "CellPoseSAM", "mSA": 0.656}, {"method": "CellSAM", "mSA": 0.634}, - {"method": "SAM3", "mSA": 0.383}, + {"method": "SAM3", "mSA": 0.367}, ], "U20S": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.258}, @@ -23,7 +26,7 @@ {"method": "CellPose3", "mSA": 0.787}, {"method": "CellPoseSAM", "mSA": 0.787}, {"method": "CellSAM", "mSA": 0.673}, - {"method": "SAM3", "mSA": 0.636}, + {"method": "SAM3", "mSA": 0.674}, ], "Arvidsson": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.416}, @@ -36,7 +39,7 @@ {"method": "CellPose3", "mSA": 0.611}, {"method": "CellPoseSAM", "mSA": 0.484}, {"method": "CellSAM", "mSA": 0.434}, - {"method": "SAM3", "mSA": 0.488}, + {"method": "SAM3", "mSA": 0.297}, ], "BitDepth NucSeg": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.224}, @@ -49,7 +52,7 @@ {"method": "CellPose3", "mSA": 0.302}, {"method": "CellPoseSAM", "mSA": 0.377}, {"method": "CellSAM", "mSA": 0.168}, - {"method": "SAM3", "mSA": 0.201}, + {"method": "SAM3", "mSA": 0.182}, ], "IFNuclei": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.293}, @@ -62,7 +65,7 @@ {"method": "CellPose3", "mSA": 0.404}, {"method": "CellPoseSAM", "mSA": 0.728}, {"method": "CellSAM", "mSA": 0.589}, - {"method": "SAM3", "mSA": 0.405}, + {"method": "SAM3", "mSA": 0.301}, ], "DynamicNuclearNet": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.298}, @@ -75,7 +78,7 @@ {"method": "CellPose3", "mSA": 0.512}, {"method": "CellPoseSAM", "mSA": 0.379}, {"method": "CellSAM", "mSA": 0.455}, - {"method": "SAM3", "mSA": 0.376}, + {"method": "SAM3", "mSA": 0.346}, ], "GoNuclear": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.339}, @@ -88,7 +91,7 @@ {"method": "CellPose3", "mSA": 0.447}, {"method": "CellPoseSAM", "mSA": 0.415}, {"method": "CellSAM", "mSA": 0.112}, - {"method": "SAM3", "mSA": 0.132}, + {"method": "SAM3", "mSA": 0.034}, ], "NIS3D": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.216}, @@ -101,7 +104,7 @@ {"method": "CellPose3", "mSA": 0.255}, {"method": "CellPoseSAM", "mSA": 0.246}, {"method": "CellSAM", "mSA": 0.264}, - {"method": "SAM3", "mSA": 0.008}, + {"method": "SAM3", "mSA": 0.031}, ], "Parhyale Regen": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.173}, @@ -114,7 +117,7 @@ {"method": "CellPose3", "mSA": 0.138}, {"method": "CellPoseSAM", "mSA": 0.272}, {"method": "CellSAM", "mSA": 0.144}, - {"method": "SAM3", "mSA": 0.039}, + {"method": "SAM3", "mSA": 0.063}, ], } @@ -130,7 +133,7 @@ {"method": "CellPose3", "mSA": 0.154}, {"method": "CellPoseSAM", "mSA": 0.475}, {"method": "CellSAM", "mSA": 0.345}, - {"method": "SAM3", "mSA": 0.089}, + {"method": "SAM3", "mSA": 0.121}, ], "CellPose": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.147}, @@ -143,7 +146,7 @@ {"method": "CellPose3", "mSA": 0.431}, {"method": "CellPoseSAM", "mSA": 0.566}, {"method": "CellSAM", "mSA": 0.413}, - {"method": "SAM3", "mSA": 0.229}, + {"method": "SAM3", "mSA": 0.299}, ], "PlantSeg (Root)": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.091}, @@ -156,7 +159,7 @@ {"method": "CellPose3", "mSA": 0.076}, {"method": "CellPoseSAM", "mSA": 0.161}, {"method": "CellSAM", "mSA": 0.096}, - {"method": "SAM3", "mSA": 0.023}, + {"method": "SAM3", "mSA": 0.067}, ], "PlantSeg (Ovules)": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.135}, @@ -169,7 +172,7 @@ {"method": "CellPose3", "mSA": 0.266}, {"method": "CellPoseSAM", "mSA": 0.331}, {"method": "CellSAM", "mSA": 0.333}, - {"method": "SAM3", "mSA": 0.071}, + {"method": "SAM3", "mSA": 0.184}, ], "PNAS Arabidopsis": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.145}, @@ -182,7 +185,7 @@ {"method": "CellPose3", "mSA": 0.411}, {"method": "CellPoseSAM", "mSA": 0.471}, {"method": "CellSAM", "mSA": 0.459}, - {"method": "SAM3", "mSA": 0.073}, + {"method": "SAM3", "mSA": 0.241}, ], "Covid-IF": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.007}, @@ -195,7 +198,7 @@ {"method": "CellPose3", "mSA": 0.161}, {"method": "CellPoseSAM", "mSA": 0.333}, {"method": "CellSAM", "mSA": 0.154}, - {"method": "SAM3", "mSA": 0.004}, + {"method": "SAM3", "mSA": 0.005}, ], "HPA": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.043}, @@ -208,7 +211,7 @@ {"method": "CellPose3", "mSA": 0.078}, {"method": "CellPoseSAM", "mSA": 0.431}, {"method": "CellSAM", "mSA": 0.301}, - {"method": "SAM3", "mSA": 0.152}, + {"method": "SAM3", "mSA": 0.155}, ], "CellBinDB": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.177}, @@ -221,7 +224,7 @@ {"method": "CellPose3", "mSA": 0.279}, {"method": "CellPoseSAM", "mSA": 0.342}, {"method": "CellSAM", "mSA": 0.264}, - {"method": "SAM3", "mSA": 0.136}, + {"method": "SAM3", "mSA": 0.137}, ], "Mouse Embryo": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.003}, @@ -234,7 +237,7 @@ {"method": "CellPose3", "mSA": 0.109}, {"method": "CellPoseSAM", "mSA": 0.155}, {"method": "CellSAM", "mSA": 0.083}, - {"method": "SAM3", "mSA": 0.073}, + {"method": "SAM3", "mSA": 0.081}, ], } @@ -250,7 +253,7 @@ {"method": "CellPose3", "mSA": 0.414}, {"method": "CellPoseSAM", "mSA": 0.444}, {"method": "CellSAM", "mSA": 0.098}, - {"method": "SAM3", "mSA": 0.313}, + {"method": "SAM3", "mSA": 0.331}, ], "OmniPose": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.137}, @@ -263,7 +266,7 @@ {"method": "CellPose3", "mSA": 0.468}, {"method": "CellPoseSAM", "mSA": 0.644}, {"method": "CellSAM", "mSA": 0.531}, - {"method": "SAM3", "mSA": 0.131}, + {"method": "SAM3", "mSA": 0.356}, ], "DeepBacs": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.057}, @@ -276,7 +279,7 @@ {"method": "CellPose3", "mSA": 0.455}, {"method": "CellPoseSAM", "mSA": 0.612}, {"method": "CellSAM", "mSA": 0.441}, - {"method": "SAM3", "mSA": 0.083}, + {"method": "SAM3", "mSA": 0.157}, ], "Usiigaci": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.051}, @@ -289,7 +292,7 @@ {"method": "CellPose3", "mSA": 0.291}, {"method": "CellPoseSAM", "mSA": 0.445}, {"method": "CellSAM", "mSA": 0.167}, - {"method": "SAM3", "mSA": 0.377}, + {"method": "SAM3", "mSA": 0.362}, ], "Vicar": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.115}, @@ -302,7 +305,7 @@ {"method": "CellPose3", "mSA": 0.338}, {"method": "CellPoseSAM", "mSA": 0.458}, {"method": "CellSAM", "mSA": 0.426}, - {"method": "SAM3", "mSA": 0.122}, + {"method": "SAM3", "mSA": 0.086}, ], "TOIAM": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.009}, @@ -315,7 +318,7 @@ {"method": "CellPose3", "mSA": 0.837}, {"method": "CellPoseSAM", "mSA": 0.898}, {"method": "CellSAM", "mSA": 0.631}, - {"method": "SAM3", "mSA": 0.0}, + {"method": "SAM3", "mSA": 0.027}, ], "DeepSeas": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.098}, @@ -328,7 +331,7 @@ {"method": "CellPose3", "mSA": 0.191}, {"method": "CellPoseSAM", "mSA": 0.345}, {"method": "CellSAM", "mSA": 0.203}, - {"method": "SAM3", "mSA": 0.227}, + {"method": "SAM3", "mSA": 0.277}, ], "YeaZ": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.382}, @@ -341,7 +344,7 @@ {"method": "CellPose3", "mSA": 0.817}, {"method": "CellPoseSAM", "mSA": 0.873}, {"method": "CellSAM", "mSA": 0.853}, - {"method": "SAM3", "mSA": 0.764}, + {"method": "SAM3", "mSA": 0.723}, ], "SegPC": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.027}, @@ -354,7 +357,7 @@ {"method": "CellPose3", "mSA": 0.001}, {"method": "CellPoseSAM", "mSA": 0.178}, {"method": "CellSAM", "mSA": 0.069}, - {"method": "SAM3", "mSA": 0.152}, + {"method": "SAM3", "mSA": 0.106}, ], } @@ -370,7 +373,7 @@ {"method": "CellPose3", "mSA": 0.152}, {"method": "CellPoseSAM", "mSA": 0.342}, {"method": "CellSAM", "mSA": 0.244}, - {"method": "SAM3", "mSA": 0.159}, + {"method": "SAM3", "mSA": 0.341}, ], "IHC TMA": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.236}, @@ -383,7 +386,7 @@ {"method": "CellPose3", "mSA": 0.297}, {"method": "CellPoseSAM", "mSA": 0.452}, {"method": "CellSAM", "mSA": 0.333}, - {"method": "SAM3", "mSA": 0.353}, + {"method": "SAM3", "mSA": 0.435}, ], "LynSec": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.233}, @@ -396,7 +399,7 @@ {"method": "CellPose3", "mSA": 0.163}, {"method": "CellPoseSAM", "mSA": 0.561}, {"method": "CellSAM", "mSA": 0.213}, - {"method": "SAM3", "mSA": 0.003}, + {"method": "SAM3", "mSA": 0.157}, ], "MoNuSeg": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.182}, @@ -409,7 +412,7 @@ {"method": "CellPose3", "mSA": 0.125}, {"method": "CellPoseSAM", "mSA": 0.373}, {"method": "CellSAM", "mSA": 0.302}, - {"method": "SAM3", "mSA": 0.207}, + {"method": "SAM3", "mSA": 0.345}, ], "NuInsSeg": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.161}, @@ -422,7 +425,7 @@ {"method": "CellPose3", "mSA": 0.144}, {"method": "CellPoseSAM", "mSA": 0.349}, {"method": "CellSAM", "mSA": 0.229}, - {"method": "SAM3", "mSA": 0.287}, + {"method": "SAM3", "mSA": 0.312}, ], "PUMA": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.232}, @@ -435,7 +438,7 @@ {"method": "CellPose3", "mSA": 0.101}, {"method": "CellPoseSAM", "mSA": 0.501}, {"method": "CellSAM", "mSA": 0.294}, - {"method": "SAM3", "mSA": 0.324}, + {"method": "SAM3", "mSA": 0.415}, ], "TNBC": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.209}, @@ -448,7 +451,7 @@ {"method": "CellPose3", "mSA": 0.075}, {"method": "CellPoseSAM", "mSA": 0.451}, {"method": "CellSAM", "mSA": 0.383}, - {"method": "SAM3", "mSA": 0.359}, + {"method": "SAM3", "mSA": 0.443}, ], "CryoNuSeg": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.165}, @@ -461,7 +464,7 @@ {"method": "CellPose3", "mSA": 0.113}, {"method": "CellPoseSAM", "mSA": 0.295}, {"method": "CellSAM", "mSA": 0.177}, - {"method": "SAM3", "mSA": 0.079}, + {"method": "SAM3", "mSA": 0.118}, ], "CytoDark0": [ {"method": "AMG (vit_b) - without grid search", "mSA": 0.182}, @@ -474,19 +477,19 @@ {"method": "CellPose3", "mSA": 0.222}, {"method": "CellPoseSAM", "mSA": 0.441}, {"method": "CellSAM", "mSA": 0.315}, - {"method": "SAM3", "mSA": 0.369}, + {"method": "SAM3", "mSA": 0.414}, ], } DISPLAY_NAME_MAP = { "AMG (vit_b) - without grid search": "AMG (SAM)", # "AMG - without grid search": r"AMG ($\mu$SAM)", - "AIS - without grid search": r"AIS ($\mu$SAM)", + "AIS - without grid search": "AIS (µSAM)", "CellPose3": "CellPose 3", "CellPoseSAM": "CellPoseSAM", "CellSAM": "CellSAM", "SAM3": "SAM3", - "APG - without grid search (cc)": r"$\mathbf{APG}$", + "APG - without grid search (cc)": r"$\mathbf{APG}$ " + r"$\mathbf{(µSAM)}$ ", # "APG - with grid search (cc)": r"$\mathbf{APG^{*}}$", } diff --git a/scripts/apg_experiments/prepare_baselines.py b/scripts/apg_experiments/prepare_baselines.py index 90271539..efeb2d92 100644 --- a/scripts/apg_experiments/prepare_baselines.py +++ b/scripts/apg_experiments/prepare_baselines.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd +from PIL import Image import imageio.v3 as imageio from elf.evaluation import mean_segmentation_accuracy, matching @@ -30,18 +31,15 @@ def run_baseline_engine(image, method, **kwargs): # Newer SAM methods. elif method == "sam3": # TODO: Wrap this out in a modular function too? - from PIL import Image - from sam3.model_builder import build_sam3_image_model - from sam3.model.sam3_image_processor import Sam3Processor - model = build_sam3_image_model() - processor = Sam3Processor(model) + processor = kwargs["processor"] + # Set the image to the processor inference_state = processor.set_image(Image.fromarray(_to_image(image))) # Prompt the model with text processor.reset_all_prompts(inference_state) segmentation = processor.set_text_prompt(state=inference_state, prompt=kwargs["prompt"]) segmentation = segmentation["masks"] # Get the masks only. - if len(segmentation) == 0: + if len(segmentation) == 0: # Handles when no objects are segmented. segmentation = np.zeros(image.shape[:2], dtype="uint32") else: # HACK: Let's get a cheap merging strategy segmentation = segmentation.squeeze(1).detach().cpu().numpy() @@ -72,9 +70,11 @@ def run_default_baselines(dataset_name, method, model_type, experiment_folder, t os.makedirs(res_folder, exist_ok=True) os.makedirs(inference_folder, exist_ok=True) - csv_path = os.path.join(experiment_folder, "results", f"{dataset_name}_{method}_{model_type}.csv") + fnext = (target if model_type == "sam3" else model_type) + csv_path = os.path.join(res_folder, f"{dataset_name}_{method}_{fnext}.csv") if os.path.exists(csv_path): - print(f"The results are computed and stored at {csv_path}") + print(pd.read_csv(csv_path)) + print(f"The results are computed and stored at '{csv_path}'.") return # Get the image and label paths. @@ -83,7 +83,7 @@ def run_default_baselines(dataset_name, method, model_type, experiment_folder, t assert isinstance(method, str) kwargs = {} if method in ["ais", "amg"]: - predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, amg=(method == "amg")) + predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, segmentation_mode="amg") kwargs["predictor"] = predictor kwargs["segmenter"] = segmenter elif method == "apg": @@ -99,6 +99,10 @@ def run_default_baselines(dataset_name, method, model_type, experiment_folder, t elif method == "sam2": kwargs["model_type"] = model_type elif method == "sam3": + from micro_sam3.util import get_sam3_model + from sam3.model.sam3_image_processor import Sam3Processor + model = get_sam3_model(input_type="image") + kwargs["processor"] = Sam3Processor(model) kwargs["prompt"] = target msas, sa50s, precisions, recalls, f1s = [], [], [], [], [] @@ -135,7 +139,8 @@ def run_default_baselines(dataset_name, method, model_type, experiment_folder, t } results = pd.DataFrame.from_dict([results]) results.to_csv(csv_path) - print(f"The results are stored at {csv_path}") + print(results) + print(f"The results above are stored at '{csv_path}'.") def main(args):