Skip to content

[magpietts] decoder CE model type #13727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions nemo/collections/tts/models/magpietts.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):

self.model_type = cfg.get('model_type', None)

self.pad_context_text_to_max_duration = self.model_type == 'decoder_context_tts'
self.pad_context_text_to_max_duration = self.model_type in ['decoder_context_tts', 'decoder_ce']
self.use_kv_cache_for_inference = cfg.get('use_kv_cache_for_inference', False)

super().__init__(cfg=cfg, trainer=trainer)
Expand Down Expand Up @@ -173,6 +173,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
self.encoder = transformer_2501.Transformer(**dict(cfg.encoder))

self.decoder = transformer_2501.Transformer(**dict(cfg.decoder))

self.final_proj = nn.Linear(cfg.decoder.d_model, self.num_audio_codebooks * self.num_all_tokens_per_codebook)

self.local_transformer_type = LocalTransformerType(cfg.get('local_transformer_type', 'none').lower())
Expand Down Expand Up @@ -208,6 +209,8 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
)

if self.model_type == 'single_encoder_sv_tts':
# Context audio goes through Titanet to get speaker embedding
# Speaker embedding is added to the transcript encoder output
self._speaker_verification_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
model_name='titanet_large'
)
Expand All @@ -217,6 +220,8 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
idx for idx in range(self.decoder.n_layers)
] # All layers are used for text
elif self.model_type == 'multi_encoder_context_tts':
# Transcript and context audio/text go to different encoders.
# Output of the encoders goes to the decoder through the cross-attention layers
self.transcript_decoder_layers = cfg.get('transcript_decoder_layers', [3, 4, 5, 6, 7, 8])
self.context_decoder_layers = cfg.get(
'context_decoder_layers', [0, 1, 2, 9, 10, 11]
Expand All @@ -229,10 +234,20 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
self.multi_encoder_mapping = multi_encoder_mapping
self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder))
elif self.model_type == 'decoder_context_tts':
# Context audio/text goes directly to the decoder (before the target audio codes)
self.transcript_decoder_layers = [
idx for idx in range(self.decoder.n_layers)
] # All layers are used for text
elif self.model_type == 'decoder_ce':
# Similar to decoder_context_tts, but we use context encoder
# Decoder gets output from context encoder instead of raw context tokens embeddings
self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder))
self.transcript_decoder_layers = [
idx for idx in range(cfg.decoder.n_layers)
] # All layers are used for text

elif self.model_type == 'decoder_pretrain_synthesizer':
# This is for pretraining the decoder only on audio data using next frame prediction loss
assert cfg.alignment_loss_scale == 0.0, "Alignment loss is not supported for decoder pretrain synthesizer"
else:
raise ValueError(f"Unsupported model type {self.model_type}")
Expand Down Expand Up @@ -888,7 +903,7 @@ def prepare_context_tensors(self, batch):
text_lens = None

# self.model_type must be one of
# [single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts, decoder_pretrain_synthesizer]
# [single_encoder_sv_tts, multi_encoder_context_tts, decoder_context_tts, decoder_ce, decoder_pretrain_synthesizer]
if self.model_type != 'decoder_pretrain_synthesizer':
text = batch['text']
text_lens = batch['text_lens']
Expand All @@ -907,7 +922,7 @@ def prepare_context_tensors(self, batch):
cond_mask = text_mask
multi_encoder_mapping = None
attn_prior = _attn_prior
elif self.model_type in ['multi_encoder_context_tts', 'decoder_context_tts']:
elif self.model_type in ['multi_encoder_context_tts', 'decoder_context_tts', 'decoder_ce']:
if 'context_audio_codes' in batch:
context_audio_codes = batch['context_audio_codes']
context_audio_codes_lens = batch['context_audio_codes_lens']
Expand Down Expand Up @@ -961,9 +976,14 @@ def prepare_context_tensors(self, batch):
multi_encoder_mapping = self.multi_encoder_mapping
attn_prior = [_attn_prior, None]

elif self.model_type == 'decoder_context_tts':
elif self.model_type in ['decoder_context_tts', 'decoder_ce']:
dec_context_size = context_mask.size(1)
context_embeddings = context_input_embedded
if self.model_type == 'decoder_context_tts':
context_embeddings = context_input_embedded
elif self.model_type == 'decoder_ce':
context_embeddings = self.context_encoder(
context_input_embedded, context_mask, cond=None, cond_mask=None
)['output']
attn_prior = _attn_prior
if attn_prior is not None:
# B, audio_timesteps, text_timesteps
Expand Down
5 changes: 5 additions & 0 deletions scripts/magpietts/infer_and_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def update_config(model_cfg, codecmodel_path, legacy_codebooks=False):
if hasattr(model_cfg, 'decoder') and hasattr(model_cfg.decoder, 'prior_eps'):
# Added to prevent crash after removing arg from transformer_2501.py in https://github.com/blisc/NeMo/pull/56
del model_cfg.decoder.prior_eps
if hasattr(model_cfg, 'use_local_transformer') and model_cfg.use_local_transformer:
# For older checkpoints trained with a different parameter name
model_cfg.local_transformer_type = "autoregressive"
del model_cfg.use_local_transformer

if legacy_codebooks:
# Added to address backward compatibility arising from
# https://github.com/blisc/NeMo/pull/64
Expand Down
Loading