|
samples = samples.unsqueeze(4) # B, 3, H , W |
|
samples = samples.expand(-1, -1, -1, -1, self.num_mix).permute(0, 1, 4, 2, 3) # B, 3, M, H, W |
|
mean1 = self.means[:, 0, :, :, :] # B, M, H, W |
|
mean2 = self.means[:, 1, :, :, :] + \ |
|
self.coeffs[:, 0, :, :, :] * samples[:, 0, :, :, :] # B, M, H, W |
|
mean3 = self.means[:, 2, :, :, :] + \ |
|
self.coeffs[:, 1, :, :, :] * samples[:, 0, :, :, :] + \ |
|
self.coeffs[:, 2, :, :, :] * samples[:, 1, :, :, :] # B, M, H, W |
|
|
|
mean1 = mean1.unsqueeze(1) # B, 1, M, H, W |
|
mean2 = mean2.unsqueeze(1) # B, 1, M, H, W |
|
mean3 = mean3.unsqueeze(1) # B, 1, M, H, W |
|
means = torch.cat([mean1, mean2, mean3], dim=1) # B, 3, M, H, W |
|
centered = samples - means # B, 3, M, H, W |
Hello! Can anyone explain me one thing: when counting
mean3for 3-rd channel (blue, I suppose), why don't we usesamples[:, 2, :, :, :]:NVAE/distributions.py
Lines 139 to 152 in 9fc1a28
Also, why do we need to update means with samples when counting log prob? For example, in Tacotron-2 code the are no updates of means.