diff --git a/mmgen/models/architectures/dalle_mini/__init__.py b/mmgen/models/architectures/dalle_mini/__init__.py new file mode 100644 index 000000000..1b5da56d5 --- /dev/null +++ b/mmgen/models/architectures/dalle_mini/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .modules import DecoderLayer, EncoderLayer + +__all__ = ['BartDecoderLayer', 'BartEncoderLayer'] diff --git a/mmgen/models/architectures/dalle_mini/modules.py b/mmgen/models/architectures/dalle_mini/modules.py new file mode 100644 index 000000000..d108bb8f5 --- /dev/null +++ b/mmgen/models/architectures/dalle_mini/modules.py @@ -0,0 +1,233 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn.bricks import Linear, build_activation_layer, build_norm_layer +from mmgen.registry import MODULES + + +@MODULES.register_module() +class GLU(nn.Module): + """GLU variants used to improve Transformer. + + Args: + in_out_channels (int): The channel number of the input + and the output feature map. + mid_channels (int): The channel number of the middle layer feature map. + """ + + def __init__(self, in_out_channels, mid_channels): + super().__init__() + _, self.norm1 = build_norm_layer(dict(type='LN'), in_out_channels) + _, self.norm2 = build_norm_layer(dict(type='LN'), mid_channels) + self.fc1 = Linear(in_out_channels, mid_channels, bias=False) + self.fc2 = Linear(in_out_channels, mid_channels, bias=False) + self.fc3 = Linear(mid_channels, in_out_channels, bias=False) + self.gelu = build_activation_layer(dict(type='GELU')) + + def forward(self, z): + """Forward function. + + Args: + z (torch.FloatTensor): Input feature map. + + Returns: + z (torch.FloatTensor): Output feature map. + """ + z = self.norm1(z) + w = self.fc1(z) + w = self.gelu(w) + v = self.fc2(z) + z = self.norm2(w * v) + z = self.fc3(z) + return z + + +@MODULES.register_module() +class AttentionBase(nn.Module): + """An Muti-head Attention block used in Bart model. + + Ref: + https://github.com/kuprel/min-dalle/blob/main/min_dalle/models + + Args: + in_channels (int): The channel number of the input feature map. + num_heads (int): Number of heads in the attention. + """ + + def __init__(self, in_channels, num_heads): + super().__init__() + self.in_channels = in_channels + self.num_heads = num_heads + self.querie = Linear(in_channels, in_channels, bias=False) + self.key = Linear(in_channels, in_channels, bias=False) + self.value = Linear(in_channels, in_channels, bias=False) + self.proj = Linear(in_channels, in_channels, bias=False) + + def qkv(self, x): + """Calculate queries, keys and values for the embedding map. + + Args: + x (torch.FloatTensor): Input feature map. + + Returns: + q (torch.FloatTensor): Querie feature map. + k (torch.FloatTensor): Key feature map. + v (torch.FloatTensor): Value feature map. + """ + q = self.querie(x) + k = self.key(x) + v = self.value(x) + + return q, k, v + + def forward(self, q, k, v, attention_mask): + """Forward function for attention. + + Args: + q (torch.FloatTensor): Querie feature map. + k (torch.FloatTensor): Key feature map. + v (torch.FloatTensor): Value feature map. + attention_mask (torch.BoolTensor): whether to use + an attention mask. + + Returns: + weights (torch.FloatTensor): Feature map after attention. + """ + q = q.reshape(q.shape[:2] + (self.num_heads, -1)) + q /= q.shape[-1]**0.5 + k = k.reshape(k.shape[:2] + (self.num_heads, -1)) + v = v.reshape(v.shape[:2] + (self.num_heads, -1)) + + attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12 + weights = torch.einsum('bqhc,bkhc->bhqk', q, k) + weights += attention_bias + weights = torch.softmax(weights, -1) + weights = torch.einsum('bhqk,bkhc->bqhc', weights, v) + shape = weights.shape[:2] + (self.in_channels, ) + weights = weights.reshape(shape) + weights = self.proj(weights) + return weights + + +@MODULES.register_module() +class BartEncoderLayer(nn.Module): + # yapf: disable + """EncoderLayer of the Bart model. + + Ref: + https://github.com/kuprel/min-dalle/blob/main/min_dalle/models + + Args: + in_channels (int): The channel number of the input feature map. + head_num (int): Number of heads in the attention. + out_channels (int): The channel number of the output feature map. + """ + + def __init__(self, in_channels, head_num, out_channels): + super().__init__() + self.attn = AttentionBase(in_channels, head_num) + _, self.norm = build_norm_layer(dict(type='LN'), in_channels) + self.glu = GLU(in_channels, out_channels) + + def forward(self, x, attention_mask): + """Forward function for the encoder layer. + + Args: + x (torch.FloatTensor): Input feature map. + attention_mask (torch.BoolTensor): Whether to use + an attention mask. + + Returns: + x (torch.FloatTensor): Output feature map. + """ + + h = self.norm(x) + q, k, v = self.attn.qkv(h) + h = self.attn(q, k, v, attention_mask) + h = self.norm(h) + x = x + h + h = self.glu(x) + x = x + h + return x + + +@MODULES.register_module() +class BartDecoderLayer(nn.Module): + # yapf: disable + """DecoderLayer of the Bart model. + + Ref: + https://github.com/kuprel/min-dalle/blob/main/min_dalle/models + + Args: + in_channels (int): The channel number of the input feature map. + head_num (int): Number of heads in the attention. + out_channels (int): The channel number of the output feature map. + token_length (int): The length of tokens. + """ + + def __init__(self, in_channels, head_num, out_channels, token_length=256): + super().__init__() + self.attn = AttentionBase(in_channels, head_num) + self.cross_attn = AttentionBase(in_channels, head_num) + _, self.norm = build_norm_layer(dict(type='LN'), in_channels) + self.glu = GLU(in_channels, out_channels) + self.token_indices = torch.arange(token_length) + + def forward(self, x, encoder_state, attention_state, + attention_mask, token_index): + """Forward function for the decoder layer. + + Args: + x (torch.FloatTensor): Input feature map of + the decoder embeddings. + encoder_state (torch.FloatTensor): Input feature map of + the encoder embeddings. + attention_state (torch.FloatTensor): Input feature map of + the attention. + attention_mask (torch.BoolTensor): whether to use + an attention mask. + token_index (torch.LongTensor): The index of tokens. + + Returns: + x (torch.FloatTensor): Output feature map of + the decoder embeddings. + attention_state (torch.FloatTensor): Output feature map of + the attention. + """ + + # Self Attention + token_count = token_index.shape[1] + if token_count == 1: + self_attn_mask = self.token_indices <= token_index + self_attn_mask = self_attn_mask[:, None, None, :] + else: + self_attn_mask = (self.token_indices[None, None, :token_count] <= + token_index[:, :, None]) + self_attn_mask = self_attn_mask[:, None, :, :] + + h = self.norm(x) + q, k, v = self.attn.qkv(h) + token_count = token_index.shape[1] + if token_count == 1: + batch_count = h.shape[0] + attn_state_new = torch.cat([k, v]).to(attention_state.dtype) + attention_state[:, token_index[0]] = attn_state_new + k = attention_state[:batch_count] + v = attention_state[batch_count:] + h = self.attn(q, k, v, self_attn_mask) + h = self.norm(h) + x = x + h + + # Cross Attention + h = self.norm(x) + q, _, _ = self.cross_attn.qkv(h) + _, k, v = self.cross_attn.qkv(h) + h = self.cross_attn(q, k, v, attention_mask) + h = self.norm(h) + x = x + h + + h = self.glu(x) + x = x + h + + return x, attention_state diff --git a/mmgen/models/architectures/ddpm/modules.py b/mmgen/models/architectures/ddpm/modules.py index 6405ee60d..04939c9ca 100644 --- a/mmgen/models/architectures/ddpm/modules.py +++ b/mmgen/models/architectures/ddpm/modules.py @@ -366,10 +366,13 @@ class DenoisingDownsample(nn.Module): downsampled. with_conv (bool, optional): Whether use convolution operation for downsampling. Defaults to `True`. + with_pad (bool, optional): Whether do asymmetric padding for + downsampling. Defaults to `False`. """ - def __init__(self, in_channels, with_conv=True): + def __init__(self, in_channels, with_conv=True, with_pad=False): super().__init__() + self.with_pad = with_pad if with_conv: self.downsample = nn.Conv2d(in_channels, in_channels, 3, 2, 1) else: @@ -383,6 +386,9 @@ def forward(self, x): Returns: torch.Tensor: Feature map after downsampling. """ + if self.with_pad: + # do asymmetric padding + x = F.pad(x, (0, 1, 0, 1), mode='constant', value=0) return self.downsample(x) diff --git a/mmgen/models/architectures/vqgan/__init__.py b/mmgen/models/architectures/vqgan/__init__.py new file mode 100644 index 000000000..f0618f885 --- /dev/null +++ b/mmgen/models/architectures/vqgan/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .modules import DiffusionResnetBlock + +__all__ = ['DiffusionResnetBlock'] diff --git a/mmgen/models/architectures/vqgan/modules.py b/mmgen/models/architectures/vqgan/modules.py new file mode 100644 index 000000000..42886cd3d --- /dev/null +++ b/mmgen/models/architectures/vqgan/modules.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn.bricks import Linear, build_conv_layer, build_norm_layer +from mmgen.registry import MODULES + + +@MODULES.register_module() +class DiffusionResnetBlock(nn.Module): + # yapf: disable + """Resblock for the diffusion model. If `in_channels` not equals to + `out_channels`, a learnable shortcut with conv layers will be added. + + Ref: + https://github.com/CompVis/taming-transformers/blob/master/taming/modules + + Args: + in_channels (int): Number of channels of the input feature map. + out_channels (int, optional): Number of output channels of the + ResBlock. If not defined, the output channels will equal to the + `in_channels`. Defaults to `None`. + conv_shortcut (bool, optional): Whether to use conv_shortcut in + convolution layers. Defaults to `False`. + dropout (float): Probability of the dropout layers. + temb_channels (int, optional): Number of channels of the input time embedding. + Defaults to `512`. + norm_cfg (dict, optional): Config for the norm of output layer. + Defaults to dict(type='BN'). + """ + + def __init__(self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + norm_cfg=dict(type='GN', num_groups=32, eps=1e-6, + affine=True)): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.silu = nn.SiLU() + + self.norm1 = build_norm_layer(norm_cfg, in_channels) + self.conv1 = build_conv_layer(None, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = Linear(temb_channels, out_channels) + self.norm2 = build_norm_layer(norm_cfg, out_channels) + self.dropout = nn.Dropout(dropout) + self.conv2 = build_conv_layer(None, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = build_conv_layer(None, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = build_conv_layer(None, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + """Forward function. + + Args: + x (torch.Tensor): Input feature map tensor. + temb (torch.Tensor): Shared time embedding. + Returns: + torch.Tensor : Output feature map tensor. + """ + h = self.norm1(x) + h = self.silu(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.silu(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.silu(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h diff --git a/mmgen/models/architectures/vqvae/__init__.py b/mmgen/models/architectures/vqvae/__init__.py new file mode 100644 index 000000000..d851c40f7 --- /dev/null +++ b/mmgen/models/architectures/vqvae/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .quantizer import GumbelQuantize, VectorQuantizer, VectorQuantizer2 + +__all__ = ['GumbelQuantize', 'VectorQuantizer', 'VectorQuantizer2'] diff --git a/mmgen/models/architectures/vqvae/quantizer.py b/mmgen/models/architectures/vqvae/quantizer.py new file mode 100644 index 000000000..bf9fc5091 --- /dev/null +++ b/mmgen/models/architectures/vqvae/quantizer.py @@ -0,0 +1,444 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from mmcv.cnn.bricks import build_conv_layer +from torch import einsum + + +@MODULES.register_module() +class VectorQuantizer(nn.Module): + """Discretization bottleneck part of the VQ-VAE. + + Ref: + https://github.com/MishaLaskin/vqvae/blob/master/models/quantizer.py + https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py + + Args: + in_channels (torch.FloatTensor): The channel number of the input feature map. + e_channels (torch.FloatTensor): The channel number of the embedding. + beta (float): Commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2. + """ + + def __init__(self, in_channels, e_channels, beta): + super().__init__() + self.in_channels = in_channels + self.e_channels = e_channels + self.beta = beta + + self.embedding = nn.Embedding(self.in_channels, self.e_channels) + self.embedding.weight.data.uniform_(-1.0 / self.in_channels, + 1.0 / self.in_channels) + + def forward(self, z): + """Forward function for the encoder network. + + Args: + z (torch.FloatTensor): Input latent vectors. + + Returns: + z_q (torch.FloatTensor): Quantized latent vectors. + loss (torch.FloatTensor): The loss value. + perplexity (torch.FloatTensor): The perplexity of the network. + min_encodings (torch.FloatTensor): The closest embedding vector. + min_encoding_indices (torch.Tensor): The indices of the closest embedding vector. + """ + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_channels) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + # find closest encodings and get quantized latent vectors + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + min_encodings = torch.zeros(min_encoding_indices.shape[0], + self.in_channels).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + + # compute loss + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # compute perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + # reshape back to match original input shape + z_q = z + (z_q - z).detach() + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + """Check for more easy handling with nn.Embedding. + + Args: + indices (torch.Tensor): The index of the embedding. + shape (tuple): The shape of the embedding. + + Returns: + z_q (torch.FloatTensor): Quantized latent vectors. + """ + min_encodings = torch.zeros(indices.shape[0], + self.in_channels).to(indices) + min_encodings.scatter_(1, indices[:, None], 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + if shape is not None: + z_q = z_q.view(shape) + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +@MODULES.register_module() +class VectorQuantizer2(nn.Module): + """Improved version over VectorQuantizer, can be used as a drop-in + replacement. + + Ref: + https://github.com/MishaLaskin/vqvae/blob/master/models/quantizer.py + https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py + + Args: + in_channels (torch.FloatTensor): The channel number of the input feature map. + e_channels (torch.FloatTensor): The channel number of the embeddings. + beta (float): Commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2. + remap (numpy.Float, optional): Remapped embeddings. Defaults to None. + unknown_index (str, optional): How to deal with the unknown index. Defaults to "random". + sane_index_shape (bool, optional): Whether to reshape the indices of the closest embedding. + Defaults to False. + legacy (bool, optional): Whether to use the buggy version. Defaults to True. + """ + + def __init__(self, + in_channels, + e_channels, + beta, + remap=None, + unknown_index="random", + sane_index_shape=False, + legacy=True): + super().__init__() + self.in_channels = in_channels + self.e_channels = e_channels + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.in_channels, self.e_channels) + self.embedding.weight.data.uniform_(-1.0 / self.in_channels, + 1.0 / self.in_channels) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.in_channels} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = in_channels + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + """Remap the indices. + + Args: + inds (torch.Tensor): The indices of the embeddings. + + Returns: + (torch.Tensor): remapped indices. + """ + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint( + 0, self.re_embed, + size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + """Unmap the indices. + + Args: + inds (torch.Tensor): The indices of the embeddings. + + Returns: + (torch.Tensor): unmapped indices. + """ + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + """Forward function for the encoder network. + + Args: + z (torch.FloatTensor): Input latent vectors. + temp (float, optional): The temperature decay for Gumbel quantizer. Defaults to None. + rescale_logits (bool, optional): Whether to rescale the logits. Defaults to False. + return_logits (bool, optional): Return the logits or not. Defaults to False. + + Returns: + z_q (torch.FloatTensor): Quantized latent vectors. + loss (torch.FloatTensor): The loss value. + perplexity (torch.FloatTensor): The perplexity of the network. + min_encodings (torch.FloatTensor): The closest embedding. + min_encoding_indices (torch.Tensor): The indices of the closest embedding. + """ + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits == False, "Only for interface compatible with Gumbel" + assert return_logits == False, "Only for interface compatible with Gumbel" + + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_channels) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # reshape back to match original input shape + z_q = z + (z_q - z).detach() + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape( + z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, + 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + """Check for more easy handling with nn.Embedding. + + Args: + indices (torch.Tensor): The index of the embedding. + shape (tuple): The shape of the embedding. + + Returns: + z_q (torch.FloatTensor): Quantized latent vectors. + """ + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + if shape is not None: + z_q = z_q.view(shape) + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +@MODULES.register_module() +class GumbelQuantize(nn.Module): + """Gumbel Softmax trick quantizer Categorical Reparameterization with + Gumbel-Softmax, Jang et al. 2016 https://arxiv.org/abs/1611.01144. + + Ref: + https://github.com/karpathy/deep-vector-quantization/blob/main/dvq/model/quantize.py + https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py + + Args: + num_hiddens (int): The channel number of the hidden layer. + in_channels (int): The channel number of the input feature map. + e_channels (int): The channel number of the embedding. + straight_through (bool, optional): Whether is eval mode. Defaults to True. + kl_weight (float, optional): The weight of kl divergence. Defaults to 5e-4. + temp_init (float, optional): The temperature decay for Gumbel quantizer. Defaults to 1.0. + use_vqinterface (bool, optional): Whether to use different return formats. Defaults to True. + remap (bool, optional): Whether to remap the embeddings. Defaults to None. + unknown_index (str, optional): How to deal with the unknown index. Defaults to "random". + """ + + def __init__(self, + num_hiddens, + in_channels, + e_channels, + straight_through=True, + kl_weight=5e-4, + temp_init=1.0, + use_vqinterface=True, + remap=None, + unknown_index="random"): + super().__init__() + + self.in_channels = in_channels + self.e_channels = e_channels + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + self.proj = build_conv_layer(None, num_hiddens, in_channels, 1) + self.embed = nn.Embedding(in_channels, e_channels) + self.use_vqinterface = use_vqinterface + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.in_channels} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = in_channels + + def remap_to_used(self, inds): + """Remap the indices. + + Args: + inds (torch.Tensor): The indices of the embeddings. + + Returns: + (torch.Tensor): remapped indices. + """ + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint( + 0, self.re_embed, + size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + """Unmap the indices. + + Args: + inds (torch.Tensor): The indices of the embeddings. + + Returns: + (torch.Tensor): unmapped indices. + """ + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, return_logits=False): + """Forward function for the encoder network. + + Args: + z (torch.FloatTensor): Input latent vectors. + temp (float, optional): The temperature decay for Gumbel quantizer. + Defaults to None. + return_logits (bool, optional): Return the logits or not. + Defaults to False. + + Returns: + z_q (torch.FloatTensor): Quantized latent vectors. + loss (torch.FloatTensor): The loss value. + inds (torch.Tensor): The indices of the embeddings. + """ + + # force hard = True when we are in eval mode + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:, self.used, ...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:, self.used, ...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, + self.embed.weight) + + # add kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + loss = self.kl_weight * torch.sum( + qy * torch.log(qy * self.in_channels + 1e-10), dim=1).mean() + + inds = soft_one_hot.argmax(dim=1) + if self.remap is not None: + inds = self.remap_to_used(inds) + if self.use_vqinterface: + if return_logits: + return z_q, loss, (None, None, inds), logits + return z_q, loss, (None, None, inds) + return z_q, loss, inds + + def get_codebook_entry(self, indices, shape): + """Check for more easy handling with nn.Embedding. + + Args: + indices (torch.Tensor): The index of the embedding. + shape (tuple): The shape of the embedding. + + Returns: + z_q (torch.FloatTensor): Quantized latent vectors. + """ + b, h, w, c = shape + assert b * h * w == indices.shape[0] + indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = F.one_hot(indices, + num_classes=self.in_channels).permute(0, 3, 1, + 2).float() + z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) + + return z_q