diff --git a/docs/tutorials/optim-reinvent.ipynb b/docs/tutorials/optim-reinvent.ipynb new file mode 100644 index 0000000..1df3f1d --- /dev/null +++ b/docs/tutorials/optim-reinvent.ipynb @@ -0,0 +1,505 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## SAFE for Goal-directed optimization" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Install the key dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# %%capture\n", + "# ! pip install pytdc\n", + "# ! pip install wandb\n", + "# ! pip install trl" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2024-09-19 15:49:31,613] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to mps (auto detect)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W0919 15:49:32.087000 8193675072 torch/distributed/elastic/multiprocessing/redirects.py:28] NOTE: Redirects are currently not supported in Windows or MacOs.\n" + ] + } + ], + "source": [ + "import os\n", + "import safe as sf\n", + "import datamol as dm\n", + "import torch\n", + "import numpy as np\n", + "from tqdm.auto import tqdm\n", + "from tdc import Evaluator\n", + "from safe.trainer.model import SAFEDoubleHeadsModel\n", + "from safe.tokenizer import SAFETokenizer\n", + "from safe.converter import encode, decode, SAFEConverter\n", + "from random import choices\n", + "from trl import AutoModelForCausalLMWithValueHead,PreTrainedModelWrapper, create_reference_model\n", + "from safe.sample import SAFEDesign\n", + "from safe.optim import REINVENTConfig, REINVENTTrainer, AutoModelForCausalLM\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reinvent training process for goal-directed generation" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# This block show the optimization loop for the goal-directed generation with SAFE-GPT model\n", + "\n", + "def REINVENT_train(config, generation_kwargs, model, tokenizer, reward_fn, prefix=None, n_episodes=100):\n", + " \"\"\" Proximal Policy Optimization training for molecules generation\n", + " Args:\n", + " config: finetuning configs.\n", + " generation_kwargs: Additional parameters for generation.\n", + " model: Base model for optimization.\n", + " tokenizer: SAFE tokenizer to tokenize molecule smiles strings.\n", + " oracle: Reward function for training.\n", + " prefix: String prefix for fragment constrained generation.\n", + " n_episodes: Number of episodes to update the policy and value function of the agent.\n", + "\n", + " Returns:\n", + " reinvent_trainer: trained REINVENT trainer\n", + " model: Fine-tuned SAFE model with optmization.\n", + " \"\"\"\n", + " # get the safe string encoder\n", + " if not isinstance(model, PreTrainedModelWrapper):\n", + " model = AutoModelForCausalLM(safe_model)\n", + " safe_encoder = SAFEConverter()\n", + "\n", + " # define the referene model during fine-tuning\n", + " prior = create_reference_model(model)\n", + " reinvent_config = REINVENTConfig(**config)\n", + "\n", + " # define evaluation metrics for tracking\n", + " diversity_evaluator = Evaluator(name = 'Diversity')\n", + " uniqueness_evaluator = Evaluator(name = 'Uniqueness')\n", + "\n", + " reinvent_trainer = REINVENTTrainer(reinvent_config, model, prior, tokenizer)\n", + " if isinstance(prefix, str):\n", + "\n", + " encoded_fragment = safe_encoder.encoder(\n", + " prefix,\n", + " canonical=False,\n", + " randomize=True,\n", + " constraints=None,\n", + " allow_empty=True,\n", + " )\n", + " prefix = encoded_fragment.rstrip(\".\") + \".\"\n", + "\n", + " if prefix is None:\n", + " prefix = \"\"\n", + "\n", + " if isinstance(prefix, str):\n", + " prefix = [prefix]\n", + "\n", + " batch_size = config.get(\"batch_size\", 32)\n", + " if len(prefix) < batch_size:\n", + " prefix = choices(prefix, k=batch_size)\n", + "\n", + " for _ in tqdm(range(n_episodes)):\n", + "\n", + " # a new complete sequence of actions for agent to learn\n", + " game_data = {}\n", + " game_data[\"query\"] = prefix\n", + " batch = tokenizer([tokenizer.bos_token+x for x in prefix], return_tensors=\"pt\", add_special_tokens=False).to(model.pretrained_model.device)\n", + " query_tensor = batch[\"input_ids\"]\n", + " # generation\n", + " response_tensor = reinvent_trainer.generate(list(query_tensor), return_prompt=False, **generation_kwargs)\n", + " decoded_safe_mols = tokenizer.batch_decode(response_tensor, skip_special_tokens=True)\n", + " decoded_smiles = [\n", + " decode(x,\n", + " as_mol=False,\n", + " fix=True,\n", + " remove_added_hs=True,\n", + " canonical=True,\n", + " ignore_errors=True,\n", + " remove_dummies=True,\n", + " ) for x in decoded_safe_mols\n", + " ]\n", + " game_data[\"response\"] = decoded_safe_mols\n", + "\n", + " # compute the reward scores\n", + " rewards = np.zeros(len(decoded_smiles), dtype=np.float32)\n", + " valid_position = []\n", + " valid_smiles = []\n", + " valid_position, valid_smiles = zip(*[(i, x) for i, x in enumerate(decoded_smiles) if x is not None])\n", + " valid_smiles = list(valid_smiles)\n", + "\n", + " # get reward function\n", + " batch_reward = [reward_fn(smi) for smi in valid_smiles]\n", + " rewards[np.asarray(valid_position)] = batch_reward\n", + " rewards = torch.from_numpy(rewards).to(device=model.pretrained_model.device)\n", + " rewards = list(rewards)\n", + "\n", + " # get the training stats\n", + " stats = reinvent_trainer.step(list(query_tensor), list(response_tensor), rewards)\n", + " stats[\"validity\"] = (len(valid_position) / batch_size)\n", + "\n", + " # other statistics to track\n", + " if len(valid_smiles) > 0:\n", + " stats[\"uniqueness\"] = uniqueness_evaluator(list(valid_smiles))\n", + " stats[\"diversity\"] = diversity_evaluator(list(valid_smiles))\n", + " reinvent_trainer.log_stats(stats, game_data, rewards)\n", + " return reinvent_trainer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define the SAFE model for fine-tuning" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# get the base safe-gpt model for fine-tuning\n", + "designer = SAFEDesign.load_default()\n", + "safe_tokenizer = designer.tokenizer\n", + "safe_model = designer.model\n", + "tokenizer = safe_tokenizer.get_pretrained()\n", + "\n", + "# wrap the model for training\n", + "model = safe_model\n", + "model.is_peft_model = False\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define the reward function that the agent can learn from. \n", + "It can be a single molecular property such as `clogP` or a surrogate function of multiple molecular properties such as `BBB score`. It can also be a scroing function based on a `predictive model` for potency etc. " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# In this tutorial, LogP is used for demonstration purpose\n", + "# The desired log P value is 4\n", + "def clogp_reward_fn(mol: str, **kwargs):\n", + " \"\"\" Reward function for optimization\n", + " Args:\n", + " mol: Molecule in SMILES.\n", + " \"\"\"\n", + " mol = dm.to_mol(mol)\n", + " if mol is None:\n", + " return -100\n", + " return dm.descriptors.clogp(mol)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Start the REINVENT training and track the training on Wandb" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"WANDB_SILENT\"] = \"False\"\n", + "os.environ[\"WANDB_LOG_MODEL\"]=\"end\"\n", + "os.environ[\"WANDB_WATCH\"]=\"all\"\n", + "os.environ[\"WANDB_ENTITY\"]=\"valencelabs\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "07bf481948e143e8bfdae1de1a988560", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/4 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.16.6" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /Users/emmanuel.noutahi/Code/safe/docs/tutorials/wandb/run-20240919_154934-i2juos5u" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run safe-gpt-reinvent-cLogP_dap to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/valencelabs/safe-reinvent-tutorial" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/valencelabs/safe-reinvent-tutorial/runs/i2juos5u" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "100e8847e76c4b3a98f523e5535145a7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/25 [00:00 34\u001b[0m trainer_map[strategy] \u001b[38;5;241m=\u001b[39m \u001b[43mREINVENT_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgeneration_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclogp_reward_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprefix\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mscaffold\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_episodes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_episodes\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[5], line 61\u001b[0m, in \u001b[0;36mREINVENT_train\u001b[0;34m(config, generation_kwargs, model, tokenizer, reward_fn, prefix, n_episodes)\u001b[0m\n\u001b[1;32m 59\u001b[0m query_tensor \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 60\u001b[0m \u001b[38;5;66;03m# generation\u001b[39;00m\n\u001b[0;32m---> 61\u001b[0m response_tensor \u001b[38;5;241m=\u001b[39m \u001b[43mreinvent_trainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mlist\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mquery_tensor\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_prompt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mgeneration_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 62\u001b[0m decoded_safe_mols \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mbatch_decode(response_tensor, skip_special_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 63\u001b[0m decoded_smiles \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 64\u001b[0m decode(x,\n\u001b[1;32m 65\u001b[0m as_mol\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 71\u001b[0m ) \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m decoded_safe_mols\n\u001b[1;32m 72\u001b[0m ]\n", + "File \u001b[0;32m~/Code/safe/safe/optim/reinvent.py:332\u001b[0m, in \u001b[0;36mREINVENTTrainer.generate\u001b[0;34m(self, query_tensor, length_sampler, batch_size, return_prompt, **generation_kwargs)\u001b[0m\n\u001b[1;32m 329\u001b[0m attention_mask \u001b[38;5;241m=\u001b[39m (input_ids \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtokenizer\u001b[38;5;241m.\u001b[39mpad_token_id)\u001b[38;5;241m.\u001b[39mlong()\n\u001b[1;32m 331\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m unwrap_model_for_generation(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator) \u001b[38;5;28;01mas\u001b[39;00m unwrapped_model:\n\u001b[0;32m--> 332\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43munwrapped_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 333\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mgeneration_kwargs\u001b[49m\n\u001b[1;32m 334\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 336\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m j, output \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(outputs):\n\u001b[1;32m 337\u001b[0m response \u001b[38;5;241m=\u001b[39m output[\u001b[38;5;28mlen\u001b[39m(batch_queries[j]) :] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m return_prompt \u001b[38;5;28;01melse\u001b[39;00m output\n", + "File \u001b[0;32m~/Code/safe/safe/optim/_utils.py:58\u001b[0m, in \u001b[0;36mAutoModelForCausalLM.generate\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mgenerate\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 47\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;124;03m A simple wrapper around the `generate` method of the wrapped model.\u001b[39;00m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;124;03m Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils)\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;124;03m Keyword arguments passed to the `generate` method of the wrapped model.\u001b[39;00m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 58\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpretrained_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/transformers/generation/utils.py:2024\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)\u001b[0m\n\u001b[1;32m 2016\u001b[0m input_ids, model_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_expand_inputs_for_generation(\n\u001b[1;32m 2017\u001b[0m input_ids\u001b[38;5;241m=\u001b[39minput_ids,\n\u001b[1;32m 2018\u001b[0m expand_size\u001b[38;5;241m=\u001b[39mgeneration_config\u001b[38;5;241m.\u001b[39mnum_return_sequences,\n\u001b[1;32m 2019\u001b[0m is_encoder_decoder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mis_encoder_decoder,\n\u001b[1;32m 2020\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mmodel_kwargs,\n\u001b[1;32m 2021\u001b[0m )\n\u001b[1;32m 2023\u001b[0m \u001b[38;5;66;03m# 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)\u001b[39;00m\n\u001b[0;32m-> 2024\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sample\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2025\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2026\u001b[0m \u001b[43m \u001b[49m\u001b[43mlogits_processor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_logits_processor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2027\u001b[0m \u001b[43m \u001b[49m\u001b[43mlogits_warper\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_logits_warper\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2028\u001b[0m \u001b[43m \u001b[49m\u001b[43mstopping_criteria\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprepared_stopping_criteria\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2029\u001b[0m \u001b[43m \u001b[49m\u001b[43mgeneration_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgeneration_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2030\u001b[0m \u001b[43m \u001b[49m\u001b[43msynced_gpus\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msynced_gpus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2031\u001b[0m \u001b[43m \u001b[49m\u001b[43mstreamer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstreamer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2032\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2033\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2035\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m generation_mode \u001b[38;5;129;01min\u001b[39;00m (GenerationMode\u001b[38;5;241m.\u001b[39mBEAM_SAMPLE, GenerationMode\u001b[38;5;241m.\u001b[39mBEAM_SEARCH):\n\u001b[1;32m 2036\u001b[0m \u001b[38;5;66;03m# 11. prepare logits warper\u001b[39;00m\n\u001b[1;32m 2037\u001b[0m prepared_logits_warper \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 2038\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_logits_warper(generation_config, device\u001b[38;5;241m=\u001b[39minput_ids\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 2039\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m generation_config\u001b[38;5;241m.\u001b[39mdo_sample\n\u001b[1;32m 2040\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 2041\u001b[0m )\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/transformers/generation/utils.py:2982\u001b[0m, in \u001b[0;36mGenerationMixin._sample\u001b[0;34m(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)\u001b[0m\n\u001b[1;32m 2979\u001b[0m model_inputs\u001b[38;5;241m.\u001b[39mupdate({\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moutput_hidden_states\u001b[39m\u001b[38;5;124m\"\u001b[39m: output_hidden_states} \u001b[38;5;28;01mif\u001b[39;00m output_hidden_states \u001b[38;5;28;01melse\u001b[39;00m {})\n\u001b[1;32m 2981\u001b[0m \u001b[38;5;66;03m# forward pass to get next token\u001b[39;00m\n\u001b[0;32m-> 2982\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 2984\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m synced_gpus \u001b[38;5;129;01mand\u001b[39;00m this_peer_finished:\n\u001b[1;32m 2985\u001b[0m \u001b[38;5;28;01mcontinue\u001b[39;00m \u001b[38;5;66;03m# don't waste resources running the code we don't need\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/Code/safe/safe/trainer/model.py:156\u001b[0m, in \u001b[0;36mSAFEDoubleHeadsModel.forward\u001b[0;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, mc_token_ids, labels, mc_labels, use_cache, output_attentions, output_hidden_states, return_dict, inputs, encoder_hidden_states, **kwargs)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 140\u001b[0m \n\u001b[1;32m 141\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[38;5;124;03m output (GPT2DoubleHeadsModelOutput): output of the model\u001b[39;00m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 155\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[0;32m--> 156\u001b[0m transformer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransformer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 157\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 158\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 160\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken_type_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken_type_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 161\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 162\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 165\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 167\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 171\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m transformer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 172\u001b[0m lm_logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py:1129\u001b[0m, in \u001b[0;36mGPT2Model.forward\u001b[0;34m(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1117\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 1118\u001b[0m block\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 1119\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1126\u001b[0m output_attentions,\n\u001b[1;32m 1127\u001b[0m )\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1129\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mblock\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1130\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[43m \u001b[49m\u001b[43mlayer_past\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlayer_past\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1132\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1133\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1134\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1135\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1136\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1137\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1138\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1140\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1141\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py:614\u001b[0m, in \u001b[0;36mGPT2Block.forward\u001b[0;34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[0m\n\u001b[1;32m 612\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m 613\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mln_1(hidden_states)\n\u001b[0;32m--> 614\u001b[0m attn_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 615\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 616\u001b[0m \u001b[43m \u001b[49m\u001b[43mlayer_past\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlayer_past\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 617\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 618\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 619\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 620\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 621\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 622\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m attn_outputs[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;66;03m# output_attn: a, present, (attentions)\u001b[39;00m\n\u001b[1;32m 623\u001b[0m outputs \u001b[38;5;241m=\u001b[39m attn_outputs[\u001b[38;5;241m1\u001b[39m:]\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/torch/nn/modules/module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/transformers/models/gpt2/modeling_gpt2.py:559\u001b[0m, in \u001b[0;36mGPT2SdpaAttention.forward\u001b[0;34m(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)\u001b[0m\n\u001b[1;32m 557\u001b[0m \u001b[38;5;66;03m# Final projection\u001b[39;00m\n\u001b[1;32m 558\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mc_proj(attn_output)\n\u001b[0;32m--> 559\u001b[0m attn_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mresid_dropout\u001b[49m\u001b[43m(\u001b[49m\u001b[43mattn_output\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 561\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m attn_output, present, \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/torch/nn/modules/module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/safe/lib/python3.12/site-packages/torch/nn/modules/module.py:1555\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m-> 1555\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_call_impl\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 1556\u001b[0m forward_call \u001b[38;5;241m=\u001b[39m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_slow_forward \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_get_tracing_state() \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward)\n\u001b[1;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# define REINVENT config\n", + "n_episodes = 25\n", + "# scaffold = \"[*:2]N1CCN(CC1)CCCCN[*:1]\"\n", + "scaffold = None # a small number for testing purpose in thßßis tutorial\n", + "\n", + "trainer_map = {}\n", + "for strategy in tqdm([\"dap\", \"sdap\", \"mauli\", \"mascof\"]):\n", + " config = {\n", + " \"batch_size\": 32,\n", + " \"mini_batch_size\":32,\n", + " \"log_with\":\"wandb\",\n", + " \"exp_name\": \"safe-gpt-reinvent-cLogP\",\n", + " \"tracker_project_name\": \"safe-reinvent-tutorial\",\n", + " \"reward_model\": \"cLogP\",\n", + " \"sigma\":100,\n", + " \"steps\": n_episodes,\n", + " \"reinvent_epochs\": 2,\n", + " \"max_buffer_size\": 512,\n", + " \"strategy\": strategy\n", + " }\n", + "\n", + " config[\"exp_name\"] += f\"_{strategy}\"\n", + " # generation config\n", + " # see more at https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig\n", + " generation_kwargs = {\n", + " \"min_length\": -1,\n", + " \"do_sample\": True,\n", + " \"pad_token_id\": tokenizer.pad_token_id,\n", + " \"bos_token_id\": tokenizer.bos_token_id,\n", + " \"eos_token_id\": tokenizer.eos_token_id,\n", + " \"max_new_tokens\": 100,\n", + " }\n", + "\n", + " trainer_map[strategy] = REINVENT_train(config, generation_kwargs, model, tokenizer, clogp_reward_fn, prefix=scaffold, n_episodes=n_episodes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now you can used the fine-tuned model and use it for goal-directed generation\n", + "Below we use de novo generation as an example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "designer = SAFEDesign(model=trainer.model, tokenizer=safe_tokenizer )\n", + "generated = designer.de_novo_generation(n_samples_per_trial=10)\n", + "valid_position, valid_smiles = zip(*[(i, x) for i, x in enumerate(generated) if x is not None])\n", + "generated_mols = [dm.to_mol(mol) for mol in valid_smiles]\n", + "mol_prop = [dm.descriptors.clogp(mol) for mol in generated_mols ]\n", + "print(np.mean(mol_prop))\n", + "dm.to_image([dm.to_mol(mol) for mol in generated if mol is not None ],\n", + " legends= [f\"clogP: {x:.2f}\" for x in mol_prop])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "safe", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/env.yml b/env.yml index fd3ab49..fd56891 100644 --- a/env.yml +++ b/env.yml @@ -15,6 +15,7 @@ dependencies: - numpy - pytorch >=2.0 - transformers + - trl - datasets - tokenizers - accelerate >=0.33 # for accelerator_config update diff --git a/mkdocs.yml b/mkdocs.yml index 11dd37b..534069d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -16,7 +16,9 @@ nav: - Getting Started: tutorials/getting-started.ipynb - Molecular design: tutorials/design-with-safe.ipynb - How it works: tutorials/how-it-works.ipynb + - WANDB support: tutorials/load-from-wandb.ipynb - Extracting representation (molfeat): tutorials/extracting-representation-molfeat.ipynb + - Optimization with REINVENT: tutorials/optim-reinvent.ipynb - API: - SAFE: api/safe.md - Visualization: api/safe.viz.md @@ -32,7 +34,7 @@ theme: extra_javascript: - assets/js/google-analytics.js - + markdown_extensions: - admonition - markdown_include.include diff --git a/safe/optim/__init__.py b/safe/optim/__init__.py new file mode 100644 index 0000000..5472c9c --- /dev/null +++ b/safe/optim/__init__.py @@ -0,0 +1,3 @@ +from .reinvent_config import REINVENTConfig +from .reinvent import REINVENTTrainer +from safe.optim._utils import AutoModelForCausalLM diff --git a/safe/optim/_utils.py b/safe/optim/_utils.py new file mode 100644 index 0000000..2b241fb --- /dev/null +++ b/safe/optim/_utils.py @@ -0,0 +1,69 @@ +from trl import PreTrainedModelWrapper + + +class AutoModelForCausalLM(PreTrainedModelWrapper): + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + return_past_key_values=False, + **kwargs, + ): + r""" + Applies a forward pass to the wrapped model and returns the logits of the value head. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`): + Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + (see `past_key_values` input) to speed up sequential decoding. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + return_past_key_values (bool): A flag indicating if the computed hidden-states should be returned. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the wrapped model. + """ + + kwargs["output_hidden_states"] = ( + True # this had already been set in the LORA / PEFT examples + ) + kwargs["past_key_values"] = past_key_values + + if ( + self.is_peft_model + and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING" + ): + kwargs.pop("past_key_values") + return self.pretrained_model.forward( + input_ids=input_ids, attention_mask=attention_mask, **kwargs + ) + + def generate(self, *args, **kwargs): + r""" + A simple wrapper around the `generate` method of the wrapped model. + Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils) + method of the wrapped model for more information about the supported arguments. + + Args: + *args (`list`, *optional*): + Positional arguments passed to the `generate` method of the wrapped model. + **kwargs (`dict`, *optional*): + Keyword arguments passed to the `generate` method of the wrapped model. + """ + return self.pretrained_model.generate(*args, **kwargs) + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + if ( + name == "pretrained_model" + ): # see #1892: prevent infinite recursion if class is not initialized + raise + return getattr(self.pretrained_model, name) diff --git a/safe/optim/reinvent.py b/safe/optim/reinvent.py new file mode 100644 index 0000000..0f04389 --- /dev/null +++ b/safe/optim/reinvent.py @@ -0,0 +1,728 @@ +import inspect +import warnings +from contextlib import nullcontext +from typing import Callable, List, Optional, Union + +import numpy as np +import torch +import heapq +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.utils import ProjectConfiguration, gather_object, is_deepspeed_available +from datasets import Dataset +from torch.optim import Adam +from transformers import ( + DataCollatorForLanguageModeling, + PreTrainedTokenizerBase, + is_torch_npu_available, + is_torch_xpu_available, +) + +from trl.core import ( + PPODecorators, + entropy_from_logits, + masked_mean, + set_seed, +) +from trl.import_utils import is_torch_greater_2_0 +from trl.models import ( + PreTrainedModelWrapper, + create_reference_model, + unwrap_model_for_generation, +) +from trl.trainer import BaseTrainer, PPOConfig, RunningMoments + +from .reinvent_config import REINVENTStrategy, REINVENTConfig + +if is_deepspeed_available(): + import deepspeed + + +class REINVENTTrainer(BaseTrainer): + """ + The REINVENTTrainer implements the REINVENT algorithm for optimizing language models with reinforcement learning. + """ + + def __init__( + self, + config: Optional[REINVENTConfig] = None, + model: Optional[PreTrainedModelWrapper] = None, + ref_model: Optional[PreTrainedModelWrapper] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + dataset: Optional[Union[torch.utils.data.Dataset, Dataset]] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + data_collator: Optional[Callable] = None, + num_shared_layers: Optional[int] = None, + lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + training_data_collator: Optional[Callable] = None, + ): + """ + Initialize REINVENTTrainer. + + Args: + config: Configuration object for REINVENTTrainer. + model: Hugging Face transformer model with a value head. + ref_model: Reference model (prior) to be used in REINVENT. + tokenizer: Hugging Face tokenizer. + dataset: PyTorch dataset or Hugging Face dataset. + optimizer: Optimizer used for training. + data_collator: Data collator function. + num_shared_layers: Number of shared layers between the model and the reference model. + lr_scheduler: Learning rate scheduler used for training. + training_data_collator: Custom data collator used for training. + """ + super().__init__(config) + + # Initial seed for reproducible experiments + set_seed(config.seed) + + # Initialize Accelerator + self.accelerator = Accelerator( + log_with=config.log_with, + gradient_accumulation_steps=config.gradient_accumulation_steps, + project_config=ProjectConfiguration(**config.project_kwargs), + **config.accelerator_kwargs, + ) + + # Runtime variables filled by the accelerator + config.world_size = self.accelerator.num_processes + config.global_batch_size = config.batch_size * config.world_size + + self.model = model + self.model_params = filter(lambda p: p.requires_grad, self.model.parameters()) + self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") + self.is_peft_model = getattr(self.model, "is_peft_model", False) + config.is_encoder_decoder = self.is_encoder_decoder + config.is_peft_model = self.is_peft_model + + is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" + self.accelerator.init_trackers( + config.tracker_project_name, + config=( + {"safe_reinvent_trainer_config": config.to_dict()} + if not is_using_tensorboard + else config.to_dict() + ), + init_kwargs=config.tracker_kwargs, + ) + + # Initialize reference (prior) model + if isinstance(ref_model, PreTrainedModelWrapper): + self.ref_model = ref_model + if num_shared_layers is not None: + warnings.warn( + "num_shared_layers is ignored when ref_model is provided. Two different models are used for the " + "model and the reference model and no layers are shared.", + UserWarning, + ) + elif ref_model is None and not self.is_peft_model: + self.ref_model = create_reference_model(self.model, num_shared_layers=num_shared_layers) + elif self.is_peft_model: + self.ref_model = None + else: + raise ValueError( + f"ref_model must be a PreTrainedModelWrapper or `None`, got {type(ref_model)} - supported " + f"architectures are: {PreTrainedModelWrapper} " + ) + self.optional_peft_ctx = ( + self.accelerator.unwrap_model(self.model).pretrained_model.disable_adapter + if self.is_peft_model + else nullcontext + ) + + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise ValueError( + "tokenizer must be a transformers.PreTrainedTokenizer or transformers.PreTrainedTokenizerFast" + ) + self.tokenizer = tokenizer + + if dataset is not None and not (isinstance(dataset, (torch.utils.data.Dataset, Dataset))): + raise ValueError("dataset must be a torch.utils.data.Dataset or datasets.Dataset") + if dataset is None: + warnings.warn( + "No dataset is provided. Make sure to set config.batch_size to the correct value before training.", + UserWarning, + ) + self.dataset = dataset + self._signature_columns = None + if self.dataset is not None: + self.dataloader = self.prepare_dataloader(self.dataset, data_collator) + elif self.dataset is None and self.accelerator.num_processes > 1: + warnings.warn( + "No dataset is provided. In a multi-GPU setting, this will lead to an error. You should" + " prepare your dataloader yourself with `dataloader = reinvent_trainer.accelerator.prepare(dataloader)`" + " and using `torch.utils.data.DataLoader`, or pass a dataset to the `REINVENTTrainer`. Please " + " refer to the documentation for more details.", + UserWarning, + ) + self.dataloader = None + else: + self.dataloader = None + + # Initialize optimizer and data collator + if training_data_collator is None: + self.data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False) + else: + self.data_collator = training_data_collator + if optimizer is None: + self.optimizer = Adam( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=self.config.learning_rate, + ) + else: + self.optimizer = optimizer + + # Initialize variables for early stopping and score scaling + self.use_score_norm = config.use_score_norm + self.reinvent_epochs = config.reinvent_epochs + self.mini_batch_size = config.mini_batch_size + + # Ensure mini_batch_size divides batch_size + if self.config.batch_size % self.mini_batch_size != 0: + raise ValueError("`batch_size` must be a multiple of `mini_batch_size`.") + + self.lr_scheduler = lr_scheduler + if self.lr_scheduler is not None: + lr_scheduler_class = ( + torch.optim.lr_scheduler._LRScheduler + if not is_torch_greater_2_0() + else torch.optim.lr_scheduler.LRScheduler + ) + + if not isinstance(self.lr_scheduler, lr_scheduler_class): + raise ValueError( + "lr_scheduler must be a torch.optim.lr_scheduler._LRScheduler or torch.optim.lr_scheduler.LRScheduler (for torch >= 2.0)" + ) + + # Strategy and parameters specific to REINVENT + self.sigma = config.sigma + self.strategy = REINVENTStrategy(config.strategy) + self.is_action_basis = config.is_action_basis + + # Initialize experience replay buffer (if needed) + self.use_experience_replay = config.use_experience_replay + self.experience_buffer = None + if self.use_experience_replay: + self.experience_buffer = [] + self.max_buffer_size = getattr(config, "max_buffer_size", 10000) # Default buffer size + + # Safety checkers for DeepSpeed integration + is_deepspeed_used = self.accelerator.distributed_type == "DEEPSPEED" and hasattr( + self.accelerator.state, "deepspeed_plugin" + ) + + if config.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + if hasattr(self.model, "enable_input_require_grads"): + self.model.enable_input_require_grads() + else: + # For backward compatibility with older versions of transformers + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + self.model.pretrained_model.get_input_embeddings().register_forward_hook( + make_inputs_require_grad + ) + + ( + self.model, + self.optimizer, + self.data_collator, + self.dataloader, + self.lr_scheduler, + ) = self.accelerator.prepare( + self.model, + self.optimizer, + self.data_collator, + self.dataloader, + self.lr_scheduler, + ) + if is_deepspeed_used: + # Quantized models are already set on the correct device + if not self.is_peft_model and not ( + getattr(self.ref_model.pretrained_model, "is_loaded_in_8bit", False) + or getattr(self.ref_model.pretrained_model, "is_loaded_in_4bit", False) + ): + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare(self.ref_model) + + # In a distributed setup, only logging needs to be performed on the main process + self.is_distributed = self.accelerator.num_processes > 1 + + # Initialize the current step + self.current_step = 0 + + # Device setup + if not getattr(self.model, "is_sequential_parallel", False): + self.current_device = self.accelerator.device + else: + if is_torch_xpu_available(): + self.current_device = torch.device("xpu:0") + elif is_torch_npu_available(): + self.current_device = torch.device("npu:0") + else: + self.current_device = torch.device("cuda:0") + + PPODecorators.optimize_device_cache = self.config.optimize_device_cache + self.running = RunningMoments(self.accelerator) + + def prepare_dataloader( + self, dataset: Union[torch.utils.data.Dataset, Dataset], data_collator=None + ): + """ + Prepare the dataloader for training. + + Args: + dataset (Union[`torch.utils.data.Dataset`, `datasets.Dataset`]): + PyTorch dataset or Hugging Face dataset. + data_collator (Optional[function]): + Data collator function. + + Returns: + `torch.utils.data.DataLoader`: PyTorch dataloader + """ + return torch.utils.data.DataLoader( + dataset, + batch_size=self.config.batch_size, + collate_fn=data_collator, + shuffle=True, + drop_last=True, + ) + + def generate( + self, + query_tensor: Union[torch.Tensor, List[torch.Tensor]], + length_sampler: Optional[Callable] = None, + batch_size: int = 4, + return_prompt: bool = True, + **generation_kwargs, + ): + """ + Generate response with the model given the query tensor. + + Args: + query_tensor: A tensor or list of tensors containing query tokens. + length_sampler: Callable that returns the number of newly generated tokens. + batch_size: Batch size used for generation, defaults to `4`. + return_prompt: If set to `False` the prompt is not returned but only the newly generated tokens. + generation_kwargs: Keyword arguments for generation. + + Returns: + List[`torch.LongTensor`]: A list of tensors containing response tokens. + """ + if isinstance(query_tensor, torch.Tensor): + query_tensor = [query_tensor] + + responses = [] + batch_size = min(len(query_tensor), batch_size) + + for i in range(0, len(query_tensor), batch_size): + batch_queries = query_tensor[i : i + batch_size] + if length_sampler is not None: + generation_kwargs["max_new_tokens"] = length_sampler() + + input_ids = torch.nn.utils.rnn.pad_sequence( + batch_queries, batch_first=True, padding_value=self.tokenizer.pad_token_id + ).to(self.current_device) + attention_mask = (input_ids != self.tokenizer.pad_token_id).long() + + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + outputs = unwrapped_model.generate( + input_ids=input_ids, attention_mask=attention_mask, **generation_kwargs + ) + + for j, output in enumerate(outputs): + response = output[len(batch_queries[j]) :] if not return_prompt else output + responses.append(response) + + return responses + + @PPODecorators.empty_device_cache() + def step( + self, + queries: List[torch.LongTensor], + responses: List[torch.LongTensor], + scores: List[float], + ): + """ + Perform a REINVENT training step with multiple epochs and mini-batches. + + Args: + queries: List of tensors containing the encoded queries. + responses: List of tensors containing the generated responses. + scores: List of scores (rewards) for the responses. + + Returns: + A dictionary of training statistics. + """ + + # Ensure tensors are on the correct device + queries = [q.to(self.current_device) for q in queries] + responses = [r.to(self.current_device) for r in responses] + scores = torch.tensor(scores, dtype=torch.float32).to(self.current_device) + + # Score scaling and clipping (same as before) + if self.config.use_score_scaling: + # Score scaling + scores_mean, scores_std = self.running.update(scores) + tensor_to_kwargs = {"dtype": scores.dtype, "device": scores.device} + score_scaling_factor = ( + self.running.std.to(**tensor_to_kwargs) + torch.finfo(scores.dtype).eps + ) + if self.config.use_score_norm: + scores = (scores - self.running.mean.to(**tensor_to_kwargs)) / score_scaling_factor + else: + scores /= score_scaling_factor + + # Optionally clip the scores + if self.config.score_clip is not None: + scores_dtype = scores.dtype + scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to( + dtype=scores_dtype + ) + + # Prepare inputs for the agent and prior + model_inputs = self.prepare_model_inputs(queries, responses) + + # Pad inputs if distributed + if self.is_distributed: + pad_first = self.tokenizer.padding_side == "left" + + model_inputs["input_ids"] = self.accelerator.pad_across_processes( + model_inputs["input_ids"], + dim=1, + pad_index=self.tokenizer.pad_token_id, + pad_first=pad_first, + ) + model_inputs["attention_mask"] = self.accelerator.pad_across_processes( + model_inputs["attention_mask"], dim=1, pad_index=0, pad_first=pad_first + ) + if self.is_encoder_decoder: + model_inputs["decoder_input_ids"] = self.accelerator.pad_across_processes( + model_inputs["decoder_input_ids"], + dim=1, + pad_index=self.tokenizer.pad_token_id, + pad_first=pad_first, + ) + model_inputs["decoder_attention_mask"] = self.accelerator.pad_across_processes( + model_inputs["decoder_attention_mask"], + dim=1, + pad_index=0, + pad_first=pad_first, + ) + + # Compute log probabilities and entropies + with torch.no_grad(): + prior_log_probs = self.compute_log_probs(self.ref_model, model_inputs) + prior_log_probs = prior_log_probs.detach() + + # Prepare dataset for mini-batching + dataset = torch.utils.data.TensorDataset( + torch.arange(len(queries)), # Indices + ) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=self.mini_batch_size, + shuffle=True, + ) + + # Initialize variables for logging + total_loss = 0.0 + + for epoch in range(self.reinvent_epochs): + for batch_indices in dataloader: + idx = batch_indices[0] + + [queries[i] for i in idx] + [responses[i] for i in idx] + mb_scores = scores[idx] + mb_model_inputs = {key: value[idx] for key, value in model_inputs.items()} + mb_prior_log_probs = prior_log_probs[idx] + + agent_log_probs, agent_logits = self.compute_log_probs( + self.model, mb_model_inputs, return_logits=True + ) + + # Compute entropies for the mini-batch + entropies = entropy_from_logits(agent_logits) + + # Adjust log_probs based on is_action_basis + mb_attention_mask = mb_model_inputs["attention_mask"][:, 1:].float() # Shifted mask + if not self.is_action_basis: + agent_log_probs = agent_log_probs * mb_attention_mask + mb_prior_log_probs = mb_prior_log_probs * mb_attention_mask + entropies = entropies * mb_attention_mask + + # Compute loss based on the selected strategy + loss = self.loss( + agent_log_probs, + mb_prior_log_probs, + mb_scores, + entropies, + mb_attention_mask, + ) + + # Backpropagation (rest of your code remains the same) + self.optimizer.zero_grad() + self.accelerator.backward(loss) + if self.config.max_grad_norm is not None and self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_(self.model_params, self.config.max_grad_norm) + self.optimizer.step() + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + total_loss += loss.detach().cpu().item() + + # Update experience buffer if enabled + if self.use_experience_replay: + self.update_experience_buffer(queries, responses, scores) + + # Prepare batch dictionary for logging + batch = { + "query": [self.tokenizer.decode(q, skip_special_tokens=True) for q in queries], + "response": [self.tokenizer.decode(r, skip_special_tokens=True) for r in responses], + } + + # Log statistics + stats = { + "loss": total_loss / (self.reinvent_epochs * len(dataloader)), + "mean_score": scores.mean().detach().cpu().item(), + "mean_entropy": entropies.mean().detach().cpu().item(), + "lr": self.optimizer.param_groups[0]["lr"], + "batch_size": len(queries), + } + self.log_stats(stats, batch, scores) + + # Return statistics + return stats + + def log_stats( + self, + stats: dict, + batch: dict, + rewards: List[torch.FloatTensor], + columns_to_log: Optional[List[str]] = None, + ): + """ + A function that logs all the training stats. Call it at the end of each epoch. + + Args: + stats (dict[str, Any]): + A dictionary of training stats. + batch (dict[str, Any]): + A dictionary of batch data, this contains the queries and responses. + rewards (`List[torch.FloatTensor]`): + A tensor of rewards. + columns_to_log (Optional[List[str]], optional): + Columns to log from the batch. Defaults to ["query", "response"]. + """ + if columns_to_log is None: + columns_to_log = ["query", "response"] + + # Gather rewards across processes + if not isinstance(rewards, torch.Tensor): + rewards = torch.tensor(rewards).to(self.current_device) + rewards = self.accelerator.gather(rewards).flatten() + + # Prepare batch data for logging + if self.config.log_with == "wandb": + import wandb + + if any(column_to_log not in batch for column_to_log in columns_to_log): + raise ValueError( + f"Columns to log {columns_to_log} are not present in the batch {batch.keys()}." + ) + + batch_list = [batch[column_to_log] for column_to_log in columns_to_log] + if self.is_distributed: + gathered_batch_list = [] + for b in batch_list: + flattened = gather_object(b) + gathered_batch_list.append(flattened) + batch_list = gathered_batch_list + + # Log only if we are in the main process + if self.accelerator.is_main_process: + logs = {} + + # Log stats + if "query" not in batch and "response" not in batch: + # Warn the user that the game logs will not be logged + warnings.warn( + "The game logs will not be logged because the batch does not contain the keys 'query' and " + "'response'. " + ) + elif self.config.log_with == "wandb": + table_rows = [list(r) for r in zip(*batch_list, rewards.cpu().tolist())] + logs.update( + {"game_log": wandb.Table(columns=[*columns_to_log, "reward"], rows=table_rows)} + ) + + logs.update(stats) + + # Manually cast in fp32 for bf16 torch tensors + for k, v in logs.items(): + if isinstance(v, torch.Tensor) and v.dtype == torch.bfloat16: + logs[k] = v.float() + + logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item() + logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item() + logs["env/reward_dist"] = rewards.cpu().numpy() + + if self.config.log_with == "tensorboard": + # Update the current step + self.current_step += 1 + + self.accelerator.log( + logs, + step=self.current_step if self.config.log_with == "tensorboard" else None, + ) + + def prepare_model_inputs(self, queries: List[torch.Tensor], responses: List[torch.Tensor]): + if self.is_encoder_decoder: + input_data = self.data_collator( + [{"input_ids": q, "attention_mask": torch.ones_like(q)} for q in queries] + ).to(self.current_device) + + decoder_inputs = self.data_collator( + [{"input_ids": r, "attention_mask": torch.ones_like(r)} for r in responses] + ).to(self.current_device) + + input_data["decoder_input_ids"] = decoder_inputs["input_ids"] + input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"] + else: + input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)] + input_data = self.data_collator( + [{"input_ids": ids, "attention_mask": torch.ones_like(ids)} for ids in input_ids] + ).to(self.current_device) + + input_data.pop("labels", None) # We don't want to compute LM losses + return input_data + + def compute_log_probs(self, model, inputs, return_logits=False): + outputs = model(**inputs) + logits = outputs.logits + + if self.is_encoder_decoder: + input_ids = inputs["decoder_input_ids"] + attention_mask = inputs["decoder_attention_mask"] + else: + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + # Shift logits and input_ids for log-likelihood calculation + shifted_logits = logits[..., :-1, :].contiguous() + shifted_input_ids = input_ids[..., 1:].contiguous() + attention_mask = attention_mask[..., 1:].contiguous() + + # Compute log probabilities + log_probs = F.log_softmax(shifted_logits, dim=-1) + log_probs = log_probs.gather(-1, shifted_input_ids.unsqueeze(-1)).squeeze(-1) + + if return_logits: + return log_probs, shifted_logits + return log_probs + + def loss(self, agent_log_probs, prior_log_probs, scores, entropies, attention_mask): + strategy = self.strategy + + if strategy.is_dap(): + if strategy == REINVENTStrategy.DAP: + augmented_log_probs = prior_log_probs + self.sigma * scores.unsqueeze(1) + loss = (augmented_log_probs - agent_log_probs).pow(2) + + # strategy == REINVENTStrategy.SDAP + else: + augmented_log_probs = prior_log_probs + self.sigma * scores.unsqueeze(1) + reward = (augmented_log_probs - agent_log_probs).pow(2) + loss = -reward * agent_log_probs + + # Include entropy regularization + if self.config.entropy_coeff > 0: + loss = loss - self.config.entropy_coeff * entropies + + # Apply masking and reduction + loss = masked_mean(loss, attention_mask) if self.is_action_basis else loss.mean() + + else: + if strategy == REINVENTStrategy.MASCOF: + rewards_sum = scores.to(agent_log_probs.device) + + elif strategy == REINVENTStrategy.MAULI: + rewards = prior_log_probs + self.sigma * scores.unsqueeze(1) + rewards_sum = ( + (rewards * attention_mask).sum(dim=1) + if self.is_action_basis + else rewards.sum(dim=1) + ) + + if self.is_action_basis: + agent_log_probs_sum = (agent_log_probs * attention_mask).sum(dim=1) + entropies_sum = (entropies * attention_mask).sum(dim=1) + else: + agent_log_probs_sum = agent_log_probs.sum(dim=1) + entropies_sum = entropies.sum(dim=1) + + # Compute loss per sequence + loss = -rewards_sum * agent_log_probs_sum + + if self.config.entropy_coeff > 0: + loss = loss - self.config.entropy_coeff * entropies_sum + loss = loss.mean() + + return loss + + def update_experience_buffer(self, queries, responses, scores): + # Store experiences with inverted scores for max-heap behavior + for q, r, s in zip(queries, responses, scores): + experience = (-s.cpu().item(), q.cpu(), r.cpu()) # Negative score for max-heap behavior + if len(self.experience_buffer) < self.max_buffer_size: + heapq.heappush(self.experience_buffer, experience) + else: + # Push new experience and pop the smallest (lowest score) experience + heapq.heappushpop(self.experience_buffer, experience) + + def sample_from_experience_buffer(self, batch_size: int): + """Sample examples from the experience buffer + + Args: + batch_size: Number of example to sample from the buffer + """ + experiences = self.experience_buffer.copy() + queries, responses, scores = zip(*[(q, r, -s) for s, q, r in experiences]) + indices = np.random.choice(len(queries), batch_size, replace=False) + sampled_queries = [queries[i] for i in indices] + sampled_responses = [responses[i] for i in indices] + sampled_scores = [scores[i] for i in indices] + return sampled_queries, sampled_responses, torch.tensor(sampled_scores) + + def _prepare_deepspeed(self, model: PreTrainedModelWrapper): + # Adapted from accelerate: https://github.com/huggingface/accelerate + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepspeed_plugin.deepspeed_config + if model is not None and hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 + * hidden_size + * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model diff --git a/safe/optim/reinvent_config.py b/safe/optim/reinvent_config.py new file mode 100644 index 0000000..4b0be8e --- /dev/null +++ b/safe/optim/reinvent_config.py @@ -0,0 +1,163 @@ +import warnings +from dataclasses import dataclass, field +from typing import Optional, Dict, Any, Literal +import tyro +import json +from enum import Enum +from typing_extensions import Annotated +from transformers import is_wandb_available +from trl.core import flatten_dict + + +JSONDict = Annotated[Optional[dict], tyro.conf.arg(metavar="JSON", constructor=json.loads)] + + +class REINVENTStrategy(Enum): + """Strategies for REINVENT optimization.""" + + DAP = "dap" + MAULI = "mauli" + MASCOF = "mascof" + SDAP = "sdap" + + def is_dap(self): + """Check if the strategy is DAP or SDAP.""" + return self in (REINVENTStrategy.DAP, REINVENTStrategy.SDAP) + + +@dataclass +class REINVENTConfig: + """ + Configuration class for REINVENTTrainer. + + This class encapsulates all the configuration parameters required for training using the REINVENT algorithm. + It is a standalone class and does not inherit from any other configuration classes. + + Args: + # General Configuration + exp_name: Name of the experiment (used for logging/tracking purposes).. Defaults to the name of the script being run. + seed: Random seed for reproducibility. Default=0 + log_with: Logging backend to use. Supported options are: ["wandb", "tensorboard"] + model_name: Name of the model (used for logging/tracking purposes). + reward_model: Name of the reward model (used for logging/tracking purposes). + remove_unused_columns: Whether to remove unused columns from the dataset. Default=True + + # Tracker and Accelerator Configuration + tracker_project_name: Name of the project for tracking. Default to safe-reinvent + tracker_kwargs: Additional keyword arguments for the tracker. + accelerator_kwargs: Additional keyword arguments for the Accelerator. + project_kwargs: additional information for the project configuration of the accelerator + + # Training Configuration + steps: Number of training steps. Default to 10000 + learning_rate: Learning rate for the optimizer. Default=1e-5 + batch_size: Number of samples per optimization step. Default to 128 + mini_batch_size: Number of samples optimized in each mini-batch. Default to 128 + gradient_accumulation_steps: Number of gradient accumulation steps. + max_grad_norm: Maximum gradient norm for clipping. Default to None for no grad clipping + gradient_checkpointing: Whether to use gradient checkpointing. Default to False + optimize_device_cache: Optimize device cache for slightly more memory-efficient training. Default to False + + # REINVENT-Specific Parameters + sigma: Scaling factor for the score. Default to 10.0 + strategy: Strategy to use for optimization. One of ["dap", "sdap", "mauli", "mascof"]. Default to "dap" + entropy_coeff: Entropy regularization coefficient. Increasing the entropy regularization will change + the loss to promote preserving diversity as much as possible, this can decrease performance however. Default to 0 + is_action_basis: Whether to compute loss on an action (token) basis. Default to False + use_experience_replay: Whether to use experience replay during training. Default To False + max_buffer_size: Maximum size of the experience replay buffer. Default to 10_000 + reinvent_epochs: Number of epochs per step (equivalent to PPO epochs). Default to 1 + score_clip: Value to clip the scores range into [-score_clip, +score_clip]. If `None`, no clipping is applied. + use_score_scaling: Whether to scale the scores. Default to False + use_score_norm: Whether to normalize the scores when scaling is used. Default to True + + Attributes: + world_size: Number of processes to use for distributed training. Set by REINVENTTrainer. + global_batch_size: Effective batch size across all processes. Set by REINVENTTrainer. + is_encoder_decoder: Whether the model is an encoder-decoder model. Set by REINVENTTrainer. + is_peft_model: Whether the model is a PEFT (Parameter-Efficient Fine-Tuning) model. Set by REINVENTTrainer. + """ + + # General Configuration + exp_name: str = None # Will default to script name if not provided + seed: int = 0 + log_with: Optional[Literal["wandb", "tensorboard"]] = None + model_name: str = "gpt2" + reward_model: Optional[str] = None + remove_unused_columns: bool = True + + # Tracker and Accelerator Configuration + tracker_project_name: str = "safe-reinvent" + tracker_kwargs: Dict[str, Any] = field(default_factory=dict) + accelerator_kwargs: Dict[str, Any] = field(default_factory=dict) + project_kwargs: Dict[str, Any] = field(default_factory=dict) + + # Training Configuration + steps: int = 10000 + learning_rate: float = 1e-3 + batch_size: int = 128 + mini_batch_size: int = 128 + gradient_accumulation_steps: int = 1 + max_grad_norm: Optional[float] = None + gradient_checkpointing: bool = False + optimize_device_cache: bool = False + + # REINVENT-Specific Parameters + sigma: float = 60.0 + strategy: Literal["dap", "sdap", "mauli", "mascof"] = "dap" + entropy_coeff: Optional[float] = None + is_action_basis: bool = False + use_experience_replay: bool = False + max_buffer_size: int = 10000 + reinvent_epochs: int = 1 + score_clip: Optional[float] = None + use_score_scaling: bool = False + use_score_norm: bool = True + + # Internal attributes set by the trainer + world_size: Optional[int] = None + global_batch_size: Optional[int] = None + is_encoder_decoder: Optional[bool] = None + is_peft_model: Optional[bool] = None + + def __post_init__(self): + # Default exp_name to script name if not provided + if self.exp_name is None: + import os + import sys + + self.exp_name = os.path.basename(sys.argv[0])[: -len(".py")] + + if self.entropy_coeff is None: + self.entropy_coeff = 0.0 + + supported_strategies = [strategy.value for strategy in REINVENTStrategy] + if self.strategy not in supported_strategies: + raise ValueError( + f"Strategy needs to be one of {supported_strategies}, got '{self.strategy}'" + ) + + if self.batch_size % self.mini_batch_size != 0: + raise ValueError("`batch_size` must be a multiple of `mini_batch_size`.") + + if self.use_score_scaling and self.score_clip is None: + warnings.warn( + "use_score_scaling is True but score_clip is None. Scores will not be clipped." + ) + + # Check if wandb is installed if logging with wandb + if self.log_with == "wandb" and not is_wandb_available(): + raise ImportError( + "Please install wandb to use wandb logging. You can do this by running `pip install wandb`." + ) + + self.tracker_kwargs.setdefault(self.log_with, {})["name"] = self.exp_name + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the configuration to a flattened dictionary. + + Returns: + Dict[str, Any]: Flattened dictionary of configuration parameters. + """ + return flatten_dict(self.__dict__) diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py index 355c795..afd0e84 100644 --- a/tests/test_notebooks.py +++ b/tests/test_notebooks.py @@ -12,7 +12,11 @@ NOTEBOOK_PATHS = list(filter(lambda x: x.name not in DISABLE_NOTEBOOKS, NOTEBOOK_PATHS)) # Discard some notebooks -NOTEBOOKS_TO_DISCARD = ["extracting-representation-molfeat.ipynb", "load-from-wandb.ipynb"] +NOTEBOOKS_TO_DISCARD = [ + "extracting-representation-molfeat.ipynb", + "load-from-wandb.ipynb", + "optim-reinvent.ipynb", +] NOTEBOOK_PATHS = list(filter(lambda x: x.name not in NOTEBOOKS_TO_DISCARD, NOTEBOOK_PATHS))