diff --git a/src/diffwave/inference.py b/src/diffwave/inference.py index 57a0b0f..ce24c14 100644 --- a/src/diffwave/inference.py +++ b/src/diffwave/inference.py @@ -72,6 +72,7 @@ def predict(spectrogram=None, model_dir=None, params=None, device=torch.device(' if len(spectrogram.shape) == 2:# Expand rank 2 tensors by adding a batch dimension. spectrogram = spectrogram.unsqueeze(0) spectrogram = spectrogram.to(device) + spectrogram = model.spectrogram_upsampler(spectrogram) audio = torch.randn(spectrogram.shape[0], model.params.hop_samples * spectrogram.shape[-1], device=device) else: audio = torch.randn(1, params.audio_len, device=device) @@ -80,7 +81,7 @@ def predict(spectrogram=None, model_dir=None, params=None, device=torch.device(' for n in range(len(alpha) - 1, -1, -1): c1 = 1 / alpha[n]**0.5 c2 = beta[n] / (1 - alpha_cum[n])**0.5 - audio = c1 * (audio - c2 * model(audio, torch.tensor([T[n]], device=audio.device), spectrogram).squeeze(1)) + audio = c1 * (audio - c2 * model(audio, torch.tensor([T[n]], device=audio.device), spectrogram, infer=True).squeeze(1)) if n > 0: noise = torch.randn_like(audio) sigma = ((1.0 - alpha_cum[n-1]) / (1.0 - alpha_cum[n]) * beta[n])**0.5 diff --git a/src/diffwave/model.py b/src/diffwave/model.py index 58485e4..9257e94 100644 --- a/src/diffwave/model.py +++ b/src/diffwave/model.py @@ -142,7 +142,7 @@ def __init__(self, params): self.output_projection = Conv1d(params.residual_channels, 1, 1) nn.init.zeros_(self.output_projection.weight) - def forward(self, audio, diffusion_step, spectrogram=None): + def forward(self, audio, diffusion_step, spectrogram=None, infer=False): assert (spectrogram is None and self.spectrogram_upsampler is None) or \ (spectrogram is not None and self.spectrogram_upsampler is not None) x = audio.unsqueeze(1) @@ -150,8 +150,9 @@ def forward(self, audio, diffusion_step, spectrogram=None): x = F.relu(x) diffusion_step = self.diffusion_embedding(diffusion_step) - if self.spectrogram_upsampler: # use conditional model - spectrogram = self.spectrogram_upsampler(spectrogram) + if not infer: + if self.spectrogram_upsampler: # use conditional model + spectrogram = self.spectrogram_upsampler(spectrogram) skip = None for layer in self.residual_layers: