Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
505 changes: 505 additions & 0 deletions docs/tutorials/optim-reinvent.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- numpy
- pytorch >=2.0
- transformers
- trl
- datasets
- tokenizers
- accelerate >=0.33 # for accelerator_config update
Expand Down
4 changes: 3 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ nav:
- Getting Started: tutorials/getting-started.ipynb
- Molecular design: tutorials/design-with-safe.ipynb
- How it works: tutorials/how-it-works.ipynb
- WANDB support: tutorials/load-from-wandb.ipynb
- Extracting representation (molfeat): tutorials/extracting-representation-molfeat.ipynb
- Optimization with REINVENT: tutorials/optim-reinvent.ipynb
- API:
- SAFE: api/safe.md
- Visualization: api/safe.viz.md
Expand All @@ -32,7 +34,7 @@ theme:

extra_javascript:
- assets/js/google-analytics.js

markdown_extensions:
- admonition
- markdown_include.include
Expand Down
3 changes: 3 additions & 0 deletions safe/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .reinvent_config import REINVENTConfig
from .reinvent import REINVENTTrainer
from safe.optim._utils import AutoModelForCausalLM
69 changes: 69 additions & 0 deletions safe/optim/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from trl import PreTrainedModelWrapper


class AutoModelForCausalLM(PreTrainedModelWrapper):

def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
return_past_key_values=False,
**kwargs,
):
r"""
Applies a forward pass to the wrapped model and returns the logits of the value head.

Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past_key_values` input) to speed up sequential decoding.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
return_past_key_values (bool): A flag indicating if the computed hidden-states should be returned.
kwargs (`dict`, `optional`):
Additional keyword arguments, that are passed to the wrapped model.
"""

kwargs["output_hidden_states"] = (
True # this had already been set in the LORA / PEFT examples
)
kwargs["past_key_values"] = past_key_values

if (
self.is_peft_model
and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING"
):
kwargs.pop("past_key_values")
return self.pretrained_model.forward(
input_ids=input_ids, attention_mask=attention_mask, **kwargs
)

def generate(self, *args, **kwargs):
r"""
A simple wrapper around the `generate` method of the wrapped model.
Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils)
method of the wrapped model for more information about the supported arguments.

Args:
*args (`list`, *optional*):
Positional arguments passed to the `generate` method of the wrapped model.
**kwargs (`dict`, *optional*):
Keyword arguments passed to the `generate` method of the wrapped model.
"""
return self.pretrained_model.generate(*args, **kwargs)

def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if (
name == "pretrained_model"
): # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.pretrained_model, name)
Loading
Loading