From ac88354df9653b6fea70f3b68b50ea8d8540d067 Mon Sep 17 00:00:00 2001 From: Charles Martin Date: Tue, 3 Mar 2026 11:15:55 -0800 Subject: [PATCH] Prioritize strong overfit modes in FAST_MODE and add signal checks --- ..._XGBoost_Accuracy_WithOverfitCatalog.ipynb | 111 ++++++++++++++---- 1 file changed, 88 insertions(+), 23 deletions(-) diff --git a/notebooks/XGBWW_Catalog_Random100_XGBoost_Accuracy_WithOverfitCatalog.ipynb b/notebooks/XGBWW_Catalog_Random100_XGBoost_Accuracy_WithOverfitCatalog.ipynb index 5537c53..bd2dd2d 100644 --- a/notebooks/XGBWW_Catalog_Random100_XGBoost_Accuracy_WithOverfitCatalog.ipynb +++ b/notebooks/XGBWW_Catalog_Random100_XGBoost_Accuracy_WithOverfitCatalog.ipynb @@ -16,13 +16,13 @@ "1. Load `dataset_catalog.csv` from Drive and keep classification tasks.\n", "2. Sample datasets with a fixed random seed and persist the selected list.\n", "3. Train one baseline (“good”) model per sampled dataset.\n", - "4. For each completed baseline dataset, train one model per selected overfit mode (`OVERFIT_MODES[:MAX_OVERFIT_CASES]`).\n", + "4. For each completed baseline dataset, train one model per selected overfit mode (`ACTIVE_OVERFIT_MODES`).\n", "5. Run `xgboost2ww` conversion + WeightWatcher metrics for both baseline and overfit runs.\n", "6. Save checkpoint files continuously so interrupted runs can resume.\n", "7. Write an aggregated output table containing both `case_type=\"good\"` and `case_type=\"overfit\"` rows.\n", "\n", "## Overfit behavior in this notebook\n", - "- Overfit modes are configured by `OVERFIT_MODES` and capped by `MAX_OVERFIT_CASES` (default 6).\n", + "- Overfit modes are configured by `OVERFIT_MODES` and capped by `MAX_OVERFIT_CASES` (default 2 in FAST_MODE, 6 in full mode).\n", "- Modes are applied **per completed dataset** (not globally across only a few datasets).\n", "- Default modes:\n", " 1. `deep_trees`\n", @@ -59,7 +59,7 @@ "source": [ "# XGBWW catalog-driven random-per-source XGBoost benchmark + targeted overfit cases\n", "\n", - "This notebook keeps the original catalog benchmark workflow, then adds intentionally overfit models per dataset (5–6 cases per dataset, based on `OVERFIT_MODES[:MAX_OVERFIT_CASES]`) and writes an aggregated checkpoint with both good and overfit results.\n" + "This notebook keeps the original catalog benchmark workflow, then adds intentionally overfit models per dataset (2 strongest overfit cases per dataset in FAST_MODE (or up to 6 in full mode), based on `ACTIVE_OVERFIT_MODES`) and writes an aggregated checkpoint with both good and overfit results.\n" ], "id": "7XlqyD7nRp1K" }, @@ -106,6 +106,15 @@ "RANDOM_SEED = 42\n", "RANDOM_SAMPLE_SIZE = 100\n", "TEST_SIZE = 0.20\n", + "\n", + "# Runtime tuning (set FAST_MODE=False for the full exhaustive run)\n", + "FAST_MODE = True\n", + "\n", + "# Fast-profile guards (material runtime reducers)\n", + "RANDOM_SAMPLE_SIZE = 25 if FAST_MODE else RANDOM_SAMPLE_SIZE\n", + "MAX_OVERFIT_MODELS = 20 if FAST_MODE else None\n", + "MAX_OVERFIT_PAIRS_PER_RUN = 120 if FAST_MODE else None\n", + "INCLUDE_CONTROL_CASE = False if FAST_MODE else True\n", "EXPERIMENT_ROOT = Path(\"/content/drive/MyDrive/xgbwwdata/experiment_checkpoints\")\n", "DEFAULT_EXPERIMENT_BASENAME = \"random100_xgboost_accuracy_plus_overfit\"\n", "\n", @@ -118,9 +127,18 @@ " \"tiny_trainset\",\n", " \"leakage\",\n", "]\n", - "MAX_OVERFIT_CASES = 6\n", + "MAX_OVERFIT_CASES = 2 if FAST_MODE else 6\n", "TINY_TRAIN_FRAC = 0.02 # More extreme to amplify overfitting for tiny_trainset mode\n", "\n", + "ACTIVE_OVERFIT_MODES = (\n", + " [\"deep_trees\", \"tiny_trainset\"] if FAST_MODE else OVERFIT_MODES[:MAX_OVERFIT_CASES]\n", + ")\n", + "\n", + "CHECKPOINT_SAVE_EVERY = 10 if FAST_MODE else 1\n", + "WW_T_POINTS = 64 if FAST_MODE else 160\n", + "WW_NFOLDS = 3 if FAST_MODE else 5\n", + "WW_RANDOMIZE = False if FAST_MODE else True\n", + "\n", "# Restart control\n", "RESTART_EXPERIMENT = True\n", "RETRY_FAILED_DATASETS = False # Default: do not retry failed/model_failed datasets on restart.\n", @@ -1580,6 +1598,21 @@ "\n", "\n", "def training_schedule(case_type: str, overfit_mode: str | None, control_mode: str | None = None):\n", + " if FAST_MODE:\n", + " if case_type == \"good\":\n", + " return 450, True\n", + " overfit_rounds = {\n", + " \"too_many_rounds\": 1600,\n", + " \"deep_trees\": 900,\n", + " \"no_regularization\": 1200,\n", + " \"no_subsampling\": 1100,\n", + " \"tiny_trainset\": 1200,\n", + " \"leakage\": 700,\n", + " \"random_labels\": 1200,\n", + " }\n", + " key = control_mode if case_type == \"control\" else overfit_mode\n", + " return overfit_rounds.get(key, 900), False\n", + "\n", " if case_type == \"good\":\n", " return 1200, True\n", "\n", @@ -1697,7 +1730,7 @@ " num_boost_round=num_boost_round,\n", " nfold=5,\n", " stratified=True,\n", - " early_stopping_rounds=50,\n", + " early_stopping_rounds=30 if FAST_MODE else 50,\n", " seed=local_seed,\n", " verbose_eval=False,\n", " )\n", @@ -1742,7 +1775,7 @@ " num_boost_round=num_boost_round,\n", " nfold=5,\n", " stratified=True,\n", - " early_stopping_rounds=60,\n", + " early_stopping_rounds=35 if FAST_MODE else 60,\n", " seed=local_seed,\n", " verbose_eval=False,\n", " )\n", @@ -1766,8 +1799,8 @@ " X_train,\n", " y_train,\n", " W=\"W7\",\n", - " nfolds=5,\n", - " t_points=160,\n", + " nfolds=WW_NFOLDS,\n", + " t_points=WW_T_POINTS,\n", " random_state=local_seed,\n", " train_params=params,\n", " num_boost_round=rounds,\n", @@ -1775,7 +1808,7 @@ " return_type=\"torch\",\n", " )\n", " watcher = ww.WeightWatcher(model=ww_layer)\n", - " details_df = watcher.analyze(ERG=True, randomize=True, plot=False)\n", + " details_df = watcher.analyze(ERG=True, randomize=WW_RANDOMIZE, plot=False)\n", "\n", " alpha = float(details_df[\"alpha\"].iloc[0]) if \"alpha\" in details_df else np.nan\n", " erg_gap = float(details_df[\"ERG_gap\"].iloc[0]) if \"ERG_gap\" in details_df else np.nan\n", @@ -2711,6 +2744,7 @@ " \"random_sample_size\": RANDOM_SAMPLE_SIZE,\n", " \"test_size\": TEST_SIZE,\n", " \"overfit_modes\": OVERFIT_MODES,\n", + " \"active_overfit_modes\": ACTIVE_OVERFIT_MODES,\n", " \"max_overfit_cases\": MAX_OVERFIT_CASES,\n", " \"selected_dataset_count\": int(len(df_pick)),\n", " \"successful_dataset_count\": int(len(results_df)),\n", @@ -2907,14 +2941,26 @@ " if str(r.get(\"status\", \"completed\")) == \"completed\":\n", " completed_pairs.add(key)\n", "\n", - " run_specs = [{\"case_type\": \"overfit\", \"overfit_mode\": m, \"control_mode\": \"none\"} for m in OVERFIT_MODES[:MAX_OVERFIT_CASES]]\n", - " run_specs.append({\"case_type\": \"control\", \"overfit_mode\": \"tiny_trainset\", \"control_mode\": \"random_labels\"})\n", + " run_specs = [{\"case_type\": \"overfit\", \"overfit_mode\": m, \"control_mode\": \"none\"} for m in ACTIVE_OVERFIT_MODES]\n", + " if INCLUDE_CONTROL_CASE:\n", + " run_specs.append({\"case_type\": \"control\", \"overfit_mode\": \"tiny_trainset\", \"control_mode\": \"random_labels\"})\n", "\n", - " expected_total = len(good_df) * len(run_specs)\n", - " print(f\"Ensuring overfit/control analysis for every model/spec pair: {len(good_df)} models x {len(run_specs)} specs = {expected_total} pairs\")\n", + " good_df_overfit = good_df.sort_values([\"model_id\"]).reset_index(drop=True)\n", + " if MAX_OVERFIT_MODELS is not None:\n", + " good_df_overfit = good_df_overfit.head(int(MAX_OVERFIT_MODELS)).copy()\n", "\n", - " total_good = len(good_df)\n", - " for ds_idx, (_, row) in enumerate(good_df.iterrows(), start=1):\n", + " expected_total = len(good_df_overfit) * len(run_specs)\n", + " print(f\"Ensuring overfit/control analysis for every model/spec pair: {len(good_df_overfit)} models x {len(run_specs)} specs = {expected_total} pairs\")\n", + "\n", + " if len(good_df_overfit) < len(good_df):\n", + " print(f\"FAST_MODE cap active: running overfit/control on first {len(good_df_overfit)} of {len(good_df)} completed good models\")\n", + " if MAX_OVERFIT_PAIRS_PER_RUN is not None:\n", + " print(f\"FAST_MODE cap active: max newly trained overfit/control pairs this run = {MAX_OVERFIT_PAIRS_PER_RUN}\")\n", + "\n", + " total_good = len(good_df_overfit)\n", + " newly_trained_pairs = 0\n", + " stop_early = False\n", + " for ds_idx, (_, row) in enumerate(good_df_overfit.iterrows(), start=1):\n", " row_data = row.to_dict()\n", " uid = str(row_data[\"dataset_uid\"])\n", " model_id = str(row_data.get(\"model_id\", uid))\n", @@ -2931,6 +2977,10 @@ " run_records[pair] = rec\n", " continue\n", "\n", + " if MAX_OVERFIT_PAIRS_PER_RUN is not None and newly_trained_pairs >= int(MAX_OVERFIT_PAIRS_PER_RUN):\n", + " stop_early = True\n", + " break\n", + "\n", " print(\n", " f\"[model {ds_idx}/{total_good}] model_id={model_id} | dataset={uid} \"\n", " f\"| case_type={spec['case_type']} | overfit_mode={spec['overfit_mode']} | control_mode={spec['control_mode']}\"\n", @@ -2951,6 +3001,7 @@ " run_result[\"easy_dataset_flag\"] = bool(id_to_easy.get(model_id, False))\n", " run_records[pair] = run_result\n", " completed_pairs.add(pair)\n", + " newly_trained_pairs += 1\n", " except Exception as e:\n", " run_records[pair] = {\n", " \"model_id\": model_id,\n", @@ -2974,9 +3025,16 @@ " \"num_traps\": np.nan,\n", " }\n", "\n", - " run_df = pd.DataFrame(list(run_records.values()))\n", - " combined_df = pd.concat([good_df, run_df], ignore_index=True, sort=False)\n", - " combined_df.to_csv(CHECKPOINT_AGGREGATED_CSV, index=False)\n", + " if (len(run_records) % CHECKPOINT_SAVE_EVERY) == 0:\n", + " run_df = pd.DataFrame(list(run_records.values()))\n", + " combined_df = pd.concat([good_df, run_df], ignore_index=True, sort=False)\n", + " combined_df.to_csv(CHECKPOINT_AGGREGATED_CSV, index=False)\n", + "\n", + " if stop_early:\n", + " break\n", + "\n", + " if stop_early:\n", + " print(\"Reached MAX_OVERFIT_PAIRS_PER_RUN; save checkpoint and rerun to continue.\")\n", "\n", " run_df = pd.DataFrame(list(run_records.values()))\n", "\n", @@ -3014,11 +3072,18 @@ " & ((completed_runs[\"accuracy_gap\"] >= 0.10) | (completed_runs[\"logloss_gap\"] >= 0.5))\n", " ]\n", "\n", + " overfit_total = int((completed_runs[\"case_type\"] == \"overfit\").sum())\n", + " overfit_strong_rate = (len(overfit_strong) / overfit_total) if overfit_total else 0.0\n", + "\n", " easy_count = int(good_df[\"easy_dataset_flag\"].fillna(False).sum())\n", " print(\"Validation checks:\")\n", " print(f\" - control random_labels near-chance test_accuracy count: {len(control_near_chance)}\")\n", - " print(f\" - overfit runs with large gap count: {len(overfit_strong)}\")\n", - " print(f\" - easy_dataset_flag True count: {easy_count}\")\n" + " print(f\" - overfit runs with large gap count: {len(overfit_strong)} / {overfit_total} ({overfit_strong_rate:.1%})\")\n", + " print(f\" - easy_dataset_flag True count: {easy_count}\")\n", + "\n", + " if FAST_MODE and overfit_total > 0 and overfit_strong_rate < 0.60:\n", + " print(\"WARNING: FAST_MODE produced a weak overfit signal (<60% strong-gap rows).\")\n", + " print(\" For publication-quality overfit evidence, set FAST_MODE=False and rerun.\")\n" ], "metadata": { "id": "a_SNj9ibYHgA", @@ -5042,7 +5107,7 @@ " plot_df.loc[plot_df['case_type'] != 'overfit', 'overfit_mode'] = 'good_model'\n", "\n", " overfit_modes = [\n", - " m for m in OVERFIT_MODES[:MAX_OVERFIT_CASES]\n", + " m for m in ACTIVE_OVERFIT_MODES\n", " if m in set(plot_df.loc[plot_df['case_type'] == 'overfit', 'overfit_mode'])\n", " ]\n", " modes_in_order = overfit_modes + (['good_model'] if 'good_model' in set(plot_df['overfit_mode']) else [])\n", @@ -5492,7 +5557,7 @@ "\n", " for group in split_groups:\n", " group_df = scatter_df[scatter_df['good_case_group'] == group].copy()\n", - " overfit_modes = [m for m in OVERFIT_MODES[:MAX_OVERFIT_CASES] if m in group_df['overfit_mode'].unique()]\n", + " overfit_modes = [m for m in ACTIVE_OVERFIT_MODES if m in group_df['overfit_mode'].unique()]\n", "\n", " if group_df.empty or not overfit_modes:\n", " print(f'No overfit rows found for group={group}.')\n", @@ -5578,7 +5643,7 @@ "\n", " for group in split_groups:\n", " group_df = scatter_df[scatter_df['good_case_group'] == group].copy()\n", - " overfit_modes = [m for m in OVERFIT_MODES[:MAX_OVERFIT_CASES] if m in group_df['overfit_mode'].unique()]\n", + " overfit_modes = [m for m in ACTIVE_OVERFIT_MODES if m in group_df['overfit_mode'].unique()]\n", "\n", " if group_df.empty or not overfit_modes:\n", " print(f'No overfit rows found for group={group}.')\n",