diff --git a/README.md b/README.md index 40ffa64..e5db941 100644 --- a/README.md +++ b/README.md @@ -95,3 +95,46 @@ agentevac-study \ ``` This runs a grid search over information noise, delay, and trust parameters and fits results against a reference metrics file. + +## Plotting Completed Runs + +Install the plotting dependency: + +```bash +pip install -e .[plot] +``` + +Generate figures for the latest run: + +```bash +python3 scripts/plot_all_run_artifacts.py +``` + +Generate figures for a specific run ID: + +```bash +python3 scripts/plot_all_run_artifacts.py --run-id 20260309_030340 +``` + +Useful individual plotting commands: + +```bash +# 2x2 dashboard for one run_metrics_*.json +python3 scripts/plot_run_metrics.py --metrics outputs/run_metrics_20260309_030340.json + +# Departures, messages, system observations, and route changes over time +python3 scripts/plot_departure_timeline.py \ + --events outputs/events_20260309_030340.jsonl \ + --replay outputs/llm_routes_20260309_030340.jsonl + +# Messaging and dialog activity +python3 scripts/plot_agent_communication.py \ + --events outputs/events_20260309_030340.jsonl \ + --dialogs outputs/llm_routes_20260309_030340.dialogs.csv + +# Compare multiple completed runs or sweep outputs +python3 scripts/plot_experiment_comparison.py \ + --results-json outputs/experiments/experiment_results.json +``` + +By default, plots are saved under `outputs/figures/` or next to the selected input file. diff --git a/pyproject.toml b/pyproject.toml index 7e7ebda..a2f80d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ dev = [ "mkdocs-material", "build", ] +plot = [ + "matplotlib>=3.8", +] [project.scripts] # Calibration / sweep tools expose a proper main() and work as CLI scripts. diff --git a/scripts/_plot_common.py b/scripts/_plot_common.py new file mode 100644 index 0000000..9f837ae --- /dev/null +++ b/scripts/_plot_common.py @@ -0,0 +1,106 @@ +"""Shared helpers for plotting completed simulation artifacts.""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from typing import Any, Iterable, List + + +def newest_file(pattern: str) -> Path: + """Return the newest file matching ``pattern``. + + Raises: + FileNotFoundError: If no matching files exist. + """ + matches = sorted(Path().glob(pattern), key=lambda p: p.stat().st_mtime, reverse=True) + if not matches: + raise FileNotFoundError(f"No files match pattern: {pattern}") + return matches[0] + + +def resolve_input(path_arg: str | None, pattern: str) -> Path: + """Resolve an explicit input path or fall back to the newest matching file.""" + if path_arg: + path = Path(path_arg) + if not path.exists(): + raise FileNotFoundError(f"Input file does not exist: {path}") + return path + return newest_file(pattern) + + +def load_json(path: Path) -> Any: + """Load a JSON document from ``path``.""" + with path.open("r", encoding="utf-8") as fh: + return json.load(fh) + + +def load_jsonl(path: Path) -> List[dict[str, Any]]: + """Load JSON Lines from ``path`` into a list of dicts.""" + rows: List[dict[str, Any]] = [] + with path.open("r", encoding="utf-8") as fh: + for line in fh: + text = line.strip() + if not text: + continue + rows.append(json.loads(text)) + return rows + + +def ensure_output_path( + input_path: Path, + output_arg: str | None, + *, + suffix: str, +) -> Path: + """Resolve output path and ensure its parent directory exists.""" + if output_arg: + out = Path(output_arg) + else: + out = input_path.with_suffix("") + out = out.with_name(f"{out.name}.{suffix}.png") + out.parent.mkdir(parents=True, exist_ok=True) + return out + + +def top_items(mapping: dict[str, float], limit: int) -> list[tuple[str, float]]: + """Return up to ``limit`` items sorted by descending value then key.""" + items = sorted(mapping.items(), key=lambda item: (-item[1], item[0])) + return items[: max(1, int(limit))] + + +def bin_counts( + times_s: Iterable[float], + *, + bin_s: float, +) -> list[tuple[float, int]]: + """Bin event times into fixed-width buckets. + + Returns: + List of ``(bin_start_s, count)`` tuples in ascending order. + """ + counts: dict[float, int] = {} + width = max(float(bin_s), 1e-9) + for t in times_s: + bucket = width * int(float(t) // width) + counts[bucket] = counts.get(bucket, 0) + 1 + return sorted(counts.items(), key=lambda item: item[0]) + + +def require_matplotlib(): + """Import matplotlib lazily with a useful error message.""" + # Constrain thread-hungry numeric backends before importing matplotlib/numpy. + os.environ.setdefault("MPLBACKEND", "Agg") + os.environ.setdefault("OMP_NUM_THREADS", "1") + os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") + os.environ.setdefault("MKL_NUM_THREADS", "1") + os.environ.setdefault("NUMEXPR_NUM_THREADS", "1") + try: + import matplotlib.pyplot as plt + except ImportError as exc: + raise SystemExit( + "matplotlib is required for plotting. Install it with " + "`pip install -e .[plot]` or `pip install matplotlib`." + ) from exc + return plt diff --git a/scripts/plot_agent_communication.py b/scripts/plot_agent_communication.py new file mode 100644 index 0000000..6474639 --- /dev/null +++ b/scripts/plot_agent_communication.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 +"""Visualize agent-to-agent messaging and LLM dialog volume for one run.""" + +from __future__ import annotations + +import argparse +import csv +from pathlib import Path +from typing import Any + +try: + from scripts._plot_common import ensure_output_path, load_jsonl, require_matplotlib, resolve_input, top_items +except ModuleNotFoundError: + from _plot_common import ensure_output_path, load_jsonl, require_matplotlib, resolve_input, top_items + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Visualize messaging and dialog activity from events JSONL and dialogs CSV." + ) + parser.add_argument( + "--events", + help="Path to an events_*.jsonl file. Defaults to the newest outputs/events_*.jsonl.", + ) + parser.add_argument( + "--dialogs", + help="Path to a *.dialogs.csv file. Defaults to the newest outputs/*.dialogs.csv.", + ) + parser.add_argument( + "--out", + help="Output PNG path. Defaults to .communication.png.", + ) + parser.add_argument( + "--show", + action="store_true", + help="Open the figure window in addition to saving the PNG.", + ) + parser.add_argument( + "--top-n", + type=int, + default=15, + help="Maximum number of bars to draw in sender/recipient charts (default: 15).", + ) + return parser.parse_args() + + +def _load_dialog_rows(path: Path) -> list[dict[str, str]]: + with path.open("r", encoding="utf-8", newline="") as fh: + return list(csv.DictReader(fh)) + + +def _draw_bar(ax, items: list[tuple[str, float]], title: str, ylabel: str, color: str) -> None: + if not items: + ax.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=11) + ax.set_title(title) + ax.set_axis_off() + return + labels = [k for k, _ in items] + values = [v for _, v in items] + ax.bar(range(len(values)), values, color=color) + ax.set_xticks(range(len(labels))) + ax.set_xticklabels(labels, rotation=60, ha="right", fontsize=8) + ax.set_title(title) + ax.set_ylabel(ylabel) + + +def _round_value(rec: dict[str, Any]) -> int | None: + for key in ("delivery_round", "deliver_round", "sent_round", "round"): + value = rec.get(key) + if value is None: + continue + try: + return int(value) + except (TypeError, ValueError): + continue + return None + + +def _plot_round_series(ax, event_rows: list[dict[str, Any]]) -> None: + series = { + "queued": {}, + "delivered": {}, + "llm": {}, + "predeparture": {}, + } + for rec in event_rows: + event = rec.get("event") + round_idx = _round_value(rec) + if round_idx is None: + continue + if event == "message_queued": + series["queued"][round_idx] = series["queued"].get(round_idx, 0) + 1 + elif event == "message_delivered": + series["delivered"][round_idx] = series["delivered"].get(round_idx, 0) + 1 + elif event == "llm_decision": + series["llm"][round_idx] = series["llm"].get(round_idx, 0) + 1 + elif event == "predeparture_llm_decision": + series["predeparture"][round_idx] = series["predeparture"].get(round_idx, 0) + 1 + + plotted = False + colors = { + "queued": "#4C78A8", + "delivered": "#54A24B", + "llm": "#F58518", + "predeparture": "#E45756", + } + for name, mapping in series.items(): + if not mapping: + continue + xs = sorted(mapping.keys()) + ys = [mapping[x] for x in xs] + ax.plot(xs, ys, marker="o", linewidth=1.8, label=name, color=colors[name]) + plotted = True + + if not plotted: + ax.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=11) + ax.set_axis_off() + return + ax.set_title("Message and Decision Volume by Round") + ax.set_xlabel("Decision Round") + ax.set_ylabel("Event Count") + ax.legend() + + +def _plot_dialog_modes(ax, dialog_rows: list[dict[str, str]]) -> None: + counts: dict[str, int] = {} + response_lengths: dict[str, list[int]] = {} + for row in dialog_rows: + mode = str(row.get("control_mode") or "unknown") + counts[mode] = counts.get(mode, 0) + 1 + response_text = row.get("response_text") or "" + response_lengths.setdefault(mode, []).append(len(response_text)) + + labels = sorted(counts.keys()) + if not labels: + ax.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=11) + ax.set_axis_off() + return + + xs = list(range(len(labels))) + count_vals = [counts[label] for label in labels] + avg_lens = [ + (sum(response_lengths[label]) / float(len(response_lengths[label]))) + if response_lengths[label] else 0.0 + for label in labels + ] + + ax.bar(xs, count_vals, color="#72B7B2", label="dialogs") + ax.set_xticks(xs) + ax.set_xticklabels(labels, rotation=20, ha="right") + ax.set_title("Dialog Volume and Avg Response Length") + ax.set_ylabel("Dialog Count") + + ax2 = ax.twinx() + ax2.plot(xs, avg_lens, color="#B279A2", marker="o", linewidth=1.8, label="avg response chars") + ax2.set_ylabel("Average Response Length (chars)") + + +def plot_agent_communication( + *, + events_path: Path, + dialogs_path: Path, + out_path: Path, + show: bool, + top_n: int, +) -> None: + plt = require_matplotlib() + event_rows = load_jsonl(events_path) + dialog_rows = _load_dialog_rows(dialogs_path) + + sender_counts: dict[str, int] = {} + recipient_counts: dict[str, int] = {} + for rec in event_rows: + event = rec.get("event") + if event == "message_queued": + sender = str(rec.get("from_id") or "unknown") + sender_counts[sender] = sender_counts.get(sender, 0) + 1 + elif event == "message_delivered": + recipient = str(rec.get("to_id") or "unknown") + recipient_counts[recipient] = recipient_counts.get(recipient, 0) + 1 + + fig, axes = plt.subplots(2, 2, figsize=(14, 10)) + fig.suptitle( + f"AgentEvac Communication Analysis\n{events_path.name} | {dialogs_path.name}", + fontsize=14, + ) + + _draw_bar( + axes[0, 0], + top_items({k: float(v) for k, v in sender_counts.items()}, top_n), + f"Top Message Senders (top {top_n})", + "Queued Messages", + "#4C78A8", + ) + _draw_bar( + axes[0, 1], + top_items({k: float(v) for k, v in recipient_counts.items()}, top_n), + f"Top Message Recipients (top {top_n})", + "Delivered Messages", + "#54A24B", + ) + _plot_round_series(axes[1, 0], event_rows) + _plot_dialog_modes(axes[1, 1], dialog_rows) + + fig.tight_layout(rect=(0, 0, 1, 0.95)) + fig.savefig(out_path, dpi=160, bbox_inches="tight") + print(f"[PLOT] events={events_path}") + print(f"[PLOT] dialogs={dialogs_path}") + print(f"[PLOT] output={out_path}") + if show: + plt.show() + plt.close(fig) + + +def main() -> None: + args = _parse_args() + events_path = resolve_input(args.events, "outputs/events_*.jsonl") + dialogs_path = resolve_input(args.dialogs, "outputs/*.dialogs.csv") + out_path = ensure_output_path(events_path, args.out, suffix="communication") + plot_agent_communication( + events_path=events_path, + dialogs_path=dialogs_path, + out_path=out_path, + show=args.show, + top_n=args.top_n, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_all_run_artifacts.py b/scripts/plot_all_run_artifacts.py new file mode 100644 index 0000000..68ad3ad --- /dev/null +++ b/scripts/plot_all_run_artifacts.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +"""Generate all standard figures for one completed AgentEvac run.""" + +from __future__ import annotations + +import argparse +import re +from pathlib import Path + +try: + from scripts._plot_common import newest_file + from scripts.plot_agent_communication import plot_agent_communication + from scripts.plot_departure_timeline import plot_timeline + from scripts.plot_experiment_comparison import load_cases, plot_experiment_comparison + from scripts.plot_run_metrics import plot_metrics_dashboard +except ModuleNotFoundError: + from _plot_common import newest_file + from plot_agent_communication import plot_agent_communication + from plot_departure_timeline import plot_timeline + from plot_experiment_comparison import load_cases, plot_experiment_comparison + from plot_run_metrics import plot_metrics_dashboard + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Generate the standard dashboard, timeline, comparison, and communication plots for one run." + ) + parser.add_argument("--run-id", help="Timestamp token such as 20260309_030340.") + parser.add_argument("--metrics", help="Explicit run_metrics JSON path.") + parser.add_argument("--events", help="Explicit events JSONL path.") + parser.add_argument("--replay", help="Explicit llm_routes JSONL path.") + parser.add_argument("--dialogs", help="Explicit dialogs CSV path.") + parser.add_argument( + "--results-json", + help="Optional experiment_results.json to also generate the multi-run comparison figure.", + ) + parser.add_argument( + "--out-dir", + help="Output directory. Defaults to outputs/figures//.", + ) + parser.add_argument("--show", action="store_true", help="Show figures interactively as they are generated.") + parser.add_argument("--top-n", type=int, default=15, help="Top-N bars for agent-level charts.") + parser.add_argument("--bin-s", type=float, default=30.0, help="Time-bin width in seconds for timeline counts.") + return parser.parse_args() + + +def _maybe_path(path_arg: str | None) -> Path | None: + if not path_arg: + return None + path = Path(path_arg) + if not path.exists(): + raise SystemExit(f"Input file does not exist: {path}") + return path + + +def _resolve_run_id(args: argparse.Namespace) -> str: + if args.run_id: + return str(args.run_id) + for path_arg in (args.events, args.metrics, args.replay, args.dialogs): + if path_arg: + match = re.search(r"(\d{8}_\d{6})", Path(path_arg).name) + if match: + return match.group(1) + newest = newest_file("outputs/events_*.jsonl") + stem = newest.stem + return stem.replace("events_", "", 1) + + +def _resolve_paths(args: argparse.Namespace, run_id: str) -> dict[str, Path | None]: + metrics = _maybe_path(args.metrics) + events = _maybe_path(args.events) + replay = _maybe_path(args.replay) + dialogs = _maybe_path(args.dialogs) + + if metrics is None: + candidate = Path(f"outputs/run_metrics_{run_id}.json") + metrics = candidate if candidate.exists() else newest_file("outputs/run_metrics_*.json") + if events is None: + candidate = Path(f"outputs/events_{run_id}.jsonl") + events = candidate if candidate.exists() else newest_file("outputs/events_*.jsonl") + if replay is None: + candidate = Path(f"outputs/llm_routes_{run_id}.jsonl") + replay = candidate if candidate.exists() else None + if dialogs is None: + candidate = Path(f"outputs/llm_routes_{run_id}.dialogs.csv") + dialogs = candidate if candidate.exists() else newest_file("outputs/*.dialogs.csv") + + return { + "metrics": metrics, + "events": events, + "replay": replay, + "dialogs": dialogs, + } + + +def main() -> None: + args = _parse_args() + run_id = _resolve_run_id(args) + paths = _resolve_paths(args, run_id) + + out_dir = Path(args.out_dir) if args.out_dir else Path("outputs/figures") / run_id + out_dir.mkdir(parents=True, exist_ok=True) + + metrics_path = paths["metrics"] + events_path = paths["events"] + replay_path = paths["replay"] + dialogs_path = paths["dialogs"] + assert metrics_path is not None + assert events_path is not None + assert dialogs_path is not None + + plot_metrics_dashboard( + metrics_path, + out_path=out_dir / "run_metrics.dashboard.png", + show=args.show, + top_n=args.top_n, + ) + plot_timeline( + events_path, + replay_path=replay_path, + out_path=out_dir / "run_timeline.png", + show=args.show, + bin_s=args.bin_s, + ) + plot_agent_communication( + events_path=events_path, + dialogs_path=dialogs_path, + out_path=out_dir / "agent_communication.png", + show=args.show, + top_n=args.top_n, + ) + comparison_source: Path | None = None + if args.results_json: + results_path = Path(args.results_json) + if not results_path.exists(): + raise SystemExit(f"Results JSON does not exist: {results_path}") + comparison_rows, comparison_source = load_cases(results_path, "outputs/run_metrics_*.json") + plot_experiment_comparison( + comparison_rows, + source_path=comparison_source, + out_path=out_dir / "experiment_comparison.png", + show=args.show, + ) + else: + metrics_matches = sorted(Path().glob("outputs/run_metrics_*.json")) + if len(metrics_matches) > 1: + comparison_rows, comparison_source = load_cases(None, "outputs/run_metrics_*.json") + plot_experiment_comparison( + comparison_rows, + source_path=comparison_source, + out_path=out_dir / "experiment_comparison.png", + show=args.show, + ) + + print(f"[PLOT] run_id={run_id}") + print(f"[PLOT] figures_dir={out_dir}") + print(f"[PLOT] metrics={metrics_path}") + print(f"[PLOT] events={events_path}") + if replay_path: + print(f"[PLOT] replay={replay_path}") + print(f"[PLOT] dialogs={dialogs_path}") + if comparison_source: + print(f"[PLOT] comparison_source={comparison_source}") + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_departure_timeline.py b/scripts/plot_departure_timeline.py new file mode 100644 index 0000000..5e8ab67 --- /dev/null +++ b/scripts/plot_departure_timeline.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python3 +"""Plot departure and communication timelines from completed simulation logs.""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +try: + from scripts._plot_common import ( + bin_counts, + ensure_output_path, + load_jsonl, + require_matplotlib, + resolve_input, + ) +except ModuleNotFoundError: + from _plot_common import ( + bin_counts, + ensure_output_path, + load_jsonl, + require_matplotlib, + resolve_input, + ) + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Visualize departures, messages, and route changes over time." + ) + parser.add_argument( + "--events", + help="Path to an events_*.jsonl file. Defaults to the newest outputs/events_*.jsonl.", + ) + parser.add_argument( + "--replay", + help="Optional llm_routes_*.jsonl replay log for route-change counts.", + ) + parser.add_argument( + "--out", + help="Output PNG path. Defaults to .timeline.png.", + ) + parser.add_argument( + "--show", + action="store_true", + help="Open the figure window in addition to saving the PNG.", + ) + parser.add_argument( + "--bin-s", + type=float, + default=30.0, + help="Time-bin width in seconds for event counts (default: 30).", + ) + return parser.parse_args() + + +def _extract_times(rows: list[dict], event_type: str) -> list[float]: + out = [] + for rec in rows: + if rec.get("event") == event_type and rec.get("time_s") is not None: + out.append(float(rec["time_s"])) + return sorted(out) + + +def _plot_cumulative(ax, times: list[float], title: str, color: str) -> None: + if not times: + ax.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=11) + ax.set_title(title) + ax.set_axis_off() + return + y = list(range(1, len(times) + 1)) + ax.step(times, y, where="post", color=color, linewidth=2) + ax.scatter(times, y, color=color, s=16) + ax.set_title(title) + ax.set_xlabel("Simulation Time (s)") + ax.set_ylabel("Cumulative Count") + + +def _plot_binned(ax, series: list[tuple[str, list[float], str]], *, bin_s: float) -> None: + plotted = False + for label, times, color in series: + binned = bin_counts(times, bin_s=bin_s) + if not binned: + continue + xs = [x for x, _ in binned] + ys = [y for _, y in binned] + ax.plot(xs, ys, marker="o", linewidth=1.8, label=label, color=color) + plotted = True + if not plotted: + ax.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=11) + ax.set_axis_off() + return + ax.set_title(f"Event Volume per {int(bin_s) if float(bin_s).is_integer() else bin_s}s Bin") + ax.set_xlabel("Simulation Time (s)") + ax.set_ylabel("Event Count") + ax.legend() + + +def plot_timeline(events_path: Path, *, replay_path: Path | None, out_path: Path, show: bool, bin_s: float) -> None: + plt = require_matplotlib() + event_rows = load_jsonl(events_path) + replay_rows = load_jsonl(replay_path) if replay_path else [] + + departure_times = _extract_times(event_rows, "departure_release") + message_times = _extract_times(event_rows, "message_delivered") + _extract_times(event_rows, "message_queued") + observation_times = _extract_times(event_rows, "system_observation_generated") + llm_times = _extract_times(event_rows, "llm_decision") + _extract_times(event_rows, "predeparture_llm_decision") + route_change_times = _extract_times(replay_rows, "route_change") + + fig, axes = plt.subplots(2, 1, figsize=(14, 9)) + fig.suptitle( + f"AgentEvac Timeline\n{events_path.name}" + (f" | replay={replay_path.name}" if replay_path else ""), + fontsize=14, + ) + + _plot_cumulative(axes[0], departure_times, "Cumulative Departures", "#E45756") + _plot_binned( + axes[1], + [ + ("Messages", sorted(message_times), "#4C78A8"), + ("System observations", sorted(observation_times), "#54A24B"), + ("LLM decisions", sorted(llm_times), "#F58518"), + ("Route changes", sorted(route_change_times), "#B279A2"), + ], + bin_s=bin_s, + ) + + fig.tight_layout(rect=(0, 0, 1, 0.95)) + fig.savefig(out_path, dpi=160, bbox_inches="tight") + print(f"[PLOT] events={events_path}") + if replay_path: + print(f"[PLOT] replay={replay_path}") + print(f"[PLOT] output={out_path}") + if show: + plt.show() + plt.close(fig) + + +def main() -> None: + args = _parse_args() + events_path = resolve_input(args.events, "outputs/events_*.jsonl") + replay_path = Path(args.replay) if args.replay else None + if replay_path and not replay_path.exists(): + raise SystemExit(f"Replay file does not exist: {replay_path}") + out_path = ensure_output_path(events_path, args.out, suffix="timeline") + plot_timeline(events_path, replay_path=replay_path, out_path=out_path, show=args.show, bin_s=args.bin_s) + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_experiment_comparison.py b/scripts/plot_experiment_comparison.py new file mode 100644 index 0000000..e63ba50 --- /dev/null +++ b/scripts/plot_experiment_comparison.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +"""Compare multiple completed runs from an experiment sweep or metrics glob.""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Any + +try: + from scripts._plot_common import ensure_output_path, load_json, require_matplotlib +except ModuleNotFoundError: + from _plot_common import ensure_output_path, load_json, require_matplotlib + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Compare multiple AgentEvac runs from experiment_results.json or a metrics glob." + ) + parser.add_argument( + "--results-json", + help="Path to experiment_results.json from agentevac.analysis.experiments.", + ) + parser.add_argument( + "--metrics-glob", + default="outputs/run_metrics_*.json", + help="Glob of metrics JSON files used if --results-json is omitted " + "(default: outputs/run_metrics_*.json).", + ) + parser.add_argument( + "--out", + help="Output PNG path. Defaults to .comparison.png or outputs/metrics_comparison.png.", + ) + parser.add_argument( + "--show", + action="store_true", + help="Open the figure window in addition to saving the PNG.", + ) + return parser.parse_args() + + +def _safe_float(value: Any, default: float = 0.0) -> float: + try: + return float(value) + except (TypeError, ValueError): + return default + + +def _metrics_row(metrics: dict[str, Any]) -> dict[str, float]: + return { + "departure_variability": _safe_float(metrics.get("departure_time_variability")), + "route_entropy": _safe_float(metrics.get("route_choice_entropy")), + "hazard_exposure": _safe_float(metrics.get("average_hazard_exposure", {}).get("global_average")), + "avg_travel_time": _safe_float(metrics.get("average_travel_time", {}).get("average")), + "arrived_agents": _safe_float(metrics.get("arrived_agents")), + "departed_agents": _safe_float(metrics.get("departed_agents")), + } + + +def load_cases(results_json: Path | None, metrics_glob: str) -> tuple[list[dict[str, Any]], Path]: + rows: list[dict[str, Any]] = [] + if results_json is not None: + payload = load_json(results_json) + if not isinstance(payload, list): + raise SystemExit(f"Expected a list in {results_json}") + for item in payload: + metrics_path = item.get("metrics_path") + if not metrics_path: + continue + path = Path(str(metrics_path)) + if not path.exists(): + continue + metrics = load_json(path) + case = item.get("case") or {} + row = { + "label": str(item.get("case_id") or path.stem), + "scenario": str(case.get("scenario", "unknown")), + "info_sigma": _safe_float(case.get("info_sigma")), + "info_delay_s": _safe_float(case.get("info_delay_s")), + "theta_trust": _safe_float(case.get("theta_trust")), + "metrics_path": str(path), + } + row.update(_metrics_row(metrics)) + rows.append(row) + return rows, results_json + + matches = sorted(Path().glob(metrics_glob)) + if not matches: + raise SystemExit(f"No metrics files match pattern: {metrics_glob}") + for path in matches: + metrics = load_json(path) + row = { + "label": path.stem, + "scenario": "unknown", + "info_sigma": 0.0, + "info_delay_s": 0.0, + "theta_trust": 0.0, + "metrics_path": str(path), + } + row.update(_metrics_row(metrics)) + rows.append(row) + return rows, matches[-1] + + +def _scatter_by_scenario(ax, rows: list[dict[str, Any]]) -> None: + scenario_colors = { + "no_notice": "#E45756", + "alert_guided": "#F58518", + "advice_guided": "#4C78A8", + "unknown": "#777777", + } + seen = set() + for row in rows: + scenario = str(row.get("scenario", "unknown")) + label = scenario if scenario not in seen else None + seen.add(scenario) + size = max(30.0, 20.0 + 20.0 * row.get("theta_trust", 0.0)) + ax.scatter( + row["hazard_exposure"], + row["avg_travel_time"], + s=size, + color=scenario_colors.get(scenario, "#777777"), + alpha=0.85, + label=label, + ) + ax.set_title("Hazard Exposure vs Travel Time") + ax.set_xlabel("Global Hazard Exposure") + ax.set_ylabel("Average Travel Time (s)") + if seen: + ax.legend() + + +def _line_vs_sigma(ax, rows: list[dict[str, Any]]) -> None: + by_scenario: dict[str, list[dict[str, Any]]] = {} + for row in rows: + by_scenario.setdefault(str(row.get("scenario", "unknown")), []).append(row) + if not by_scenario: + ax.text(0.5, 0.5, "No data", ha="center", va="center") + ax.set_axis_off() + return + for scenario, scenario_rows in sorted(by_scenario.items()): + ordered = sorted(scenario_rows, key=lambda item: item.get("info_sigma", 0.0)) + xs = [r.get("info_sigma", 0.0) for r in ordered] + ys = [r.get("route_entropy", 0.0) for r in ordered] + ax.plot(xs, ys, marker="o", linewidth=1.8, label=scenario) + ax.set_title("Route Entropy vs Info Sigma") + ax.set_xlabel("INFO_SIGMA") + ax.set_ylabel("Route Choice Entropy") + ax.legend() + + +def _bar_mean_by_scenario(ax, rows: list[dict[str, Any]], field: str, title: str, ylabel: str, color: str) -> None: + groups: dict[str, list[float]] = {} + for row in rows: + groups.setdefault(str(row.get("scenario", "unknown")), []).append(float(row.get(field, 0.0))) + labels = sorted(groups.keys()) + if not labels: + ax.text(0.5, 0.5, "No data", ha="center", va="center") + ax.set_axis_off() + return + means = [sum(groups[label]) / float(len(groups[label])) for label in labels] + ax.bar(range(len(labels)), means, color=color) + ax.set_xticks(range(len(labels))) + ax.set_xticklabels(labels, rotation=20, ha="right") + ax.set_title(title) + ax.set_ylabel(ylabel) + + +def plot_experiment_comparison(rows: list[dict[str, Any]], *, source_path: Path, out_path: Path, show: bool) -> None: + plt = require_matplotlib() + fig, axes = plt.subplots(2, 2, figsize=(14, 10)) + fig.suptitle( + f"AgentEvac Experiment Comparison\n{source_path.name} | runs={len(rows)}", + fontsize=14, + ) + + _scatter_by_scenario(axes[0, 0], rows) + _line_vs_sigma(axes[0, 1], rows) + _bar_mean_by_scenario( + axes[1, 0], + rows, + field="avg_travel_time", + title="Mean Travel Time by Scenario", + ylabel="Average Travel Time (s)", + color="#4C78A8", + ) + _bar_mean_by_scenario( + axes[1, 1], + rows, + field="hazard_exposure", + title="Mean Hazard Exposure by Scenario", + ylabel="Global Hazard Exposure", + color="#E45756", + ) + + fig.tight_layout(rect=(0, 0, 1, 0.95)) + fig.savefig(out_path, dpi=160, bbox_inches="tight") + print(f"[PLOT] source={source_path}") + print(f"[PLOT] output={out_path}") + if show: + plt.show() + plt.close(fig) + + +def main() -> None: + args = _parse_args() + results_path = Path(args.results_json) if args.results_json else None + if results_path and not results_path.exists(): + raise SystemExit(f"Results JSON does not exist: {results_path}") + rows, source_path = load_cases(results_path, args.metrics_glob) + out_path = ensure_output_path(source_path, args.out, suffix="comparison") + plot_experiment_comparison(rows, source_path=source_path, out_path=out_path, show=args.show) + + +if __name__ == "__main__": + main() diff --git a/scripts/plot_run_metrics.py b/scripts/plot_run_metrics.py new file mode 100644 index 0000000..d09638e --- /dev/null +++ b/scripts/plot_run_metrics.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +"""Plot a compact dashboard for one completed simulation metrics JSON.""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +try: + from scripts._plot_common import ensure_output_path, load_json, require_matplotlib, resolve_input, top_items +except ModuleNotFoundError: + from _plot_common import ensure_output_path, load_json, require_matplotlib, resolve_input, top_items + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Visualize one run_metrics_*.json file as a 2x2 dashboard." + ) + parser.add_argument( + "--metrics", + help="Path to a metrics JSON file. Defaults to the newest outputs/run_metrics_*.json.", + ) + parser.add_argument( + "--out", + help="Output PNG path. Defaults to .dashboard.png.", + ) + parser.add_argument( + "--show", + action="store_true", + help="Open the figure window in addition to saving the PNG.", + ) + parser.add_argument( + "--top-n", + type=int, + default=20, + help="Maximum number of per-agent bars to draw in each panel (default: 20).", + ) + return parser.parse_args() + + +def _draw_or_empty(ax, items: list[tuple[str, float]], title: str, ylabel: str, color: str, *, highest_first: bool = True): + if not items: + ax.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=11) + ax.set_title(title) + ax.set_axis_off() + return + labels = [k for k, _ in items] + values = [v for _, v in items] + if not highest_first: + labels = list(reversed(labels)) + values = list(reversed(values)) + ax.bar(range(len(values)), values, color=color) + ax.set_xticks(range(len(labels))) + ax.set_xticklabels(labels, rotation=60, ha="right", fontsize=8) + ax.set_title(title) + ax.set_ylabel(ylabel) + + +def plot_metrics_dashboard(metrics_path: Path, *, out_path: Path, show: bool, top_n: int) -> None: + plt = require_matplotlib() + metrics = load_json(metrics_path) + + kpis = { + "Departure variance": float(metrics.get("departure_time_variability", 0.0)), + "Route entropy": float(metrics.get("route_choice_entropy", 0.0)), + "Hazard exposure": float(metrics.get("average_hazard_exposure", {}).get("global_average", 0.0)), + "Avg travel time": float(metrics.get("average_travel_time", {}).get("average", 0.0)), + } + exposure = metrics.get("average_hazard_exposure", {}).get("per_agent_average", {}) or {} + travel = metrics.get("average_travel_time", {}).get("per_agent", {}) or {} + instability = metrics.get("decision_instability", {}).get("per_agent_changes", {}) or {} + + fig, axes = plt.subplots(2, 2, figsize=(14, 10)) + fig.suptitle( + f"AgentEvac Run Metrics\n{metrics_path.name} | mode={metrics.get('run_mode', 'unknown')} " + f"| departed={metrics.get('departed_agents', 0)} | arrived={metrics.get('arrived_agents', 0)}", + fontsize=14, + ) + + axes[0, 0].bar(range(len(kpis)), list(kpis.values()), color=["#4C78A8", "#F58518", "#E45756", "#54A24B"]) + axes[0, 0].set_xticks(range(len(kpis))) + axes[0, 0].set_xticklabels(list(kpis.keys()), rotation=20, ha="right") + axes[0, 0].set_title("Run KPI Summary") + axes[0, 0].set_ylabel("Value") + + _draw_or_empty( + axes[0, 1], + top_items(travel, top_n), + f"Per-Agent Travel Time (top {top_n})", + "Seconds", + "#4C78A8", + ) + _draw_or_empty( + axes[1, 0], + top_items(exposure, top_n), + f"Per-Agent Hazard Exposure (top {top_n})", + "Average Risk Score", + "#E45756", + ) + _draw_or_empty( + axes[1, 1], + top_items({k: float(v) for k, v in instability.items()}, top_n), + f"Per-Agent Decision Instability (top {top_n})", + "Choice Changes", + "#72B7B2", + ) + + fig.tight_layout(rect=(0, 0, 1, 0.95)) + fig.savefig(out_path, dpi=160, bbox_inches="tight") + print(f"[PLOT] metrics={metrics_path}") + print(f"[PLOT] output={out_path}") + if show: + plt.show() + plt.close(fig) + + +def main() -> None: + args = _parse_args() + metrics_path = resolve_input(args.metrics, "outputs/run_metrics_*.json") + out_path = ensure_output_path(metrics_path, args.out, suffix="dashboard") + plot_metrics_dashboard(metrics_path, out_path=out_path, show=args.show, top_n=args.top_n) + + +if __name__ == "__main__": + main()