From 42ce824cf68465e8bf7dc2f50b216d369da0bbcd Mon Sep 17 00:00:00 2001 From: Ilya Borovik Date: Tue, 9 Feb 2021 13:21:53 +0300 Subject: [PATCH] model.py: fixing inputs passed to GST + modules.py: checking input dimensions in ReferenceEncoder --- model.py | 2 +- modules.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/model.py b/model.py index db842267b..8ad0aefcc 100755 --- a/model.py +++ b/model.py @@ -601,7 +601,7 @@ def forward(self, inputs): embedded_inputs = self.embedding(inputs).transpose(1, 2) embedded_text = self.encoder(embedded_inputs, input_lengths) embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] - embedded_gst = self.gst(targets, output_lengths) + embedded_gst = self.gst(targets.transpose(1, 2), output_lengths) embedded_gst = embedded_gst.repeat(1, embedded_text.size(1), 1) embedded_speakers = embedded_speakers.repeat(1, embedded_text.size(1), 1) diff --git a/modules.py b/modules.py index 736c3d586..61e2ecc73 100755 --- a/modules.py +++ b/modules.py @@ -58,6 +58,7 @@ def __init__(self, hp): self.ref_enc_gru_size = hp.ref_enc_gru_size def forward(self, inputs, input_lengths=None): + assert inputs.size(-1) % self.n_mel_channels == 0 out = inputs.view(inputs.size(0), 1, -1, self.n_mel_channels) for conv, bn in zip(self.convs, self.bns): out = conv(out)