Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
"1. Load `dataset_catalog.csv` from Drive and keep classification tasks.\n",
"2. Sample datasets with a fixed random seed and persist the selected list.\n",
"3. Train one baseline (“good”) model per sampled dataset.\n",
"4. For each completed baseline dataset, train one model per selected overfit mode (`OVERFIT_MODES[:MAX_OVERFIT_CASES]`).\n",
"4. For each completed baseline dataset, train one model per selected overfit mode (`ACTIVE_OVERFIT_MODES`).\n",
"5. Run `xgboost2ww` conversion + WeightWatcher metrics for both baseline and overfit runs.\n",
"6. Save checkpoint files continuously so interrupted runs can resume.\n",
"7. Write an aggregated output table containing both `case_type=\"good\"` and `case_type=\"overfit\"` rows.\n",
"\n",
"## Overfit behavior in this notebook\n",
"- Overfit modes are configured by `OVERFIT_MODES` and capped by `MAX_OVERFIT_CASES` (default 6).\n",
"- Overfit modes are configured by `OVERFIT_MODES` and capped by `MAX_OVERFIT_CASES` (default 2 in FAST_MODE, 6 in full mode).\n",
"- Modes are applied **per completed dataset** (not globally across only a few datasets).\n",
"- Default modes:\n",
" 1. `deep_trees`\n",
Expand Down Expand Up @@ -59,7 +59,7 @@
"source": [
"# XGBWW catalog-driven random-per-source XGBoost benchmark + targeted overfit cases\n",
"\n",
"This notebook keeps the original catalog benchmark workflow, then adds intentionally overfit models per dataset (5–6 cases per dataset, based on `OVERFIT_MODES[:MAX_OVERFIT_CASES]`) and writes an aggregated checkpoint with both good and overfit results.\n"
"This notebook keeps the original catalog benchmark workflow, then adds intentionally overfit models per dataset (2 strongest overfit cases per dataset in FAST_MODE (or up to 6 in full mode), based on `ACTIVE_OVERFIT_MODES`) and writes an aggregated checkpoint with both good and overfit results.\n"
],
"id": "7XlqyD7nRp1K"
},
Expand Down Expand Up @@ -106,6 +106,15 @@
"RANDOM_SEED = 42\n",
"RANDOM_SAMPLE_SIZE = 100\n",
"TEST_SIZE = 0.20\n",
"\n",
"# Runtime tuning (set FAST_MODE=False for the full exhaustive run)\n",
"FAST_MODE = True\n",
"\n",
"# Fast-profile guards (material runtime reducers)\n",
"RANDOM_SAMPLE_SIZE = 25 if FAST_MODE else RANDOM_SAMPLE_SIZE\n",
"MAX_OVERFIT_MODELS = 20 if FAST_MODE else None\n",
"MAX_OVERFIT_PAIRS_PER_RUN = 120 if FAST_MODE else None\n",
"INCLUDE_CONTROL_CASE = False if FAST_MODE else True\n",
"EXPERIMENT_ROOT = Path(\"/content/drive/MyDrive/xgbwwdata/experiment_checkpoints\")\n",
"DEFAULT_EXPERIMENT_BASENAME = \"random100_xgboost_accuracy_plus_overfit\"\n",
"\n",
Expand All @@ -118,9 +127,18 @@
" \"tiny_trainset\",\n",
" \"leakage\",\n",
"]\n",
"MAX_OVERFIT_CASES = 6\n",
"MAX_OVERFIT_CASES = 2 if FAST_MODE else 6\n",
"TINY_TRAIN_FRAC = 0.02 # More extreme to amplify overfitting for tiny_trainset mode\n",
"\n",
"ACTIVE_OVERFIT_MODES = (\n",
" [\"deep_trees\", \"tiny_trainset\"] if FAST_MODE else OVERFIT_MODES[:MAX_OVERFIT_CASES]\n",
")\n",
"\n",
"CHECKPOINT_SAVE_EVERY = 10 if FAST_MODE else 1\n",
"WW_T_POINTS = 64 if FAST_MODE else 160\n",
"WW_NFOLDS = 3 if FAST_MODE else 5\n",
"WW_RANDOMIZE = False if FAST_MODE else True\n",
"\n",
"# Restart control\n",
"RESTART_EXPERIMENT = True\n",
"RETRY_FAILED_DATASETS = False # Default: do not retry failed/model_failed datasets on restart.\n",
Expand Down Expand Up @@ -1580,6 +1598,21 @@
"\n",
"\n",
"def training_schedule(case_type: str, overfit_mode: str | None, control_mode: str | None = None):\n",
" if FAST_MODE:\n",
" if case_type == \"good\":\n",
" return 450, True\n",
" overfit_rounds = {\n",
" \"too_many_rounds\": 1600,\n",
" \"deep_trees\": 900,\n",
" \"no_regularization\": 1200,\n",
" \"no_subsampling\": 1100,\n",
" \"tiny_trainset\": 1200,\n",
" \"leakage\": 700,\n",
" \"random_labels\": 1200,\n",
" }\n",
" key = control_mode if case_type == \"control\" else overfit_mode\n",
" return overfit_rounds.get(key, 900), False\n",
"\n",
" if case_type == \"good\":\n",
" return 1200, True\n",
"\n",
Expand Down Expand Up @@ -1697,7 +1730,7 @@
" num_boost_round=num_boost_round,\n",
" nfold=5,\n",
" stratified=True,\n",
" early_stopping_rounds=50,\n",
" early_stopping_rounds=30 if FAST_MODE else 50,\n",
" seed=local_seed,\n",
" verbose_eval=False,\n",
" )\n",
Expand Down Expand Up @@ -1742,7 +1775,7 @@
" num_boost_round=num_boost_round,\n",
" nfold=5,\n",
" stratified=True,\n",
" early_stopping_rounds=60,\n",
" early_stopping_rounds=35 if FAST_MODE else 60,\n",
" seed=local_seed,\n",
" verbose_eval=False,\n",
" )\n",
Expand All @@ -1766,16 +1799,16 @@
" X_train,\n",
" y_train,\n",
" W=\"W7\",\n",
" nfolds=5,\n",
" t_points=160,\n",
" nfolds=WW_NFOLDS,\n",
" t_points=WW_T_POINTS,\n",
" random_state=local_seed,\n",
" train_params=params,\n",
" num_boost_round=rounds,\n",
" multiclass=\"avg\" if n_classes > 2 else \"error\",\n",
" return_type=\"torch\",\n",
" )\n",
" watcher = ww.WeightWatcher(model=ww_layer)\n",
" details_df = watcher.analyze(ERG=True, randomize=True, plot=False)\n",
" details_df = watcher.analyze(ERG=True, randomize=WW_RANDOMIZE, plot=False)\n",
"\n",
" alpha = float(details_df[\"alpha\"].iloc[0]) if \"alpha\" in details_df else np.nan\n",
" erg_gap = float(details_df[\"ERG_gap\"].iloc[0]) if \"ERG_gap\" in details_df else np.nan\n",
Expand Down Expand Up @@ -2711,6 +2744,7 @@
" \"random_sample_size\": RANDOM_SAMPLE_SIZE,\n",
" \"test_size\": TEST_SIZE,\n",
" \"overfit_modes\": OVERFIT_MODES,\n",
" \"active_overfit_modes\": ACTIVE_OVERFIT_MODES,\n",
" \"max_overfit_cases\": MAX_OVERFIT_CASES,\n",
" \"selected_dataset_count\": int(len(df_pick)),\n",
" \"successful_dataset_count\": int(len(results_df)),\n",
Expand Down Expand Up @@ -2907,14 +2941,26 @@
" if str(r.get(\"status\", \"completed\")) == \"completed\":\n",
" completed_pairs.add(key)\n",
"\n",
" run_specs = [{\"case_type\": \"overfit\", \"overfit_mode\": m, \"control_mode\": \"none\"} for m in OVERFIT_MODES[:MAX_OVERFIT_CASES]]\n",
" run_specs.append({\"case_type\": \"control\", \"overfit_mode\": \"tiny_trainset\", \"control_mode\": \"random_labels\"})\n",
" run_specs = [{\"case_type\": \"overfit\", \"overfit_mode\": m, \"control_mode\": \"none\"} for m in ACTIVE_OVERFIT_MODES]\n",
" if INCLUDE_CONTROL_CASE:\n",
" run_specs.append({\"case_type\": \"control\", \"overfit_mode\": \"tiny_trainset\", \"control_mode\": \"random_labels\"})\n",
"\n",
" expected_total = len(good_df) * len(run_specs)\n",
" print(f\"Ensuring overfit/control analysis for every model/spec pair: {len(good_df)} models x {len(run_specs)} specs = {expected_total} pairs\")\n",
" good_df_overfit = good_df.sort_values([\"model_id\"]).reset_index(drop=True)\n",
" if MAX_OVERFIT_MODELS is not None:\n",
" good_df_overfit = good_df_overfit.head(int(MAX_OVERFIT_MODELS)).copy()\n",
"\n",
" total_good = len(good_df)\n",
" for ds_idx, (_, row) in enumerate(good_df.iterrows(), start=1):\n",
" expected_total = len(good_df_overfit) * len(run_specs)\n",
" print(f\"Ensuring overfit/control analysis for every model/spec pair: {len(good_df_overfit)} models x {len(run_specs)} specs = {expected_total} pairs\")\n",
"\n",
" if len(good_df_overfit) < len(good_df):\n",
" print(f\"FAST_MODE cap active: running overfit/control on first {len(good_df_overfit)} of {len(good_df)} completed good models\")\n",
" if MAX_OVERFIT_PAIRS_PER_RUN is not None:\n",
" print(f\"FAST_MODE cap active: max newly trained overfit/control pairs this run = {MAX_OVERFIT_PAIRS_PER_RUN}\")\n",
"\n",
" total_good = len(good_df_overfit)\n",
" newly_trained_pairs = 0\n",
" stop_early = False\n",
" for ds_idx, (_, row) in enumerate(good_df_overfit.iterrows(), start=1):\n",
" row_data = row.to_dict()\n",
" uid = str(row_data[\"dataset_uid\"])\n",
" model_id = str(row_data.get(\"model_id\", uid))\n",
Expand All @@ -2931,6 +2977,10 @@
" run_records[pair] = rec\n",
" continue\n",
"\n",
" if MAX_OVERFIT_PAIRS_PER_RUN is not None and newly_trained_pairs >= int(MAX_OVERFIT_PAIRS_PER_RUN):\n",
" stop_early = True\n",
" break\n",
"\n",
" print(\n",
" f\"[model {ds_idx}/{total_good}] model_id={model_id} | dataset={uid} \"\n",
" f\"| case_type={spec['case_type']} | overfit_mode={spec['overfit_mode']} | control_mode={spec['control_mode']}\"\n",
Expand All @@ -2951,6 +3001,7 @@
" run_result[\"easy_dataset_flag\"] = bool(id_to_easy.get(model_id, False))\n",
" run_records[pair] = run_result\n",
" completed_pairs.add(pair)\n",
" newly_trained_pairs += 1\n",
" except Exception as e:\n",
" run_records[pair] = {\n",
" \"model_id\": model_id,\n",
Expand All @@ -2974,9 +3025,16 @@
" \"num_traps\": np.nan,\n",
" }\n",
"\n",
" run_df = pd.DataFrame(list(run_records.values()))\n",
" combined_df = pd.concat([good_df, run_df], ignore_index=True, sort=False)\n",
" combined_df.to_csv(CHECKPOINT_AGGREGATED_CSV, index=False)\n",
" if (len(run_records) % CHECKPOINT_SAVE_EVERY) == 0:\n",
" run_df = pd.DataFrame(list(run_records.values()))\n",
" combined_df = pd.concat([good_df, run_df], ignore_index=True, sort=False)\n",
" combined_df.to_csv(CHECKPOINT_AGGREGATED_CSV, index=False)\n",
"\n",
" if stop_early:\n",
" break\n",
"\n",
" if stop_early:\n",
" print(\"Reached MAX_OVERFIT_PAIRS_PER_RUN; save checkpoint and rerun to continue.\")\n",
"\n",
" run_df = pd.DataFrame(list(run_records.values()))\n",
"\n",
Expand Down Expand Up @@ -3014,11 +3072,18 @@
" & ((completed_runs[\"accuracy_gap\"] >= 0.10) | (completed_runs[\"logloss_gap\"] >= 0.5))\n",
" ]\n",
"\n",
" overfit_total = int((completed_runs[\"case_type\"] == \"overfit\").sum())\n",
" overfit_strong_rate = (len(overfit_strong) / overfit_total) if overfit_total else 0.0\n",
"\n",
" easy_count = int(good_df[\"easy_dataset_flag\"].fillna(False).sum())\n",
" print(\"Validation checks:\")\n",
" print(f\" - control random_labels near-chance test_accuracy count: {len(control_near_chance)}\")\n",
" print(f\" - overfit runs with large gap count: {len(overfit_strong)}\")\n",
" print(f\" - easy_dataset_flag True count: {easy_count}\")\n"
" print(f\" - overfit runs with large gap count: {len(overfit_strong)} / {overfit_total} ({overfit_strong_rate:.1%})\")\n",
" print(f\" - easy_dataset_flag True count: {easy_count}\")\n",
"\n",
" if FAST_MODE and overfit_total > 0 and overfit_strong_rate < 0.60:\n",
" print(\"WARNING: FAST_MODE produced a weak overfit signal (<60% strong-gap rows).\")\n",
" print(\" For publication-quality overfit evidence, set FAST_MODE=False and rerun.\")\n"
],
"metadata": {
"id": "a_SNj9ibYHgA",
Expand Down Expand Up @@ -5042,7 +5107,7 @@
" plot_df.loc[plot_df['case_type'] != 'overfit', 'overfit_mode'] = 'good_model'\n",
"\n",
" overfit_modes = [\n",
" m for m in OVERFIT_MODES[:MAX_OVERFIT_CASES]\n",
" m for m in ACTIVE_OVERFIT_MODES\n",
" if m in set(plot_df.loc[plot_df['case_type'] == 'overfit', 'overfit_mode'])\n",
" ]\n",
" modes_in_order = overfit_modes + (['good_model'] if 'good_model' in set(plot_df['overfit_mode']) else [])\n",
Expand Down Expand Up @@ -5492,7 +5557,7 @@
"\n",
" for group in split_groups:\n",
" group_df = scatter_df[scatter_df['good_case_group'] == group].copy()\n",
" overfit_modes = [m for m in OVERFIT_MODES[:MAX_OVERFIT_CASES] if m in group_df['overfit_mode'].unique()]\n",
" overfit_modes = [m for m in ACTIVE_OVERFIT_MODES if m in group_df['overfit_mode'].unique()]\n",
"\n",
" if group_df.empty or not overfit_modes:\n",
" print(f'No overfit rows found for group={group}.')\n",
Expand Down Expand Up @@ -5578,7 +5643,7 @@
"\n",
" for group in split_groups:\n",
" group_df = scatter_df[scatter_df['good_case_group'] == group].copy()\n",
" overfit_modes = [m for m in OVERFIT_MODES[:MAX_OVERFIT_CASES] if m in group_df['overfit_mode'].unique()]\n",
" overfit_modes = [m for m in ACTIVE_OVERFIT_MODES if m in group_df['overfit_mode'].unique()]\n",
"\n",
" if group_df.empty or not overfit_modes:\n",
" print(f'No overfit rows found for group={group}.')\n",
Expand Down