-
Notifications
You must be signed in to change notification settings - Fork 284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update conditioners.py #176
Open
RoyalCities
wants to merge
1
commit into
Stability-AI:main
Choose a base branch
from
RoyalCities:fix/t5-base-tokenizer-config
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -275,10 +275,9 @@ def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple | |
return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)] | ||
|
||
class T5Conditioner(Conditioner): | ||
|
||
T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", | ||
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", | ||
"google/flan-t5-xl", "google/flan-t5-xxl", "google/t5-v1_1-xl", "google/t5-v1_1-xxl"] | ||
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", | ||
"google/flan-t5-xl", "google/flan-t5-xxl", "google/t5-v1_1-xl", "google/t5-v1_1-xxl"] | ||
|
||
T5_MODEL_DIMS = { | ||
"t5-small": 512, | ||
|
@@ -301,14 +300,14 @@ def __init__( | |
self, | ||
output_dim: int, | ||
t5_model_name: str = "t5-base", | ||
max_length: str = 128, | ||
max_length: int = 128, # Changed from str to int | ||
enable_grad: bool = False, | ||
project_out: bool = False | ||
): | ||
assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}" | ||
super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out) | ||
|
||
from transformers import T5EncoderModel, AutoTokenizer | ||
from transformers import T5EncoderModel, T5Tokenizer # Changed to T5Tokenizer | ||
|
||
self.max_length = max_length | ||
self.enable_grad = enable_grad | ||
|
@@ -319,10 +318,26 @@ def __init__( | |
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") | ||
try: | ||
# self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length) | ||
# model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad) | ||
self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name) | ||
model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) | ||
# Explicit tokenizer configuration with legacy support | ||
self.tokenizer = T5Tokenizer.from_pretrained( | ||
t5_model_name, | ||
model_max_length=max_length, | ||
bos_token="<s>", | ||
eos_token="</s>", | ||
unk_token="<unk>", | ||
pad_token="<pad>", | ||
use_auth_token=False, # Disable auth checks | ||
legacy=True # Handle older T5 models | ||
) | ||
|
||
# Model initialization with conditional precision | ||
model = T5EncoderModel.from_pretrained(t5_model_name) | ||
model = model.train(enable_grad).requires_grad_(enable_grad) | ||
if enable_grad: | ||
model = model.to(torch.float16) | ||
else: | ||
model = model.to(torch.float32) | ||
|
||
finally: | ||
logging.disable(previous_level) | ||
|
||
|
@@ -331,9 +346,7 @@ def __init__( | |
else: | ||
self.__dict__["model"] = model | ||
|
||
|
||
def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: | ||
|
||
self.model.to(device) | ||
self.proj_out.to(device) | ||
|
||
|
@@ -346,25 +359,22 @@ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> t | |
) | ||
|
||
input_ids = encoded["input_ids"].to(device) | ||
attention_mask = encoded["attention_mask"].to(device).to(torch.bool) | ||
attention_mask = encoded["attention_mask"].to(device).bool() | ||
|
||
self.model.eval() | ||
|
||
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): | ||
with torch.cuda.amp.autocast(enabled=self.enable_grad, dtype=torch.float16 if self.enable_grad else torch.float32): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would the autocast change based on the grad being enabled? I haven't used it in a long time, but I believe enable_grad was about fine-tuning the T5 model, shouldn't be related to autocasting. |
||
embeddings = self.model( | ||
input_ids=input_ids, attention_mask=attention_mask | ||
)["last_hidden_state"] | ||
input_ids=input_ids, | ||
attention_mask=attention_mask | ||
).last_hidden_state | ||
|
||
# Cast embeddings to same type as proj_out, unless proj_out is Identity | ||
if not isinstance(self.proj_out, nn.Identity): | ||
proj_out_dtype = next(self.proj_out.parameters()).dtype | ||
embeddings = embeddings.to(proj_out_dtype) | ||
|
||
embeddings = self.proj_out(embeddings) | ||
embeddings = embeddings.to(next(self.proj_out.parameters()).dtype) | ||
embeddings = self.proj_out(embeddings) | ||
|
||
embeddings = embeddings * attention_mask.unsqueeze(-1).float() | ||
return embeddings * attention_mask.unsqueeze(-1).float(), attention_mask | ||
|
||
return embeddings, attention_mask | ||
|
||
class PhonemeConditioner(Conditioner): | ||
""" | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this change for?