Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion micro_sam/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,9 @@ def batched_inference(
# Call get image embeddings, this will throw an error if they have not yet been computed.
predictor.get_image_embedding()
else:
input_ = image if i is None else image[i]
image_embeddings = util.precompute_image_embeddings(
predictor, image, embedding_path, verbose=verbose_embeddings, i=i,
predictor, input_, embedding_path, verbose=verbose_embeddings
)
util.set_precomputed(predictor, image_embeddings)

Expand Down
22 changes: 21 additions & 1 deletion scripts/apg_experiments/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
## Experiments for Automatic Prompt Generation (APG)

More explanation coming soon!
This folder contains evaluaton code for applying the new APG method, built on `micro-sam` to microscopy data using the `micro_sam` library. This code was used for our experiments in preparation of the [manuscript](https://openreview.net/forum?id=xFO3DFZN45).

Please note that this folder may become outdated due to changes in function signatures, etc., and often does not use the functionality that we recommend to users. We also will not actively maintain the code here. Please refer to the [example notebooks](https://github.com/computational-cell-analytics/micro-sam/tree/master/notebooks) and [example scripts](https://github.com/computational-cell-analytics/micro-sam/tree/master/examples) for well maintained and documented `micro-sam`'s APG examples.

### Evaluation Scripts:

The top-level folder contains scripts to evaluate other models with `micro-sam`, and the `plotting` subfolder contains scripts for visualizations and plots for the manuscript.

- `analyze_posthoc.py`: Experiments to visually understand and debug the APG method.
- `perform_posthoc.py`: Experiments to store results to work with `analyze_posthoc.py`.
- `prepare_baselines.py`: Experiments to run all methods presented in the manuscript with default parameters on all microscopy imaging data.
- `run_evaluation.py`: Experiments related to APG for understanding the hyperparameters using grid-search.
- `submit_evaluation.py`: Scripts for submitting jobs to slurm.
- `util,py`: Scripts containing data processing scripts and other miscellanous stuff.
- `plotting`/
- `calculate_mean.py`: Scripts to calculate mean performance metric over per modality per method.
- `plot_ablation.py`: Scripts to show results for comparing connected components- vs. boundary distance-based prompt extraction.
- `plot_average.py`: Same as `calculate_mean.py`. The scripts plot the mean values in a barplot and displays absolute rank per method over all datasets per modality.
- `plot_qualitative.py`: Scripts to display qualitative results over all datasets.
- `plot_quantitative.py`: Scripts to display quantitative results over all datasets.
- `plot_util.py`: Stores related information helpful for plotting.
75 changes: 75 additions & 0 deletions scripts/apg_experiments/plotting/calculate_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import numpy as np

from plot_util import (
msa_results,
msa_results_fluorescence,
msa_results_label_free,
msa_results_histopathology,
)


AVG_METHODS = [
"AMG (vit_b) - without grid search",
"AIS - without grid search",
"SAM3",
"CellPose3",
"CellPoseSAM",
"CellSAM",
"APG - without grid search (cc)",
]

AVG_DISPLAY_NAME_MAP = {
"AMG (vit_b) - without grid search": "SAM",
"AIS - without grid search": "AIS (µSAM)",
"SAM3": "SAM3",
"CellPose3": "CellPose 3",
"CellPoseSAM": "CellPoseSAM",
"CellSAM": "CellSAM",
"APG - without grid search (cc)": "APG (µSAM)",
}


def compute_method_means(msa_results_dict, methods_filter=None):
vals_per_method = {}

for _, entries in msa_results_dict.items():
for e in entries:
m = e["method"]
v = e["mSA"]
if methods_filter is not None and m not in methods_filter:
continue
if v is None:
continue
vals_per_method.setdefault(m, []).append(float(v))

mean_msa = {m: float(np.mean(vs)) for m, vs in vals_per_method.items()}
return mean_msa


def format_float(v):
return f"{v:.3f}"


if __name__ == "__main__":
datasets = {
"Label-Free (Cell)": msa_results_label_free,
"Fluorescence (Cell)": msa_results_fluorescence,
"Fluorescence (Nucleus)": msa_results,
"Histopathology (Nucleus)": msa_results_histopathology,
}

means_by_modality = {}
for mod_name, data in datasets.items():
means_by_modality[mod_name] = compute_method_means(data, methods_filter=AVG_METHODS)

# Print a compact validation table
col_order = list(datasets.keys())
print("Averages (mSA) per modality:")
print("Method".ljust(22), *(c[:18].rjust(18) for c in col_order))
for m in AVG_METHODS:
disp = AVG_DISPLAY_NAME_MAP.get(m, m)
row = [disp.ljust(22)]
for mod in col_order:
v = means_by_modality[mod].get(m, np.nan)
row.append((format_float(v) if np.isfinite(v) else "--").rjust(18))
print("".join(row))
69 changes: 18 additions & 51 deletions scripts/apg_experiments/plotting/plot_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

from util import (
from plot_util import (
msa_results,
msa_results_fluorescence,
msa_results_label_free,
Expand All @@ -18,10 +18,10 @@
APG_GS_CC = "APG - with grid search (cc)"

plt.rcParams.update({
"axes.titlesize": 10,
"axes.titlesize": 11,
"axes.labelsize": 9,
"xtick.labelsize": 8,
"ytick.labelsize": 8,
"xtick.labelsize": 11,
"ytick.labelsize": 10,
})


Expand Down Expand Up @@ -142,47 +142,18 @@ def max_improvement(d):
if max_val > 0:
ax.axhspan(0, max_val, color="#e0f3db", alpha=0.8, zorder=0)

bars_wo_bd = ax.bar(
x - 1.5 * width,
rel_wo_bd,
width,
color=color_wo_bd,
zorder=1,
)
bars_wo_cc = ax.bar(
x - 0.5 * width,
rel_wo_cc,
width,
color=color_wo_cc,
zorder=1,
)
bars_gs_bd = ax.bar(
x + 0.5 * width,
rel_gs_bd,
width,
color=color_gs_bd,
zorder=1,
)
bars_gs_cc = ax.bar(
x + 1.5 * width,
rel_gs_cc,
width,
color=color_gs_cc,
zorder=1,
)
bars_wo_bd = ax.bar(x - 1.5 * width, rel_wo_bd, width, color=color_wo_bd, zorder=1)
bars_wo_cc = ax.bar(x - 0.5 * width, rel_wo_cc, width, color=color_wo_cc, zorder=1)
bars_gs_bd = ax.bar(x + 0.5 * width, rel_gs_bd, width, color=color_gs_bd, zorder=1)
bars_gs_cc = ax.bar(x + 1.5 * width, rel_gs_cc, width, color=color_gs_cc, zorder=1)

best_mask_wo_bd = []
best_mask_wo_cc = []
best_mask_gs_bd = []
best_mask_gs_cc = []

for i in range(len(datasets)):
vals_i = [
rel_wo_bd[i],
rel_wo_cc[i],
rel_gs_bd[i],
rel_gs_cc[i],
]
vals_i = [rel_wo_bd[i], rel_wo_cc[i], rel_gs_bd[i], rel_gs_cc[i]]
max_i = max(vals_i)

best_mask_wo_bd.append(rel_wo_bd[i] == max_i)
Expand Down Expand Up @@ -216,7 +187,7 @@ def annotate_bars(bars, vals, best_mask):
f"{v:+.1f}%",
ha="center",
va=va,
fontsize=7,
fontsize=9,
rotation=90,
fontweight=fontweight,
)
Expand All @@ -234,31 +205,27 @@ def annotate_bars(bars, vals, best_mask):

fig.tight_layout(rect=[0.06, 0.18, 1, 0.97])
fig.text(
0.05, 0.55,
0.06, 0.575,
"Relative Mean Segmentation Accuracy (compared to AIS)",
va="center",
ha="center",
rotation="vertical",
fontsize=11,
fontsize=10,
fontweight="bold",
)

legend_patches = [
Patch(facecolor=color_wo_bd, edgecolor="black",
label="APG (Boundary Distance) - Default"),
Patch(facecolor=color_wo_cc, edgecolor="black",
label="APG (Components) - Default"),
Patch(facecolor=color_gs_bd, edgecolor="black",
label="APG (Boundary) - GS"),
Patch(facecolor=color_gs_cc, edgecolor="black",
label="APG (Components) - GS"),
Patch(facecolor=color_wo_bd, label="APG (Boundary Distance) - Default"),
Patch(facecolor=color_wo_cc, label="APG (Components) - Default"),
Patch(facecolor=color_gs_bd, label="APG (Boundary Distance) - Grid Search"),
Patch(facecolor=color_gs_cc, label="APG (Components) - Grid Search"),
]

fig.legend(
handles=legend_patches,
loc="lower center",
bbox_to_anchor=(0.5, 0.125),
ncol=4,
bbox_to_anchor=(0.5, 0.1),
ncol=2,
fontsize=8,
frameon=True,
)
Expand Down
23 changes: 12 additions & 11 deletions scripts/apg_experiments/plotting/plot_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import matplotlib.pyplot as plt

from util import (
from plot_util import (
msa_results,
msa_results_fluorescence,
msa_results_label_free,
Expand All @@ -23,23 +23,24 @@
APG_METHOD = "APG - without grid search (cc)"

AVG_DISPLAY_NAME_MAP = {
"AMG (vit_b) - without grid search": "SAM",
"AIS - without grid search": r"$\mu$SAM",
"AMG (vit_b) - without grid search": "AMG (SAM)",
"AIS - without grid search": "AIS (µSAM)",
"SAM3": "SAM3",
"CellPose3": "CellPose 3",
"CellPoseSAM": "CellPoseSAM",
"CellSAM": "CellSAM",
"APG - without grid search (cc)": "APG",
"APG - without grid search (cc)": "APG (µSAM)",
}

AVG_DISPLAY_NAME_MAP_HISTO = AVG_DISPLAY_NAME_MAP.copy()
AVG_DISPLAY_NAME_MAP_HISTO["AIS - without grid search"] = "PathoSAM"
AVG_DISPLAY_NAME_MAP_HISTO["AIS - without grid search"] = "AIS (PathoSAM)"
AVG_DISPLAY_NAME_MAP_HISTO["APG - without grid search (cc)"] = "APG \n (PathoSAM)"

plt.rcParams.update({
"axes.titlesize": 10,
"axes.labelsize": 9,
"xtick.labelsize": 8,
"ytick.labelsize": 8,
"axes.titlesize": 11,
"axes.labelsize": 10,
"xtick.labelsize": 10,
"ytick.labelsize": 10,
})


Expand Down Expand Up @@ -217,7 +218,7 @@ def filtered_methods_for_modality(mean_dict):

disp_names = [disp_map[m] for m in methods]
ax.set_xticks(x)
ax.set_xticklabels(disp_names, rotation=60, ha="right")
ax.set_xticklabels(disp_names, rotation=45, ha="right")

xticklabels = ax.get_xticklabels()
for idx_lbl, lbl in enumerate(xticklabels):
Expand All @@ -231,7 +232,7 @@ def filtered_methods_for_modality(mean_dict):
ax.set_ylim(0.0, 1.0)

fig.text(
0.06, 0.5,
0.05, 0.575,
"Mean Segmentation Accuracy (mSA)",
va="center",
ha="center",
Expand Down
Loading