Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions scripts/apg_experiments/plotting/calculate_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import numpy as np

from plot_util import (
msa_results,
msa_results_fluorescence,
msa_results_label_free,
msa_results_histopathology,
)


AVG_METHODS = [
"AMG (vit_b) - without grid search",
"AIS - without grid search",
"SAM3",
"CellPose3",
"CellPoseSAM",
"CellSAM",
"APG - without grid search (cc)",
]

AVG_DISPLAY_NAME_MAP = {
"AMG (vit_b) - without grid search": "SAM",
"AIS - without grid search": "AIS (µSAM)",
"SAM3": "SAM3",
"CellPose3": "CellPose 3",
"CellPoseSAM": "CellPoseSAM",
"CellSAM": "CellSAM",
"APG - without grid search (cc)": "APG (µSAM)",
}


def compute_method_means(msa_results_dict, methods_filter=None):
vals_per_method = {}

for _, entries in msa_results_dict.items():
for e in entries:
m = e["method"]
v = e["mSA"]
if methods_filter is not None and m not in methods_filter:
continue
if v is None:
continue
vals_per_method.setdefault(m, []).append(float(v))

mean_msa = {m: float(np.mean(vs)) for m, vs in vals_per_method.items()}
return mean_msa


def format_float(v):
return f"{v:.3f}"


if __name__ == "__main__":
datasets = {
"Label-Free (Cell)": msa_results_label_free,
"Fluorescence (Cell)": msa_results_fluorescence,
"Fluorescence (Nucleus)": msa_results,
"Histopathology (Nucleus)": msa_results_histopathology,
}

means_by_modality = {}
for mod_name, data in datasets.items():
means_by_modality[mod_name] = compute_method_means(data, methods_filter=AVG_METHODS)

# Print a compact validation table
col_order = list(datasets.keys())
print("Averages (mSA) per modality:")
print("Method".ljust(22), *(c[:18].rjust(18) for c in col_order))
for m in AVG_METHODS:
disp = AVG_DISPLAY_NAME_MAP.get(m, m)
row = [disp.ljust(22)]
for mod in col_order:
v = means_by_modality[mod].get(m, np.nan)
row.append((format_float(v) if np.isfinite(v) else "--").rjust(18))
print("".join(row))
69 changes: 18 additions & 51 deletions scripts/apg_experiments/plotting/plot_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

from util import (
from plot_util import (
msa_results,
msa_results_fluorescence,
msa_results_label_free,
Expand All @@ -18,10 +18,10 @@
APG_GS_CC = "APG - with grid search (cc)"

plt.rcParams.update({
"axes.titlesize": 10,
"axes.titlesize": 11,
"axes.labelsize": 9,
"xtick.labelsize": 8,
"ytick.labelsize": 8,
"xtick.labelsize": 11,
"ytick.labelsize": 10,
})


Expand Down Expand Up @@ -142,47 +142,18 @@ def max_improvement(d):
if max_val > 0:
ax.axhspan(0, max_val, color="#e0f3db", alpha=0.8, zorder=0)

bars_wo_bd = ax.bar(
x - 1.5 * width,
rel_wo_bd,
width,
color=color_wo_bd,
zorder=1,
)
bars_wo_cc = ax.bar(
x - 0.5 * width,
rel_wo_cc,
width,
color=color_wo_cc,
zorder=1,
)
bars_gs_bd = ax.bar(
x + 0.5 * width,
rel_gs_bd,
width,
color=color_gs_bd,
zorder=1,
)
bars_gs_cc = ax.bar(
x + 1.5 * width,
rel_gs_cc,
width,
color=color_gs_cc,
zorder=1,
)
bars_wo_bd = ax.bar(x - 1.5 * width, rel_wo_bd, width, color=color_wo_bd, zorder=1)
bars_wo_cc = ax.bar(x - 0.5 * width, rel_wo_cc, width, color=color_wo_cc, zorder=1)
bars_gs_bd = ax.bar(x + 0.5 * width, rel_gs_bd, width, color=color_gs_bd, zorder=1)
bars_gs_cc = ax.bar(x + 1.5 * width, rel_gs_cc, width, color=color_gs_cc, zorder=1)

best_mask_wo_bd = []
best_mask_wo_cc = []
best_mask_gs_bd = []
best_mask_gs_cc = []

for i in range(len(datasets)):
vals_i = [
rel_wo_bd[i],
rel_wo_cc[i],
rel_gs_bd[i],
rel_gs_cc[i],
]
vals_i = [rel_wo_bd[i], rel_wo_cc[i], rel_gs_bd[i], rel_gs_cc[i]]
max_i = max(vals_i)

best_mask_wo_bd.append(rel_wo_bd[i] == max_i)
Expand Down Expand Up @@ -216,7 +187,7 @@ def annotate_bars(bars, vals, best_mask):
f"{v:+.1f}%",
ha="center",
va=va,
fontsize=7,
fontsize=9,
rotation=90,
fontweight=fontweight,
)
Expand All @@ -234,31 +205,27 @@ def annotate_bars(bars, vals, best_mask):

fig.tight_layout(rect=[0.06, 0.18, 1, 0.97])
fig.text(
0.05, 0.55,
0.06, 0.575,
"Relative Mean Segmentation Accuracy (compared to AIS)",
va="center",
ha="center",
rotation="vertical",
fontsize=11,
fontsize=10,
fontweight="bold",
)

legend_patches = [
Patch(facecolor=color_wo_bd, edgecolor="black",
label="APG (Boundary Distance) - Default"),
Patch(facecolor=color_wo_cc, edgecolor="black",
label="APG (Components) - Default"),
Patch(facecolor=color_gs_bd, edgecolor="black",
label="APG (Boundary) - GS"),
Patch(facecolor=color_gs_cc, edgecolor="black",
label="APG (Components) - GS"),
Patch(facecolor=color_wo_bd, label="APG (Boundary Distance) - Default"),
Patch(facecolor=color_wo_cc, label="APG (Components) - Default"),
Patch(facecolor=color_gs_bd, label="APG (Boundary Distance) - Grid Search"),
Patch(facecolor=color_gs_cc, label="APG (Components) - Grid Search"),
]

fig.legend(
handles=legend_patches,
loc="lower center",
bbox_to_anchor=(0.5, 0.125),
ncol=4,
bbox_to_anchor=(0.5, 0.1),
ncol=2,
fontsize=8,
frameon=True,
)
Expand Down
23 changes: 12 additions & 11 deletions scripts/apg_experiments/plotting/plot_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import matplotlib.pyplot as plt

from util import (
from plot_util import (
msa_results,
msa_results_fluorescence,
msa_results_label_free,
Expand All @@ -23,23 +23,24 @@
APG_METHOD = "APG - without grid search (cc)"

AVG_DISPLAY_NAME_MAP = {
"AMG (vit_b) - without grid search": "SAM",
"AIS - without grid search": r"$\mu$SAM",
"AMG (vit_b) - without grid search": "AMG (SAM)",
"AIS - without grid search": "AIS (µSAM)",
"SAM3": "SAM3",
"CellPose3": "CellPose 3",
"CellPoseSAM": "CellPoseSAM",
"CellSAM": "CellSAM",
"APG - without grid search (cc)": "APG",
"APG - without grid search (cc)": "APG (µSAM)",
}

AVG_DISPLAY_NAME_MAP_HISTO = AVG_DISPLAY_NAME_MAP.copy()
AVG_DISPLAY_NAME_MAP_HISTO["AIS - without grid search"] = "PathoSAM"
AVG_DISPLAY_NAME_MAP_HISTO["AIS - without grid search"] = "AIS (PathoSAM)"
AVG_DISPLAY_NAME_MAP_HISTO["APG - without grid search (cc)"] = "APG \n (PathoSAM)"

plt.rcParams.update({
"axes.titlesize": 10,
"axes.labelsize": 9,
"xtick.labelsize": 8,
"ytick.labelsize": 8,
"axes.titlesize": 11,
"axes.labelsize": 10,
"xtick.labelsize": 10,
"ytick.labelsize": 10,
})


Expand Down Expand Up @@ -217,7 +218,7 @@ def filtered_methods_for_modality(mean_dict):

disp_names = [disp_map[m] for m in methods]
ax.set_xticks(x)
ax.set_xticklabels(disp_names, rotation=60, ha="right")
ax.set_xticklabels(disp_names, rotation=45, ha="right")

xticklabels = ax.get_xticklabels()
for idx_lbl, lbl in enumerate(xticklabels):
Expand All @@ -231,7 +232,7 @@ def filtered_methods_for_modality(mean_dict):
ax.set_ylim(0.0, 1.0)

fig.text(
0.06, 0.5,
0.05, 0.575,
"Mean Segmentation Accuracy (mSA)",
va="center",
ha="center",
Expand Down
131 changes: 131 additions & 0 deletions scripts/apg_experiments/plotting/plot_qualitative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os
import sys
from glob import glob
from tqdm import tqdm
from natsort import natsorted

import imageio.v3 as imageio
import matplotlib.pyplot as plt

from torch_em.util.util import get_random_colors

from elf.evaluation import mean_segmentation_accuracy

from micro_sam.evaluation.model_comparison import _overlay_outline
from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation


sys.path.append("..")


ROOT = "/mnt/vast-nhr/projects/cidas/cca"


def plot_quali(dataset_name, model_type):
# Get APG result paths.
pred_paths = natsorted(
glob(
os.path.join(
ROOT, "experiments", "micro_sam",
"apg_baselines", "cc_without_box", "inference", f"{dataset_name}_apg_*",
"*"
)
)
)

# Get the image and label paths
from util import get_image_label_paths
image_paths, label_paths = get_image_label_paths(dataset_name=dataset_name, split="test")

# HACK: Sometimes there's just too many images. Just compare over first 500.
image_paths, label_paths, pred_paths = image_paths[:500], label_paths[:500], pred_paths[:500]

# Get the cellpose model to run CPSAM on the fly.
from cellpose import models
model = models.CellposeModel(gpu=True, model_type="cpsam")

# Let's iterate over each image, label and predictions.
results = []
for image_path, label_path, pred_path in tqdm(
zip(image_paths, label_paths, pred_paths), desc="Qualitative analysis", total=len(image_paths)
):
# Load everything first
image = imageio.imread(image_path)
label = imageio.imread(label_path)
pred = imageio.imread(pred_path)

assert label.shape == pred.shape

# Let's run CellPoseSAM to get the best relative results for APG.
seg, _, _ = model.eval(image)

# Compare both
apg_msa = mean_segmentation_accuracy(pred, label)
cpsam_msa = mean_segmentation_accuracy(seg, label)

# Next, find the relative score and store it for the dataset
diff = apg_msa - cpsam_msa
results.append((diff, image_path, label_path, pred_path, apg_msa, cpsam_msa))

# Let's fetch top-k where APG wins
k = 10
results.sort(key=lambda x: abs(x[0]), reverse=True)
top_k = results[:k]

# Prepare other model stuff
# SAM-related models
predictor_amg, segmenter_amg = get_predictor_and_segmenter(model_type="vit_b", segmentation_mode="amg")
predictor_ais, segmenter_ais = get_predictor_and_segmenter(model_type=model_type, segmentation_mode="ais")

# Plot each of the top-k images as a separate horizontal triplet
os.makedirs(f"./quali_figures/{dataset_name}", exist_ok=True)
for rank, (diff, image_path, label_path, pred_path, apg_msa, cpsam_msa) in enumerate(top_k, 1):
from micro_sam.util import _to_image
image = _to_image(imageio.imread(image_path))
gt = imageio.imread(label_path)

# APG prediction is read from disk, CPSAM is recomputed for plotting
amg_pred = automatic_instance_segmentation(
input_path=image, verbose=False, ndim=2, predictor=predictor_amg, segmenter=segmenter_amg,
)
ais_pred = automatic_instance_segmentation(
input_path=image, verbose=False, ndim=2, predictor=predictor_ais, segmenter=segmenter_ais,
)
apg_pred = imageio.imread(pred_path)
cpsam_pred, _, _ = model.eval(image)

from cellSAM import cellsam_pipeline
cellsam_pred = cellsam_pipeline(image, use_wsi=False)
if cellsam_pred.ndim == 3: # NOTE: For images with no objects found, a weird segmentation is returned.
cellsam_pred = cellsam_pred[0]

# Prepare image as expected (normalize and overlay labels)
curr_image = _overlay_outline(image, gt, outline_dilation=1)

fig, axes = plt.subplots(1, 6, figsize=(15, 5), constrained_layout=True)
for ax in axes:
ax.set_axis_off()

axes[0].imshow(curr_image[:256, :256])
axes[1].imshow(amg_pred[:256, :256], cmap=get_random_colors(amg_pred))
axes[2].imshow(ais_pred[:256, :256], cmap=get_random_colors(ais_pred))
axes[3].imshow(cpsam_pred[:256, :256], cmap=get_random_colors(cpsam_pred))
axes[4].imshow(cellsam_pred[:256, :256], cmap=get_random_colors(cellsam_pred))
axes[5].imshow(apg_pred[:256, :256], cmap=get_random_colors(apg_pred))

plt.savefig(f"./quali_figures/{dataset_name}/{rank}.png", dpi=400, bbox_inches="tight")
plt.savefig(f"./quali_figures/{dataset_name}/{rank}.svg", dpi=400, bbox_inches="tight")
plt.close()


def main(args):
plot_quali(args.dataset, args.model_type)


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataset", type=str, required=True)
parser.add_argument("-m", "--model_type", type=str, default="vit_b_lm") # NOTE: For AIS
args = parser.parse_args()
main(args)
Loading