diff --git a/notebooks/XGBWW_Random10_LongRun_Alpha_Tracking.ipynb b/notebooks/XGBWW_Random10_LongRun_Alpha_Tracking.ipynb
index caef712..e91c4d4 100644
--- a/notebooks/XGBWW_Random10_LongRun_Alpha_Tracking.ipynb
+++ b/notebooks/XGBWW_Random10_LongRun_Alpha_Tracking.ipynb
@@ -1,1356 +1,1408 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "0wW66rXME4dH"
- },
- "source": [
- "[](https://colab.research.google.com/github/CalculatedContent/xgboost2ww/blob/main/notebooks/XGBWW_Random10_LongRun_Alpha_Tracking.ipynb)"
- ],
- "id": "0wW66rXME4dH"
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "7bRnlL4uE4dI"
- },
- "source": [
- "# Random-10 Dataset Long-Run Alpha Tracking\n",
- "\n",
- "**Designed for Google Colab.** This notebook is Colab-first: it mounts Google Drive for persistent checkpointing, installs `xgboost2ww` and `xgbwwdata` from source, resumes long runs from Drive checkpoints, and prefers Colab GPU when available (with CPU `hist` fallback)."
- ],
- "id": "7bRnlL4uE4dI"
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "crO-HlpcE4dJ"
- },
- "source": [
- "## 1. Colab / Drive setup"
- ],
- "id": "crO-HlpcE4dJ"
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "w6n5vumaE4dJ",
- "outputId": "58d551c8-4eda-4a5d-8264-e1eaedbb4a8a",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "execution_count": 50,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "IN_COLAB=True\n",
- "Python=3.12.12\n",
- "Platform=Linux-6.6.113+-x86_64-with-glibc2.35\n",
- "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n",
- "Drive mounted\n",
- "Google Drive checkpointing enabled\n",
- "project_root: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking\n",
- "registry_dir: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/registry\n",
- "per_dataset_dir: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/per_dataset\n",
- "aggregate_dir: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/aggregate\n",
- "logs_dir: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/logs\n"
- ]
- }
- ],
- "source": [
- "# Colab-first runtime + Drive setup\n",
- "import os\n",
- "import sys\n",
- "import platform\n",
- "from pathlib import Path\n",
- "\n",
- "try:\n",
- " import google.colab # noqa: F401\n",
- " IN_COLAB = True\n",
- "except Exception:\n",
- " IN_COLAB = False\n",
- "\n",
- "USE_GOOGLE_DRIVE = True if IN_COLAB else False\n",
- "FORCE_REMOUNT_DRIVE = False\n",
- "\n",
- "print(f\"IN_COLAB={IN_COLAB}\")\n",
- "print(f\"Python={sys.version.split()[0]}\")\n",
- "print(f\"Platform={platform.platform()}\")\n",
- "\n",
- "if IN_COLAB:\n",
- " from google.colab import drive\n",
- " drive.mount('/content/drive', force_remount=FORCE_REMOUNT_DRIVE)\n",
- " project_root = Path('/content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking')\n",
- "else:\n",
- " project_root = Path('./random10_longrun_alpha_tracking')\n",
- "\n",
- "registry_dir = project_root / 'registry'\n",
- "per_dataset_dir = project_root / 'per_dataset'\n",
- "aggregate_dir = project_root / 'aggregate'\n",
- "logs_dir = project_root / 'logs'\n",
- "for p in [project_root, registry_dir, per_dataset_dir, aggregate_dir, logs_dir]:\n",
- " p.mkdir(parents=True, exist_ok=True)\n",
- "\n",
- "errors_path = logs_dir / 'errors.csv'\n",
- "skipped_path = aggregate_dir / 'skipped_datasets.csv'\n",
- "\n",
- "if IN_COLAB:\n",
- " print('Drive mounted')\n",
- " print('Google Drive checkpointing enabled')\n",
- "\n",
- "print('project_root:', project_root)\n",
- "print('registry_dir:', registry_dir)\n",
- "print('per_dataset_dir:', per_dataset_dir)\n",
- "print('aggregate_dir:', aggregate_dir)\n",
- "print('logs_dir:', logs_dir)\n"
- ],
- "id": "w6n5vumaE4dJ"
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Vkmos-hTE4dK"
- },
- "source": [
- "## 2. Colab/bootstrap installs"
- ],
- "id": "Vkmos-hTE4dK"
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "mOauVnRtE4dK",
- "outputId": "a6147224-6754-4f31-a955-df695a3bf993",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "execution_count": 51,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "[INFO] INCLUDE_KEEL_DATASETS=False -> skipping keel-ds install.\n",
- " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- " Building wheel for xgboost2ww (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- " Building wheel for xgbwwdata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
- "[INFO] INCLUDE_KEEL_DATASETS=False -> skipping KEEL source install checks.\n",
- "xgboost: 3.2.0\n",
- "weightwatcher: 0.7.7\n",
- "xgboost2ww location: /usr/local/lib/python3.12/dist-packages/xgboost2ww/__init__.py\n",
- "xgbwwdata location: /usr/local/lib/python3.12/dist-packages/xgbwwdata/__init__.py\n",
- "xgboost2ww installed from source\n",
- "xgbwwdata installed from source\n"
- ]
- }
- ],
- "source": [
- "# Colab/bootstrap installs\n",
- "INCLUDE_KEEL_DATASETS = False # Optional dataset source; default off for smoother Colab installs\n",
- "\n",
- "%pip install -q xgboost weightwatcher scikit-learn pandas matplotlib seaborn scipy feather-format pyarrow\n",
- "if INCLUDE_KEEL_DATASETS:\n",
- " %pip install -q openml pmlb keel-ds\n",
- "else:\n",
- " %pip install -q openml pmlb\n",
- " print('[INFO] INCLUDE_KEEL_DATASETS=False -> skipping keel-ds install.')\n",
- "\n",
- "import pathlib\n",
- "import shutil\n",
- "import subprocess\n",
- "import sys\n",
- "\n",
- "def clone_or_update(repo_url: str, target_dir: str, branch: str = 'main') -> None:\n",
- " target = pathlib.Path(target_dir)\n",
- " if (target / '.git').exists():\n",
- " subprocess.run(['git', '-C', str(target), 'fetch', '--depth', '1', 'origin', branch], check=True)\n",
- " subprocess.run(['git', '-C', str(target), 'reset', '--hard', 'FETCH_HEAD'], check=True)\n",
- " else:\n",
- " if target.exists():\n",
- " shutil.rmtree(target)\n",
- " subprocess.run(['git', 'clone', '--depth', '1', '--branch', branch, repo_url, str(target)], check=True)\n",
- "\n",
- "def ensure_keel_ds() -> bool:\n",
- " try:\n",
- " __import__('keel_ds')\n",
- " print('keel_ds already importable')\n",
- " return True\n",
- " except Exception:\n",
- " pass\n",
- "\n",
- " repo_candidates = [\n",
- " 'https://github.com/CalculatedContent/keel_ds.git',\n",
- " 'https://github.com/CalculatedContent/keel-ds.git',\n",
- " ]\n",
- " for repo in repo_candidates:\n",
- " try:\n",
- " print(f'[INFO] Trying keel_ds source install from: {repo}')\n",
- " clone_or_update(repo, '/tmp/keel-ds-src')\n",
- " subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', '--no-build-isolation', '/tmp/keel-ds-src'], check=True)\n",
- " __import__('keel_ds')\n",
- " print(f'[INFO] keel_ds installed from source: {repo}')\n",
- " return True\n",
- " except Exception:\n",
- " continue\n",
- "\n",
- " print('[WARN] Unable to install keel_ds from source repo candidates; continuing without KEEL datasets.')\n",
- " return False\n",
- "\n",
- "clone_or_update('https://github.com/CalculatedContent/xgboost2ww.git', '/tmp/xgboost2ww-src')\n",
- "clone_or_update('https://github.com/CalculatedContent/xgbwwdata.git', '/tmp/xgbwwdata-src')\n",
- "\n",
- "%pip install -q --no-build-isolation /tmp/xgboost2ww-src\n",
- "%pip install -q --no-build-isolation /tmp/xgbwwdata-src\n",
- "\n",
- "if INCLUDE_KEEL_DATASETS:\n",
- " ensure_keel_ds()\n",
- "else:\n",
- " print('[INFO] INCLUDE_KEEL_DATASETS=False -> skipping KEEL source install checks.')\n",
- "\n",
- "import importlib\n",
- "\n",
- "xgboost = importlib.import_module('xgboost')\n",
- "weightwatcher = importlib.import_module('weightwatcher')\n",
- "xgboost2ww = importlib.import_module('xgboost2ww')\n",
- "xgbwwdata = importlib.import_module('xgbwwdata')\n",
- "\n",
- "print('xgboost:', getattr(xgboost, '__version__', 'unknown'))\n",
- "print('weightwatcher:', getattr(weightwatcher, '__version__', 'unknown'))\n",
- "print('xgboost2ww location:', pathlib.Path(xgboost2ww.__file__).resolve())\n",
- "print('xgbwwdata location:', pathlib.Path(xgbwwdata.__file__).resolve())\n",
- "print('xgboost2ww installed from source')\n",
- "print('xgbwwdata installed from source')\n",
- "\n"
- ],
- "id": "mOauVnRtE4dK"
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "eK1Kae5yE4dK"
- },
- "source": [
- "If installs are rerun many times and import state becomes inconsistent, restart runtime and rerun from the top. In normal usage, this notebook is structured to avoid requiring a manual restart."
- ],
- "id": "eK1Kae5yE4dK"
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Of0eSX0kE4dL"
- },
- "source": [
- "## 3. Imports"
- ],
- "id": "Of0eSX0kE4dL"
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "mEqMoblyE4dL"
- },
- "execution_count": 52,
- "outputs": [],
- "source": [
- "# Imports\n",
- "import json\n",
- "import time\n",
- "import random\n",
- "import traceback\n",
- "import subprocess\n",
- "\n",
- "import numpy as np\n",
- "import pandas as pd\n",
- "import matplotlib.pyplot as plt\n",
- "import seaborn as sns\n",
- "import xgboost as xgb\n",
- "import weightwatcher as ww\n",
- "\n",
- "from scipy import sparse\n",
- "from sklearn.model_selection import train_test_split\n",
- "from sklearn.preprocessing import StandardScaler\n",
- "from sklearn.metrics import accuracy_score\n",
- "\n",
- "from xgboost2ww import convert\n",
- "from xgbwwdata import Filters, load_dataset\n",
- "\n",
- "try:\n",
- " from xgbwwdata import scan_datasets as xgbww_scan_datasets\n",
- "except Exception:\n",
- " xgbww_scan_datasets = None\n",
- "\n",
- "try:\n",
- " from xgbwwdata import list_datasets as xgbww_list_datasets\n",
- "except Exception:\n",
- " xgbww_list_datasets = None\n"
- ],
- "id": "mEqMoblyE4dL"
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "XnJvev0xE4dL"
- },
- "source": [
- "## 4. Runtime configuration\n",
- "\n",
- "Long Colab runs may disconnect. This notebook is checkpointed to Drive so you can reconnect and rerun from the top to resume. Use `FORCE_RESTART_ALL=True` to start over, or set `FORCE_RESTART_DATASETS` to selected dataset slugs/uids."
- ],
- "id": "XnJvev0xE4dL"
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "ieH6OSQjE4dL",
- "outputId": "a471fbd2-15f5-4e8e-a7cc-a78de33c4a9d",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "execution_count": 53,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "N_STEPS=48, sample=10, dry_run=False\n",
- "USE_GOOGLE_DRIVE=True, RESTART=True, RESUME_FROM_CHECKPOINT=True, REUSE_SAMPLED_REGISTRY=True\n",
- "INCLUDE_KEEL_DATASETS=False (default=False)\n"
- ]
- }
- ],
- "source": [
- "# Runtime + global config\n",
- "RANDOM_STATE = 42\n",
- "DATASET_SAMPLE_SIZE = 10\n",
- "TOTAL_ROUNDS = 1200\n",
- "CHUNK_SIZE = 25\n",
- "N_STEPS = TOTAL_ROUNDS // CHUNK_SIZE\n",
- "CHECKPOINT_EVERY_STEPS = 1\n",
- "\n",
- "INCLUDE_KEEL_DATASETS = globals().get('INCLUDE_KEEL_DATASETS', False)\n",
- "RESTART = True # Optional: set to False to ignore existing checkpoints and overwrite them\n",
- "RESUME_FROM_CHECKPOINT = RESTART\n",
- "FORCE_RESTART_ALL = True\n",
- "FORCE_RESTART_DATASETS = []\n",
- "REUSE_SAMPLED_REGISTRY = True\n",
- "\n",
- "SELECTED_DATASET_UIDS = []\n",
- "DRY_RUN = False\n",
- "MAX_DENSE_ELEMENTS = int(2e8)\n",
- "\n",
- "if DRY_RUN:\n",
- " DATASET_SAMPLE_SIZE = 2\n",
- " TOTAL_ROUNDS = 50\n",
- " CHUNK_SIZE = 25\n",
- " N_STEPS = TOTAL_ROUNDS // CHUNK_SIZE\n",
- "\n",
- "random.seed(RANDOM_STATE)\n",
- "np.random.seed(RANDOM_STATE)\n",
- "sns.set_theme(style='whitegrid')\n",
- "print(f'N_STEPS={N_STEPS}, sample={DATASET_SAMPLE_SIZE}, dry_run={DRY_RUN}')\n",
- "print(f'USE_GOOGLE_DRIVE={USE_GOOGLE_DRIVE}, RESTART={RESTART}, RESUME_FROM_CHECKPOINT={RESUME_FROM_CHECKPOINT}, REUSE_SAMPLED_REGISTRY={REUSE_SAMPLED_REGISTRY}')\n",
- "print(f'INCLUDE_KEEL_DATASETS={INCLUDE_KEEL_DATASETS} (default=False)')\n"
- ],
- "id": "ieH6OSQjE4dL"
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "owCeTeVWE4dM"
- },
- "source": [
- "## 5. Runtime diagnostics"
- ],
- "id": "owCeTeVWE4dM"
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "UIteF5tmE4dM",
- "outputId": "2d2b7a77-5230-4491-f36e-242d51a6d918",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "execution_count": 62,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "cpu_count: 2\n",
- "platform: Linux-6.6.113+-x86_64-with-glibc2.35\n",
- "[WARN] nvidia-smi unavailable: [Errno 2] No such file or directory: 'nvidia-smi'\n",
- "colab_gpu_runtime=True\n",
- "CUDA visible to XGBoost: no (modern config failed)\n",
- "[fallback] modern GPU failed -> trying gpu_hist: [05:09:52] WARNING: /__w/xgboost/xgboost/src/context.cc:53: No visible GPU is found, setting device to CPU. ; [05:09:52] WARNING: /__w/xgboost/xgboost/src/context.cc:207: Device is changed from GPU to CPU as we couldn't find any available GPU on the system.\n",
- "[fallback] gpu_hist failed -> using CPU hist: Invalid Input: 'gpu_hist', valid values are: {'approx', 'auto', 'exact', 'hist'}\n",
- "Backend fallback path: CPU hist\n",
- "Chosen backend: cpu_hist\n",
- "Chosen XGBoost params patch: {'tree_method': 'hist', 'nthread': 1}\n"
- ]
- }
- ],
- "source": [
- "# Runtime diagnostics + Colab-aware backend selection\n",
- "import warnings\n",
- "\n",
- "print('cpu_count:', os.cpu_count())\n",
- "print('platform:', platform.platform())\n",
- "\n",
- "if IN_COLAB:\n",
- " try:\n",
- " smi = subprocess.run(['nvidia-smi'], capture_output=True, text=True, check=False)\n",
- " print('nvidia-smi returncode:', smi.returncode)\n",
- " print(smi.stdout[:1000] if smi.stdout else smi.stderr[:500])\n",
- " except Exception as e:\n",
- " print('[WARN] nvidia-smi unavailable:', e)\n",
- "\n",
- "def _probe_xgb(params_patch):\n",
- " X_probe = np.array([[0.0], [1.0], [2.0], [3.0]], dtype=np.float32)\n",
- " y_probe = np.array([0, 0, 1, 1], dtype=np.float32)\n",
- " dprobe = xgb.DMatrix(X_probe, label=y_probe)\n",
- " with warnings.catch_warnings(record=True) as caught:\n",
- " warnings.simplefilter('always')\n",
- " xgb.train(params={'objective': 'binary:logistic', 'eval_metric': 'logloss', **params_patch}, dtrain=dprobe, num_boost_round=1, verbose_eval=False)\n",
- " gpu_fallback_msgs = [\n",
- " str(w.message) for w in caught\n",
- " if ('No visible GPU is found' in str(w.message)) or ('Device is changed from GPU to CPU' in str(w.message))\n",
- " ]\n",
- " if gpu_fallback_msgs:\n",
- " raise RuntimeError(' ; '.join(gpu_fallback_msgs))\n",
- "\n",
- "def choose_xgb_runtime_colab_aware():\n",
- " cpu_threads = max(1, (os.cpu_count() or 1) - 1)\n",
- " print(f'colab_gpu_runtime={IN_COLAB}')\n",
- "\n",
- " if IN_COLAB:\n",
- " try:\n",
- " modern_gpu = {'tree_method': 'hist', 'device': 'cuda'}\n",
- " _probe_xgb(modern_gpu)\n",
- " print('CUDA visible to XGBoost: yes')\n",
- " print('Backend fallback path: modern gpu config succeeded')\n",
- " return {'backend_name': 'colab_cuda_hist', 'params_patch': modern_gpu, 'is_gpu': True}\n",
- " except Exception as e1:\n",
- " print('CUDA visible to XGBoost: no (modern config failed)')\n",
- " print('[fallback] modern GPU failed -> trying gpu_hist:', e1)\n",
- " try:\n",
- " legacy_gpu = {'tree_method': 'gpu_hist'}\n",
- " _probe_xgb(legacy_gpu)\n",
- " print('Backend fallback path: legacy gpu_hist succeeded')\n",
- " return {'backend_name': 'colab_gpu_hist_legacy', 'params_patch': legacy_gpu, 'is_gpu': True}\n",
- " except Exception as e2:\n",
- " print('[fallback] gpu_hist failed -> using CPU hist:', e2)\n",
- "\n",
- " cpu_patch = {'tree_method': 'hist', 'nthread': cpu_threads}\n",
- " print('Backend fallback path: CPU hist')\n",
- " return {'backend_name': 'cpu_hist', 'params_patch': cpu_patch, 'is_gpu': False}\n",
- "\n",
- "BACKEND_INFO = choose_xgb_runtime_colab_aware()\n",
- "print('Chosen backend:', BACKEND_INFO['backend_name'])\n",
- "print('Chosen XGBoost params patch:', BACKEND_INFO['params_patch'])\n",
- "\n"
- ],
- "id": "UIteF5tmE4dM"
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "fD9dmsP5E4dM"
- },
- "source": [
- "## 6. Dataset catalog scan / cached registry"
- ],
- "id": "fD9dmsP5E4dM"
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "cmP4hat2E4dM",
- "outputId": "4e3fb4f4-7302-4d24-b849-8a1491314638",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "execution_count": 63,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "[INFO] INCLUDE_KEEL_DATASETS=False -> not attempting keel_ds installation.\n",
- "[WARN] Proceeding without KEEL datasets; registry scan returned empty due to missing keel_ds.\n",
- "Scanned full registry: 0 rows\n",
- "Filtered registry rows: 0\n"
- ]
- }
- ],
- "source": [
- "# xgbwwdata scan / registry build\n",
- "\n",
- "import inspect\n",
- "import importlib\n",
- "import sys\n",
- "from dataclasses import asdict, is_dataclass\n",
- "\n",
- "\n",
- "def _install_missing_dataset_dependency(module_name: str) -> bool:\n",
- " if module_name != 'keel_ds':\n",
- " return False\n",
- " if not INCLUDE_KEEL_DATASETS:\n",
- " print('[INFO] INCLUDE_KEEL_DATASETS=False -> not attempting keel_ds installation.')\n",
- " return False\n",
- "\n",
- " print('[INFO] Missing optional dataset dependency: keel_ds. Installing now...')\n",
- " candidates = [\n",
- " 'keel-ds',\n",
- " 'keel_ds',\n",
- " 'git+https://github.com/CalculatedContent/keel_ds.git',\n",
- " 'git+https://github.com/CalculatedContent/keel-ds.git',\n",
- " ]\n",
- " install_errors = []\n",
- " for pkg in candidates:\n",
- " try:\n",
- " subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pkg])\n",
- " importlib.import_module('keel_ds')\n",
- " print(f'[INFO] Installed KEEL dependency using installer target: {pkg}')\n",
- " return True\n",
- " except Exception as ex:\n",
- " install_errors.append(f'{pkg}: {type(ex).__name__}: {ex}')\n",
- " continue\n",
- "\n",
- " print('[WARN] Unable to auto-install keel_ds. KEEL datasets will be skipped for this run.')\n",
- " if install_errors:\n",
- " print('[WARN] KEEL install attempts failed with:')\n",
- " for err in install_errors:\n",
- " print(' -', err)\n",
- " return False\n",
- "\n",
- "\n",
- "def _copy_filters_with_overrides(filters: Filters, **overrides):\n",
- " if is_dataclass(filters):\n",
- " data = asdict(filters)\n",
- " elif hasattr(filters, 'model_dump'):\n",
- " data = filters.model_dump()\n",
- " elif hasattr(filters, '__dict__'):\n",
- " data = dict(filters.__dict__)\n",
- " else:\n",
- " data = {}\n",
- " data.update(overrides)\n",
- " return type(filters)(**data)\n",
- "\n",
- "\n",
- "def _filters_without_keel(filters: Filters):\n",
- " source_keys = ('sources', 'include_sources', 'source_allowlist')\n",
- " for key in source_keys:\n",
- " current = getattr(filters, key, None)\n",
- " if not current:\n",
- " continue\n",
- " try:\n",
- " cleaned = [s for s in current if 'keel' not in str(s).lower()]\n",
- " if cleaned and len(cleaned) != len(current):\n",
- " print(f'[INFO] Retrying scan without KEEL via filters.{key}.')\n",
- " return _copy_filters_with_overrides(filters, **{key: cleaned})\n",
- " except Exception:\n",
- " continue\n",
- "\n",
- " # If the current filters do not explicitly set source lists, force a KEEL-free allowlist.\n",
- " non_keel_sources = ['openml', 'pmlb', 'uci', 'sklearn']\n",
- " for key in source_keys:\n",
- " if hasattr(filters, key):\n",
- " try:\n",
- " print(f'[INFO] Retrying scan with explicit non-KEEL sources via filters.{key}.')\n",
- " return _copy_filters_with_overrides(filters, **{key: non_keel_sources})\n",
- " except Exception:\n",
- " continue\n",
- " return None\n",
- "\n",
- "\n",
- "def _call_xgbww_scan(api_func, filters: Filters):\n",
- " sig = inspect.signature(api_func)\n",
- " kwargs = {}\n",
- " if 'filters' in sig.parameters:\n",
- " kwargs['filters'] = filters\n",
- " if 'preprocess' in sig.parameters:\n",
- " kwargs['preprocess'] = True\n",
- " if 'smoke_train' in sig.parameters:\n",
- " kwargs['smoke_train'] = True\n",
- " if not kwargs:\n",
- " return api_func()\n",
- " return api_func(**kwargs)\n",
- "\n",
- "\n",
- "def _materialize_registry(reg):\n",
- " if isinstance(reg, pd.DataFrame):\n",
- " return reg\n",
- " if hasattr(reg, 'to_pandas'):\n",
- " return reg.to_pandas()\n",
- " return pd.DataFrame(list(reg))\n",
- "\n",
- "\n",
- "def _safe_call_scan(api_func, filters: Filters):\n",
- " try:\n",
- " return _materialize_registry(_call_xgbww_scan(api_func, filters))\n",
- " except ModuleNotFoundError as e:\n",
- " missing = getattr(e, 'name', None)\n",
- " if not missing:\n",
- " raise\n",
- "\n",
- " installed = _install_missing_dataset_dependency(missing)\n",
- " if installed:\n",
- " return _materialize_registry(_call_xgbww_scan(api_func, filters))\n",
- "\n",
- " if missing == 'keel_ds':\n",
- " fallback_filters = _filters_without_keel(filters)\n",
- " if fallback_filters is not None:\n",
- " try:\n",
- " return _materialize_registry(_call_xgbww_scan(api_func, fallback_filters))\n",
- " except ModuleNotFoundError as fallback_err:\n",
- " if getattr(fallback_err, 'name', None) != 'keel_ds':\n",
- " raise\n",
- " print('[WARN] Proceeding without KEEL datasets; registry scan returned empty due to missing keel_ds.')\n",
- " return pd.DataFrame()\n",
- "\n",
- " raise\n",
- "\n",
- "\n",
- "def build_registry_df(filters: Filters) -> pd.DataFrame:\n",
- " if xgbww_scan_datasets is not None:\n",
- " return _safe_call_scan(xgbww_scan_datasets, filters)\n",
- " if xgbww_list_datasets is not None:\n",
- " return _safe_call_scan(xgbww_list_datasets, filters)\n",
- " try:\n",
- " from xgbwwdata import get_registry\n",
- " return _safe_call_scan(get_registry, filters)\n",
- " except Exception as e:\n",
- " raise RuntimeError('No compatible xgbwwdata scan/list API found in this environment.') from e\n",
- "\n",
- "\n",
- "filters = Filters(min_rows=200, max_rows=60000, max_features=50000, max_dense_elements=MAX_DENSE_ELEMENTS, preprocess=True)\n",
- "if not INCLUDE_KEEL_DATASETS:\n",
- " keel_free_filters = _filters_without_keel(filters)\n",
- " if keel_free_filters is not None:\n",
- " filters = keel_free_filters\n",
- " print('[INFO] KEEL disabled: forcing non-KEEL dataset sources in filters before scanning.')\n",
- "\n",
- "full_registry_csv = registry_dir / 'full_registry.csv'\n",
- "full_registry_feather = registry_dir / 'full_registry.feather'\n",
- "\n",
- "if RESUME_FROM_CHECKPOINT and full_registry_csv.exists() and not FORCE_RESTART_ALL:\n",
- " full_registry_df = pd.read_csv(full_registry_csv)\n",
- " print(f'Loaded cached full registry: {len(full_registry_df)} rows from {full_registry_csv}')\n",
- "else:\n",
- " full_registry_df = build_registry_df(filters)\n",
- " full_registry_df.to_csv(full_registry_csv, index=False)\n",
- " try:\n",
- " full_registry_df.reset_index(drop=True).to_feather(full_registry_feather)\n",
- " except Exception as e:\n",
- " print('[WARN] Feather save failed for full registry:', e)\n",
- " print(f'Scanned full registry: {len(full_registry_df)} rows')\n",
- "\n",
- "registry_df = full_registry_df.copy()\n",
- "if 'task' in registry_df.columns:\n",
- " registry_df = registry_df[registry_df['task'].astype(str).str.contains('class', case=False, na=False)]\n",
- "if 'n_classes' in registry_df.columns:\n",
- " registry_df = registry_df[(registry_df['n_classes'].fillna(0) >= 2) & (registry_df['n_classes'].fillna(0) <= 20)]\n",
- "registry_df = registry_df.reset_index(drop=True)\n",
- "print('Filtered registry rows:', len(registry_df))\n",
- "\n"
- ],
- "id": "cmP4hat2E4dM"
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "4Qp4CSnzE4dM"
- },
- "source": [
- "## 7. Random-10 dataset selection"
- ],
- "id": "4Qp4CSnzE4dM"
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "hD5MMjf5E4dM",
- "outputId": "19e51d34-70f2-431f-d3cc-49efb4b387a2",
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 88
- }
- },
- "execution_count": 64,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Sampled 0 datasets\n",
- "Saved sampled registry to: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/registry/random10_registry.csv\n"
- ]
- },
- {
- "output_type": "display_data",
- "data": {
- "text/plain": [
- "Empty DataFrame\n",
- "Columns: []\n",
- "Index: []"
- ],
- "text/html": [
- "\n",
- "
\n",
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- "
\n",
- "
\n",
- "
\n",
- "
\n"
- ],
- "application/vnd.google.colaboratory.intrinsic+json": {
- "type": "dataframe",
- "summary": "{\n \"name\": \"display(sampled_registry\",\n \"rows\": 0,\n \"fields\": []\n}"
- }
- },
- "metadata": {}
- }
- ],
- "source": [
- "# Random sample of 10 datasets (saved immediately to Drive in Colab)\n",
- "random10_csv = registry_dir / 'random10_registry.csv'\n",
- "random10_feather = registry_dir / 'random10_registry.feather'\n",
- "\n",
- "if SELECTED_DATASET_UIDS:\n",
- " sampled_registry = registry_df[registry_df['dataset_uid'].isin(SELECTED_DATASET_UIDS)].copy()\n",
- " print(f'Using explicit SELECTED_DATASET_UIDS ({len(sampled_registry)} rows).')\n",
- "elif REUSE_SAMPLED_REGISTRY and RESTART and RESUME_FROM_CHECKPOINT and random10_csv.exists() and not FORCE_RESTART_ALL:\n",
- " sampled_registry = pd.read_csv(random10_csv)\n",
- " print(f'Reusing sampled registry from checkpoint: {len(sampled_registry)} rows ({random10_csv})')\n",
- "else:\n",
- " n_take = min(DATASET_SAMPLE_SIZE, len(registry_df))\n",
- " sampled_registry = registry_df.sample(n=n_take, random_state=RANDOM_STATE).copy()\n",
- " print(f'Sampled {len(sampled_registry)} datasets')\n",
- "\n",
- "sampled_registry = sampled_registry.reset_index(drop=True)\n",
- "sampled_registry.to_csv(random10_csv, index=False)\n",
- "try:\n",
- " sampled_registry.reset_index(drop=True).to_feather(random10_feather)\n",
- "except Exception as e:\n",
- " print('[WARN] Feather save failed for sampled registry:', e)\n",
- "print('Saved sampled registry to:', random10_csv)\n",
- "\n",
- "display(sampled_registry.head(20))\n"
- ],
- "id": "hD5MMjf5E4dM"
- },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0wW66rXME4dH"
+ },
+ "source": [
+ "[](https://colab.research.google.com/github/CalculatedContent/xgboost2ww/blob/main/notebooks/XGBWW_Random10_LongRun_Alpha_Tracking.ipynb)"
+ ],
+ "id": "0wW66rXME4dH"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "7bRnlL4uE4dI"
+ },
+ "source": [
+ "# Random-10 Dataset Long-Run Alpha Tracking\n",
+ "\n",
+ "**Designed for Google Colab.** This notebook is Colab-first: it mounts Google Drive for persistent checkpointing, installs `xgboost2ww` and `xgbwwdata` from source, resumes long runs from Drive checkpoints, and prefers Colab GPU when available (with CPU `hist` fallback)."
+ ],
+ "id": "7bRnlL4uE4dI"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "crO-HlpcE4dJ"
+ },
+ "source": [
+ "## 1. Colab / Drive setup"
+ ],
+ "id": "crO-HlpcE4dJ"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "w6n5vumaE4dJ",
+ "outputId": "58d551c8-4eda-4a5d-8264-e1eaedbb4a8a",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "execution_count": 50,
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {
- "id": "P-Y2_0X6E4dN"
- },
- "source": [
- "## 8. Helper functions"
- ],
- "id": "P-Y2_0X6E4dN"
- },
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "IN_COLAB=True\n",
+ "Python=3.12.12\n",
+ "Platform=Linux-6.6.113+-x86_64-with-glibc2.35\n",
+ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n",
+ "Drive mounted\n",
+ "Google Drive checkpointing enabled\n",
+ "project_root: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking\n",
+ "registry_dir: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/registry\n",
+ "per_dataset_dir: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/per_dataset\n",
+ "aggregate_dir: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/aggregate\n",
+ "logs_dir: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/logs\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Colab-first runtime + Drive setup\n",
+ "import os\n",
+ "import sys\n",
+ "import platform\n",
+ "from pathlib import Path\n",
+ "\n",
+ "try:\n",
+ " import google.colab # noqa: F401\n",
+ " IN_COLAB = True\n",
+ "except Exception:\n",
+ " IN_COLAB = False\n",
+ "\n",
+ "USE_GOOGLE_DRIVE = True if IN_COLAB else False\n",
+ "FORCE_REMOUNT_DRIVE = False\n",
+ "\n",
+ "print(f\"IN_COLAB={IN_COLAB}\")\n",
+ "print(f\"Python={sys.version.split()[0]}\")\n",
+ "print(f\"Platform={platform.platform()}\")\n",
+ "\n",
+ "if IN_COLAB:\n",
+ " from google.colab import drive\n",
+ " drive.mount('/content/drive', force_remount=FORCE_REMOUNT_DRIVE)\n",
+ " project_root = Path('/content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking')\n",
+ "else:\n",
+ " project_root = Path('./random10_longrun_alpha_tracking')\n",
+ "\n",
+ "registry_dir = project_root / 'registry'\n",
+ "per_dataset_dir = project_root / 'per_dataset'\n",
+ "aggregate_dir = project_root / 'aggregate'\n",
+ "logs_dir = project_root / 'logs'\n",
+ "for p in [project_root, registry_dir, per_dataset_dir, aggregate_dir, logs_dir]:\n",
+ " p.mkdir(parents=True, exist_ok=True)\n",
+ "\n",
+ "errors_path = logs_dir / 'errors.csv'\n",
+ "skipped_path = aggregate_dir / 'skipped_datasets.csv'\n",
+ "\n",
+ "if IN_COLAB:\n",
+ " print('Drive mounted')\n",
+ " print('Google Drive checkpointing enabled')\n",
+ "\n",
+ "print('project_root:', project_root)\n",
+ "print('registry_dir:', registry_dir)\n",
+ "print('per_dataset_dir:', per_dataset_dir)\n",
+ "print('aggregate_dir:', aggregate_dir)\n",
+ "print('logs_dir:', logs_dir)\n"
+ ],
+ "id": "w6n5vumaE4dJ"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Vkmos-hTE4dK"
+ },
+ "source": [
+ "## 2. Colab/bootstrap installs"
+ ],
+ "id": "Vkmos-hTE4dK"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "mOauVnRtE4dK",
+ "outputId": "a6147224-6754-4f31-a955-df695a3bf993",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "execution_count": 51,
+ "outputs": [
{
- "cell_type": "code",
- "metadata": {
- "id": "5bTodyRVE4dN"
- },
- "execution_count": 65,
- "outputs": [],
- "source": [
- "# Helper functions\n",
- "W_LIST = ['W1', 'W2', 'W7', 'W8', 'W9', 'W10']\n",
- "\n",
- "def make_safe_slug(text):\n",
- " text = str(text)\n",
- " out = ''.join(ch.lower() if ch.isalnum() else '_' for ch in text)\n",
- " while '__' in out:\n",
- " out = out.replace('__', '_')\n",
- " return out.strip('_')[:120]\n",
- "\n",
- "def log_error(error_rows, dataset_uid, dataset_slug, stage, exc):\n",
- " error_rows.append({'dataset_uid': dataset_uid, 'dataset_slug': dataset_slug, 'stage': stage, 'error_type': type(exc).__name__, 'error_message': str(exc), 'traceback': traceback.format_exc(), 'timestamp': pd.Timestamp.utcnow().isoformat()})\n",
- "\n",
- "def save_error_log(error_rows, errors_path):\n",
- " if error_rows:\n",
- " pd.DataFrame(error_rows).to_csv(errors_path, index=False)\n",
- "\n",
- "def make_dataset_paths(dataset_slug):\n",
- " base = per_dataset_dir / dataset_slug\n",
- " out = {'base': base, 'results': base / 'results', 'models': base / 'models', 'plots': base / 'plots'}\n",
- " for p in out.values():\n",
- " p.mkdir(parents=True, exist_ok=True)\n",
- " return out\n",
- "\n",
- "def detect_task_type(y):\n",
- " y_np = np.asarray(y)\n",
- " n_classes = int(len(np.unique(y_np)))\n",
- " if n_classes < 2:\n",
- " return 'invalid', n_classes, y_np\n",
- " return ('binary' if n_classes == 2 else 'multiclass'), n_classes, y_np\n",
- "\n",
- "def build_params_for_dataset(base_backend_patch, n_classes):\n",
- " params = {'max_depth': 4, 'eta': 0.05, 'subsample': 0.9, 'colsample_bytree': 0.8, 'min_child_weight': 5, 'reg_lambda': 5.0, 'reg_alpha': 0.5, 'gamma': 1.0, 'seed': RANDOM_STATE, **base_backend_patch}\n",
- " if n_classes == 2:\n",
- " params.update({'objective': 'binary:logistic', 'eval_metric': 'logloss'})\n",
- " else:\n",
- " params.update({'objective': 'multi:softprob', 'eval_metric': 'mlogloss', 'num_class': int(n_classes)})\n",
- " return params\n",
- "\n",
- "def prepare_train_test_data(X, y):\n",
- " stratify = y if len(np.unique(y)) > 1 else None\n",
- " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=RANDOM_STATE, stratify=stratify)\n",
- " scaler = StandardScaler(with_mean=False) if sparse.issparse(X_train) else StandardScaler()\n",
- " X_train_scaled = scaler.fit_transform(X_train).astype(np.float32)\n",
- " X_test_scaled = scaler.transform(X_test).astype(np.float32)\n",
- " y_train_np = np.asarray(y_train).astype(np.int32).reshape(-1)\n",
- " y_test_np = np.asarray(y_test).astype(np.int32).reshape(-1)\n",
- " return X_train_scaled, X_test_scaled, y_train_np, y_test_np\n",
- "\n",
- "def densify_for_convert_if_safe(X_train_scaled, max_dense_elements):\n",
- " if sparse.issparse(X_train_scaled):\n",
- " n_elements = int(X_train_scaled.shape[0]) * int(X_train_scaled.shape[1])\n",
- " if n_elements > max_dense_elements:\n",
- " raise MemoryError(f'Refusing to densify convert input: shape={X_train_scaled.shape}, elements={n_elements:,}')\n",
- " return X_train_scaled.toarray().astype(np.float32)\n",
- " return np.asarray(X_train_scaled).astype(np.float32)\n",
- "\n",
- "def build_layer_for_W(bst, W_name, current_round, X_train_for_convert, y_train_np, params):\n",
- " multiclass_mode = 'softprob' if int(params.get('num_class', 2)) > 2 else 'error'\n",
- " return convert(model=bst, data=X_train_for_convert, labels=y_train_np, W=W_name, return_type='torch', nfolds=5, t_points=min(current_round, 160), random_state=RANDOM_STATE, train_params=params, num_boost_round=current_round, multiclass=multiclass_mode)\n",
- "\n",
- "def compute_alpha_from_layer(layer):\n",
- " watcher = ww.WeightWatcher(model=layer)\n",
- " df = watcher.analyze(randomize=True, detX=True)\n",
- " return float(df['alpha'].iloc[0])\n",
- "\n",
- "def save_dataset_checkpoint(rows, metrics_csv_path, metrics_feather_path, bst, current_round, models_dir, metadata, meta_json_path):\n",
- " df = pd.DataFrame(rows)\n",
- " df.to_csv(metrics_csv_path, index=False)\n",
- " try:\n",
- " df.reset_index(drop=True).to_feather(metrics_feather_path)\n",
- " except Exception as e:\n",
- " print('[WARN] Feather save failed:', e)\n",
- "\n",
- " round_model_path = models_dir / f'model_round_{current_round:04d}.json'\n",
- " latest_model_path = models_dir / 'model_latest.json'\n",
- " bst.save_model(round_model_path)\n",
- " bst.save_model(latest_model_path)\n",
- "\n",
- " metadata = dict(metadata)\n",
- " metadata['latest_completed_round'] = int(current_round)\n",
- " metadata['updated_at'] = pd.Timestamp.utcnow().isoformat()\n",
- " meta_json_path.write_text(json.dumps(metadata, indent=2))\n",
- "\n",
- " print(f\"[CHECKPOINT] round={current_round} saved -> {metrics_csv_path}\")\n",
- " print(f\"[CHECKPOINT] model_latest -> {latest_model_path}\")\n",
- "\n",
- "def load_dataset_checkpoint(metrics_csv_path, latest_model_path, chunk_size, n_steps):\n",
- " if not (metrics_csv_path.exists() and latest_model_path.exists()):\n",
- " return None, 1, []\n",
- " prior_df = pd.read_csv(metrics_csv_path)\n",
- " if prior_df.empty:\n",
- " return None, 1, []\n",
- " max_round = int(prior_df['boosting_round'].max())\n",
- " if max_round % chunk_size != 0:\n",
- " raise ValueError(f'Checkpoint round {max_round} incompatible with CHUNK_SIZE={chunk_size}')\n",
- " completed_steps = max_round // chunk_size\n",
- " if completed_steps > n_steps:\n",
- " raise ValueError(f'Checkpoint completed_steps={completed_steps} > N_STEPS={n_steps}')\n",
- " bst = xgb.Booster()\n",
- " bst.load_model(latest_model_path)\n",
- " return bst, completed_steps + 1, prior_df.to_dict('records')\n",
- "\n",
- "def plot_dataset_dynamics(df_dataset, dataset_slug, plots_dir):\n",
- " if df_dataset.empty:\n",
- " return None, None\n",
- " long_alpha = df_dataset.melt(id_vars=['boosting_round', 'test_accuracy'], value_vars=[f'alpha_{w}' for w in W_LIST], var_name='alpha_name', value_name='alpha_value')\n",
- " fig, axes = plt.subplots(1, 3, figsize=(20, 5))\n",
- " axes[0].plot(df_dataset['boosting_round'], df_dataset['test_accuracy'], marker='o')\n",
- " axes[0].set_title(f'{dataset_slug}: test accuracy vs round')\n",
- " sns.lineplot(data=long_alpha, x='boosting_round', y='alpha_value', hue='alpha_name', ax=axes[1])\n",
- " axes[1].set_title('alpha vs round')\n",
- " sns.scatterplot(data=long_alpha, x='alpha_value', y='test_accuracy', hue='alpha_name', ax=axes[2])\n",
- " axes[2].set_title('accuracy vs alpha')\n",
- " fig.tight_layout()\n",
- " p1 = plots_dir / f'{dataset_slug}_dynamics.png'\n",
- " fig.savefig(p1, dpi=140)\n",
- " plt.close(fig)\n",
- " fig2, ax2 = plt.subplots(figsize=(7, 5))\n",
- " for w in W_LIST:\n",
- " col = f'alpha_{w}'\n",
- " ax2.scatter(df_dataset[col], df_dataset['test_accuracy'], alpha=0.6, label=w)\n",
- " ax2.set_title(f'{dataset_slug}: accuracy vs alpha (all W)')\n",
- " ax2.legend(ncol=2)\n",
- " fig2.tight_layout()\n",
- " p2 = plots_dir / f'{dataset_slug}_accuracy_vs_alpha.png'\n",
- " fig2.savefig(p2, dpi=140)\n",
- " plt.close(fig2)\n",
- " return p1, p2\n"
- ],
- "id": "5bTodyRVE4dN"
- },
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "[INFO] INCLUDE_KEEL_DATASETS=False -> skipping keel-ds install.\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ " Building wheel for xgboost2ww (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ " Building wheel for xgbwwdata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "[INFO] INCLUDE_KEEL_DATASETS=False -> skipping KEEL source install checks.\n",
+ "xgboost: 3.2.0\n",
+ "weightwatcher: 0.7.7\n",
+ "xgboost2ww location: /usr/local/lib/python3.12/dist-packages/xgboost2ww/__init__.py\n",
+ "xgbwwdata location: /usr/local/lib/python3.12/dist-packages/xgbwwdata/__init__.py\n",
+ "xgboost2ww installed from source\n",
+ "xgbwwdata installed from source\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Colab/bootstrap installs\n",
+ "INCLUDE_KEEL_DATASETS = False # Optional dataset source; default off for smoother Colab installs\n",
+ "\n",
+ "%pip install -q xgboost weightwatcher scikit-learn pandas matplotlib seaborn scipy feather-format pyarrow\n",
+ "if INCLUDE_KEEL_DATASETS:\n",
+ " %pip install -q openml pmlb keel-ds\n",
+ "else:\n",
+ " %pip install -q openml pmlb\n",
+ " print('[INFO] INCLUDE_KEEL_DATASETS=False -> skipping keel-ds install.')\n",
+ "\n",
+ "import pathlib\n",
+ "import shutil\n",
+ "import subprocess\n",
+ "import sys\n",
+ "\n",
+ "def clone_or_update(repo_url: str, target_dir: str, branch: str = 'main') -> None:\n",
+ " target = pathlib.Path(target_dir)\n",
+ " if (target / '.git').exists():\n",
+ " subprocess.run(['git', '-C', str(target), 'fetch', '--depth', '1', 'origin', branch], check=True)\n",
+ " subprocess.run(['git', '-C', str(target), 'reset', '--hard', 'FETCH_HEAD'], check=True)\n",
+ " else:\n",
+ " if target.exists():\n",
+ " shutil.rmtree(target)\n",
+ " subprocess.run(['git', 'clone', '--depth', '1', '--branch', branch, repo_url, str(target)], check=True)\n",
+ "\n",
+ "def ensure_keel_ds() -> bool:\n",
+ " try:\n",
+ " __import__('keel_ds')\n",
+ " print('keel_ds already importable')\n",
+ " return True\n",
+ " except Exception:\n",
+ " pass\n",
+ "\n",
+ " repo_candidates = [\n",
+ " 'https://github.com/CalculatedContent/keel_ds.git',\n",
+ " 'https://github.com/CalculatedContent/keel-ds.git',\n",
+ " ]\n",
+ " for repo in repo_candidates:\n",
+ " try:\n",
+ " print(f'[INFO] Trying keel_ds source install from: {repo}')\n",
+ " clone_or_update(repo, '/tmp/keel-ds-src')\n",
+ " subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', '--no-build-isolation', '/tmp/keel-ds-src'], check=True)\n",
+ " __import__('keel_ds')\n",
+ " print(f'[INFO] keel_ds installed from source: {repo}')\n",
+ " return True\n",
+ " except Exception:\n",
+ " continue\n",
+ "\n",
+ " print('[WARN] Unable to install keel_ds from source repo candidates; continuing without KEEL datasets.')\n",
+ " return False\n",
+ "\n",
+ "clone_or_update('https://github.com/CalculatedContent/xgboost2ww.git', '/tmp/xgboost2ww-src')\n",
+ "clone_or_update('https://github.com/CalculatedContent/xgbwwdata.git', '/tmp/xgbwwdata-src')\n",
+ "\n",
+ "%pip install -q --no-build-isolation /tmp/xgboost2ww-src\n",
+ "%pip install -q --no-build-isolation /tmp/xgbwwdata-src\n",
+ "\n",
+ "if INCLUDE_KEEL_DATASETS:\n",
+ " ensure_keel_ds()\n",
+ "else:\n",
+ " print('[INFO] INCLUDE_KEEL_DATASETS=False -> skipping KEEL source install checks.')\n",
+ "\n",
+ "import importlib\n",
+ "\n",
+ "xgboost = importlib.import_module('xgboost')\n",
+ "weightwatcher = importlib.import_module('weightwatcher')\n",
+ "xgboost2ww = importlib.import_module('xgboost2ww')\n",
+ "xgbwwdata = importlib.import_module('xgbwwdata')\n",
+ "\n",
+ "print('xgboost:', getattr(xgboost, '__version__', 'unknown'))\n",
+ "print('weightwatcher:', getattr(weightwatcher, '__version__', 'unknown'))\n",
+ "print('xgboost2ww location:', pathlib.Path(xgboost2ww.__file__).resolve())\n",
+ "print('xgbwwdata location:', pathlib.Path(xgbwwdata.__file__).resolve())\n",
+ "print('xgboost2ww installed from source')\n",
+ "print('xgbwwdata installed from source')\n",
+ "\n"
+ ],
+ "id": "mOauVnRtE4dK"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eK1Kae5yE4dK"
+ },
+ "source": [
+ "If installs are rerun many times and import state becomes inconsistent, restart runtime and rerun from the top. In normal usage, this notebook is structured to avoid requiring a manual restart."
+ ],
+ "id": "eK1Kae5yE4dK"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Of0eSX0kE4dL"
+ },
+ "source": [
+ "## 3. Imports"
+ ],
+ "id": "Of0eSX0kE4dL"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "mEqMoblyE4dL"
+ },
+ "execution_count": 52,
+ "outputs": [],
+ "source": [
+ "# Imports\n",
+ "import json\n",
+ "import time\n",
+ "import random\n",
+ "import traceback\n",
+ "import subprocess\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "import seaborn as sns\n",
+ "import xgboost as xgb\n",
+ "import weightwatcher as ww\n",
+ "\n",
+ "from scipy import sparse\n",
+ "from sklearn.model_selection import train_test_split\n",
+ "from sklearn.preprocessing import StandardScaler\n",
+ "from sklearn.metrics import accuracy_score\n",
+ "\n",
+ "from xgboost2ww import convert\n",
+ "from xgbwwdata import Filters, load_dataset\n",
+ "\n",
+ "try:\n",
+ " from xgbwwdata import scan_datasets as xgbww_scan_datasets\n",
+ "except Exception:\n",
+ " xgbww_scan_datasets = None\n",
+ "\n",
+ "try:\n",
+ " from xgbwwdata import list_datasets as xgbww_list_datasets\n",
+ "except Exception:\n",
+ " xgbww_list_datasets = None\n"
+ ],
+ "id": "mEqMoblyE4dL"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XnJvev0xE4dL"
+ },
+ "source": [
+ "## 4. Runtime configuration\n",
+ "\n",
+ "Long Colab runs may disconnect. This notebook is checkpointed to Drive so you can reconnect and rerun from the top to resume. Use `FORCE_RESTART_ALL=True` to start over, or set `FORCE_RESTART_DATASETS` to selected dataset slugs/uids."
+ ],
+ "id": "XnJvev0xE4dL"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "ieH6OSQjE4dL",
+ "outputId": "a471fbd2-15f5-4e8e-a7cc-a78de33c4a9d",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "execution_count": 53,
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {
- "id": "vzDVS76QE4dN"
- },
- "source": [
- "## 9. Per-dataset long-run training"
- ],
- "id": "vzDVS76QE4dN"
- },
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "N_STEPS=48, sample=10, dry_run=False\n",
+ "USE_GOOGLE_DRIVE=True, RESTART=True, RESUME_FROM_CHECKPOINT=True, REUSE_SAMPLED_REGISTRY=True\n",
+ "INCLUDE_KEEL_DATASETS=False (default=False)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Runtime + global config\n",
+ "RANDOM_STATE = 42\n",
+ "DATASET_SAMPLE_SIZE = 10\n",
+ "TOTAL_ROUNDS = 1200\n",
+ "CHUNK_SIZE = 25\n",
+ "N_STEPS = TOTAL_ROUNDS // CHUNK_SIZE\n",
+ "CHECKPOINT_EVERY_STEPS = 1\n",
+ "\n",
+ "INCLUDE_KEEL_DATASETS = globals().get('INCLUDE_KEEL_DATASETS', False)\n",
+ "RESTART = True # Set False to start fresh: delete prior checkpoint artifacts before running.\n",
+ "RESUME_FROM_CHECKPOINT = RESTART\n",
+ "FORCE_RESTART_ALL = True\n",
+ "FORCE_RESTART_DATASETS = []\n",
+ "REUSE_SAMPLED_REGISTRY = True\n",
+ "\n",
+ "SELECTED_DATASET_UIDS = []\n",
+ "DRY_RUN = False\n",
+ "MAX_DENSE_ELEMENTS = int(2e8)\n",
+ "\n",
+ "if DRY_RUN:\n",
+ " DATASET_SAMPLE_SIZE = 2\n",
+ " TOTAL_ROUNDS = 50\n",
+ " CHUNK_SIZE = 25\n",
+ " N_STEPS = TOTAL_ROUNDS // CHUNK_SIZE\n",
+ "\n",
+ "random.seed(RANDOM_STATE)\n",
+ "np.random.seed(RANDOM_STATE)\n",
+ "sns.set_theme(style='whitegrid')\n",
+ "print(f'N_STEPS={N_STEPS}, sample={DATASET_SAMPLE_SIZE}, dry_run={DRY_RUN}')\n",
+ "print(f'USE_GOOGLE_DRIVE={USE_GOOGLE_DRIVE}, RESTART={RESTART}, RESUME_FROM_CHECKPOINT={RESUME_FROM_CHECKPOINT}, REUSE_SAMPLED_REGISTRY={REUSE_SAMPLED_REGISTRY}')\n",
+ "print(f'INCLUDE_KEEL_DATASETS={INCLUDE_KEEL_DATASETS} (default=False)')\n",
+ "\n",
+ "\n",
+ "def remove_stale_run_artifacts():\n",
+ " targets = [registry_dir, per_dataset_dir, aggregate_dir, logs_dir]\n",
+ " removed = []\n",
+ " for root in targets:\n",
+ " if not root.exists():\n",
+ " continue\n",
+ " for child in list(root.iterdir()):\n",
+ " try:\n",
+ " if child.is_dir():\n",
+ " shutil.rmtree(child)\n",
+ " else:\n",
+ " child.unlink()\n",
+ " removed.append(str(child))\n",
+ " except Exception as e:\n",
+ " print(f'[WARN] Could not remove stale artifact {child}: {e}')\n",
+ " print(f'[RESET] Removed {len(removed)} stale artifact(s) from prior runs.')\n",
+ "\n",
+ "start_fresh = FORCE_RESTART_ALL or (not RESTART) or (not RESUME_FROM_CHECKPOINT)\n",
+ "if start_fresh:\n",
+ " print('[RESET] Starting fresh run; clearing prior registry/checkpoint/results artifacts.')\n",
+ " remove_stale_run_artifacts()\n"
+ ],
+ "id": "ieH6OSQjE4dL"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "owCeTeVWE4dM"
+ },
+ "source": [
+ "## 5. Runtime diagnostics"
+ ],
+ "id": "owCeTeVWE4dM"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "UIteF5tmE4dM",
+ "outputId": "2d2b7a77-5230-4491-f36e-242d51a6d918",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "execution_count": 62,
+ "outputs": [
{
- "cell_type": "code",
- "metadata": {
- "id": "bzbd6Jk7E4dN",
- "outputId": "16bf26a9-68dd-40dd-b70c-60317547a065",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "execution_count": 66,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Done processing sampled datasets.\n"
- ]
- }
- ],
- "source": [
- "# Per-dataset training loop\n",
- "all_results, skipped_rows, error_rows = [], [], []\n",
- "\n",
- "for i, row in sampled_registry.iterrows():\n",
- " dataset_uid = row.get('dataset_uid', f'row_{i}')\n",
- " source = row.get('source', 'unknown')\n",
- " dataset_slug = make_safe_slug(f\"{dataset_uid}_{source}\")\n",
- " print(f\"\\n[{i+1}/{len(sampled_registry)}] dataset_uid={dataset_uid} source={source}\")\n",
- " force_restart_this = FORCE_RESTART_ALL or (dataset_slug in FORCE_RESTART_DATASETS) or (dataset_uid in FORCE_RESTART_DATASETS)\n",
- "\n",
- " paths = make_dataset_paths(dataset_slug)\n",
- " metrics_csv = paths['results'] / f'{dataset_slug}_metrics.csv'\n",
- " metrics_feather = paths['results'] / f'{dataset_slug}_metrics.feather'\n",
- " latest_model_path = paths['models'] / 'model_latest.json'\n",
- " dataset_meta_json = paths['results'] / '_meta.json'\n",
- "\n",
- " try:\n",
- " X, y, meta = load_dataset(dataset_uid, filters=filters)\n",
- " except Exception as e:\n",
- " log_error(error_rows, dataset_uid, dataset_slug, 'dataset_load', e)\n",
- " skipped_rows.append({'dataset_uid': dataset_uid, 'dataset_slug': dataset_slug, 'reason': 'dataset_load_failed'})\n",
- " save_error_log(error_rows, errors_path)\n",
- " continue\n",
- "\n",
- " try:\n",
- " task_type, n_classes, y_np = detect_task_type(y)\n",
- " if task_type == 'invalid' or n_classes > 20:\n",
- " raise ValueError(f'Invalid class count: n_classes={n_classes}')\n",
- " X_train_scaled, X_test_scaled, y_train_np, y_test_np = prepare_train_test_data(X, y_np)\n",
- " X_train_for_convert = densify_for_convert_if_safe(X_train_scaled, MAX_DENSE_ELEMENTS)\n",
- " except Exception as e:\n",
- " log_error(error_rows, dataset_uid, dataset_slug, 'split_preprocess', e)\n",
- " skipped_rows.append({'dataset_uid': dataset_uid, 'dataset_slug': dataset_slug, 'reason': 'split_preprocess_failed'})\n",
- " save_error_log(error_rows, errors_path)\n",
- " continue\n",
- "\n",
- " dtrain, dtest = xgb.DMatrix(X_train_scaled, label=y_train_np), xgb.DMatrix(X_test_scaled, label=y_test_np)\n",
- " params = build_params_for_dataset(BACKEND_INFO['params_patch'], n_classes)\n",
- "\n",
- " base_meta = {\n",
- " 'dataset_uid': str(dataset_uid),\n",
- " 'dataset_slug': dataset_slug,\n",
- " 'source': str(source),\n",
- " 'shapes': {\n",
- " 'X_train': [int(X_train_scaled.shape[0]), int(X_train_scaled.shape[1])],\n",
- " 'X_test': [int(X_test_scaled.shape[0]), int(X_test_scaled.shape[1])],\n",
- " },\n",
- " 'class_count': int(n_classes),\n",
- " 'selected_params': params,\n",
- " 'backend_info': BACKEND_INFO,\n",
- " 'start_time_utc': pd.Timestamp.utcnow().isoformat(),\n",
- " 'latest_completed_round': 0,\n",
- " }\n",
- " dataset_meta_json.write_text(json.dumps(base_meta, indent=2))\n",
- "\n",
- " bst, rows, start_step = None, [], 1\n",
- " if (not force_restart_this) and RESUME_FROM_CHECKPOINT and RESTART:\n",
- " try:\n",
- " bst, start_step, rows = load_dataset_checkpoint(metrics_csv, latest_model_path, CHUNK_SIZE, N_STEPS)\n",
- " if bst is not None:\n",
- " print(f'[RESUME] step={start_step}/{N_STEPS} from {metrics_csv}')\n",
- " except Exception as e:\n",
- " log_error(error_rows, dataset_uid, dataset_slug, 'checkpoint_load', e)\n",
- " bst, rows, start_step = None, [], 1\n",
- " elif not RESTART:\n",
- " print('[RESTART=False] Ignoring checkpoints and overwriting saved outputs for this dataset.')\n",
- "\n",
- " dataset_t0 = time.time()\n",
- " dataset_failed = False\n",
- " for step in range(start_step, N_STEPS + 1):\n",
- " current_round = step * CHUNK_SIZE\n",
- " step_t0 = time.time()\n",
- " try:\n",
- " bst = xgb.train(params=params, dtrain=dtrain, num_boost_round=CHUNK_SIZE, xgb_model=bst, verbose_eval=False)\n",
- " y_prob = bst.predict(dtest)\n",
- " y_pred = (y_prob >= 0.5).astype(np.int32) if n_classes == 2 else np.argmax(y_prob, axis=1).astype(np.int32)\n",
- " test_accuracy = float(accuracy_score(y_test_np, y_pred))\n",
- " except Exception as e:\n",
- " log_error(error_rows, dataset_uid, dataset_slug, 'train_or_predict', e)\n",
- " dataset_failed = True\n",
- " break\n",
- "\n",
- " alpha_values, alpha_failures = {}, 0\n",
- " for w_name in W_LIST:\n",
- " try:\n",
- " layer = build_layer_for_W(bst, w_name, current_round, X_train_for_convert, y_train_np, params)\n",
- " alpha_values[f'alpha_{w_name}'] = compute_alpha_from_layer(layer)\n",
- " except Exception as e:\n",
- " alpha_values[f'alpha_{w_name}'] = np.nan\n",
- " alpha_failures += 1\n",
- " log_error(error_rows, dataset_uid, dataset_slug, f'alpha_{w_name}', e)\n",
- "\n",
- " if alpha_failures == len(W_LIST):\n",
- " dataset_failed = True\n",
- " log_error(error_rows, dataset_uid, dataset_slug, 'alpha_all_failed', RuntimeError('All W matrices failed this round'))\n",
- " break\n",
- "\n",
- " rows.append({'dataset_uid': dataset_uid, 'dataset_slug': dataset_slug, 'source': source, 'n_samples': int(np.asarray(y_np).shape[0]), 'n_features': int(X_train_scaled.shape[1]), 'n_classes': int(n_classes), 'boosting_round': int(current_round), 'test_accuracy': test_accuracy, **alpha_values, 'backend': BACKEND_INFO.get('backend_name'), 'tree_method': params.get('tree_method'), 'device': params.get('device', 'cpu'), 'elapsed_seconds_round': float(time.time() - step_t0), 'elapsed_seconds_total': float(time.time() - dataset_t0)})\n",
- "\n",
- " if step % CHECKPOINT_EVERY_STEPS == 0:\n",
- " try:\n",
- " save_dataset_checkpoint(rows, metrics_csv, metrics_feather, bst, current_round, paths['models'], base_meta, dataset_meta_json)\n",
- " except Exception as e:\n",
- " log_error(error_rows, dataset_uid, dataset_slug, 'checkpoint_save', e)\n",
- "\n",
- " print(f\" step={step:03d}/{N_STEPS} round={current_round:04d} acc={test_accuracy:.4f}\")\n",
- "\n",
- " df_dataset = pd.DataFrame(rows)\n",
- " if not df_dataset.empty:\n",
- " all_results.append(df_dataset)\n",
- " try:\n",
- " plot_dataset_dynamics(df_dataset, dataset_slug, paths['plots'])\n",
- " except Exception as e:\n",
- " log_error(error_rows, dataset_uid, dataset_slug, 'plotting', e)\n",
- " if dataset_failed and df_dataset.empty:\n",
- " skipped_rows.append({'dataset_uid': dataset_uid, 'dataset_slug': dataset_slug, 'reason': 'dataset_failed'})\n",
- "\n",
- " save_error_log(error_rows, errors_path)\n",
- "\n",
- "print('Done processing sampled datasets.')\n"
- ],
- "id": "bzbd6Jk7E4dN"
- },
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "cpu_count: 2\n",
+ "platform: Linux-6.6.113+-x86_64-with-glibc2.35\n",
+ "[WARN] nvidia-smi unavailable: [Errno 2] No such file or directory: 'nvidia-smi'\n",
+ "colab_gpu_runtime=True\n",
+ "CUDA visible to XGBoost: no (modern config failed)\n",
+ "[fallback] modern GPU failed -> trying gpu_hist: [05:09:52] WARNING: /__w/xgboost/xgboost/src/context.cc:53: No visible GPU is found, setting device to CPU. ; [05:09:52] WARNING: /__w/xgboost/xgboost/src/context.cc:207: Device is changed from GPU to CPU as we couldn't find any available GPU on the system.\n",
+ "[fallback] gpu_hist failed -> using CPU hist: Invalid Input: 'gpu_hist', valid values are: {'approx', 'auto', 'exact', 'hist'}\n",
+ "Backend fallback path: CPU hist\n",
+ "Chosen backend: cpu_hist\n",
+ "Chosen XGBoost params patch: {'tree_method': 'hist', 'nthread': 1}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Runtime diagnostics + Colab-aware backend selection\n",
+ "import warnings\n",
+ "\n",
+ "print('cpu_count:', os.cpu_count())\n",
+ "print('platform:', platform.platform())\n",
+ "\n",
+ "if IN_COLAB:\n",
+ " try:\n",
+ " smi = subprocess.run(['nvidia-smi'], capture_output=True, text=True, check=False)\n",
+ " print('nvidia-smi returncode:', smi.returncode)\n",
+ " print(smi.stdout[:1000] if smi.stdout else smi.stderr[:500])\n",
+ " except Exception as e:\n",
+ " print('[WARN] nvidia-smi unavailable:', e)\n",
+ "\n",
+ "def _probe_xgb(params_patch):\n",
+ " X_probe = np.array([[0.0], [1.0], [2.0], [3.0]], dtype=np.float32)\n",
+ " y_probe = np.array([0, 0, 1, 1], dtype=np.float32)\n",
+ " dprobe = xgb.DMatrix(X_probe, label=y_probe)\n",
+ " with warnings.catch_warnings(record=True) as caught:\n",
+ " warnings.simplefilter('always')\n",
+ " xgb.train(params={'objective': 'binary:logistic', 'eval_metric': 'logloss', **params_patch}, dtrain=dprobe, num_boost_round=1, verbose_eval=False)\n",
+ " gpu_fallback_msgs = [\n",
+ " str(w.message) for w in caught\n",
+ " if ('No visible GPU is found' in str(w.message)) or ('Device is changed from GPU to CPU' in str(w.message))\n",
+ " ]\n",
+ " if gpu_fallback_msgs:\n",
+ " raise RuntimeError(' ; '.join(gpu_fallback_msgs))\n",
+ "\n",
+ "def choose_xgb_runtime_colab_aware():\n",
+ " cpu_threads = max(1, (os.cpu_count() or 1) - 1)\n",
+ " print(f'colab_gpu_runtime={IN_COLAB}')\n",
+ "\n",
+ " if IN_COLAB:\n",
+ " try:\n",
+ " modern_gpu = {'tree_method': 'hist', 'device': 'cuda'}\n",
+ " _probe_xgb(modern_gpu)\n",
+ " print('CUDA visible to XGBoost: yes')\n",
+ " print('Backend fallback path: modern gpu config succeeded')\n",
+ " return {'backend_name': 'colab_cuda_hist', 'params_patch': modern_gpu, 'is_gpu': True}\n",
+ " except Exception as e1:\n",
+ " print('CUDA visible to XGBoost: no (modern config failed)')\n",
+ " print('[fallback] modern GPU failed -> trying gpu_hist:', e1)\n",
+ " try:\n",
+ " legacy_gpu = {'tree_method': 'gpu_hist'}\n",
+ " _probe_xgb(legacy_gpu)\n",
+ " print('Backend fallback path: legacy gpu_hist succeeded')\n",
+ " return {'backend_name': 'colab_gpu_hist_legacy', 'params_patch': legacy_gpu, 'is_gpu': True}\n",
+ " except Exception as e2:\n",
+ " print('[fallback] gpu_hist failed -> using CPU hist:', e2)\n",
+ "\n",
+ " cpu_patch = {'tree_method': 'hist', 'nthread': cpu_threads}\n",
+ " print('Backend fallback path: CPU hist')\n",
+ " return {'backend_name': 'cpu_hist', 'params_patch': cpu_patch, 'is_gpu': False}\n",
+ "\n",
+ "BACKEND_INFO = choose_xgb_runtime_colab_aware()\n",
+ "print('Chosen backend:', BACKEND_INFO['backend_name'])\n",
+ "print('Chosen XGBoost params patch:', BACKEND_INFO['params_patch'])\n",
+ "\n"
+ ],
+ "id": "UIteF5tmE4dM"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fD9dmsP5E4dM"
+ },
+ "source": [
+ "## 6. Dataset catalog scan / cached registry"
+ ],
+ "id": "fD9dmsP5E4dM"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "cmP4hat2E4dM",
+ "outputId": "4e3fb4f4-7302-4d24-b849-8a1491314638",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "execution_count": 63,
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {
- "id": "ZVykdmiqE4dN"
- },
- "source": [
- "## 10. Per-dataset plots\n",
- "\n",
- "Per-dataset alpha/accuracy plots are saved under each dataset folder in `per_dataset//plots/` on Drive when running in Colab."
- ],
- "id": "ZVykdmiqE4dN"
- },
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "[INFO] INCLUDE_KEEL_DATASETS=False -> not attempting keel_ds installation.\n",
+ "[WARN] Proceeding without KEEL datasets; registry scan returned empty due to missing keel_ds.\n",
+ "Scanned full registry: 0 rows\n",
+ "Filtered registry rows: 0\n"
+ ]
+ }
+ ],
+ "source": [
+ "# xgbwwdata scan / registry build\n",
+ "\n",
+ "import inspect\n",
+ "import importlib\n",
+ "import sys\n",
+ "from dataclasses import asdict, is_dataclass\n",
+ "\n",
+ "\n",
+ "def _install_missing_dataset_dependency(module_name: str) -> bool:\n",
+ " if module_name != 'keel_ds':\n",
+ " return False\n",
+ " if not INCLUDE_KEEL_DATASETS:\n",
+ " print('[INFO] INCLUDE_KEEL_DATASETS=False -> not attempting keel_ds installation.')\n",
+ " return False\n",
+ "\n",
+ " print('[INFO] Missing optional dataset dependency: keel_ds. Installing now...')\n",
+ " candidates = [\n",
+ " 'keel-ds',\n",
+ " 'keel_ds',\n",
+ " 'git+https://github.com/CalculatedContent/keel_ds.git',\n",
+ " 'git+https://github.com/CalculatedContent/keel-ds.git',\n",
+ " ]\n",
+ " install_errors = []\n",
+ " for pkg in candidates:\n",
+ " try:\n",
+ " subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', pkg])\n",
+ " importlib.import_module('keel_ds')\n",
+ " print(f'[INFO] Installed KEEL dependency using installer target: {pkg}')\n",
+ " return True\n",
+ " except Exception as ex:\n",
+ " install_errors.append(f'{pkg}: {type(ex).__name__}: {ex}')\n",
+ " continue\n",
+ "\n",
+ " print('[WARN] Unable to auto-install keel_ds. KEEL datasets will be skipped for this run.')\n",
+ " if install_errors:\n",
+ " print('[WARN] KEEL install attempts failed with:')\n",
+ " for err in install_errors:\n",
+ " print(' -', err)\n",
+ " return False\n",
+ "\n",
+ "\n",
+ "def _copy_filters_with_overrides(filters: Filters, **overrides):\n",
+ " if is_dataclass(filters):\n",
+ " data = asdict(filters)\n",
+ " elif hasattr(filters, 'model_dump'):\n",
+ " data = filters.model_dump()\n",
+ " elif hasattr(filters, '__dict__'):\n",
+ " data = dict(filters.__dict__)\n",
+ " else:\n",
+ " data = {}\n",
+ " data.update(overrides)\n",
+ " return type(filters)(**data)\n",
+ "\n",
+ "\n",
+ "def _filters_without_keel(filters: Filters):\n",
+ " source_keys = ('sources', 'include_sources', 'source_allowlist')\n",
+ " for key in source_keys:\n",
+ " current = getattr(filters, key, None)\n",
+ " if not current:\n",
+ " continue\n",
+ " try:\n",
+ " cleaned = [s for s in current if 'keel' not in str(s).lower()]\n",
+ " if cleaned and len(cleaned) != len(current):\n",
+ " print(f'[INFO] Retrying scan without KEEL via filters.{key}.')\n",
+ " return _copy_filters_with_overrides(filters, **{key: cleaned})\n",
+ " except Exception:\n",
+ " continue\n",
+ "\n",
+ " # If the current filters do not explicitly set source lists, force a KEEL-free allowlist.\n",
+ " non_keel_sources = ['openml', 'pmlb', 'uci', 'sklearn']\n",
+ " for key in source_keys:\n",
+ " if hasattr(filters, key):\n",
+ " try:\n",
+ " print(f'[INFO] Retrying scan with explicit non-KEEL sources via filters.{key}.')\n",
+ " return _copy_filters_with_overrides(filters, **{key: non_keel_sources})\n",
+ " except Exception:\n",
+ " continue\n",
+ " return None\n",
+ "\n",
+ "\n",
+ "def _call_xgbww_scan(api_func, filters: Filters):\n",
+ " sig = inspect.signature(api_func)\n",
+ " kwargs = {}\n",
+ " if 'filters' in sig.parameters:\n",
+ " kwargs['filters'] = filters\n",
+ " if 'preprocess' in sig.parameters:\n",
+ " kwargs['preprocess'] = True\n",
+ " if 'smoke_train' in sig.parameters:\n",
+ " kwargs['smoke_train'] = True\n",
+ " if not kwargs:\n",
+ " return api_func()\n",
+ " return api_func(**kwargs)\n",
+ "\n",
+ "\n",
+ "def _materialize_registry(reg):\n",
+ " if isinstance(reg, pd.DataFrame):\n",
+ " return reg\n",
+ " if hasattr(reg, 'to_pandas'):\n",
+ " return reg.to_pandas()\n",
+ " return pd.DataFrame(list(reg))\n",
+ "\n",
+ "\n",
+ "def _safe_call_scan(api_func, filters: Filters):\n",
+ " try:\n",
+ " return _materialize_registry(_call_xgbww_scan(api_func, filters))\n",
+ " except ModuleNotFoundError as e:\n",
+ " missing = getattr(e, 'name', None)\n",
+ " if not missing:\n",
+ " raise\n",
+ "\n",
+ " installed = _install_missing_dataset_dependency(missing)\n",
+ " if installed:\n",
+ " return _materialize_registry(_call_xgbww_scan(api_func, filters))\n",
+ "\n",
+ " if missing == 'keel_ds':\n",
+ " fallback_filters = _filters_without_keel(filters)\n",
+ " if fallback_filters is not None:\n",
+ " try:\n",
+ " return _materialize_registry(_call_xgbww_scan(api_func, fallback_filters))\n",
+ " except ModuleNotFoundError as fallback_err:\n",
+ " if getattr(fallback_err, 'name', None) != 'keel_ds':\n",
+ " raise\n",
+ " print('[WARN] Proceeding without KEEL datasets; registry scan returned empty due to missing keel_ds.')\n",
+ " return pd.DataFrame()\n",
+ "\n",
+ " raise\n",
+ "\n",
+ "\n",
+ "def build_registry_df(filters: Filters) -> pd.DataFrame:\n",
+ " if xgbww_scan_datasets is not None:\n",
+ " return _safe_call_scan(xgbww_scan_datasets, filters)\n",
+ " if xgbww_list_datasets is not None:\n",
+ " return _safe_call_scan(xgbww_list_datasets, filters)\n",
+ " try:\n",
+ " from xgbwwdata import get_registry\n",
+ " return _safe_call_scan(get_registry, filters)\n",
+ " except Exception as e:\n",
+ " raise RuntimeError('No compatible xgbwwdata scan/list API found in this environment.') from e\n",
+ "\n",
+ "\n",
+ "filters = Filters(min_rows=200, max_rows=60000, max_features=50000, max_dense_elements=MAX_DENSE_ELEMENTS, preprocess=True)\n",
+ "if not INCLUDE_KEEL_DATASETS:\n",
+ " keel_free_filters = _filters_without_keel(filters)\n",
+ " if keel_free_filters is not None:\n",
+ " filters = keel_free_filters\n",
+ " print('[INFO] KEEL disabled: forcing non-KEEL dataset sources in filters before scanning.')\n",
+ "\n",
+ "full_registry_csv = registry_dir / 'full_registry.csv'\n",
+ "full_registry_feather = registry_dir / 'full_registry.feather'\n",
+ "\n",
+ "if RESUME_FROM_CHECKPOINT and full_registry_csv.exists() and not FORCE_RESTART_ALL:\n",
+ " try:\n",
+ " full_registry_df = pd.read_csv(full_registry_csv)\n",
+ " if full_registry_df.empty:\n",
+ " raise pd.errors.EmptyDataError('Cached full registry CSV has 0 rows')\n",
+ " print(f'Loaded cached full registry: {len(full_registry_df)} rows from {full_registry_csv}')\n",
+ " except (pd.errors.EmptyDataError, pd.errors.ParserError) as e:\n",
+ " print(f'[WARN] Invalid cached full registry ({e}); rebuilding from dataset scan.')\n",
+ " full_registry_csv.unlink(missing_ok=True)\n",
+ " full_registry_feather.unlink(missing_ok=True)\n",
+ " full_registry_df = build_registry_df(filters)\n",
+ " full_registry_df.to_csv(full_registry_csv, index=False)\n",
+ " try:\n",
+ " full_registry_df.reset_index(drop=True).to_feather(full_registry_feather)\n",
+ " except Exception as fe:\n",
+ " print('[WARN] Feather save failed for full registry:', fe)\n",
+ "else:\n",
+ " full_registry_df = build_registry_df(filters)\n",
+ " full_registry_df.to_csv(full_registry_csv, index=False)\n",
+ " try:\n",
+ " full_registry_df.reset_index(drop=True).to_feather(full_registry_feather)\n",
+ " except Exception as e:\n",
+ " print('[WARN] Feather save failed for full registry:', e)\n",
+ " print(f'Scanned full registry: {len(full_registry_df)} rows')\n",
+ "\n",
+ "registry_df = full_registry_df.copy()\n",
+ "if 'task' in registry_df.columns:\n",
+ " registry_df = registry_df[registry_df['task'].astype(str).str.contains('class', case=False, na=False)]\n",
+ "if 'n_classes' in registry_df.columns:\n",
+ " registry_df = registry_df[(registry_df['n_classes'].fillna(0) >= 2) & (registry_df['n_classes'].fillna(0) <= 20)]\n",
+ "registry_df = registry_df.reset_index(drop=True)\n",
+ "print('Filtered registry rows:', len(registry_df))\n",
+ "\n"
+ ],
+ "id": "cmP4hat2E4dM"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4Qp4CSnzE4dM"
+ },
+ "source": [
+ "## 7. Random-10 dataset selection"
+ ],
+ "id": "4Qp4CSnzE4dM"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "hD5MMjf5E4dM",
+ "outputId": "19e51d34-70f2-431f-d3cc-49efb4b387a2",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 88
+ }
+ },
+ "execution_count": 64,
+ "outputs": [
{
- "cell_type": "markdown",
- "metadata": {
- "id": "dITJemlaE4dN"
- },
- "source": [
- "## 11. Aggregate summaries"
- ],
- "id": "dITJemlaE4dN"
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Sampled 0 datasets\n",
+ "Saved sampled registry to: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/registry/random10_registry.csv\n"
+ ]
},
{
- "cell_type": "code",
- "metadata": {
- "id": "R8NPHRR8E4dN",
- "outputId": "eaa88724-f44f-42a3-e134-623dafff2c9b",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "execution_count": 67,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "No completed dataset results to aggregate.\n"
- ]
- }
- ],
- "source": [
- "# Aggregate summary / plots\n",
- "all_df = pd.concat(all_results, ignore_index=True) if all_results else pd.DataFrame()\n",
- "\n",
- "all_csv = aggregate_dir / 'all_datasets_longrun_metrics.csv'\n",
- "all_feather = aggregate_dir / 'all_datasets_longrun_metrics.feather'\n",
- "best_csv = aggregate_dir / 'best_rows_per_dataset.csv'\n",
- "corr_csv = aggregate_dir / 'correlation_summary.csv'\n",
- "run_summary_json = aggregate_dir / 'run_summary.json'\n",
- "\n",
- "if not all_df.empty:\n",
- " all_df.to_csv(all_csv, index=False)\n",
- " try:\n",
- " all_df.reset_index(drop=True).to_feather(all_feather)\n",
- " except Exception as e:\n",
- " print('[WARN] Feather save failed for aggregate metrics:', e)\n",
- "\n",
- " best_df = all_df.loc[all_df.groupby('dataset_uid')['test_accuracy'].idxmax()].sort_values('test_accuracy', ascending=False)\n",
- " best_df.to_csv(best_csv, index=False)\n",
- "\n",
- " corr_rows = []\n",
- " for dataset_uid, g in all_df.groupby('dataset_uid'):\n",
- " row = {'dataset_uid': dataset_uid, 'dataset_slug': g['dataset_slug'].iloc[0]}\n",
- " for w in W_LIST:\n",
- " c = g['test_accuracy'].corr(g[f'alpha_{w}'])\n",
- " row[f'corr_test_accuracy_alpha_{w}'] = float(c) if pd.notnull(c) else np.nan\n",
- " corr_rows.append(row)\n",
- " corr_df = pd.DataFrame(corr_rows)\n",
- " corr_df.to_csv(corr_csv, index=False)\n",
- "\n",
- " rank_rows = []\n",
- " for w in W_LIST:\n",
- " cvals = corr_df[f'corr_test_accuracy_alpha_{w}'].dropna()\n",
- " rank_rows.append({'W': w, 'avg_abs_corr': float(np.abs(cvals).mean()) if len(cvals) else np.nan, 'share_lower_alpha_better': float((cvals < 0).mean()) if len(cvals) else np.nan})\n",
- " rank_df = pd.DataFrame(rank_rows).sort_values('avg_abs_corr', ascending=False)\n",
- "\n",
- " run_summary = {'project_root': str(project_root), 'n_sampled_datasets': int(len(sampled_registry)), 'n_completed_datasets': int(all_df['dataset_uid'].nunique()), 'n_rows_total': int(len(all_df)), 'backend': BACKEND_INFO, 'timestamp_utc': pd.Timestamp.utcnow().isoformat()}\n",
- " run_summary_json.write_text(json.dumps(run_summary, indent=2))\n",
- "\n",
- " if skipped_rows:\n",
- " pd.DataFrame(skipped_rows).to_csv(skipped_path, index=False)\n",
- "\n",
- " display(sampled_registry)\n",
- " display(pd.DataFrame(skipped_rows) if skipped_rows else pd.DataFrame(columns=['dataset_uid', 'dataset_slug', 'reason']))\n",
- " display(best_df)\n",
- " display(all_df.sort_values('test_accuracy', ascending=False).head(20))\n",
- " display(corr_df)\n",
- " display(rank_df)\n",
- "else:\n",
- " print('No completed dataset results to aggregate.')\n"
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Empty DataFrame\n",
+ "Columns: []\n",
+ "Index: []"
],
- "id": "R8NPHRR8E4dN"
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "-YHPQf9EE4dN"
- },
- "source": [
- "## 12. Saved artifacts on Google Drive"
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n"
],
- "id": "-YHPQf9EE4dN"
- },
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "dataframe",
+ "summary": "{\n \"name\": \"display(sampled_registry\",\n \"rows\": 0,\n \"fields\": []\n}"
+ }
+ },
+ "metadata": {}
+ }
+ ],
+ "source": [
+ "# Random sample of 10 datasets (saved immediately to Drive in Colab)\n",
+ "random10_csv = registry_dir / 'random10_registry.csv'\n",
+ "random10_feather = registry_dir / 'random10_registry.feather'\n",
+ "\n",
+ "if SELECTED_DATASET_UIDS:\n",
+ " sampled_registry = registry_df[registry_df['dataset_uid'].isin(SELECTED_DATASET_UIDS)].copy()\n",
+ " print(f'Using explicit SELECTED_DATASET_UIDS ({len(sampled_registry)} rows).')\n",
+ "elif REUSE_SAMPLED_REGISTRY and RESTART and RESUME_FROM_CHECKPOINT and random10_csv.exists() and not FORCE_RESTART_ALL:\n",
+ " try:\n",
+ " sampled_registry = pd.read_csv(random10_csv)\n",
+ " if sampled_registry.empty:\n",
+ " raise pd.errors.EmptyDataError('Cached sampled registry CSV has 0 rows')\n",
+ " print(f'Reusing sampled registry from checkpoint: {len(sampled_registry)} rows ({random10_csv})')\n",
+ " except (pd.errors.EmptyDataError, pd.errors.ParserError) as e:\n",
+ " print(f'[WARN] Invalid cached sampled registry ({e}); resampling datasets.')\n",
+ " random10_csv.unlink(missing_ok=True)\n",
+ " random10_feather.unlink(missing_ok=True)\n",
+ " n_take = min(DATASET_SAMPLE_SIZE, len(registry_df))\n",
+ " sampled_registry = registry_df.sample(n=n_take, random_state=RANDOM_STATE).copy()\n",
+ " print(f'Sampled {len(sampled_registry)} datasets')\n",
+ "else:\n",
+ " n_take = min(DATASET_SAMPLE_SIZE, len(registry_df))\n",
+ " sampled_registry = registry_df.sample(n=n_take, random_state=RANDOM_STATE).copy()\n",
+ " print(f'Sampled {len(sampled_registry)} datasets')\n",
+ "\n",
+ "sampled_registry = sampled_registry.reset_index(drop=True)\n",
+ "sampled_registry.to_csv(random10_csv, index=False)\n",
+ "try:\n",
+ " sampled_registry.reset_index(drop=True).to_feather(random10_feather)\n",
+ "except Exception as e:\n",
+ " print('[WARN] Feather save failed for sampled registry:', e)\n",
+ "print('Saved sampled registry to:', random10_csv)\n",
+ "\n",
+ "display(sampled_registry.head(20))\n"
+ ],
+ "id": "hD5MMjf5E4dM"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "P-Y2_0X6E4dN"
+ },
+ "source": [
+ "## 8. Helper functions"
+ ],
+ "id": "P-Y2_0X6E4dN"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "5bTodyRVE4dN"
+ },
+ "execution_count": 65,
+ "outputs": [],
+ "source": [
+ "# Helper functions\n",
+ "W_LIST = ['W1', 'W2', 'W7', 'W8', 'W9', 'W10']\n",
+ "\n",
+ "def make_safe_slug(text):\n",
+ " text = str(text)\n",
+ " out = ''.join(ch.lower() if ch.isalnum() else '_' for ch in text)\n",
+ " while '__' in out:\n",
+ " out = out.replace('__', '_')\n",
+ " return out.strip('_')[:120]\n",
+ "\n",
+ "def log_error(error_rows, dataset_uid, dataset_slug, stage, exc):\n",
+ " error_rows.append({'dataset_uid': dataset_uid, 'dataset_slug': dataset_slug, 'stage': stage, 'error_type': type(exc).__name__, 'error_message': str(exc), 'traceback': traceback.format_exc(), 'timestamp': pd.Timestamp.utcnow().isoformat()})\n",
+ "\n",
+ "def save_error_log(error_rows, errors_path):\n",
+ " if error_rows:\n",
+ " pd.DataFrame(error_rows).to_csv(errors_path, index=False)\n",
+ "\n",
+ "def make_dataset_paths(dataset_slug):\n",
+ " base = per_dataset_dir / dataset_slug\n",
+ " out = {'base': base, 'results': base / 'results', 'models': base / 'models', 'plots': base / 'plots'}\n",
+ " for p in out.values():\n",
+ " p.mkdir(parents=True, exist_ok=True)\n",
+ " return out\n",
+ "\n",
+ "def detect_task_type(y):\n",
+ " y_np = np.asarray(y)\n",
+ " n_classes = int(len(np.unique(y_np)))\n",
+ " if n_classes < 2:\n",
+ " return 'invalid', n_classes, y_np\n",
+ " return ('binary' if n_classes == 2 else 'multiclass'), n_classes, y_np\n",
+ "\n",
+ "def build_params_for_dataset(base_backend_patch, n_classes):\n",
+ " params = {'max_depth': 4, 'eta': 0.05, 'subsample': 0.9, 'colsample_bytree': 0.8, 'min_child_weight': 5, 'reg_lambda': 5.0, 'reg_alpha': 0.5, 'gamma': 1.0, 'seed': RANDOM_STATE, **base_backend_patch}\n",
+ " if n_classes == 2:\n",
+ " params.update({'objective': 'binary:logistic', 'eval_metric': 'logloss'})\n",
+ " else:\n",
+ " params.update({'objective': 'multi:softprob', 'eval_metric': 'mlogloss', 'num_class': int(n_classes)})\n",
+ " return params\n",
+ "\n",
+ "def prepare_train_test_data(X, y):\n",
+ " stratify = y if len(np.unique(y)) > 1 else None\n",
+ " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=RANDOM_STATE, stratify=stratify)\n",
+ " scaler = StandardScaler(with_mean=False) if sparse.issparse(X_train) else StandardScaler()\n",
+ " X_train_scaled = scaler.fit_transform(X_train).astype(np.float32)\n",
+ " X_test_scaled = scaler.transform(X_test).astype(np.float32)\n",
+ " y_train_np = np.asarray(y_train).astype(np.int32).reshape(-1)\n",
+ " y_test_np = np.asarray(y_test).astype(np.int32).reshape(-1)\n",
+ " return X_train_scaled, X_test_scaled, y_train_np, y_test_np\n",
+ "\n",
+ "def densify_for_convert_if_safe(X_train_scaled, max_dense_elements):\n",
+ " if sparse.issparse(X_train_scaled):\n",
+ " n_elements = int(X_train_scaled.shape[0]) * int(X_train_scaled.shape[1])\n",
+ " if n_elements > max_dense_elements:\n",
+ " raise MemoryError(f'Refusing to densify convert input: shape={X_train_scaled.shape}, elements={n_elements:,}')\n",
+ " return X_train_scaled.toarray().astype(np.float32)\n",
+ " return np.asarray(X_train_scaled).astype(np.float32)\n",
+ "\n",
+ "def build_layer_for_W(bst, W_name, current_round, X_train_for_convert, y_train_np, params):\n",
+ " multiclass_mode = 'softprob' if int(params.get('num_class', 2)) > 2 else 'error'\n",
+ " return convert(model=bst, data=X_train_for_convert, labels=y_train_np, W=W_name, return_type='torch', nfolds=5, t_points=min(current_round, 160), random_state=RANDOM_STATE, train_params=params, num_boost_round=current_round, multiclass=multiclass_mode)\n",
+ "\n",
+ "def compute_alpha_from_layer(layer):\n",
+ " watcher = ww.WeightWatcher(model=layer)\n",
+ " df = watcher.analyze(randomize=True, detX=True)\n",
+ " return float(df['alpha'].iloc[0])\n",
+ "\n",
+ "def save_dataset_checkpoint(rows, metrics_csv_path, metrics_feather_path, bst, current_round, models_dir, metadata, meta_json_path):\n",
+ " df = pd.DataFrame(rows)\n",
+ " df.to_csv(metrics_csv_path, index=False)\n",
+ " try:\n",
+ " df.reset_index(drop=True).to_feather(metrics_feather_path)\n",
+ " except Exception as e:\n",
+ " print('[WARN] Feather save failed:', e)\n",
+ "\n",
+ " round_model_path = models_dir / f'model_round_{current_round:04d}.json'\n",
+ " latest_model_path = models_dir / 'model_latest.json'\n",
+ " bst.save_model(round_model_path)\n",
+ " bst.save_model(latest_model_path)\n",
+ "\n",
+ " metadata = dict(metadata)\n",
+ " metadata['latest_completed_round'] = int(current_round)\n",
+ " metadata['updated_at'] = pd.Timestamp.utcnow().isoformat()\n",
+ " meta_json_path.write_text(json.dumps(metadata, indent=2))\n",
+ "\n",
+ " print(f\"[CHECKPOINT] round={current_round} saved -> {metrics_csv_path}\")\n",
+ " print(f\"[CHECKPOINT] model_latest -> {latest_model_path}\")\n",
+ "\n",
+ "def load_dataset_checkpoint(metrics_csv_path, latest_model_path, chunk_size, n_steps):\n",
+ " if not (metrics_csv_path.exists() and latest_model_path.exists()):\n",
+ " return None, 1, []\n",
+ " try:\n",
+ " prior_df = pd.read_csv(metrics_csv_path)\n",
+ " except (pd.errors.EmptyDataError, pd.errors.ParserError) as e:\n",
+ " print(f'[WARN] Invalid dataset checkpoint {metrics_csv_path} ({e}); starting this dataset from round 1.')\n",
+ " metrics_csv_path.unlink(missing_ok=True)\n",
+ " return None, 1, []\n",
+ " if prior_df.empty:\n",
+ " metrics_csv_path.unlink(missing_ok=True)\n",
+ " return None, 1, []\n",
+ " max_round = int(prior_df['boosting_round'].max())\n",
+ " if max_round % chunk_size != 0:\n",
+ " raise ValueError(f'Checkpoint round {max_round} incompatible with CHUNK_SIZE={chunk_size}')\n",
+ " completed_steps = max_round // chunk_size\n",
+ " if completed_steps > n_steps:\n",
+ " raise ValueError(f'Checkpoint completed_steps={completed_steps} > N_STEPS={n_steps}')\n",
+ " bst = xgb.Booster()\n",
+ " bst.load_model(latest_model_path)\n",
+ " return bst, completed_steps + 1, prior_df.to_dict('records')\n",
+ "\n",
+ "def plot_dataset_dynamics(df_dataset, dataset_slug, plots_dir):\n",
+ " if df_dataset.empty:\n",
+ " return None, None\n",
+ " long_alpha = df_dataset.melt(id_vars=['boosting_round', 'test_accuracy'], value_vars=[f'alpha_{w}' for w in W_LIST], var_name='alpha_name', value_name='alpha_value')\n",
+ " fig, axes = plt.subplots(1, 3, figsize=(20, 5))\n",
+ " axes[0].plot(df_dataset['boosting_round'], df_dataset['test_accuracy'], marker='o')\n",
+ " axes[0].set_title(f'{dataset_slug}: test accuracy vs round')\n",
+ " sns.lineplot(data=long_alpha, x='boosting_round', y='alpha_value', hue='alpha_name', ax=axes[1])\n",
+ " axes[1].set_title('alpha vs round')\n",
+ " sns.scatterplot(data=long_alpha, x='alpha_value', y='test_accuracy', hue='alpha_name', ax=axes[2])\n",
+ " axes[2].set_title('accuracy vs alpha')\n",
+ " fig.tight_layout()\n",
+ " p1 = plots_dir / f'{dataset_slug}_dynamics.png'\n",
+ " fig.savefig(p1, dpi=140)\n",
+ " plt.close(fig)\n",
+ " fig2, ax2 = plt.subplots(figsize=(7, 5))\n",
+ " for w in W_LIST:\n",
+ " col = f'alpha_{w}'\n",
+ " ax2.scatter(df_dataset[col], df_dataset['test_accuracy'], alpha=0.6, label=w)\n",
+ " ax2.set_title(f'{dataset_slug}: accuracy vs alpha (all W)')\n",
+ " ax2.legend(ncol=2)\n",
+ " fig2.tight_layout()\n",
+ " p2 = plots_dir / f'{dataset_slug}_accuracy_vs_alpha.png'\n",
+ " fig2.savefig(p2, dpi=140)\n",
+ " plt.close(fig2)\n",
+ " return p1, p2\n"
+ ],
+ "id": "5bTodyRVE4dN"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vzDVS76QE4dN"
+ },
+ "source": [
+ "## 9. Per-dataset long-run training"
+ ],
+ "id": "vzDVS76QE4dN"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "bzbd6Jk7E4dN",
+ "outputId": "16bf26a9-68dd-40dd-b70c-60317547a065",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "execution_count": 66,
+ "outputs": [
{
- "cell_type": "code",
- "metadata": {
- "id": "ZwuPipGHE4dN",
- "outputId": "3c913ec7-6a0b-431a-939b-946981c65173",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "execution_count": 68,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Sampled registry CSV: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/registry/random10_registry.csv\n",
- "Sampled registry Feather: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/registry/random10_registry.feather\n",
- "Error log: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/logs/errors.csv\n",
- "Aggregate metrics: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/aggregate/all_datasets_longrun_metrics.csv\n",
- "Best rows summary: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/aggregate/best_rows_per_dataset.csv\n",
- "Per-dataset folders: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/per_dataset\n"
- ]
- }
- ],
- "source": [
- "# Drive artifact summary\n",
- "print('Sampled registry CSV:', registry_dir / 'random10_registry.csv')\n",
- "print('Sampled registry Feather:', registry_dir / 'random10_registry.feather')\n",
- "print('Error log:', errors_path)\n",
- "print('Aggregate metrics:', aggregate_dir / 'all_datasets_longrun_metrics.csv')\n",
- "print('Best rows summary:', aggregate_dir / 'best_rows_per_dataset.csv')\n",
- "print('Per-dataset folders:', per_dataset_dir)\n"
- ],
- "id": "ZwuPipGHE4dN"
- },
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Done processing sampled datasets.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Per-dataset training loop\n",
+ "all_results, skipped_rows, error_rows = [], [], []\n",
+ "\n",
+ "for i, row in sampled_registry.iterrows():\n",
+ " dataset_uid = row.get('dataset_uid', f'row_{i}')\n",
+ " source = row.get('source', 'unknown')\n",
+ " dataset_slug = make_safe_slug(f\"{dataset_uid}_{source}\")\n",
+ " print(f\"\\n[{i+1}/{len(sampled_registry)}] dataset_uid={dataset_uid} source={source}\")\n",
+ " force_restart_this = FORCE_RESTART_ALL or (dataset_slug in FORCE_RESTART_DATASETS) or (dataset_uid in FORCE_RESTART_DATASETS)\n",
+ "\n",
+ " paths = make_dataset_paths(dataset_slug)\n",
+ " metrics_csv = paths['results'] / f'{dataset_slug}_metrics.csv'\n",
+ " metrics_feather = paths['results'] / f'{dataset_slug}_metrics.feather'\n",
+ " latest_model_path = paths['models'] / 'model_latest.json'\n",
+ " dataset_meta_json = paths['results'] / '_meta.json'\n",
+ "\n",
+ " try:\n",
+ " X, y, meta = load_dataset(dataset_uid, filters=filters)\n",
+ " except Exception as e:\n",
+ " log_error(error_rows, dataset_uid, dataset_slug, 'dataset_load', e)\n",
+ " skipped_rows.append({'dataset_uid': dataset_uid, 'dataset_slug': dataset_slug, 'reason': 'dataset_load_failed'})\n",
+ " save_error_log(error_rows, errors_path)\n",
+ " continue\n",
+ "\n",
+ " try:\n",
+ " task_type, n_classes, y_np = detect_task_type(y)\n",
+ " if task_type == 'invalid' or n_classes > 20:\n",
+ " raise ValueError(f'Invalid class count: n_classes={n_classes}')\n",
+ " X_train_scaled, X_test_scaled, y_train_np, y_test_np = prepare_train_test_data(X, y_np)\n",
+ " X_train_for_convert = densify_for_convert_if_safe(X_train_scaled, MAX_DENSE_ELEMENTS)\n",
+ " except Exception as e:\n",
+ " log_error(error_rows, dataset_uid, dataset_slug, 'split_preprocess', e)\n",
+ " skipped_rows.append({'dataset_uid': dataset_uid, 'dataset_slug': dataset_slug, 'reason': 'split_preprocess_failed'})\n",
+ " save_error_log(error_rows, errors_path)\n",
+ " continue\n",
+ "\n",
+ " dtrain, dtest = xgb.DMatrix(X_train_scaled, label=y_train_np), xgb.DMatrix(X_test_scaled, label=y_test_np)\n",
+ " params = build_params_for_dataset(BACKEND_INFO['params_patch'], n_classes)\n",
+ "\n",
+ " base_meta = {\n",
+ " 'dataset_uid': str(dataset_uid),\n",
+ " 'dataset_slug': dataset_slug,\n",
+ " 'source': str(source),\n",
+ " 'shapes': {\n",
+ " 'X_train': [int(X_train_scaled.shape[0]), int(X_train_scaled.shape[1])],\n",
+ " 'X_test': [int(X_test_scaled.shape[0]), int(X_test_scaled.shape[1])],\n",
+ " },\n",
+ " 'class_count': int(n_classes),\n",
+ " 'selected_params': params,\n",
+ " 'backend_info': BACKEND_INFO,\n",
+ " 'start_time_utc': pd.Timestamp.utcnow().isoformat(),\n",
+ " 'latest_completed_round': 0,\n",
+ " }\n",
+ " dataset_meta_json.write_text(json.dumps(base_meta, indent=2))\n",
+ "\n",
+ " bst, rows, start_step = None, [], 1\n",
+ " if (not force_restart_this) and RESUME_FROM_CHECKPOINT and RESTART:\n",
+ " try:\n",
+ " bst, start_step, rows = load_dataset_checkpoint(metrics_csv, latest_model_path, CHUNK_SIZE, N_STEPS)\n",
+ " if bst is not None:\n",
+ " print(f'[RESUME] step={start_step}/{N_STEPS} from {metrics_csv}')\n",
+ " except Exception as e:\n",
+ " log_error(error_rows, dataset_uid, dataset_slug, 'checkpoint_load', e)\n",
+ " bst, rows, start_step = None, [], 1\n",
+ " elif not RESTART:\n",
+ " print('[RESTART=False] Ignoring checkpoints and overwriting saved outputs for this dataset.')\n",
+ "\n",
+ " dataset_t0 = time.time()\n",
+ " dataset_failed = False\n",
+ " for step in range(start_step, N_STEPS + 1):\n",
+ " current_round = step * CHUNK_SIZE\n",
+ " step_t0 = time.time()\n",
+ " try:\n",
+ " bst = xgb.train(params=params, dtrain=dtrain, num_boost_round=CHUNK_SIZE, xgb_model=bst, verbose_eval=False)\n",
+ " y_prob = bst.predict(dtest)\n",
+ " y_pred = (y_prob >= 0.5).astype(np.int32) if n_classes == 2 else np.argmax(y_prob, axis=1).astype(np.int32)\n",
+ " test_accuracy = float(accuracy_score(y_test_np, y_pred))\n",
+ " except Exception as e:\n",
+ " log_error(error_rows, dataset_uid, dataset_slug, 'train_or_predict', e)\n",
+ " dataset_failed = True\n",
+ " break\n",
+ "\n",
+ " alpha_values, alpha_failures = {}, 0\n",
+ " for w_name in W_LIST:\n",
+ " try:\n",
+ " layer = build_layer_for_W(bst, w_name, current_round, X_train_for_convert, y_train_np, params)\n",
+ " alpha_values[f'alpha_{w_name}'] = compute_alpha_from_layer(layer)\n",
+ " except Exception as e:\n",
+ " alpha_values[f'alpha_{w_name}'] = np.nan\n",
+ " alpha_failures += 1\n",
+ " log_error(error_rows, dataset_uid, dataset_slug, f'alpha_{w_name}', e)\n",
+ "\n",
+ " if alpha_failures == len(W_LIST):\n",
+ " dataset_failed = True\n",
+ " log_error(error_rows, dataset_uid, dataset_slug, 'alpha_all_failed', RuntimeError('All W matrices failed this round'))\n",
+ " break\n",
+ "\n",
+ " rows.append({'dataset_uid': dataset_uid, 'dataset_slug': dataset_slug, 'source': source, 'n_samples': int(np.asarray(y_np).shape[0]), 'n_features': int(X_train_scaled.shape[1]), 'n_classes': int(n_classes), 'boosting_round': int(current_round), 'test_accuracy': test_accuracy, **alpha_values, 'backend': BACKEND_INFO.get('backend_name'), 'tree_method': params.get('tree_method'), 'device': params.get('device', 'cpu'), 'elapsed_seconds_round': float(time.time() - step_t0), 'elapsed_seconds_total': float(time.time() - dataset_t0)})\n",
+ "\n",
+ " if step % CHECKPOINT_EVERY_STEPS == 0:\n",
+ " try:\n",
+ " save_dataset_checkpoint(rows, metrics_csv, metrics_feather, bst, current_round, paths['models'], base_meta, dataset_meta_json)\n",
+ " except Exception as e:\n",
+ " log_error(error_rows, dataset_uid, dataset_slug, 'checkpoint_save', e)\n",
+ "\n",
+ " print(f\" step={step:03d}/{N_STEPS} round={current_round:04d} acc={test_accuracy:.4f}\")\n",
+ "\n",
+ " df_dataset = pd.DataFrame(rows)\n",
+ " if not df_dataset.empty:\n",
+ " all_results.append(df_dataset)\n",
+ " try:\n",
+ " plot_dataset_dynamics(df_dataset, dataset_slug, paths['plots'])\n",
+ " except Exception as e:\n",
+ " log_error(error_rows, dataset_uid, dataset_slug, 'plotting', e)\n",
+ " if dataset_failed and df_dataset.empty:\n",
+ " skipped_rows.append({'dataset_uid': dataset_uid, 'dataset_slug': dataset_slug, 'reason': 'dataset_failed'})\n",
+ "\n",
+ " save_error_log(error_rows, errors_path)\n",
+ "\n",
+ "print('Done processing sampled datasets.')\n"
+ ],
+ "id": "bzbd6Jk7E4dN"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ZVykdmiqE4dN"
+ },
+ "source": [
+ "## 10. Per-dataset plots\n",
+ "\n",
+ "Per-dataset alpha/accuracy plots are saved under each dataset folder in `per_dataset//plots/` on Drive when running in Colab."
+ ],
+ "id": "ZVykdmiqE4dN"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "dITJemlaE4dN"
+ },
+ "source": [
+ "## 11. Aggregate summaries"
+ ],
+ "id": "dITJemlaE4dN"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "R8NPHRR8E4dN",
+ "outputId": "eaa88724-f44f-42a3-e134-623dafff2c9b",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "execution_count": 67,
+ "outputs": [
{
- "cell_type": "code",
- "metadata": {
- "id": "loQQLDXXE4dN",
- "outputId": "e435bd4c-cf9e-4f4c-a35e-e4321be0cda0",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "execution_count": 69,
- "outputs": [
- {
- "output_type": "stream",
- "name": "stdout",
- "text": [
- "Saved zip: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/aggregate/random10_longrun_alpha_tracking_outputs.zip\n"
- ]
- }
- ],
- "source": [
- "# Optional zip export for Colab users (aggregate + logs)\n",
- "import shutil\n",
- "import tempfile\n",
- "\n",
- "zip_output = aggregate_dir / 'random10_longrun_alpha_tracking_outputs.zip'\n",
- "with tempfile.TemporaryDirectory() as td:\n",
- " td_path = Path(td)\n",
- " shutil.copytree(aggregate_dir, td_path / 'aggregate', dirs_exist_ok=True)\n",
- " shutil.copytree(logs_dir, td_path / 'logs', dirs_exist_ok=True)\n",
- " built = shutil.make_archive(str(zip_output).replace('.zip', ''), 'zip', root_dir=td_path)\n",
- "print('Saved zip:', built)\n"
- ],
- "id": "loQQLDXXE4dN"
- },
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "No completed dataset results to aggregate.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Aggregate summary / plots\n",
+ "all_df = pd.concat(all_results, ignore_index=True) if all_results else pd.DataFrame()\n",
+ "\n",
+ "all_csv = aggregate_dir / 'all_datasets_longrun_metrics.csv'\n",
+ "all_feather = aggregate_dir / 'all_datasets_longrun_metrics.feather'\n",
+ "best_csv = aggregate_dir / 'best_rows_per_dataset.csv'\n",
+ "corr_csv = aggregate_dir / 'correlation_summary.csv'\n",
+ "run_summary_json = aggregate_dir / 'run_summary.json'\n",
+ "\n",
+ "if not all_df.empty:\n",
+ " all_df.to_csv(all_csv, index=False)\n",
+ " try:\n",
+ " all_df.reset_index(drop=True).to_feather(all_feather)\n",
+ " except Exception as e:\n",
+ " print('[WARN] Feather save failed for aggregate metrics:', e)\n",
+ "\n",
+ " best_df = all_df.loc[all_df.groupby('dataset_uid')['test_accuracy'].idxmax()].sort_values('test_accuracy', ascending=False)\n",
+ " best_df.to_csv(best_csv, index=False)\n",
+ "\n",
+ " corr_rows = []\n",
+ " for dataset_uid, g in all_df.groupby('dataset_uid'):\n",
+ " row = {'dataset_uid': dataset_uid, 'dataset_slug': g['dataset_slug'].iloc[0]}\n",
+ " for w in W_LIST:\n",
+ " c = g['test_accuracy'].corr(g[f'alpha_{w}'])\n",
+ " row[f'corr_test_accuracy_alpha_{w}'] = float(c) if pd.notnull(c) else np.nan\n",
+ " corr_rows.append(row)\n",
+ " corr_df = pd.DataFrame(corr_rows)\n",
+ " corr_df.to_csv(corr_csv, index=False)\n",
+ "\n",
+ " rank_rows = []\n",
+ " for w in W_LIST:\n",
+ " cvals = corr_df[f'corr_test_accuracy_alpha_{w}'].dropna()\n",
+ " rank_rows.append({'W': w, 'avg_abs_corr': float(np.abs(cvals).mean()) if len(cvals) else np.nan, 'share_lower_alpha_better': float((cvals < 0).mean()) if len(cvals) else np.nan})\n",
+ " rank_df = pd.DataFrame(rank_rows).sort_values('avg_abs_corr', ascending=False)\n",
+ "\n",
+ " run_summary = {'project_root': str(project_root), 'n_sampled_datasets': int(len(sampled_registry)), 'n_completed_datasets': int(all_df['dataset_uid'].nunique()), 'n_rows_total': int(len(all_df)), 'backend': BACKEND_INFO, 'timestamp_utc': pd.Timestamp.utcnow().isoformat()}\n",
+ " run_summary_json.write_text(json.dumps(run_summary, indent=2))\n",
+ "\n",
+ " if skipped_rows:\n",
+ " pd.DataFrame(skipped_rows).to_csv(skipped_path, index=False)\n",
+ "\n",
+ " display(sampled_registry)\n",
+ " display(pd.DataFrame(skipped_rows) if skipped_rows else pd.DataFrame(columns=['dataset_uid', 'dataset_slug', 'reason']))\n",
+ " display(best_df)\n",
+ " display(all_df.sort_values('test_accuracy', ascending=False).head(20))\n",
+ " display(corr_df)\n",
+ " display(rank_df)\n",
+ "else:\n",
+ " print('No completed dataset results to aggregate.')\n"
+ ],
+ "id": "R8NPHRR8E4dN"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-YHPQf9EE4dN"
+ },
+ "source": [
+ "## 12. Saved artifacts on Google Drive"
+ ],
+ "id": "-YHPQf9EE4dN"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "ZwuPipGHE4dN",
+ "outputId": "3c913ec7-6a0b-431a-939b-946981c65173",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
+ },
+ "execution_count": 68,
+ "outputs": [
{
- "cell_type": "code",
- "source": [],
- "metadata": {
- "id": "1bsLkS-2Jt37"
- },
- "id": "1bsLkS-2Jt37",
- "execution_count": 69,
- "outputs": []
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Sampled registry CSV: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/registry/random10_registry.csv\n",
+ "Sampled registry Feather: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/registry/random10_registry.feather\n",
+ "Error log: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/logs/errors.csv\n",
+ "Aggregate metrics: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/aggregate/all_datasets_longrun_metrics.csv\n",
+ "Best rows summary: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/aggregate/best_rows_per_dataset.csv\n",
+ "Per-dataset folders: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/per_dataset\n"
+ ]
}
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "name": "python",
- "version": "3.x"
- },
+ ],
+ "source": [
+ "# Drive artifact summary\n",
+ "print('Sampled registry CSV:', registry_dir / 'random10_registry.csv')\n",
+ "print('Sampled registry Feather:', registry_dir / 'random10_registry.feather')\n",
+ "print('Error log:', errors_path)\n",
+ "print('Aggregate metrics:', aggregate_dir / 'all_datasets_longrun_metrics.csv')\n",
+ "print('Best rows summary:', aggregate_dir / 'best_rows_per_dataset.csv')\n",
+ "print('Per-dataset folders:', per_dataset_dir)\n"
+ ],
+ "id": "ZwuPipGHE4dN"
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "loQQLDXXE4dN",
+ "outputId": "e435bd4c-cf9e-4f4c-a35e-e4321be0cda0",
"colab": {
- "provenance": []
+ "base_uri": "https://localhost:8080/"
}
+ },
+ "execution_count": 69,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Saved zip: /content/drive/MyDrive/xgboost2ww_runs/random10_longrun_alpha_tracking/aggregate/random10_longrun_alpha_tracking_outputs.zip\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Optional zip export for Colab users (aggregate + logs)\n",
+ "import shutil\n",
+ "import tempfile\n",
+ "\n",
+ "zip_output = aggregate_dir / 'random10_longrun_alpha_tracking_outputs.zip'\n",
+ "with tempfile.TemporaryDirectory() as td:\n",
+ " td_path = Path(td)\n",
+ " shutil.copytree(aggregate_dir, td_path / 'aggregate', dirs_exist_ok=True)\n",
+ " shutil.copytree(logs_dir, td_path / 'logs', dirs_exist_ok=True)\n",
+ " built = shutil.make_archive(str(zip_output).replace('.zip', ''), 'zip', root_dir=td_path)\n",
+ "print('Saved zip:', built)\n"
+ ],
+ "id": "loQQLDXXE4dN"
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "1bsLkS-2Jt37"
+ },
+ "id": "1bsLkS-2Jt37",
+ "execution_count": 69,
+ "outputs": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.x"
},
- "nbformat": 4,
- "nbformat_minor": 5
-}
\ No newline at end of file
+ "colab": {
+ "provenance": []
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}