forked from openai/weak-to-strong
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Jeff Wu and Adrien Ecoffet and Manas Joglekar and Jan Hendrik Kirchner and Pavel Izmailov
authored and
WuTheFWasThat
committed
Dec 14, 2023
0 parents
commit 1dbfbd6
Showing
24 changed files
with
1,723 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
dump | ||
*.pyc | ||
*.swp | ||
*.swo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Copyright 2023 OpenAI | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
**STATUS**: This codebase is not well tested and does not use the exact same settings we used in the paper, but in our experience gives qualitatively similar results when using large model size gaps and multiple seeds. Expected results can be found for two datasets below. We may update the code significantly in the coming week. | ||
|
||
# Weak-to-strong generalization | ||
|
||
 | ||
|
||
This project contains code for implementing our [paper on weak-to-strong generalization](https://cdn.openai.com/papers/weak-to-strong-generalization.pdf). | ||
|
||
The primary codebase contains a re-implementation of our weak-to-strong learning setup for binary classification tasks. The codebase contains code for fine-tuning pretrained language models, and also training against the labels from another language model. We support various losses described in the paper as well, such as the confidence auxiliary loss. | ||
|
||
The `vision` directory contains stand-alone code for weak-to-strong in the vision models setting (AlexNet -> DINO on ImageNet). | ||
|
||
### Getting Started | ||
|
||
These instructions will get you a copy of the project up and running on your local machine for development and testing purposes. | ||
|
||
#### Installation | ||
|
||
You need to have Python installed on your machine. The project also has some dependencies, which can be installed with pip: | ||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
#### Running the Script | ||
|
||
The main script of the project is train_weak_to_strong.py. It can be run from the command line using the following command: | ||
``` | ||
python train_weak_to_strong.py | ||
``` | ||
|
||
The script accepts several command-line arguments to customize the training process. Here are some examples: | ||
|
||
``` | ||
python train_weak_to_strong.py --batch_size 32 --max_ctx 512 --ds_name "sciq" --loss "logconf" --n_docs 1000 --n_test_docs 100 --weak_model_size "gpt2-medium" --strong_model_size "gpt2-large" --seed 42 | ||
``` | ||
|
||
#### Expected results | ||
|
||
<img src="notebooks/amazon_polarity_None.png" width="350"> | ||
<br> | ||
<img src="notebooks/sciq_None.png" width="350"> | ||
<br> | ||
<img src="notebooks/Anthropic-hh-rlhf_None.png" width="350"> | ||
|
||
### Authors | ||
|
||
- Adrien Ecoffet | ||
- Manas Joglekar | ||
- Jeffrey Wu | ||
- Jan Hendrik Kirchner | ||
- Pavel Izmailov (vision) | ||
|
||
### License | ||
|
||
This project is licensed under the MIT License - see the LICENSE.md file for details. | ||
|
||
### Acknowledgments | ||
|
||
- Hugging Face for their open-source transformer models |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "eb9a4b5a", | ||
"metadata": {}, | ||
"source": [ | ||
"# Simple Plotting\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "88c7ff9f", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"RESULTS_PATH = \"../../your_sweep_results_path\"\n", | ||
"\n", | ||
"PLOT_ALL_SEEDS = False\n", | ||
"# Full sweep\n", | ||
"MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\", \"gpt2-xl\", \"Qwen/Qwen-1_8B\", \"Qwen/Qwen-7B\", \"Qwen/Qwen-14B\"]\n", | ||
"# Minimal sweep\n", | ||
"# MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\"]\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "00ca073c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import seaborn as sns\n", | ||
"sns.set_style('whitegrid')\n", | ||
"\n", | ||
"from IPython.display import display\n", | ||
"\n", | ||
"import os\n", | ||
"import glob\n", | ||
"import json" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "e5caa051", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"records = []\n", | ||
"all_results_folders = ['/'.join(e.split('/')[:-1]) for e in glob.glob(os.path.join(RESULTS_PATH, \"**/*.results_summary.json\"), recursive=True)]\n", | ||
"for result_folder in set(all_results_folders):\n", | ||
" config_file = os.path.join(result_folder, \"config.json\")\n", | ||
" config = json.load(open(config_file, \"r\"))\n", | ||
" if config[\"strong_model_size\"] not in MODELS_TO_PLOT:\n", | ||
" continue\n", | ||
" if 'seed' not in config:\n", | ||
" config['seed'] = 0\n", | ||
" result_filename = (config[\"weak_model_size\"].replace('.', '_') + \"_\" + config[\"strong_model_size\"].replace('.', '_') + \".results_summary.json\").replace('/', '_')\n", | ||
" record = config.copy()\n", | ||
" record.update(json.load(open(config_file.replace('config.json', result_filename))))\n", | ||
" records.append(record)\n", | ||
"\n", | ||
"df = pd.DataFrame.from_records(records).sort_values(['ds_name', 'weak_model_size', 'strong_model_size'])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2f628577", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"datasets = df.ds_name.unique()\n", | ||
"for dataset in datasets:\n", | ||
" cur_df = df[(df.ds_name == dataset)]\n", | ||
" base_df = pd.concat([\n", | ||
" pd.DataFrame.from_dict({\"strong_model_size\": cur_df['weak_model_size'].to_list(), \"accuracy\": cur_df['weak_acc'].to_list(), \"seed\": cur_df['seed'].to_list()}),\n", | ||
" pd.DataFrame.from_dict({\"strong_model_size\": cur_df['strong_model_size'].to_list(), \"accuracy\": cur_df['strong_acc'].to_list(), \"seed\": cur_df['seed'].to_list()})\n", | ||
" ])\n", | ||
" base_accuracies = base_df.groupby('strong_model_size').agg({'accuracy': 'mean', 'seed': 'count'}).sort_values('accuracy')\n", | ||
" base_accuracy_lookup = base_accuracies['accuracy'].to_dict()\n", | ||
" base_accuracies = base_accuracies.reset_index()\n", | ||
" base_df.reset_index(inplace=True)\n", | ||
" base_df['weak_model_size'] = 'ground truth'\n", | ||
" base_df['loss'] = 'xent'\n", | ||
" base_df['strong_model_accuracy'] = base_df['strong_model_size'].apply(lambda x: base_accuracy_lookup[x])\n", | ||
"\n", | ||
" weak_to_strong = cur_df[['weak_model_size', 'strong_model_size', 'seed'] + [e for e in cur_df.columns if e.startswith('transfer_acc')]]\n", | ||
" weak_to_strong = weak_to_strong.melt(id_vars=['weak_model_size', 'strong_model_size', 'seed'], var_name='loss', value_name='accuracy')\n", | ||
" weak_to_strong = weak_to_strong.dropna(subset=['accuracy'])\n", | ||
" weak_to_strong.reset_index(inplace=True)\n", | ||
" weak_to_strong['loss'] = weak_to_strong['loss'].str.replace('transfer_acc_', '')\n", | ||
" weak_to_strong['strong_model_accuracy'] = weak_to_strong['strong_model_size'].apply(lambda x: base_accuracy_lookup[x])\n", | ||
"\n", | ||
" # Exclude cases where the weak model is better than the strong model from PGR calculation.\n", | ||
" pgr_df = cur_df[(cur_df['weak_model_size'] != cur_df['strong_model_size']) & (cur_df['strong_acc'] > cur_df['weak_acc'])]\n", | ||
" pgr_df = pgr_df.melt(id_vars=[e for e in cur_df.columns if not e.startswith('transfer_acc')], var_name='loss', value_name='transfer_acc')\n", | ||
" pgr_df = pgr_df.dropna(subset=['transfer_acc'])\n", | ||
" pgr_df['loss'] = pgr_df['loss'].str.replace('transfer_acc_', '')\n", | ||
" pgr_df['pgr'] = (pgr_df['transfer_acc'] - pgr_df['weak_acc']) / (pgr_df['strong_acc'] - pgr_df['weak_acc'])\n", | ||
"\n", | ||
" for seed in [None] + (sorted(cur_df['seed'].unique().tolist()) if PLOT_ALL_SEEDS else []):\n", | ||
" plot_df = pd.concat([base_df, weak_to_strong])\n", | ||
" seed_pgr_df = pgr_df\n", | ||
" if seed is not None:\n", | ||
" plot_df = plot_df[plot_df['seed'] == seed]\n", | ||
" # We mean across seeds, this is because sometimes the weak and strong models will have run on different hardware and therefore\n", | ||
" # have slight differences. We want to average these out when filtering by seed.\n", | ||
"\n", | ||
" seed_pgr_df = pgr_df[pgr_df['seed'] == seed]\n", | ||
"\n", | ||
" if seed is not None or cur_df['seed'].nunique() == 1:\n", | ||
" plot_df = plot_df[['strong_model_accuracy', 'weak_model_size', 'loss', 'accuracy']].groupby(['strong_model_accuracy', 'weak_model_size', 'loss']).mean().reset_index().sort_values(['loss', 'weak_model_size'], ascending=False)\n", | ||
"\n", | ||
" print(f\"Dataset: {dataset} (seed: {seed})\")\n", | ||
"\n", | ||
" pgr_results = seed_pgr_df.groupby(['loss']).aggregate({\"pgr\": \"median\"})\n", | ||
" display(pgr_results)\n", | ||
"\n", | ||
" palette = sns.color_palette('colorblind', n_colors=len(plot_df['weak_model_size'].unique()) - 1)\n", | ||
" color_dict = {model: (\"black\" if model == 'ground truth' else palette.pop()) for model in plot_df['weak_model_size'].unique()}\n", | ||
"\n", | ||
" sns.lineplot(data=plot_df, x='strong_model_accuracy', y='accuracy', hue='weak_model_size', style='loss', markers=True, palette=color_dict)\n", | ||
" pd.plotting.table(plt.gca(), pgr_results.round(4), loc='lower right', colWidths=[0.1, 0.1], cellLoc='center', rowLoc='center')\n", | ||
" plt.xticks(ticks=base_accuracies['accuracy'], labels=[f\"{e} ({base_accuracy_lookup[e]:.4f})\" for e in base_accuracies['strong_model_size']], rotation=90)\n", | ||
" plt.title(f\"Dataset: {dataset} (seed: {seed})\")\n", | ||
" plt.legend(loc='upper left')\n", | ||
" plt.savefig(f\"{dataset.replace('/', '-')}_{seed}.png\", dpi=300, bbox_inches='tight')\n", | ||
" plt.show()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "openai", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[build-system] | ||
requires = ["hatchling"] | ||
build-backend = "hatchling.build" | ||
|
||
[project] | ||
name = "weak_to_strong" | ||
version = "0.0.1" | ||
authors = [ | ||
{ name="OpenAI", email="[email protected]" }, | ||
] | ||
description = "Weak-to-strong generalization" | ||
readme = "README.md" | ||
requires-python = ">=3.7" | ||
classifiers = [ | ||
"Programming Language :: Python :: 3", | ||
"License :: OSI Approved :: MIT License", | ||
"Operating System :: OS Independent", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
torch~=2.1 | ||
numpy~=1.24 | ||
transformers~=4.36 | ||
datasets~=2.14 | ||
fire~=0.4 | ||
accelerate~=0.25 | ||
transformers-stream-generator~=0.0.4 | ||
torch_optimizer~=0.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import setuptools | ||
|
||
setuptools.setup( | ||
name="weak_to_strong", | ||
version="0.1", | ||
description="Weak-to-strong generalization", | ||
url="#", | ||
author="OpenAI", | ||
author_email="[email protected]", | ||
packages=setuptools.find_packages(), | ||
zip_safe=False, | ||
) |
Oops, something went wrong.