Skip to content

Why output for 3-rd channel is unused in Logistic mixture? #46

@univanxx

Description

@univanxx

Hello! Can anyone explain me one thing: when counting mean3 for 3-rd channel (blue, I suppose), why don't we use samples[:, 2, :, :, :]:

NVAE/distributions.py

Lines 139 to 152 in 9fc1a28

samples = samples.unsqueeze(4) # B, 3, H , W
samples = samples.expand(-1, -1, -1, -1, self.num_mix).permute(0, 1, 4, 2, 3) # B, 3, M, H, W
mean1 = self.means[:, 0, :, :, :] # B, M, H, W
mean2 = self.means[:, 1, :, :, :] + \
self.coeffs[:, 0, :, :, :] * samples[:, 0, :, :, :] # B, M, H, W
mean3 = self.means[:, 2, :, :, :] + \
self.coeffs[:, 1, :, :, :] * samples[:, 0, :, :, :] + \
self.coeffs[:, 2, :, :, :] * samples[:, 1, :, :, :] # B, M, H, W
mean1 = mean1.unsqueeze(1) # B, 1, M, H, W
mean2 = mean2.unsqueeze(1) # B, 1, M, H, W
mean3 = mean3.unsqueeze(1) # B, 1, M, H, W
means = torch.cat([mean1, mean2, mean3], dim=1) # B, 3, M, H, W
centered = samples - means # B, 3, M, H, W

Also, why do we need to update means with samples when counting log prob? For example, in Tacotron-2 code the are no updates of means.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions