Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,14 @@
"CHECKPOINT_RESULTS_CSV = CHECKPOINT_DIR / \"checkpoint_results.csv\"\n",
"CHECKPOINT_ERRORS_CSV = CHECKPOINT_DIR / \"errors.csv\"\n",
"CHECKPOINT_AGGREGATED_CSV = CHECKPOINT_DIR / \"checkpoint_results_good_plus_overfit.csv\"\n",
"OVERFIT_RESULTS_CSV = CHECKPOINT_DIR / \"overfit_results.csv\"\n",
"SELECTED_DATASETS_CSV = CHECKPOINT_DIR / \"selected_datasets.csv\"\n",
"EXPERIMENT_CONFIG_JSON = CHECKPOINT_DIR / \"experiment_config.json\"\n",
"\n",
"print(\"Catalog path:\", CATALOG_CSV)\n",
"print(\"Experiment checkpoint:\", CHECKPOINT_DIR)\n",
"print(\"Restart mode:\", RESTART_EXPERIMENT)\n",
"print(\"Overfit results CSV:\", OVERFIT_RESULTS_CSV)\n",
"\n",
"\n"
],
Expand Down Expand Up @@ -1043,9 +1045,17 @@
" combined_df.to_csv(CHECKPOINT_AGGREGATED_CSV, index=False)\n",
"\n",
" overfit_df = pd.DataFrame(list(overfit_records.values()))\n",
" for acc_col in [\"train_accuracy\", \"test_accuracy\", \"accuracy_gap\"]:\n",
" if acc_col not in overfit_df.columns:\n",
" overfit_df[acc_col] = np.nan\n",
" overfit_df[acc_col] = pd.to_numeric(overfit_df[acc_col], errors=\"coerce\")\n",
" overfit_df.to_csv(OVERFIT_RESULTS_CSV, index=False)\n",
"\n",
" combined_df = pd.concat([good_df, overfit_df], ignore_index=True, sort=False)\n",
" combined_df.to_csv(CHECKPOINT_AGGREGATED_CSV, index=False)\n",
"\n",
" print(\"Saved overfit-only checkpoint:\")\n",
" print(\" -\", OVERFIT_RESULTS_CSV)\n",
" print(\"Saved aggregated checkpoint:\")\n",
" print(\" -\", CHECKPOINT_AGGREGATED_CSV)\n",
" print(\n",
Expand Down Expand Up @@ -1117,57 +1127,85 @@
{
"cell_type": "code",
"source": [
"# 9) Histograms of WW metrics by overfit case (mode)\n",
"# 9) Per-overfit-mode accuracy visualizations\n",
"\n",
"import math\n",
"import matplotlib.pyplot as plt\n",
"\n",
"if 'combined_df' not in globals() or combined_df.empty:\n",
" print('No combined results available. Run the training/aggregation cells first.')\n",
"else:\n",
" hist_df = combined_df[combined_df['case_type'] == 'overfit'].copy()\n",
" if hist_df.empty:\n",
" print('No overfit rows available for histogram plots.')\n",
" overfit_plot_df = combined_df[combined_df['case_type'] == 'overfit'].copy()\n",
" if overfit_plot_df.empty:\n",
" print('No overfit rows available for accuracy plots.')\n",
" else:\n",
" metrics = ['alpha', 'ERG_gap', 'num_traps']\n",
" modes = [m for m in OVERFIT_MODES[:MAX_OVERFIT_CASES] if m in set(hist_df['overfit_mode'].astype(str))]\n",
"\n",
" print(f'Building WW histograms for {len(modes)} overfit modes across {hist_df[\"dataset_uid\"].nunique()} datasets.')\n",
" overfit_plot_df = overfit_plot_df[overfit_plot_df.get('status', 'completed').astype(str) == 'completed'].copy()\n",
" for acc_col in ['train_accuracy', 'test_accuracy']:\n",
" overfit_plot_df[acc_col] = pd.to_numeric(overfit_plot_df[acc_col], errors='coerce')\n",
" overfit_plot_df = overfit_plot_df.dropna(subset=['train_accuracy', 'test_accuracy'])\n",
" overfit_plot_df['generalization_gap'] = overfit_plot_df['train_accuracy'] - overfit_plot_df['test_accuracy']\n",
"\n",
" if overfit_plot_df.empty:\n",
" print('No completed overfit rows with train/test accuracies available.')\n",
" else:\n",
" modes = [m for m in OVERFIT_MODES[:MAX_OVERFIT_CASES] if m in set(overfit_plot_df['overfit_mode'].astype(str))]\n",
" print(\n",
" f'Building overfit-mode accuracy plots for {len(modes)} modes '\n",
" f'with threshold RANDOM_SAMPLE_SIZE={RANDOM_SAMPLE_SIZE}.'\n",
" )\n",
"\n",
" for metric in metrics:\n",
" mode_frames = []\n",
" for mode in modes:\n",
" sub = hist_df.loc[hist_df['overfit_mode'] == mode, [metric]].copy()\n",
" sub[metric] = pd.to_numeric(sub[metric], errors='coerce')\n",
" sub = sub.dropna(subset=[metric])\n",
" mode_frames.append((mode, sub))\n",
"\n",
" valid_frames = [(m, d) for m, d in mode_frames if not d.empty]\n",
" if not valid_frames:\n",
" print(f'Skipping {metric}: no numeric values in overfit rows.')\n",
" continue\n",
"\n",
" n_modes = len(valid_frames)\n",
" n_cols = min(3, n_modes)\n",
" n_rows = math.ceil(n_modes / n_cols)\n",
"\n",
" fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 3.5 * n_rows), squeeze=False)\n",
" axes_flat = axes.flatten()\n",
"\n",
" for ax, (mode, sub) in zip(axes_flat, valid_frames):\n",
" ax.hist(sub[metric].values, bins=20, alpha=0.85, edgecolor='black')\n",
" ax.set_title(f'{metric} | overfit_mode={mode} | n={len(sub)}')\n",
" ax.set_xlabel(metric)\n",
" ax.set_ylabel('Count')\n",
" ax.grid(alpha=0.2)\n",
"\n",
" for ax in axes_flat[len(valid_frames):]:\n",
" ax.axis('off')\n",
"\n",
" fig.suptitle(f'Overfit-case histograms by mode: {metric}', y=1.02)\n",
" fig.tight_layout()\n",
" plt.show()\n",
"\n"
" mode_df = overfit_plot_df.loc[overfit_plot_df['overfit_mode'] == mode].copy()\n",
" n_models = len(mode_df)\n",
" print(f'overfit_mode={mode} | completed models={n_models}')\n",
"\n",
" if n_models <= RANDOM_SAMPLE_SIZE:\n",
" print(\n",
" f'Generating per-model train/test bar charts for mode={mode} '\n",
" f'(n={n_models} <= {RANDOM_SAMPLE_SIZE}).'\n",
" )\n",
" mode_df = mode_df.sort_values(['source', 'dataset_uid']).reset_index(drop=True)\n",
" for idx, row in mode_df.iterrows():\n",
" fig, ax = plt.subplots(figsize=(6, 4))\n",
" bars = ax.bar(\n",
" ['Train accuracy', 'Test accuracy'],\n",
" [row['train_accuracy'], row['test_accuracy']],\n",
" color=['tab:blue', 'tab:orange'],\n",
" )\n",
" ax.set_ylim(0, 1)\n",
" ax.set_ylabel('Accuracy')\n",
" ax.set_title(\n",
" f\"{mode} | dataset={row['dataset_uid']} | \"\n",
" f\"model {idx + 1}/{n_models}\"\n",
" )\n",
" ax.grid(axis='y', alpha=0.2)\n",
" for bar in bars:\n",
" height = bar.get_height()\n",
" ax.text(\n",
" bar.get_x() + bar.get_width() / 2,\n",
" min(height + 0.02, 1.0),\n",
" f'{height:.3f}',\n",
" ha='center',\n",
" va='bottom',\n",
" fontsize=9,\n",
" )\n",
" fig.tight_layout()\n",
" plt.show()\n",
" else:\n",
" print(\n",
" f'Generating generalization-gap histogram for mode={mode} '\n",
" f'(n={n_models} > {RANDOM_SAMPLE_SIZE}).'\n",
" )\n",
" fig, ax = plt.subplots(figsize=(7, 4))\n",
" ax.hist(mode_df['generalization_gap'].values, bins=25, alpha=0.85, edgecolor='black')\n",
" ax.set_title(\n",
" f'Generalization gap histogram | mode={mode} | '\n",
" f'n={n_models}'\n",
" )\n",
" ax.set_xlabel('Generalization gap (train_accuracy - test_accuracy)')\n",
" ax.set_ylabel('Count')\n",
" ax.grid(alpha=0.2)\n",
" fig.tight_layout()\n",
" plt.show()\n"
],
"metadata": {
"id": "_9cTHchdNwhS"
Expand Down