+
+
+
\ No newline at end of file
diff --git a/docs/data/0.json b/docs/data/0.json
new file mode 100644
index 00000000..1e390613
--- /dev/null
+++ b/docs/data/0.json
@@ -0,0 +1,548 @@
+{
+ "0": {
+ "file_id": 0,
+ "content": "/MANIFEST.in",
+ "type": "filepath"
+ },
+ "1": {
+ "file_id": 0,
+ "content": "This code is specifying the recursive inclusion of all .txt files in the dalle2_pytorch directory for the MANIFEST.in file.",
+ "type": "summary"
+ },
+ "2": {
+ "file_id": 0,
+ "content": "recursive-include dalle2_pytorch *.txt",
+ "type": "code",
+ "location": "/MANIFEST.in:1-1"
+ },
+ "3": {
+ "file_id": 0,
+ "content": "This code is specifying the recursive inclusion of all .txt files in the dalle2_pytorch directory for the MANIFEST.in file.",
+ "type": "comment"
+ },
+ "4": {
+ "file_id": 1,
+ "content": "/Makefile",
+ "type": "filepath"
+ },
+ "5": {
+ "file_id": 1,
+ "content": "Install the updated pip and then install the project in editable mode. Then, run tests for a specific decoder configuration using CUDA visible devices and a provided JSON file.",
+ "type": "summary"
+ },
+ "6": {
+ "file_id": 1,
+ "content": "install:\n\tpip install -U pip\n\tpip install -e .\ntest:\n\tCUDA_VISIBLE_DEVICES= python train_decoder.py --config_file configs/train_decoder_config.test.json",
+ "type": "code",
+ "location": "/Makefile:1-6"
+ },
+ "7": {
+ "file_id": 1,
+ "content": "Install the updated pip and then install the project in editable mode. Then, run tests for a specific decoder configuration using CUDA visible devices and a provided JSON file.",
+ "type": "comment"
+ },
+ "8": {
+ "file_id": 2,
+ "content": "/README.md",
+ "type": "filepath"
+ },
+ "9": {
+ "file_id": 2,
+ "content": "The code enhances DALL-E 2 with additional layers, unconditional generation, and prior models. It provides installation guidance for saving generated images and inpainting through Latent Diffusion. The accompanying code snippet offers BibTeX entries for four research articles published between 2021 and 2022.",
+ "type": "summary"
+ },
+ "10": {
+ "file_id": 2,
+ "content": "\n## DALL-E 2 - Pytorch\nImplementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch.\nYannic Kilcher summary | AssemblyAI explainer\nThe main novelty seems to be an extra layer of indirection with the prior network (whether it is an autoregressive transformer or a diffusion network), which predicts an image embedding based on the text embedding from CLIP. Specifically, this repository will only build out the diffusion prior network, as it is the best performing variant (but which incidentally involves a causal transformer as the denoising network 😂)\nThis model is SOTA for text-to-image for now.\nPlease join if you are interested in helping out with the replication with the LAION community | Yannic Interview\nAs of 5/23/22, it is no longer SOTA. SOTA will be here. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.\n## Status\n- A research group has used the code in this repository to train a functional diffusion prior for their CLIP generations. Will share their work once they release their preprint. This, and Katherine's own experiments, validate OpenAI's finding that the extra prior increases variety of generations.\n- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.\n\n*ongoing at 21k steps*",
+ "type": "code",
+ "location": "/README.md:13-25"
+ },
+ "13": {
+ "file_id": 2,
+ "content": "This code is for a DALL-E 2 model. It has been improved and simplified using the Imagen architecture, and its decoder is verified to work for unconditional generation in Oxford flowers experiments. Researchers have used this code for training a functional diffusion prior, validating its effectiveness.",
+ "type": "comment"
+ },
+ "14": {
+ "file_id": 2,
+ "content": "- Justin Pinkney successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application\n- Romain has scaled up training to 800 GPUs with the available scripts without any issues\n## Pre-Trained Models\n- LAION is training prior models. Checkpoints are available on 🤗huggingface and the training statistics are available on 🐝WANDB.\n- Decoder - In-progress test run 🚧\n- Decoder - Another test run with sparse attention\n- DALL-E 2 🚧 - DALL-E 2 Laion repository",
+ "type": "code",
+ "location": "/README.md:27-36"
+ },
+ "15": {
+ "file_id": 2,
+ "content": "Justin Pinkney successfully trained the diffusion prior for his CLIP to Stylegan2 text-to-image application. Romain scaled up training to 800 GPUs with existing scripts without any issues. LAION is training prior models, available on HuggingFace and WANDB. Decoder testing runs are ongoing. DALL-E 2 repository by LAION is under development.",
+ "type": "comment"
+ },
+ "16": {
+ "file_id": 2,
+ "content": "## Appreciation\nThis library would not have gotten to this working state without the help of\n- Zion for the distributed training code for the diffusion prior\n- Aidan for the distributed training code for the decoder as well as the dataloaders\n- Kumar for working on the initial diffusion training script\n- Romain for the pull request reviews and project management\n- He Cao and xiankgx for the Q&A and for identifying of critical bugs\n- Marunine for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes\n- MalumaDev for proposing the use of pixel shuffle upsampler for fixing checkboard artifacts",
+ "type": "code",
+ "location": "/README.md:38-48"
+ },
+ "17": {
+ "file_id": 2,
+ "content": "This code block expresses gratitude to the contributors who assisted in developing and improving this library, acknowledging their efforts for distributed training code, bug fixes, Q&A support, and project management.",
+ "type": "comment"
+ },
+ "18": {
+ "file_id": 2,
+ "content": "- Katherine for her advice\n- Stability AI for the generous sponsorship\n- 🤗 Huggingface and in particular Sylvain for the Accelerate library\n- Alex for einops, indispensable tool for tensor manipulation\n... and many others. Thank you! 🙏\n## Install\n```bash\n$ pip install dalle2-pytorch\n```\n## Usage\nTo train DALLE-2 is a 3 step process, with the training of CLIP being the most important\nTo train CLIP, you can either use x-clip package, or join the LAION discord, where a lot of replication efforts are already underway.\nThis repository will demonstrate integration with `x-clip` for starters\n```python",
+ "type": "code",
+ "location": "/README.md:49-70"
+ },
+ "19": {
+ "file_id": 2,
+ "content": "Acknowledgments to Katherine, Stability AI, HuggingFace (Sylvain), and Alex for their contributions; installation instructions with pip command; usage notes mentioning CLIP training, x-clip package, and LAION discord; repository integration with `x-clip` mentioned.",
+ "type": "comment"
+ },
+ "20": {
+ "file_id": 2,
+ "content": "import torch\nfrom dalle2_pytorch import CLIP\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 1,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 1,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8,\n use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)\n decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)\n extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)\n use_visual_ssl = True, # whether to do self supervised learning on images\n visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP\n use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)",
+ "type": "code",
+ "location": "/README.md:71-91"
+ },
+ "21": {
+ "file_id": 2,
+ "content": "The code imports necessary libraries and initializes a CLIP model with specific dimensions for text, image, and latent embeddings. It includes various settings such as token counts, encoding depths, image sizes, heads, and learning techniques (FILIP, DCL, CLOOB, DeCLIP, SLIP). It also indicates whether to use masked language learning on text (MLM) or not.",
+ "type": "comment"
+ },
+ "22": {
+ "file_id": 2,
+ "content": " text_ssl_loss_weight = 0.05, # weight for text MLM loss\n image_ssl_loss_weight = 0.05 # weight for image self-supervised learning loss\n).cuda()\n# mock data\ntext = torch.randint(0, 49408, (4, 256)).cuda()\nimages = torch.randn(4, 3, 256, 256).cuda()\n# train\nloss = clip(\n text,\n images,\n return_loss = True # needs to be set to True to return contrastive loss\n)\nloss.backward()\n# do the above with as many texts and images as possible in a loop\n```\nThen, you will need to train the decoder, which learns to generate images based on the image embedding coming from the trained CLIP above\n```python\nimport torch\nfrom dalle2_pytorch import Unet, Decoder, CLIP\n# trained clip from step 1\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 1,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 1,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8\n).cuda()\n# unet for the decoder",
+ "type": "code",
+ "location": "/README.md:92-136"
+ },
+ "23": {
+ "file_id": 2,
+ "content": "The code snippet initializes a CLIP model, sets text and image self-supervised loss weights, generates mock data for training, computes the contrastive loss, backpropagates gradients, and trains the decoder using a Unet architecture. The trained CLIP from step 1 is used in this step to train the decoder.",
+ "type": "comment"
+ },
+ "24": {
+ "file_id": 2,
+ "content": "unet = Unet(\n dim = 128,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults=(1, 2, 4, 8)\n).cuda()\n# decoder, which contains the unet and clip\ndecoder = Decoder(\n unet = unet,\n clip = clip,\n timesteps = 100,\n image_cond_drop_prob = 0.1,\n text_cond_drop_prob = 0.5\n).cuda()\n# mock images (get a lot of this)\nimages = torch.randn(4, 3, 256, 256).cuda()\n# feed images into decoder\nloss = decoder(images)\nloss.backward()\n# do the above for many many many many steps\n# then it will learn to generate images based on the CLIP image embeddings\n```\nFinally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP from the first step\n```python\nimport torch\nfrom dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP\n# get trained CLIP from step one\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,",
+ "type": "code",
+ "location": "/README.md:138-181"
+ },
+ "25": {
+ "file_id": 2,
+ "content": "In this code, a U-Net model is created using the provided configuration and then placed on the GPU. A decoder is also created, containing the U-Net and CLIP models, with specific parameters for timesteps, image, and text drop probabilities. The decoder generates images based on CLIP image embeddings after going through many steps of training. Finally, a trained CLIP model from step one is imported for use in the diffusion prior network.",
+ "type": "comment"
+ },
+ "26": {
+ "file_id": 2,
+ "content": " text_enc_depth = 6,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 6,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8,\n).cuda()\n# setup prior network, which contains an autoregressive transformer\nprior_network = DiffusionPriorNetwork(\n dim = 512,\n depth = 6,\n dim_head = 64,\n heads = 8\n).cuda()\n# diffusion prior network, which contains the CLIP and network (with transformer) above\ndiffusion_prior = DiffusionPrior(\n net = prior_network,\n clip = clip,\n timesteps = 100,\n cond_drop_prob = 0.2\n).cuda()\n# mock data\ntext = torch.randint(0, 49408, (4, 256)).cuda()\nimages = torch.randn(4, 3, 256, 256).cuda()\n# feed text and images into diffusion prior network\nloss = diffusion_prior(text, images)\nloss.backward()\n# do the above for many many many steps\n# now the diffusion prior can generate image embeddings from the text embeddings\n```\nIn the paper, they actually used a recently discovered technique,",
+ "type": "code",
+ "location": "/README.md:182-223"
+ },
+ "27": {
+ "file_id": 2,
+ "content": "The code sets up a diffusion prior network for generating image embeddings from text embeddings using PyTorch. The network is composed of an autoregressive transformer, CLIP model, and other layers. It also includes a prior_network, random data, and losses are calculated by feeding the text and images into the diffusion prior network before backpropagation. This process is repeated many times to train the network.",
+ "type": "comment"
+ },
+ "28": {
+ "file_id": 2,
+ "content": " from Jonathan Ho himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.\nThis can easily be used within this framework as so\n```python\nimport torch\nfrom dalle2_pytorch import Unet, Decoder, CLIP\n# trained clip from step 1\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 6,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 6,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8\n).cuda()\n# 2 unets for the decoder (a la cascading DDPM)\nunet1 = Unet(\n dim = 32,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults = (1, 2, 4, 8)\n).cuda()\nunet2 = Unet(\n dim = 32,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults = (1, 2, 4, 8, 16)\n).cuda()\n# decoder, which contains the unet(s) and clip\ndecoder = Decoder(\n clip = clip,\n unet = (unet1, unet2), ",
+ "type": "code",
+ "location": "/README.md:223-269"
+ },
+ "29": {
+ "file_id": 2,
+ "content": "This code imports necessary modules and initializes a CLIP model, two UNETs for the decoder, and a decoder itself. The CLIP model is trained from a previous step, while the two UNETs are initialized with different dimensions for cascading DDPMs. The decoder contains the CLIP model and both UNETs.",
+ "type": "comment"
+ },
+ "30": {
+ "file_id": 2,
+ "content": " # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)\n image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)\n timesteps = 1000,\n image_cond_drop_prob = 0.1,\n text_cond_drop_prob = 0.5\n).cuda()\n# mock images (get a lot of this)\nimages = torch.randn(4, 3, 512, 512).cuda()\n# feed images into decoder, specifying which unet you want to train\n# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme\nloss = decoder(images, unet_number = 1)\nloss.backward()\nloss = decoder(images, unet_number = 2)\nloss.backward()\n# do the above for many steps for both unets\n```\nFinally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))\n```python\nfrom dalle2_pytorch import DALLE2\ndalle2 = DALLE2(\n prior = diffusion_prior,",
+ "type": "code",
+ "location": "/README.md:269-298"
+ },
+ "31": {
+ "file_id": 2,
+ "content": "The code inserts two U-Nets into a decoder model in ascending order of resolution. The images are generated using a specified number of U-Nets, and the loss is calculated for each U-Net separately. Finally, a trained `DiffusionPrior` and `Decoder` (wrapping `CLIP`, a causal transformer, and unet(s)) are inserted to generate DALL-E2 images from text.",
+ "type": "comment"
+ },
+ "32": {
+ "file_id": 2,
+ "content": " decoder = decoder\n)\n# send the text as a string if you want to use the simple tokenizer from DALLE v1\n# or you can do it as token ids, if you have your own tokenizer\ntexts = ['glistening morning dew on a flower petal']\nimages = dalle2(texts) # (1, 3, 256, 256)\n```\nThat's it!\nLet's see the whole script below\n```python\nimport torch\nfrom dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 6,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 6,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8\n).cuda()\n# mock data\ntext = torch.randint(0, 49408, (4, 256)).cuda()\nimages = torch.randn(4, 3, 256, 256).cuda()\n# train\nloss = clip(\n text,\n images,\n return_loss = True\n)\nloss.backward()\n# do above for many steps ...\n# prior networks (with transformer)\nprior_network = DiffusionPriorNetwork(\n dim = 512,\n depth = 6,\n dim_head = 64,",
+ "type": "code",
+ "location": "/README.md:299-353"
+ },
+ "33": {
+ "file_id": 2,
+ "content": "This code is importing necessary modules and creating an instance of DALLE2 model. It then generates images from input text using the model, and performs training on the model by calculating loss and performing backpropagation for multiple steps. The code also creates a prior network with specified dimensions and depth.",
+ "type": "comment"
+ },
+ "34": {
+ "file_id": 2,
+ "content": " heads = 8\n).cuda()\ndiffusion_prior = DiffusionPrior(\n net = prior_network,\n clip = clip,\n timesteps = 1000,\n sample_timesteps = 64,\n cond_drop_prob = 0.2\n).cuda()\nloss = diffusion_prior(text, images)\nloss.backward()\n# do above for many steps ...\n# decoder (with unet)\nunet1 = Unet(\n dim = 128,\n image_embed_dim = 512,\n text_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults=(1, 2, 4, 8),\n cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings\n).cuda()\nunet2 = Unet(\n dim = 16,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults = (1, 2, 4, 8, 16)\n).cuda()\ndecoder = Decoder(\n unet = (unet1, unet2),\n image_sizes = (128, 256),\n clip = clip,\n timesteps = 100,\n image_cond_drop_prob = 0.1,\n text_cond_drop_prob = 0.5\n).cuda()\nfor unet_number in (1, 2):\n loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much",
+ "type": "code",
+ "location": "/README.md:354-400"
+ },
+ "35": {
+ "file_id": 2,
+ "content": "The code sets up a DALLE-like model using PyTorch, with two Unets for the decoder and trains it by iteratively calculating the loss. The model uses diffusion prior and is conditioned on both text and image encodings. It has different dimensions and timesteps for each Unet. The code is used to train an AI image generation model.",
+ "type": "comment"
+ },
+ "36": {
+ "file_id": 2,
+ "content": " loss.backward()\n# do above for many steps\ndalle2 = DALLE2(\n prior = diffusion_prior,\n decoder = decoder\n)\nimages = dalle2(\n ['cute puppy chasing after a squirrel'],\n cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)\n)\n# save your image (in this example, of size 256x256)\n```\nEverything in this readme should run without error\nYou can also train the decoder on images of greater than the size (say 512x512) at which CLIP was trained (256x256). The images will be resized to CLIP image resolution for the image embeddings\nFor the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.\n## Training on Preprocessed CLIP Embeddings\nIt is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings`\nWorking example below\n```python\nimport torch",
+ "type": "code",
+ "location": "/README.md:401-431"
+ },
+ "37": {
+ "file_id": 2,
+ "content": "The code is initializing a DALLE2 model with specified prior and decoder, generating images from input text using classifier-free guidance (with conditional scale 2), and then saving the generated image of size 256x256. The code also mentions that training will be automated into a CLI tool for small-scale training and that preprocessing images and text into embeddings might be required for scaling up.",
+ "type": "comment"
+ },
+ "38": {
+ "file_id": 2,
+ "content": "from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP\n# get trained CLIP from step one\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 6,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 6,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8,\n).cuda()\n# setup prior network, which contains an autoregressive transformer\nprior_network = DiffusionPriorNetwork(\n dim = 512,\n depth = 6,\n dim_head = 64,\n heads = 8\n).cuda()\n# diffusion prior network, which contains the CLIP and network (with transformer) above\ndiffusion_prior = DiffusionPrior(\n net = prior_network,\n clip = clip,\n timesteps = 100,\n cond_drop_prob = 0.2,\n condition_on_text_encodings = False # this probably should be true, but just to get Laion started\n).cuda()\n# mock data\ntext = torch.randint(0, 49408, (4, 256)).cuda()\nimages = torch.randn(4, 3, 256, 256).cuda()\n# precompute the text and image embeddings",
+ "type": "code",
+ "location": "/README.md:432-474"
+ },
+ "39": {
+ "file_id": 2,
+ "content": "This code is importing modules, initializing a trained CLIP model and setting up a diffusion prior network containing an autoregressive transformer. The diffusion prior contains both the CLIP and network. Mock data is then created for testing purposes.",
+ "type": "comment"
+ },
+ "40": {
+ "file_id": 2,
+ "content": "# here using the diffusion prior class, but could be done with CLIP alone\nclip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed\nclip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed\n# feed text and images into diffusion prior network\nloss = diffusion_prior(\n text_embed = clip_text_embeds,\n image_embed = clip_image_embeds\n)\nloss.backward()\n# do the above for many many many steps\n# now the diffusion prior can generate image embeddings from the text embeddings\n```\nYou can also completely go `CLIP`-less, in which case you will need to pass in the `image_embed_dim` into the `DiffusionPrior` on initialization\n```python\nimport torch\nfrom dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior\n# setup prior network, which contains an autoregressive transformer\nprior_network = DiffusionPriorNetwork(\n dim = 512,\n depth = 6,\n dim_head = 64,\n heads = 8\n).cuda()\n# diffusion prior network, which contains the CLIP and network (with transformer) above\ndiffusion_prior = DiffusionPrior(",
+ "type": "code",
+ "location": "/README.md:475-510"
+ },
+ "41": {
+ "file_id": 2,
+ "content": "The code initializes a diffusion prior network with an autoregressive transformer and uses CLIP for image and text embeddings. It then calculates the loss by feeding the embeddings into the diffusion prior network, backpropagates the gradients, and repeats this process multiple times. Alternatively, CLIP can be excluded from the model initialization by passing the `image_embed_dim` directly to the `DiffusionPrior` class.",
+ "type": "comment"
+ },
+ "42": {
+ "file_id": 2,
+ "content": " net = prior_network,\n image_embed_dim = 512, # this needs to be set\n timesteps = 100,\n cond_drop_prob = 0.2,\n condition_on_text_encodings = False # this probably should be true, but just to get Laion started\n).cuda()\n# mock data\ntext = torch.randint(0, 49408, (4, 256)).cuda()\nimages = torch.randn(4, 3, 256, 256).cuda()\n# precompute the text and image embeddings\n# here using the diffusion prior class, but could be done with CLIP alone\nclip_image_embeds = torch.randn(4, 512).cuda()\nclip_text_embeds = torch.randn(4, 512).cuda()\n# feed text and images into diffusion prior network\nloss = diffusion_prior(\n text_embed = clip_text_embeds,\n image_embed = clip_image_embeds\n)\nloss.backward()\n# do the above for many many many steps\n# now the diffusion prior can generate image embeddings from the text embeddings\n```\n## OpenAI CLIP\nAlthough there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your o",
+ "type": "code",
+ "location": "/README.md:511-544"
+ },
+ "43": {
+ "file_id": 2,
+ "content": "The code snippet is creating a diffusion model using the provided parameters and utilizing the OpenAI CLIP for image and text embeddings. The text and image embeddings are precomputed, then fed into the diffusion prior network to calculate loss and perform backpropagation. This process is repeated many times to train the model for generating image embeddings from text embeddings.",
+ "type": "comment"
+ },
+ "44": {
+ "file_id": 2,
+ "content": "wn CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.\nTo use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so\n```python\nimport torch\nfrom dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter\n# openai pretrained clip - defaults to ViT-B/32\nclip = OpenAIClipAdapter()\n# mock data\ntext = torch.randint(0, 49408, (4, 256)).cuda()\nimages = torch.randn(4, 3, 256, 256).cuda()\n# prior networks (with transformer)\nprior_network = DiffusionPriorNetwork(\n dim = 512,\n depth = 6,\n dim_head = 64,\n heads = 8\n).cuda()\ndiffusion_prior = DiffusionPrior(\n net = prior_network,\n clip = clip,\n timesteps = 100,\n cond_drop_prob = 0.2\n).cuda()\nloss = diffusion_prior(text, images)\nloss.backward()\n# do above for many steps ...\n# decoder (with unet)\nunet1 = Unet(\n dim = 128,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults=(1, 2, 4, 8),",
+ "type": "code",
+ "location": "/README.md:544-589"
+ },
+ "45": {
+ "file_id": 2,
+ "content": "This code snippet demonstrates how to use OpenAI's CLIP model, pre-trained, within the DALLE2 PyTorch framework. It defines a function `OpenAIClipAdapter` that allows easy integration of pre-trained CLIP with DALLE2's prior and decoder networks. The code provides an example of how to use these networks for training purposes by defining a diffusion prior and unet decoder, and applying them to some mock data.",
+ "type": "comment"
+ },
+ "46": {
+ "file_id": 2,
+ "content": " text_embed_dim = 512,\n cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)\n).cuda()\nunet2 = Unet(\n dim = 16,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults = (1, 2, 4, 8, 16)\n).cuda()\ndecoder = Decoder(\n unet = (unet1, unet2),\n image_sizes = (128, 256),\n clip = clip,\n timesteps = 1000,\n sample_timesteps = (250, 27),\n image_cond_drop_prob = 0.1,\n text_cond_drop_prob = 0.5\n).cuda()\nfor unet_number in (1, 2):\n loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much\n loss.backward()\n# do above for many steps\ndalle2 = DALLE2(\n prior = diffusion_prior,\n decoder = decoder\n)\nimages = dalle2(\n ['a butterfly trying to escape a tornado'],\n cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)",
+ "type": "code",
+ "location": "/README.md:590-625"
+ },
+ "47": {
+ "file_id": 2,
+ "content": "The code initializes a DALLE2 model and trains it by feeding images and text. It creates Unet layers, a Decoder, and a DALLE2 instance using given dimensions and parameters. The training loop iterates over unet_number, calculates loss, and applies gradient descent to optimize the model. Finally, the DALLE2 model generates images based on input text with conditional scaling.",
+ "type": "comment"
+ },
+ "48": {
+ "file_id": 2,
+ "content": ")\n# save your image (in this example, of size 256x256)\n```\nAlternatively, you can also use Open Clip\n```bash\n$ pip install open-clip-torch\n```\nEx. using the SOTA Open Clip model trained by Romain\n```python\nfrom dalle2_pytorch import OpenClipAdapter\nclip = OpenClipAdapter('ViT-H/14')\n```\nNow you'll just have to worry about training the Prior and the Decoder!\n## Inpainting\nInpainting is also built into the `Decoder`. You simply have to pass in the `inpaint_image` and `inpaint_mask` (boolean tensor where `True` indicates which regions of the inpaint image to keep)\nThis repository uses the formulation put forth by Lugmayr et al. in Repaint\n```python\nimport torch\nfrom dalle2_pytorch import Unet, Decoder, CLIP\n# trained clip from step 1\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,",
+ "type": "code",
+ "location": "/README.md:626-663"
+ },
+ "49": {
+ "file_id": 2,
+ "content": "The code provides instructions on how to save an image and use Open Clip for image processing. It mentions installing the open-clip-torch package, using a state-of-the-art (SOTA) Open Clip model, initializing the OpenClipAdapter with the desired model, and utilizing the Decoder's built-in inpainting feature, following the formulation presented in Repaint. The code also showcases how to import necessary modules and initialize a CLIP object with specified dimensions.",
+ "type": "comment"
+ },
+ "50": {
+ "file_id": 2,
+ "content": " text_enc_depth = 6,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 6,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8\n).cuda()\n# 2 unets for the decoder (a la cascading DDPM)\nunet = Unet(\n dim = 16,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults = (1, 1, 1, 1)\n).cuda()\n# decoder, which contains the unet(s) and clip\ndecoder = Decoder(\n clip = clip,\n unet = (unet,), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)\n image_sizes = (256,), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)\n timesteps = 1000,\n image_cond_drop_prob = 0.1,\n text_cond_drop_prob = 0.5\n).cuda()\n# mock images (get a lot of this)\nimages = torch.randn(4, 3, 256, 256).cuda()\n# feed images into decoder, specifying which unet you want to train\n# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme",
+ "type": "code",
+ "location": "/README.md:664-700"
+ },
+ "51": {
+ "file_id": 2,
+ "content": "This code initializes a DALL-E 2 model with specified dimensions for text and visual encoders, along with two UNet models for the decoder. The decoder is then instantiated using these components and a set of image sizes, timesteps, and conditional drop probabilities. Finally, mock images are created for training purposes.",
+ "type": "comment"
+ },
+ "52": {
+ "file_id": 2,
+ "content": "loss = decoder(images, unet_number = 1)\nloss.backward()\n# do the above for many steps for both unets\nmock_image_embed = torch.randn(1, 512).cuda()\n# then to do inpainting\ninpaint_image = torch.randn(1, 3, 256, 256).cuda() # (batch, channels, height, width)\ninpaint_mask = torch.ones(1, 256, 256).bool().cuda() # (batch, height, width)\ninpainted_images = decoder.sample(\n image_embed = mock_image_embed,\n inpaint_image = inpaint_image, # just pass in the inpaint image\n inpaint_mask = inpaint_mask # and the mask\n)\ninpainted_images.shape # (1, 3, 256, 256)\n```\n## Experimental\n### DALL-E2 with Latent Diffusion\nThis repository decides to take the next step and offer DALL-E v2 combined with latent diffusion, from Rombach et al.\nYou can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.\nThe repository also comes equipped with all the necessary settin",
+ "type": "code",
+ "location": "/README.md:702-731"
+ },
+ "53": {
+ "file_id": 2,
+ "content": "This code initializes a decoder and performs inpainting using DALL-E2 with Latent Diffusion. It generates a mock image embedding, sets the input image and mask for inpainting, then samples the inpainted images from the decoder.",
+ "type": "comment"
+ },
+ "54": {
+ "file_id": 2,
+ "content": "gs to recreate `ViT-VQGan` from the Improved VQGans paper. Furthermore, the vector quantization library also comes equipped to do residual or multi-headed quantization, which I believe will give an even further boost in performance to the autoencoder.\n```python\nimport torch\nfrom dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE\n# trained clip from step 1\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 1,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 1,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8\n)\n# 3 unets for the decoder (a la cascading DDPM)\n# first two unets are doing latent diffusion\n# vqgan-vae must be trained beforehand\nvae1 = VQGanVAE(\n dim = 32,\n image_size = 256,\n layers = 3,\n layer_mults = (1, 2, 4)",
+ "type": "code",
+ "location": "/README.md:731-762"
+ },
+ "55": {
+ "file_id": 2,
+ "content": "The code is importing necessary modules for training a VQGAN-VAE model. It initializes a CLIP model and three Unet models for the decoder, as well as a VQGanVAE model. The CLIP model is pre-trained, while the VQGanVAE needs to be trained beforehand. This code seems to aim at improving the performance of an autoencoder using residual or multi-headed quantization techniques.",
+ "type": "comment"
+ },
+ "56": {
+ "file_id": 2,
+ "content": ")\nvae2 = VQGanVAE(\n dim = 32,\n image_size = 512,\n layers = 3,\n layer_mults = (1, 2, 4)\n)\nunet1 = Unet(\n dim = 32,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n sparse_attn = True,\n sparse_attn_window = 2,\n dim_mults = (1, 2, 4, 8)\n)\nunet2 = Unet(\n dim = 32,\n image_embed_dim = 512,\n channels = 3,\n dim_mults = (1, 2, 4, 8, 16),\n cond_on_image_embeds = True,\n cond_on_text_encodings = False\n)\nunet3 = Unet(\n dim = 32,\n image_embed_dim = 512,\n channels = 3,\n dim_mults = (1, 2, 4, 8, 16),\n cond_on_image_embeds = True,\n cond_on_text_encodings = False,\n attend_at_middle = False\n)\n# decoder, which contains the unet(s) and clip\ndecoder = Decoder(\n clip = clip,\n vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3\n unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)\n image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third",
+ "type": "code",
+ "location": "/README.md:763-807"
+ },
+ "57": {
+ "file_id": 2,
+ "content": "This code sets up a DALLE2 model by creating and configuring various components: VQGanVAE (vae1), Unet models (unet1, unet2, unet3) and a Decoder. The decoder combines the clip and VAEs with corresponding Unets at different resolutions. The image sizes specify the resolutions for each Unet stage, starting from 256 for the first one up to 1024 for the third one.",
+ "type": "comment"
+ },
+ "58": {
+ "file_id": 2,
+ "content": " timesteps = 100,\n image_cond_drop_prob = 0.1,\n text_cond_drop_prob = 0.5\n).cuda()\n# mock images (get a lot of this)\nimages = torch.randn(1, 3, 1024, 1024).cuda()\n# feed images into decoder, specifying which unet you want to train\n# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme\nwith decoder.one_unet_in_gpu(1):\n loss = decoder(images, unet_number = 1)\n loss.backward()\nwith decoder.one_unet_in_gpu(2):\n loss = decoder(images, unet_number = 2)\n loss.backward()\nwith decoder.one_unet_in_gpu(3):\n loss = decoder(images, unet_number = 3)\n loss.backward()\n# do the above for many steps for both unets\n# then it will learn to generate images based on the CLIP image embeddings\n# chaining the unets from lowest resolution to highest resolution (thus cascading)\nmock_image_embed = torch.randn(1, 512).cuda()\nimages = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)\n```\n## Training wrapper\n### Decoder Training\nTraining the `Decoder` may be confusing,",
+ "type": "code",
+ "location": "/README.md:808-846"
+ },
+ "59": {
+ "file_id": 2,
+ "content": "The code demonstrates how to train a `Decoder` with multiple unets using a cascading DDPM scheme. First, it initializes the model's parameters and assigns the required GPU. Then, it creates random images and specifies which unet to train in each iteration by calling the `one_unet_in_gpu()` method. The code trains multiple steps for each unet before moving on to the next one. Finally, a mock image is generated from an embedding using the trained decoder.",
+ "type": "comment"
+ },
+ "60": {
+ "file_id": 2,
+ "content": " as one needs to keep track of an optimizer for each of the `Unet`(s) separately. Each `Unet` will also need its own corresponding exponential moving average. The `DecoderTrainer` hopes to make this simple, as shown below\n```python\nimport torch\nfrom dalle2_pytorch import DALLE2, Unet, Decoder, CLIP, DecoderTrainer\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 6,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 6,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8\n).cuda()\n# mock data\ntext = torch.randint(0, 49408, (32, 256)).cuda()\nimages = torch.randn(32, 3, 256, 256).cuda()\n# decoder (with unet)\nunet1 = Unet(\n dim = 128,\n image_embed_dim = 512,\n text_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults=(1, 2, 4, 8),\n cond_on_text_encodings = True,\n).cuda()\nunet2 = Unet(\n dim = 16,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults = (1, 2, 4, 8, 16),",
+ "type": "code",
+ "location": "/README.md:846-888"
+ },
+ "61": {
+ "file_id": 2,
+ "content": "The code defines a `CLIP` object and creates two `Unet` instances with different architectures. The `CLIP` model is used for text-to-image generation, while the `Unet` models are variational autoencoders that will be trained to generate images based on input text. The `unet1` has a smaller architecture compared to `unet2`, and both use the same embeddings. The code also provides mock data for testing the functionality of the decoder and trainers.",
+ "type": "comment"
+ },
+ "62": {
+ "file_id": 2,
+ "content": ").cuda()\ndecoder = Decoder(\n unet = (unet1, unet2),\n image_sizes = (128, 256),\n clip = clip,\n timesteps = 1000\n).cuda()\ndecoder_trainer = DecoderTrainer(\n decoder,\n lr = 3e-4,\n wd = 1e-2,\n ema_beta = 0.99,\n ema_update_after_step = 1000,\n ema_update_every = 10,\n)\nfor unet_number in (1, 2):\n loss = decoder_trainer(\n images,\n text = text,\n unet_number = unet_number, # which unet to train on\n max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times\n )\n decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average\n# after much training\n# you can sample from the exponentially moving averaged unets as so\nmock_image_embed = torch.randn(32, 512).cuda()\nimages = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)\n```\n### Diffusion Prior Training\nSimilarly, one can use the `Di",
+ "type": "code",
+ "location": "/README.md:889-926"
+ },
+ "63": {
+ "file_id": 2,
+ "content": "This code sets up a decoder, trainer, and trains the unets to generate images based on text input. The trainer updates the unets and their exponential moving averages after each iteration. Finally, it samples from the moving-averaged unets to create new images.",
+ "type": "comment"
+ },
+ "64": {
+ "file_id": 2,
+ "content": "ffusionPriorTrainer` to automatically instantiate and keep track of an exponential moving averaged prior.\n```python\nimport torch\nfrom dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, CLIP\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 6,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 6,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8\n).cuda()\n# mock data\ntext = torch.randint(0, 49408, (512, 256)).cuda()\nimages = torch.randn(512, 3, 256, 256).cuda()\n# prior networks (with transformer)\nprior_network = DiffusionPriorNetwork(\n dim = 512,\n depth = 6,\n dim_head = 64,\n heads = 8\n).cuda()\ndiffusion_prior = DiffusionPrior(\n net = prior_network,\n clip = clip,\n timesteps = 100,\n cond_drop_prob = 0.2\n).cuda()\ndiffusion_prior_trainer = DiffusionPriorTrainer(\n diffusion_prior,\n lr = 3e-4,\n wd = 1e-2,\n ema_beta = 0.99,",
+ "type": "code",
+ "location": "/README.md:926-971"
+ },
+ "65": {
+ "file_id": 2,
+ "content": "This code creates a CLIP model, initializes a diffusion prior network, and sets up a trainer for the diffusion prior. The CLIP model is used to encode text and images into latent representations, while the diffusion prior network is responsible for predicting the future of latent samples. The trainer will automatically update the moving average of the prior network over time.",
+ "type": "comment"
+ },
+ "66": {
+ "file_id": 2,
+ "content": " ema_update_after_step = 1000,\n ema_update_every = 10,\n)\nloss = diffusion_prior_trainer(text, images, max_batch_size = 4)\ndiffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior\n# after much of the above three lines in a loop\n# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior\nimage_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings\n```\n## Bonus\n### Unconditional Training\nThe repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`\nex.\n```python\nimport torch\nfrom dalle2_pytorch import Unet, Decoder, DecoderTrainer\n# unet for the cascading ddpm\nunet1 = Unet(\n dim = 128,\n dim_mults=(1, 2, 4, 8)\n).cuda()\nunet2 = Unet(\n dim = 32,\n dim_mults = (1, 2, 4, 8, 16)\n).cuda()\n# decoder, which contains the unets",
+ "type": "code",
+ "location": "/README.md:972-1009"
+ },
+ "67": {
+ "file_id": 2,
+ "content": "This code initializes a diffusion prior trainer with exponential moving average (EMA) update parameters, trains the model using diffusion_prior_trainer, updates optimizer and EMA diffusion prior, and finally samples from the EMA of the diffusion prior. The code also mentions that unconditional training or cascading DDPMs can be done by setting `unconditional = True` in the Decoder.",
+ "type": "comment"
+ },
+ "68": {
+ "file_id": 2,
+ "content": "decoder = Decoder(\n unet = (unet1, unet2),\n image_sizes = (256, 512), # first unet up to 256px, then second to 512px\n timesteps = 1000,\n unconditional = True\n).cuda()\n# decoder trainer\ndecoder_trainer = DecoderTrainer(decoder)\n# images (get a lot of this)\nimages = torch.randn(1, 3, 512, 512).cuda()\n# feed images into decoder\nfor i in (1, 2):\n loss = decoder_trainer(images, unet_number = i)\n decoder_trainer.update(unet_number = i)\n# do the above for many many many many images\n# then it will learn to generate images\nimages = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)\n```\n## Dataloaders\n### Decoder Dataloaders\nIn order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.\n#### Decoder: Image Embedding Dataset\nWhen training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two simi",
+ "type": "code",
+ "location": "/README.md:1011-1046"
+ },
+ "69": {
+ "file_id": 2,
+ "content": "The code initializes a decoder, trainer for the decoder, and generates images. It then trains the decoder by feeding images into it, updating the trainer, and repeats this process many times to enable learning. Finally, it uses the trained decoder to generate new images.",
+ "type": "comment"
+ },
+ "70": {
+ "file_id": 2,
+ "content": "lar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.\nGenerating a dataset of this type: \n1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.\n2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.",
+ "type": "code",
+ "location": "/README.md:1046-1050"
+ },
+ "71": {
+ "file_id": 2,
+ "content": "This code describes a dataset format using webdataset, containing .jpg and .npy files in .tar archives. It allows specifying an external source for embeddings with the same shard numbers and filename-to-index correspondence. The code provides steps to generate this type of dataset using img2dataset and clip-retrieval.",
+ "type": "comment"
+ },
+ "72": {
+ "file_id": 2,
+ "content": "3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.\nUsage:\n```python\nfrom dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader\n# Create a dataloader directly.\ndataloader = create_image_embedding_dataloader(\n tar_url=\"/path/or/url/to/webdataset/{0000..9999}.tar\", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar\n embeddings_url=\"path/or/url/to/embeddings/folder\", # Included if .npy files are not in webdataset. Left out or set to None otherwise\n num_workers=4,\n batch_size=32,\n shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index\n shuffle_num=200, # Does a shuffle of the data with a buffer size of 200\n shuffle_shards=True, # Shuffle the order the shards are read in",
+ "type": "code",
+ "location": "/README.md:1051-1066"
+ },
+ "73": {
+ "file_id": 2,
+ "content": "This code snippet demonstrates the usage of the `create_image_embedding_dataloader` function from the DALLE2-pytorch library. It creates an image embedding dataloader by specifying a URL path for the webdataset tar files and optional embeddings folder, setting the number of workers and batch size, defining the shard width, and deciding whether to shuffle the shards or not. The purpose is to reorder the embeddings into the expected format for image generation tasks.",
+ "type": "comment"
+ },
+ "74": {
+ "file_id": 2,
+ "content": " resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually\n)\nfor img, emb in dataloader:\n print(img.shape) # torch.Size([32, 3, 256, 256])\n print(emb[\"img\"].shape) # torch.Size([32, 512])\n # Train decoder only as shown above\n# Or create a dataset without a loader so you can configure it manually\ndataset = ImageEmbeddingDataset(\n urls=\"/path/or/url/to/webdataset/{0000..9999}.tar\",\n embedding_folder_url=\"path/or/url/to/embeddings/folder\",\n shard_width=4,\n shuffle_shards=True,\n resample=False\n)\n```\n### Scripts\n#### `train_diffusion_prior.py`\nFor detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md)\n## Todo\n- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon\n- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)",
+ "type": "code",
+ "location": "/README.md:1067-1093"
+ },
+ "75": {
+ "file_id": 2,
+ "content": "The code snippet shows how to load an ImageEmbeddingDataset and print its shape. The dataset is loaded from a webdataset at the specified URL, with embedding files located in the given folder. It uses shard_width=4 for sharding the data and sets resample to False. The loader creates images and embeddings which are printed for verification. Additionally, it mentions creating a dataset without a loader if manual configuration is preferred.",
+ "type": "comment"
+ },
+ "76": {
+ "file_id": 2,
+ "content": "- [x] make sure it works end to end to produce an output tensor, taking a single gradient step\n- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)\n- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)\n- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions\n- [x] add efficient attention in unet\n- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)\n- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)\n- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms\n- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0",
+ "type": "code",
+ "location": "/README.md:1094-1102"
+ },
+ "77": {
+ "file_id": 2,
+ "content": "The code outlines the steps to create a DDPM model, including conditioning it with text encodings and incorporating a cascade of unets for different resolutions. It also mentions adding efficient attention in unet, allowing customization of conditioning for specific unets, offloading unets to CPU, building latent diffusion architecture, and providing the option for vq-reg variant (vqgan-vae). The decoder objective can be customized between predicting epsilon or x0.",
+ "type": "comment"
+ },
+ "78": {
+ "file_id": 2,
+ "content": "- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435\n- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms\n- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion\n- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in\n- [x] take care of mixed precision as well as gradient accumulation within decoder trainer\n- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer\n- [x] bring in tools to train vqgan-vae\n- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)\n- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)\n- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)\n- [x] ",
+ "type": "code",
+ "location": "/README.md:1103-1113"
+ },
+ "79": {
+ "file_id": 2,
+ "content": "This code is a list of tasks to be completed for the DALLE2-pytorch project. It includes implementing attention-based upsampling, using inheritance, integrating Vit-VQGAN, creating an abstract interface for CLIP adapters, handling mixed precision and gradient accumulation in the decoder trainer, adding a training wrapper class for each unet in the cascade, incorporating convnext backbone for VQGAN-VAE, making sure DDPMs can be run with traditional resnet blocks, enabling super resolution training on crops for latter unets, and allowing conv-like attention with rel pos bias.",
+ "type": "comment"
+ },
+ "80": {
+ "file_id": 2,
+ "content": "offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention\n- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)\n- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training\n- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images\n- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14\n- [x] cross embed layers for downsampling, as an option\n- [x] use an experimental tracker agnostic setup, as done here\n- [x] use pydantic for config drive training\n- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)",
+ "type": "code",
+ "location": "/README.md:1113-1121"
+ },
+ "81": {
+ "file_id": 2,
+ "content": "The code is about configuring and training a diffusion prior model for image generation. It includes improvements such as making hyperparameters configurable, incorporating cross-scale embedding, introducing cross embed layers for downsampling, and using an experimental tracker agnostic setup. The code also utilizes pydantic for configuration drive training, saves and restores all exponential moving averaged models for both diffusion prior and decoder.",
+ "type": "comment"
+ },
+ "82": {
+ "file_id": 2,
+ "content": "- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes\n- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs\n- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)\n- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)\n- [x] allow for unet to be able to condition non-cross attention style as well\n- [x] speed up inference, read up on papers (ddim)\n- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865\n- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments\n- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364",
+ "type": "code",
+ "location": "/README.md:1122-1130"
+ },
+ "83": {
+ "file_id": 2,
+ "content": "This code lists various tasks and features that have been implemented or are planned for the DALLE2-pytorch model. These include save/load methods, creation of diffusion prior models, skip-layer excitations, grid attention in Cascading DDPM, unet conditioning, speed up inference, resampler from REPAINT paper, final combination of upsample feature maps, and consideration for Elucidated DALLE2.",
+ "type": "comment"
+ },
+ "84": {
+ "file_id": 2,
+ "content": "- [ ] add simple outpainting, text-guided 2x size the image for starters\n- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2\n## Citations\n```bibtex\n@misc{ramesh2022,\n title = {Hierarchical Text-Conditional Image Generation with CLIP Latents}, \n author = {Aditya Ramesh et al},\n year = {2022}\n}\n```\n```bibtex\n@misc{crowson2022,\n author = {Katherine Crowson},\n url = {https://twitter.com/rivershavewings}\n}\n```\n```bibtex\n@misc{rombach2021highresolution,\n title = {High-Resolution Image Synthesis with Latent Diffusion Models}, \n author = {Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},\n year = {2021},\n eprint = {2112.10752},\n archivePrefix = {arXiv},\n primaryClass = {cs.CV}\n}\n```\n```bibtex\n@article{shen2019efficient,\n author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},\n title = {Efficient Attention: Attention with Linear Complexities},",
+ "type": "code",
+ "location": "/README.md:1131-1165"
+ },
+ "85": {
+ "file_id": 2,
+ "content": "This code chunk appears to be a task list for the DALLE2-pytorch project, followed by citations in BibTeX format. The tasks include implementing simple outpainting and text-guided 2x image size expansion. The project also plans on integrating the VQGAN-VAE, which can be pulled from a pretrained model to test latent diffusion and DALL-E2 integration. The cited works include \"Hierarchical Text-Conditional Image Generation with CLIP Latents\", \"High-Resolution Image Synthesis with Latent Diffusion Models\", \"Efficient Attention: Attention with Linear Complexities\" and a Twitter post by Katherine Crowson.",
+ "type": "comment"
+ },
+ "86": {
+ "file_id": 2,
+ "content": " journal = {CoRR},\n year = {2018},\n url = {http://arxiv.org/abs/1812.01243},\n}\n```\n```bibtex\n@article{Yu2021VectorquantizedIM,\n title = {Vector-quantized Image Modeling with Improved VQGAN},\n author = {Jiahui Yu and Xin Li and Jing Yu Koh and Han Zhang and Ruoming Pang and James Qin and Alexander Ku and Yuanzhong Xu and Jason Baldridge and Yonghui Wu},\n journal = {ArXiv},\n year = {2021},\n volume = {abs/2110.04627}\n}\n```\n```bibtex\n@article{Shleifer2021NormFormerIT,\n title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},\n author = {Sam Shleifer and Jason Weston and Myle Ott},\n journal = {ArXiv},\n year = {2021},\n volume = {abs/2110.09456}\n}\n```\n```bibtex\n@article{Yu2022CoCaCC,\n title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},\n author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},\n journal = {ArXiv},\n year = {2022},\n volume = {abs/2205.01917}",
+ "type": "code",
+ "location": "/README.md:1166-1198"
+ },
+ "87": {
+ "file_id": 2,
+ "content": "The code defines four BibTeX entries for academic papers, providing the paper title, author(s), journal or arXiv, and publication year. These entries can be used to cite the papers in a BibTeX database or bibliography file.",
+ "type": "comment"
+ },
+ "88": {
+ "file_id": 2,
+ "content": "}\n```\n```bibtex\n@misc{wang2021crossformer,\n title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},\n author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},\n year = {2021},\n eprint = {2108.00154},\n archivePrefix = {arXiv},\n primaryClass = {cs.CV}\n}\n```\n```bibtex\n@article{ho2021cascaded,\n title = {Cascaded Diffusion Models for High Fidelity Image Generation},\n author = {Ho, Jonathan and Saharia, Chitwan and Chan, William and Fleet, David J and Norouzi, Mohammad and Salimans, Tim},\n journal = {arXiv preprint arXiv:2106.15282},\n year = {2021}\n}\n```\n```bibtex\n@misc{Saharia2022,\n title = {Imagen: unprecedented photorealism × deep level of language understanding},\n author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},",
+ "type": "code",
+ "location": "/README.md:1199-1225"
+ },
+ "89": {
+ "file_id": 2,
+ "content": "This code snippet represents the citation for a research paper called \"CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention\" by Wenxiao Wang et al. The paper is available at arXiv with ID 2108.00154 and focuses on a versatile vision transformer model using cross-scale attention.",
+ "type": "comment"
+ },
+ "90": {
+ "file_id": 2,
+ "content": " year = {2022}\n}\n```\n```bibtex\n@article{Choi2022PerceptionPT,\n title = {Perception Prioritized Training of Diffusion Models},\n author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},\n journal = {ArXiv},\n year = {2022},\n volume = {abs/2204.00227}\n}\n```\n```bibtex\n@article{Saharia2021PaletteID,\n title = {Palette: Image-to-Image Diffusion Models},\n author = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi},\n journal = {ArXiv},\n year = {2021},\n volume = {abs/2111.05826}\n}\n```\n```bibtex\n@article{Lugmayr2022RePaintIU,\n title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},\n author = {Andreas Lugmayr and Martin Danelljan and Andr{\\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},\n journal = {ArXiv},\n year = {2022},\n volume = {abs/2201.09865}\n}\n```\n```bibtex\n@misc{chen2022analog,",
+ "type": "code",
+ "location": "/README.md:1226-1261"
+ },
+ "91": {
+ "file_id": 2,
+ "content": "The code includes four BibTeX entries, each representing a different research article. The first entry is for the article titled \"Perception Prioritized Training of Diffusion Models\" by Choi et al., published in 2022. The second entry is for the article titled \"Palette: Image-to-Image Diffusion Models\" by Saharia et al., published in 2021. The third entry is for the article titled \"RePaint: Inpainting using Denoising Diffusion Probabilistic Models\" by Lugmayr et al., published in 2022. The last entry, marked as incomplete, is for an unnamed article by Chen et al. published in 2022.",
+ "type": "comment"
+ },
+ "92": {
+ "file_id": 2,
+ "content": " title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},\n author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},\n year = {2022},\n eprint = {2208.04202},\n archivePrefix = {arXiv},\n primaryClass = {cs.CV}\n}\n```\n```bibtex\n@article{Qiao2019WeightS,\n title = {Weight Standardization},\n author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},\n journal = {ArXiv},\n year = {2019},\n volume = {abs/1903.10520}\n}\n```\n```bibtex\n@inproceedings{rogozhnikov2022einops,\n title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},\n author = {Alex Rogozhnikov},\n booktitle = {International Conference on Learning Representations},\n year = {2022},\n url = {https://openreview.net/forum?id=oapKSVM2bcj}\n}\n```\n```bibtex\n@article{Sunkara2022NoMS,\n title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},",
+ "type": "code",
+ "location": "/README.md:1262-1293"
+ },
+ "93": {
+ "file_id": 2,
+ "content": "The code represents BibTeX entries for research articles and conferences. These entries provide information about the title, authors, journals or conferences, publication years, and related identifiers of the cited works.",
+ "type": "comment"
+ },
+ "94": {
+ "file_id": 2,
+ "content": " author = {Raja Sunkara and Tie Luo},\n journal = {ArXiv},\n year = {2022},\n volume = {abs/2208.03641}\n}\n```\n```bibtex\n@article{Salimans2022ProgressiveDF,\n title = {Progressive Distillation for Fast Sampling of Diffusion Models},\n author = {Tim Salimans and Jonathan Ho},\n journal = {ArXiv},\n year = {2022},\n volume = {abs/2202.00512}\n}\n```\n*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper",
+ "type": "code",
+ "location": "/README.md:1294-1311"
+ },
+ "95": {
+ "file_id": 2,
+ "content": "This code snippet provides the citation information for two papers in the BibTeX format. The first paper is titled \"Progressive Distillation for Fast Sampling of Diffusion Models\" by Tim Salimans and Jonathan Ho, published in ArXiv in 2022 with volume abs/2202.00512. The second paper is called \"ArXiv:2208.03641\" which doesn't seem to have a title or authors listed, but it was also published on ArXiv in 2022 with the volume abs/2208.03641. Both papers discuss generative modeling techniques using diffusion models.",
+ "type": "comment"
+ },
+ "96": {
+ "file_id": 3,
+ "content": "/configs/README.md",
+ "type": "filepath"
+ },
+ "97": {
+ "file_id": 3,
+ "content": "This code configures DALLE2 model training and logging options in PyTorch, with customizable settings for Unet and Decoder, dataloader, preprocessing, hyperparameters, image metrics, and experiment tracking. It supports various configurations based on selected logger and storage types.",
+ "type": "summary"
+ },
+ "98": {
+ "file_id": 3,
+ "content": "## DALLE2 Training Configurations\nFor more complex configuration, we provide the option of using a configuration file instead of command line arguments.\n### Decoder Trainer\nThe decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).\n**Unet:**\nThis is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `dim` | Yes | N/A | The starting channels of the unet. |\n| `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. |\n| `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. |\nAny parameter from the `Unet` constructor can also be given here.\n**Decoder:**\nDefines the configuration options for the decoder model. The unets defined above will automatically be inserted.\n| Option | Required | Default | Description |",
+ "type": "code",
+ "location": "/configs/README.md:1-24"
+ },
+ "99": {
+ "file_id": 3,
+ "content": "This code provides details on configuring training for DALLE2, a complex model that requires various settings. It includes sections for Unet and Decoder configurations with optional parameters. An example configuration file is also mentioned for easier understanding.",
+ "type": "comment"
+ }
+}
\ No newline at end of file
diff --git a/docs/data/1.json b/docs/data/1.json
new file mode 100644
index 00000000..ac01c43b
--- /dev/null
+++ b/docs/data/1.json
@@ -0,0 +1,549 @@
+{
+ "100": {
+ "file_id": 3,
+ "content": "| ------ | -------- | ------- | ----------- |\n| `unets` | Yes | N/A | A list of unets, using the configuration above |\n| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |\n| `image_size` | Yes | N/A | Not used. Can be any number. |\n| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |\n| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |\n| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |\n| `learned_variance` | No | `True` | Whether to learn the variance. |\n| `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. |\nAny parameter from the `Decoder` constructor can also be given here.\n**Data:**\nSettings for creation of the dataloaders.\n| Option | Required | Default | Description |",
+ "type": "code",
+ "location": "/configs/README.md:25-40"
+ },
+ "101": {
+ "file_id": 3,
+ "content": "This code appears to be defining the configuration for a machine learning model, specifically one using U-Nets. The configuration includes options for the number of unets, image resolution, timesteps, loss function type, noise schedule, and learned variance. Additionally, there are settings for creating dataloaders for the model's data. The code also notes that any parameter from the `Decoder` constructor can be included in this configuration.",
+ "type": "comment"
+ },
+ "102": {
+ "file_id": 3,
+ "content": "| ------ | -------- | ------- | ----------- |\n| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |\n| `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. |\n| `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. |\n| `num_workers` | No | `4` | The number of workers used in the dataloader. |\n| `batch_size` | No | `64` | The batch size. |\n| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |\n| `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. |\n| `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. |\n| `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. |\n| `splits` | No | `{ \"tra",
+ "type": "code",
+ "location": "/configs/README.md:41-51"
+ },
+ "103": {
+ "file_id": 3,
+ "content": "This code defines various configuration options for a dataloader, including webdataset and embeddings urls, worker numbers, batch size, shard range, and file indexing. The config allows flexibility in handling different types of datasets, with optional embeddings or use of the webdataset library.",
+ "type": "comment"
+ },
+ "104": {
+ "file_id": 3,
+ "content": "in\": 0.75, \"val\": 0.15, \"test\": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. |\n| `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. |\n| `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. |\n| `preprocessing` | No | `{ \"ToTensor\": True }` | Defines preprocessing applied to images from the datasets. |\n[^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`.\n[^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5.",
+ "type": "code",
+ "location": "/configs/README.md:51-58"
+ },
+ "105": {
+ "file_id": 3,
+ "content": "This code defines the proportion of shards allocated to training, validation, and testing datasets as well as whether to shuffle training dataset, preprocessing applied to images from datasets, and details for downloading shard files. It also provides information on how to use protocols like `s3` and calculating the shard length based on filename.",
+ "type": "comment"
+ },
+ "106": {
+ "file_id": 3,
+ "content": "[^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4.\n**Train:**\nSettings for controlling the training hyperparameters.\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `epochs` | No | `20` | The number of epochs in the training run. |\n| `lr` | No | `1e-4` | The learning rate. |\n| `wd` | No | `0.01` | The weight decay. |\n| `max_grad_norm`| No | `0.5` | The grad norm clipping. |\n| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |\n| `cond_scale` | No | `1.0` | Conditioning scale to use for sampling. Can also be an array of values, one for each unet. |\n| `device` | No | `cuda:0` | The device to train on. |\n| `epoch_samples` | No | `None` | Limits the num",
+ "type": "code",
+ "location": "/configs/README.md:60-74"
+ },
+ "107": {
+ "file_id": 3,
+ "content": "The code provides settings for controlling training hyperparameters, such as the number of epochs, learning rate, weight decay, and grad norm clipping. It also allows saving checkpoints at specific intervals and specifying the device to train on. The conditioning scale can be customized for each unet if desired.",
+ "type": "comment"
+ },
+ "108": {
+ "file_id": 3,
+ "content": "ber of samples iterated through in each epoch. This must be set if resampling. None means no limit. |\n| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |\n| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |\n| `ema_beta` | No | `0.99` | The ema coefficient. |\n| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |\n**Evaluate:**\nDefines which evaluation metrics will be used to test the model.\nEach metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |\n| `FID` | No | `None` | Setting to",
+ "type": "code",
+ "location": "/configs/README.md:74-87"
+ },
+ "109": {
+ "file_id": 3,
+ "content": "The code snippet defines configurations for training a DALLE2 model in PyTorch. It includes settings such as the number of samples iterated through in each epoch, number of validation samples, whether to use exponential moving average models for sampling, and the ema coefficient. Additionally, it allows defining which evaluation metrics will be used to test the model by setting their configurations using torchmetrics constructors. The number of samples generated to test the model is also specified.",
+ "type": "comment"
+ },
+ "110": {
+ "file_id": 3,
+ "content": " an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric. \n| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.\n| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |\n| `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. |\n**Tracker:**\nSelects how the experiment will be tracked.\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `data_path` | No | `./.tracker-data` | The path to the folder where temporary tracker data will be saved. |\n| `overwrite_data_path` | No | `False` | If true, the data path will be overwritten. Otherwise, you need to delete it yourself. |",
+ "type": "code",
+ "location": "/configs/README.md:87-98"
+ },
+ "111": {
+ "file_id": 3,
+ "content": "This code snippet is from the configs/README.md file of the DALLE2-pytorch project. It describes how to enable different image metrics and set up experiment tracking. The available metrics are Frechet Inception Distance, Inception Score, Kernel Inception Distance, and Learned Perceptual Image Patch Similarity. The tracker can be configured with data_path and overwrite_data_path options for storing temporary tracking data.",
+ "type": "comment"
+ },
+ "112": {
+ "file_id": 3,
+ "content": "| `log` | Yes | N/A | Logging configuration. |\n| `load` | No | `None` | Checkpoint loading configuration. |\n| `save` | Yes | N/A | Checkpoint/Model saving configuration. |\nTracking is split up into three sections:\n* Log: Where to save run metadata and image output. Options are `console` or `wandb`.\n* Load: Where to load a checkpoint from. Options are `local`, `url`, or `wandb`.\n* Save: Where to save a checkpoint to. Options are `local`, `huggingface`, or `wandb`.\n**Logging:**\nAll loggers have the following keys:\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `log_type` | Yes | N/A | The type of logger class to use. |\n| `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. |\n| `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. |\nIf using `console` there is no further configuration than setting `log_type` to `console`.",
+ "type": "code",
+ "location": "/configs/README.md:99-116"
+ },
+ "113": {
+ "file_id": 3,
+ "content": "The code defines configuration settings for logging, loading checkpoints, and saving checkpoints in a DALLE2-pytorch application. The logging section allows specifying where to save run metadata and image output (options: console or wandb). Loading can be from local, URL, or Wandb sources. Saving can be done locally, on HuggingFace, or via Wandb. Loggers have options for resume and auto-resume functions. If using console logging, only the log_type needs to be set as console.",
+ "type": "comment"
+ },
+ "114": {
+ "file_id": 3,
+ "content": "| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `log_type` | Yes | N/A | Must be `console`. |\nIf using `wandb`\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `log_type` | Yes | N/A | Must be `wandb`. |\n| `wandb_entity` | Yes | N/A | The wandb entity to log to. |\n| `wandb_project` | Yes | N/A | The wandb project save the run to. |\n| `wandb_run_name` | No | `None` | The wandb run name. |\n| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |\n**Loading:**\nAll loaders have the following keys:\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `load_from` | Yes | N/A | The type of loader class to use. |\n| `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. |\nIf using `local`\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `load_from` | Yes | N/A | Must be `local`. |",
+ "type": "code",
+ "location": "/configs/README.md:117-141"
+ },
+ "115": {
+ "file_id": 3,
+ "content": "This code is defining the configuration options for logging and loading in a DALLE2-pytorch application. The user has to specify the log type (console or wandb) along with other required and optional parameters depending on the selected logger. The loaders have options to specify the loader class type (e.g., local) and whether to only auto resume if the run is being resumed.",
+ "type": "comment"
+ },
+ "116": {
+ "file_id": 3,
+ "content": "| `file_path` | Yes | N/A | The path to the checkpoint file. |\nIf using `url`\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `load_from` | Yes | N/A | Must be `url`. |\n| `url` | Yes | N/A | The url of the checkpoint file. |\nIf using `wandb`\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `load_from` | Yes | N/A | Must be `wandb`. |\n| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the run that is being resumed. |\n| `wandb_file_path` | Yes | N/A | The path to the checkpoint file in the W&B file system. |\n**Saving:**\nUnlike `log` and `load`, `save` may be an array of options so that you can save to different locations in a run.\nAll save locations have these configuration options\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |\n| `save_latest_to` | No | `None` | Sets the relative path to save the latest model to. |",
+ "type": "code",
+ "location": "/configs/README.md:142-164"
+ },
+ "117": {
+ "file_id": 3,
+ "content": "The code defines the options for loading and saving checkpoint files. It supports loading from a file path, URL or WandB run, with each option having specific required configurations. Saving to different locations is also supported through options like local, huggingface, or wandb, with additional configuration possibilities.",
+ "type": "comment"
+ },
+ "118": {
+ "file_id": 3,
+ "content": "| `save_best_to` | No | `None` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |\n| `save_meta_to` | No | `None` | The path to save metadata files in. This includes the config files used to start the training. |\n| `save_type` | No | `checkpoint` | The type of save. `checkpoint` saves a checkpoint, `model` saves a model without any fluff (Saves with ema if ema is enabled). |\nIf using `local`\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `save_to` | Yes | N/A | Must be `local`. |\nIf using `huggingface`\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `save_to` | Yes | N/A | Must be `huggingface`. |\n| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |\n| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |\nIf using `wandb`\n| Option | Required | Default | Description |",
+ "type": "code",
+ "location": "/configs/README.md:165-182"
+ },
+ "119": {
+ "file_id": 3,
+ "content": "This code sets options for saving models and metadata during training. It allows saving to local, huggingface or wandb storage with specific requirements for each option. The save type can be checkpoint or model, and there are additional options like saving best models, token file path, and repository paths.",
+ "type": "comment"
+ },
+ "120": {
+ "file_id": 3,
+ "content": "| ------ | -------- | ------- | ----------- |\n| `save_to` | Yes | N/A | Must be `wandb`. |\n| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the current run. You will almost always want this to be `None`. |",
+ "type": "code",
+ "location": "/configs/README.md:183-185"
+ },
+ "121": {
+ "file_id": 3,
+ "content": "The code defines configuration options for saving and interacting with the Weights & Biases (Wandb) run path. If `save_to` is set to `wandb`, the `wandb_run_path` should be `None`. Otherwise, it defaults to the current run if `wandb_run_path` is set to `None`.",
+ "type": "comment"
+ },
+ "122": {
+ "file_id": 4,
+ "content": "/dalle2_pytorch/__init__.py",
+ "type": "filepath"
+ },
+ "123": {
+ "file_id": 4,
+ "content": "This code is importing modules from the DALLE2-pytorch library, which includes the main DALLE2 class, diffusion prior network, unet, decoder, clip adapters, trainer for the decoder and diffusion prior, and VQGanVAE. The x_clip module also appears to be imported, but its purpose is not explicitly described in this chunk of code.",
+ "type": "summary"
+ },
+ "124": {
+ "file_id": 4,
+ "content": "from dalle2_pytorch.version import __version__\nfrom dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder\nfrom dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter\nfrom dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer\nfrom dalle2_pytorch.vqgan_vae import VQGanVAE\nfrom x_clip import CLIP",
+ "type": "code",
+ "location": "/dalle2_pytorch/__init__.py:1-7"
+ },
+ "125": {
+ "file_id": 4,
+ "content": "This code is importing modules from the DALLE2-pytorch library, which includes the main DALLE2 class, diffusion prior network, unet, decoder, clip adapters, trainer for the decoder and diffusion prior, and VQGanVAE. The x_clip module also appears to be imported, but its purpose is not explicitly described in this chunk of code.",
+ "type": "comment"
+ },
+ "126": {
+ "file_id": 5,
+ "content": "/dalle2_pytorch/cli.py",
+ "type": "filepath"
+ },
+ "127": {
+ "file_id": 5,
+ "content": "This code imports libraries, defines functions, and parses command-line arguments for model path, conditioning scale, and input text. It loads a DALL-E2 model, generates an image based on the input text, saves it in PIL format, and returns the saved image.",
+ "type": "summary"
+ },
+ "128": {
+ "file_id": 5,
+ "content": "import click\nimport torch\nimport torchvision.transforms as T\nfrom functools import reduce\nfrom pathlib import Path\nfrom dalle2_pytorch import DALLE2, Decoder, DiffusionPrior\ndef safeget(dictionary, keys, default = None):\n return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)\ndef simple_slugify(text, max_length = 255):\n return text.replace(\"-\", \"_\").replace(\",\", \"\").replace(\" \", \"_\").replace(\"|\", \"--\").strip('-_')[:max_length]\ndef get_pkg_version():\n from pkg_resources import get_distribution\n return get_distribution('dalle2_pytorch').version\ndef main():\n pass\n@click.command()\n@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')\n@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')\n@click.argument('text')\ndef dream(\n model,\n cond_scale,\n text\n):\n model_path = Path(model)\n full_model_path = str(model_path.resolve())\n assert model_path.exists(), f'model not found at {full_model_path}'",
+ "type": "code",
+ "location": "/dalle2_pytorch/cli.py:1-33"
+ },
+ "129": {
+ "file_id": 5,
+ "content": "This code imports necessary libraries, defines some utility functions and a main function. It also includes a command-line argument parser with options for model path, conditioning scale, and the text input. The assert statement ensures that the specified model file exists before proceeding.",
+ "type": "comment"
+ },
+ "130": {
+ "file_id": 5,
+ "content": " loaded = torch.load(str(model_path))\n version = safeget(loaded, 'version')\n print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')\n prior_init_params = safeget(loaded, 'init_params.prior')\n decoder_init_params = safeget(loaded, 'init_params.decoder')\n model_params = safeget(loaded, 'model_params')\n prior = DiffusionPrior(**prior_init_params)\n decoder = Decoder(**decoder_init_params)\n dalle2 = DALLE2(prior, decoder)\n dalle2.load_state_dict(model_params)\n image = dalle2(text, cond_scale = cond_scale)\n pil_image = T.ToPILImage()(image)\n return pil_image.save(f'./{simple_slugify(text)}.png')",
+ "type": "code",
+ "location": "/dalle2_pytorch/cli.py:34-52"
+ },
+ "131": {
+ "file_id": 5,
+ "content": "This code loads a saved DALL-E2 model from a specified path, checks the version, initializes the prior and decoder components, recreates the model using these components, loads its parameters, generates an image based on input text, converts it to PIL format, saves it with a file name derived from the input text, and returns the saved image.",
+ "type": "comment"
+ },
+ "132": {
+ "file_id": 6,
+ "content": "/dalle2_pytorch/dalle2_pytorch.py",
+ "type": "filepath"
+ },
+ "133": {
+ "file_id": 6,
+ "content": "The code uses VQGAN-VAE, CLIP, and CoCa libraries for image generation, and includes helper functions, PyTorch CLIP model, neural networks, DALL-E 2 architecture, self-attention layers with normalization and dropout regularization. It initializes efficient DALL-E 2 and Imagen models, utilizes diffusion models for denoising and inpainting images, and incorporates conditional sampling from DALLE2-pytorch model for low-resolution image generation.",
+ "type": "summary"
+ },
+ "134": {
+ "file_id": 6,
+ "content": "import math\nimport random\nfrom tqdm.auto import tqdm\nfrom functools import partial, wraps\nfrom contextlib import contextmanager\nfrom collections import namedtuple\nfrom pathlib import Path\nimport torch\nimport torch.nn.functional as F\nfrom torch.utils.checkpoint import checkpoint\nfrom torch import nn, einsum\nimport torchvision.transforms as T\nfrom einops import rearrange, repeat, reduce, pack, unpack\nfrom einops.layers.torch import Rearrange\nfrom kornia.filters import gaussian_blur2d\nimport kornia.augmentation as K\nfrom dalle2_pytorch.tokenizer import tokenizer\nfrom dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE\nfrom resize_right import resize\n# rotary embeddings\nfrom rotary_embedding_torch import RotaryEmbedding\n# use x-clip\nfrom x_clip import CLIP\nfrom coca_pytorch import CoCa\n# constants\nNAT = 1. / math.log(2.)\nUnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized'])\n# helper functions\ndef exists(val):\n return val is not None\ndef identity(t, *args, **kwargs):\n return t\ndef first(arr, d = None):",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1-49"
+ },
+ "135": {
+ "file_id": 6,
+ "content": "This code imports various libraries and defines functions for data processing, including image resizing, Gaussian blurring, and rotary embeddings. It also utilizes the VQGAN-VAE, CLIP model, and CoCa. The code contains namedtuples, helper functions, and constants relevant to the tasks of image generation and language modeling.",
+ "type": "comment"
+ },
+ "136": {
+ "file_id": 6,
+ "content": " if len(arr) == 0:\n return d\n return arr[0]\ndef maybe(fn):\n @wraps(fn)\n def inner(x, *args, **kwargs):\n if not exists(x):\n return x\n return fn(x, *args, **kwargs)\n return inner\ndef default(val, d):\n if exists(val):\n return val\n return d() if callable(d) else d\ndef cast_tuple(val, length = None, validate = True):\n if isinstance(val, list):\n val = tuple(val)\n out = val if isinstance(val, tuple) else ((val,) * default(length, 1))\n if exists(length) and validate:\n assert len(out) == length\n return out\ndef module_device(module):\n if isinstance(module, nn.Identity):\n return 'cpu' # It doesn't matter\n return next(module.parameters()).device\ndef zero_init_(m):\n nn.init.zeros_(m.weight)\n if exists(m.bias):\n nn.init.zeros_(m.bias)\n@contextmanager\ndef null_context(*args, **kwargs):\n yield\ndef eval_decorator(fn):\n def inner(model, *args, **kwargs):\n was_training = model.training\n model.eval()\n out = fn(model, *args, **kwargs)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:50-96"
+ },
+ "137": {
+ "file_id": 6,
+ "content": "Function 'if len(arr) == 0: return d' checks if the array is empty and returns the value 'd' if it is.\n'maybe(fn)' function creates a decorator that checks if the input exists, returning it if it does not.\n'default(val, d)' function returns the provided value 'val' if it exists; otherwise, it returns the default value 'd'.\n'cast_tuple(val, length=None, validate=True)' casts its argument to a tuple and optionally checks its length.\n'module_device(module)' retrieves the device of the module, defaulting to CPU for certain types like nn.Identity.\n'zero_init_(m)' initializes the weights and biases of the given module 'm' with zeros.\n'null_context(*args, **kwargs)' is a context manager that does nothing.\n'eval_decorator(fn)' wraps a function to evaluate the model before executing it.",
+ "type": "comment"
+ },
+ "138": {
+ "file_id": 6,
+ "content": " model.train(was_training)\n return out\n return inner\ndef is_float_dtype(dtype):\n return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])\ndef is_list_str(x):\n if not isinstance(x, (list, tuple)):\n return False\n return all([type(el) == str for el in x])\ndef pad_tuple_to_length(t, length, fillvalue = None):\n remain_length = length - len(t)\n if remain_length <= 0:\n return t\n return (*t, *((fillvalue,) * remain_length))\n# checkpointing helper function\ndef make_checkpointable(fn, **kwargs):\n if isinstance(fn, nn.ModuleList):\n return [maybe(make_checkpointable)(el, **kwargs) for el in fn]\n condition = kwargs.pop('condition', None)\n if exists(condition) and not condition(fn):\n return fn\n @wraps(fn)\n def inner(*args):\n input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])\n if not input_needs_grad:\n return fn(*args)\n return checkpoint(fn, *args)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:97-133"
+ },
+ "139": {
+ "file_id": 6,
+ "content": "This code defines several helper functions for processing lists of strings, padding tuples to a specific length, and creating checkpointable versions of Python functions. It also includes a function to determine if a given dtype is a floating point type, and a conditional wrapper for creating a checkpointable version of a function or module list.",
+ "type": "comment"
+ },
+ "140": {
+ "file_id": 6,
+ "content": " return inner\n# for controlling freezing of CLIP\ndef set_module_requires_grad_(module, requires_grad):\n for param in module.parameters():\n param.requires_grad = requires_grad\ndef freeze_all_layers_(module):\n set_module_requires_grad_(module, False)\ndef unfreeze_all_layers_(module):\n set_module_requires_grad_(module, True)\ndef freeze_model_and_make_eval_(model):\n model.eval()\n freeze_all_layers_(model)\n# tensor helpers\ndef log(t, eps = 1e-12):\n return torch.log(t.clamp(min = eps))\ndef l2norm(t):\n return F.normalize(t, dim = -1)\ndef resize_image_to(\n image,\n target_image_size,\n clamp_range = None,\n nearest = False,\n **kwargs\n):\n orig_image_size = image.shape[-1]\n if orig_image_size == target_image_size:\n return image\n if not nearest:\n scale_factors = target_image_size / orig_image_size\n out = resize(image, scale_factors = scale_factors, **kwargs)\n else:\n out = F.interpolate(image, target_image_size, mode = 'nearest')\n if exists(clamp_range):",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:135-179"
+ },
+ "141": {
+ "file_id": 6,
+ "content": "The code defines functions for controlling the gradient flow in a module, freezing all layers in a model, and making it evaluate only. It also includes helper functions to log a tensor, normalize a tensor using L2 norm, and resize an image to the specified size with optional interpolation method.",
+ "type": "comment"
+ },
+ "142": {
+ "file_id": 6,
+ "content": " out = out.clamp(*clamp_range)\n return out\n# image normalization functions\n# ddpms expect images to be in the range of -1 to 1\n# but CLIP may otherwise\ndef normalize_neg_one_to_one(img):\n return img * 2 - 1\ndef unnormalize_zero_to_one(normed_img):\n return (normed_img + 1) * 0.5\n# clip related adapters\nEmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings'])\nEmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])\nclass BaseClipAdapter(nn.Module):\n def __init__(self, clip, **kwargs):\n super().__init__()\n self.clip = clip\n self.overrides = kwargs\n def validate_and_resize_image(self, image):\n image_size = image.shape[-1]\n assert image_size >= self.image_size, f'you are passing in an image of size {image_size} but CLIP requires the image size to be at least {self.image_size}'\n return resize_image_to(image, self.image_size)\n @property\n def dim_latent(self):\n raise NotImplementedError\n @property",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:180-214"
+ },
+ "143": {
+ "file_id": 6,
+ "content": "This code defines a function for normalizing an image to the range of -1 to 1, and another for unnormalizing it back to the 0 to 1 range. It also includes a namedtuple for returning embedded text and image data along with their encodings. The code further defines a base class for clip adapters that takes a CLIP model as an argument and provides methods for validating and resizing images to match CLIP's requirements.",
+ "type": "comment"
+ },
+ "144": {
+ "file_id": 6,
+ "content": " def image_size(self):\n raise NotImplementedError\n @property\n def image_channels(self):\n raise NotImplementedError\n @property\n def max_text_len(self):\n raise NotImplementedError\n def embed_text(self, text):\n raise NotImplementedError\n def embed_image(self, image):\n raise NotImplementedError\nclass XClipAdapter(BaseClipAdapter):\n @property\n def dim_latent(self):\n return self.clip.dim_latent\n @property\n def image_size(self):\n return self.clip.image_size\n @property\n def image_channels(self):\n return self.clip.image_channels\n @property\n def max_text_len(self):\n return self.clip.text_seq_len\n @torch.no_grad()\n def embed_text(self, text):\n text = text[..., :self.max_text_len]\n text_mask = text != 0\n encoder_output = self.clip.text_transformer(text)\n encoder_output_is_cls = encoder_output.ndim == 3\n text_cls, text_encodings = (encoder_output[:, 0], encoder_output[:, 1:]) if encoder_output_is_cls else (encoder_output, None)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:215-257"
+ },
+ "145": {
+ "file_id": 6,
+ "content": "This code defines a base class `BaseClipAdapter` with four methods that must be implemented by derived classes. The `XClipAdapter` class inherits from `BaseClipAdapter` and provides implementations for the properties of the underlying `clip` object, which is an instance of some clip model. The `embed_text` method takes a text input, truncates it to fit the maximum text length defined by `max_text_len`, applies a text transformer from the `clip` object, and returns the embeddings.",
+ "type": "comment"
+ },
+ "146": {
+ "file_id": 6,
+ "content": " text_embed = self.clip.to_text_latent(text_cls)\n if exists(text_encodings):\n text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)\n return EmbeddedText(l2norm(text_embed), text_encodings)\n @torch.no_grad()\n def embed_image(self, image):\n image = self.validate_and_resize_image(image)\n encoder_output = self.clip.visual_transformer(image)\n image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]\n image_embed = self.clip.to_visual_latent(image_cls)\n return EmbeddedImage(l2norm(image_embed), image_encodings)\nclass CoCaAdapter(BaseClipAdapter):\n @property\n def dim_latent(self):\n return self.clip.dim\n @property\n def image_size(self):\n assert 'image_size' in self.overrides\n return self.overrides['image_size']\n @property\n def image_channels(self):\n assert 'image_channels' in self.overrides\n return self.overrides['image_channels']\n @property\n def max_text_len(self):",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:258-289"
+ },
+ "147": {
+ "file_id": 6,
+ "content": "This code snippet defines a class called CoCaAdapter, which is a base adapter for the DALL-E 2 PyTorch model. It contains methods to embed text and images, with optional overrides for image size and channels. The dim_latent property returns the dimension of the latent space, while max_text_len is used to set the maximum length for text inputs.",
+ "type": "comment"
+ },
+ "148": {
+ "file_id": 6,
+ "content": " assert 'max_text_len' in self.overrides\n return self.overrides['max_text_len']\n @torch.no_grad()\n def embed_text(self, text):\n text = text[..., :self.max_text_len]\n text_mask = text != 0\n text_embed, text_encodings = self.clip.embed_text(text)\n text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)\n return EmbeddedText(text_embed, text_encodings)\n @torch.no_grad()\n def embed_image(self, image):\n image = self.validate_and_resize_image(image)\n image_embed, image_encodings = self.clip.embed_image(image)\n return EmbeddedImage(image_embed, image_encodings)\nclass OpenAIClipAdapter(BaseClipAdapter):\n def __init__(\n self,\n name = 'ViT-B/32'\n ):\n import clip\n openai_clip, preprocess = clip.load(name)\n super().__init__(openai_clip)\n self.eos_id = 49407 # for handling 0 being also '!'\n text_attention_final = self.find_layer('ln_final')\n self.dim_latent_ = text_attention_final.weight.shape[0]",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:290-319"
+ },
+ "149": {
+ "file_id": 6,
+ "content": "This code is for a text-to-image model that uses CLIP as its base. It has functions to embed texts and images, with the ability to handle maximum text length. It initializes an OpenAIClipAdapter class using CLIP's 'ViT-B/32' model and finds the layer for text attention final output.",
+ "type": "comment"
+ },
+ "150": {
+ "file_id": 6,
+ "content": " self.handle = text_attention_final.register_forward_hook(self._hook)\n self.clip_normalize = preprocess.transforms[-1]\n self.cleared = False\n def find_layer(self, layer):\n modules = dict([*self.clip.named_modules()])\n return modules.get(layer, None)\n def clear(self):\n if self.cleared:\n return\n self.handle()\n def _hook(self, _, inputs, outputs):\n self.text_encodings = outputs\n @property\n def dim_latent(self):\n return self.dim_latent_\n @property\n def image_size(self):\n return self.clip.visual.input_resolution\n @property\n def image_channels(self):\n return 3\n @property\n def max_text_len(self):\n return self.clip.context_length\n @torch.no_grad()\n def embed_text(self, text):\n text = text[..., :self.max_text_len]\n is_eos_id = (text == self.eos_id)\n text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0\n text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:320-360"
+ },
+ "151": {
+ "file_id": 6,
+ "content": "This code is part of a neural network model for text-to-image generation using PyTorch. It includes functions to handle text attention, clear the internal state, and embed input text. The class has properties such as `dim_latent`, `image_size`, `image_channels`, `max_text_len` which are used to define the network's structure and behavior.",
+ "type": "comment"
+ },
+ "152": {
+ "file_id": 6,
+ "content": " text_mask = text_mask & (text != 0)\n assert not self.cleared\n text_embed = self.clip.encode_text(text)\n text_encodings = self.text_encodings\n text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)\n del self.text_encodings\n return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())\n @torch.no_grad()\n def embed_image(self, image):\n assert not self.cleared\n image = self.validate_and_resize_image(image)\n image = self.clip_normalize(image)\n image_embed = self.clip.encode_image(image)\n return EmbeddedImage(l2norm(image_embed.float()), None)\nclass OpenClipAdapter(BaseClipAdapter):\n def __init__(\n self,\n name = 'ViT-B/32',\n pretrained = 'laion400m_e32'\n ):\n import open_clip\n clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)\n super().__init__(clip)\n self.eos_id = 49407\n text_attention_final = self.find_layer('ln_final')",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:361-390"
+ },
+ "153": {
+ "file_id": 6,
+ "content": "Method to embed text using CLIP model by encoding the input text, applying a mask on text encodings, and returning EmbeddedText object with L2 normalized text embedding and float text encodings.",
+ "type": "comment"
+ },
+ "154": {
+ "file_id": 6,
+ "content": " self._dim_latent = text_attention_final.weight.shape[0]\n self.handle = text_attention_final.register_forward_hook(self._hook)\n self.clip_normalize = preprocess.transforms[-1]\n self.cleared = False\n def find_layer(self, layer):\n modules = dict([*self.clip.named_modules()])\n return modules.get(layer, None)\n def clear(self):\n if self.cleared:\n return\n self.handle()\n def _hook(self, _, inputs, outputs):\n self.text_encodings = outputs\n @property\n def dim_latent(self):\n return self._dim_latent\n @property\n def image_size(self):\n image_size = self.clip.visual.image_size\n if isinstance(image_size, tuple):\n return max(image_size)\n return image_size\n @property\n def image_channels(self):\n return 3\n @property\n def max_text_len(self):\n return self.clip.context_length\n @torch.no_grad()\n def embed_text(self, text):\n text = text[..., :self.max_text_len]\n is_eos_id = (text == self.eos_id)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:391-433"
+ },
+ "155": {
+ "file_id": 6,
+ "content": "The code represents a class that appears to be a part of a larger model. It has methods for embedding text, clearing internal state, finding layers in the network, and retrieving properties like latent dimension and maximum text length. The class relies on other components such as `preprocess`, `clip`, and `image_size`.",
+ "type": "comment"
+ },
+ "156": {
+ "file_id": 6,
+ "content": " text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0\n text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)\n text_mask = text_mask & (text != 0)\n assert not self.cleared\n text_embed = self.clip.encode_text(text)\n text_encodings = self.text_encodings\n text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)\n del self.text_encodings\n return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())\n @torch.no_grad()\n def embed_image(self, image):\n assert not self.cleared\n image = self.validate_and_resize_image(image)\n image = self.clip_normalize(image)\n image_embed = self.clip.encode_image(image)\n return EmbeddedImage(l2norm(image_embed.float()), None)\n# classifier free guidance functions\ndef prob_mask_like(shape, prob, device):\n if prob == 1:\n return torch.ones(shape, device = device, dtype = torch.bool)\n elif prob == 0:\n return torch.zeros(shape, device = device, dtype = torch.bool)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:434-459"
+ },
+ "157": {
+ "file_id": 6,
+ "content": "This function takes in a text input and returns an EmbeddedText object containing the embedded text representation and a corresponding mask. It first creates a mask excluding the end of sentence (EOS) token, pads it, and applies the mask to the original mask. Then, it encodes the text using CLIP's encode_text function, and finally normalizes the resulting embeddings. The classifier free guidance functions return a probability mask based on the given probability value for a specific shape and device.",
+ "type": "comment"
+ },
+ "158": {
+ "file_id": 6,
+ "content": " else:\n return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob\n# gaussian diffusion helper functions\ndef extract(a, t, x_shape):\n b, *_ = t.shape\n out = a.gather(-1, t)\n return out.reshape(b, *((1,) * (len(x_shape) - 1)))\ndef meanflat(x):\n return x.mean(dim = tuple(range(1, len(x.shape))))\ndef normal_kl(mean1, logvar1, mean2, logvar2):\n return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))\ndef approx_standard_normal_cdf(x):\n return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3))))\ndef discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):\n assert x.shape == means.shape == log_scales.shape\n # attempting to correct nan gradients when learned variance is turned on\n # in the setting of deepspeed fp16\n eps = 1e-12 if x.dtype == torch.float32 else 1e-3\n centered_x = x - means\n inv_stdv = torch.exp(-log_scales)\n plus_in = inv_stdv * (centered_x + 1. / 255.)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:460-488"
+ },
+ "159": {
+ "file_id": 6,
+ "content": "This code defines several helper functions used in the DALLE2-pytorch model. These functions are involved in tasks such as extracting values, calculating normal KL divergence, approximating the standard normal cumulative distribution function, and computing the discretized Gaussian log likelihood. The code also includes error handling for potential nan gradients when using deepspeed fp16.",
+ "type": "comment"
+ },
+ "160": {
+ "file_id": 6,
+ "content": " cdf_plus = approx_standard_normal_cdf(plus_in)\n min_in = inv_stdv * (centered_x - 1. / 255.)\n cdf_min = approx_standard_normal_cdf(min_in)\n log_cdf_plus = log(cdf_plus, eps = eps)\n log_one_minus_cdf_min = log(1. - cdf_min, eps = eps)\n cdf_delta = cdf_plus - cdf_min\n log_probs = torch.where(x < -thres,\n log_cdf_plus,\n torch.where(x > thres,\n log_one_minus_cdf_min,\n log(cdf_delta, eps = eps)))\n return log_probs\ndef cosine_beta_schedule(timesteps, s = 0.008):\n \"\"\"\n cosine schedule\n as proposed in https://openreview.net/forum?id=-NEXDKk8gZ\n \"\"\"\n steps = timesteps + 1\n x = torch.linspace(0, timesteps, steps, dtype = torch.float64)\n alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2\n alphas_cumprod = alphas_cumprod / first(alphas_cumprod)\n betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])\n return torch.clip(betas, 0, 0.999)\ndef linear_beta_schedule(timesteps):\n scale = 1000 / timesteps\n beta_start = scale * 0.0001",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:489-519"
+ },
+ "161": {
+ "file_id": 6,
+ "content": "Function at line 488-518 calculates log probabilities for a given input x, using an adaptive quantile regression approach with a cosine or linear schedule. The cosine_beta_schedule function generates a sequence of beta values using a cosine schedule, and the linear_beta_schedule function generates a sequence of beta values linearly.",
+ "type": "comment"
+ },
+ "162": {
+ "file_id": 6,
+ "content": " beta_end = scale * 0.02\n return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)\ndef quadratic_beta_schedule(timesteps):\n scale = 1000 / timesteps\n beta_start = scale * 0.0001\n beta_end = scale * 0.02\n return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2\ndef sigmoid_beta_schedule(timesteps):\n scale = 1000 / timesteps\n beta_start = scale * 0.0001\n beta_end = scale * 0.02\n betas = torch.linspace(-6, 6, timesteps, dtype = torch.float64)\n return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start\nclass NoiseScheduler(nn.Module):\n def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):\n super().__init__()\n if beta_schedule == \"cosine\":\n betas = cosine_beta_schedule(timesteps)\n elif beta_schedule == \"linear\":\n betas = linear_beta_schedule(timesteps)\n elif beta_schedule == \"quadratic\":\n betas = quadratic_beta_schedule(timesteps)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:520-548"
+ },
+ "163": {
+ "file_id": 6,
+ "content": "This code defines three beta scheduling functions (linear, quadratic, cosine) and a class for the NoiseScheduler. The scheduler initializes with a selected beta schedule and timesteps. The beta_schedule parameter determines which function to use for generating the betas, which represent noise scaling factors in the model's training process.",
+ "type": "comment"
+ },
+ "164": {
+ "file_id": 6,
+ "content": " elif beta_schedule == \"jsd\":\n betas = 1.0 / torch.linspace(timesteps, 1, timesteps)\n elif beta_schedule == \"sigmoid\":\n betas = sigmoid_beta_schedule(timesteps)\n else:\n raise NotImplementedError()\n alphas = 1. - betas\n alphas_cumprod = torch.cumprod(alphas, axis = 0)\n alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)\n timesteps, = betas.shape\n self.num_timesteps = int(timesteps)\n if loss_type == 'l1':\n loss_fn = F.l1_loss\n elif loss_type == 'l2':\n loss_fn = F.mse_loss\n elif loss_type == 'huber':\n loss_fn = F.smooth_l1_loss\n else:\n raise NotImplementedError()\n self.loss_type = loss_type\n self.loss_fn = loss_fn\n # register buffer helper function to cast double back to float\n register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))\n register_buffer('betas', betas)\n register_buffer('alphas_cumprod', alphas_cumprod)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:549-580"
+ },
+ "165": {
+ "file_id": 6,
+ "content": "This code sets the beta schedule and alpha values based on user input, then selects a loss function according to the specified type. The code also registers buffer helper functions for 'betas' and 'alphas_cumprod'.",
+ "type": "comment"
+ },
+ "166": {
+ "file_id": 6,
+ "content": " register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)\n # calculations for diffusion q(x_t | x_{t-1}) and others\n register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))\n register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))\n register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))\n register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))\n register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))\n # calculations for posterior q(x_{t-1} | x_t, x_0)\n posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)\n # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n register_buffer('posterior_variance', posterior_variance)\n # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:581-601"
+ },
+ "167": {
+ "file_id": 6,
+ "content": "The code is registering various buffers for computations related to diffusion. It calculates the posterior variance and clips the log of the posterior variance to avoid numerical instability at the beginning of the diffusion chain.",
+ "type": "comment"
+ },
+ "168": {
+ "file_id": 6,
+ "content": " register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))\n register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))\n # p2 loss reweighting\n self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.\n register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)\n def sample_random_times(self, batch):\n return torch.randint(0, self.num_timesteps, (batch,), device = self.betas.device, dtype = torch.long)\n def q_posterior(self, x_start, x_t, t):\n posterior_mean = (\n extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +\n extract(self.posterior_mean_coef2, t, x_t.shape) * x_t\n )\n posterior_variance = extract(self.posterior_variance, t, x_t.shape)\n posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)\n return posterior_mean, posterior_variance, posterior_log_variance_clipped",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:602-620"
+ },
+ "169": {
+ "file_id": 6,
+ "content": "In this code segment, the author is computing posterior means for a model, performing loss reweighting, generating random times, and calculating posterior values. The posterior means are calculated based on betas and alphas, while the loss reweighting considers p2_loss_weight_gamma. Random times are sampled for a batch of inputs using torch.randint. The q_posterior function calculates posterior mean, variance, and log-variance clipped from these computed values.",
+ "type": "comment"
+ },
+ "170": {
+ "file_id": 6,
+ "content": " def q_sample(self, x_start, t, noise = None):\n noise = default(noise, lambda: torch.randn_like(x_start))\n return (\n extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +\n extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise\n )\n def calculate_v(self, x_start, t, noise = None):\n return (\n extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -\n extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start\n )\n def q_sample_from_to(self, x_from, from_t, to_t, noise = None):\n shape = x_from.shape\n noise = default(noise, lambda: torch.randn_like(x_from))\n alpha = extract(self.sqrt_alphas_cumprod, from_t, shape)\n sigma = extract(self.sqrt_one_minus_alphas_cumprod, from_t, shape)\n alpha_next = extract(self.sqrt_alphas_cumprod, to_t, shape)\n sigma_next = extract(self.sqrt_one_minus_alphas_cumprod, to_t, shape)\n return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:622-645"
+ },
+ "171": {
+ "file_id": 6,
+ "content": "The code defines three functions: `q_sample`, `calculate_v`, and `q_sample_from_to`. These functions are part of a neural network for generating images. `q_sample` combines alpha and noise values to generate a sample, while `calculate_v` calculates the difference between an alpha-blended noise and a one minus alpha-blended image start. The `q_sample_from_to` function samples from one timestep to another by interpolating alphas and sigmas.",
+ "type": "comment"
+ },
+ "172": {
+ "file_id": 6,
+ "content": " def predict_start_from_v(self, x_t, t, v):\n return (\n extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -\n extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v\n )\n def predict_start_from_noise(self, x_t, t, noise):\n return (\n extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -\n extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise\n )\n def predict_noise_from_start(self, x_t, t, x0):\n return (\n (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \\\n extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)\n )\n def p2_reweigh_loss(self, loss, times):\n if not self.has_p2_loss_reweighting:\n return loss\n return loss * extract(self.p2_loss_weight, times, loss.shape)\n# rearrange image to sequence\nclass RearrangeToSequence(nn.Module):\n def __init__(self, fn):\n super().__init__()\n self.fn = fn\n def forward(self, x):",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:647-677"
+ },
+ "173": {
+ "file_id": 6,
+ "content": "The code defines three methods for predicting values from different inputs, including v and noise. It also includes a method to reweight loss using p2_loss_weight and a class to rearrange images into sequences.",
+ "type": "comment"
+ },
+ "174": {
+ "file_id": 6,
+ "content": " x = rearrange(x, 'b c ... -> b ... c')\n x, ps = pack([x], 'b * c')\n x = self.fn(x)\n x, = unpack(x, ps, 'b * c')\n x = rearrange(x, 'b ... c -> b c ...')\n return x\n# diffusion prior\nclass LayerNorm(nn.Module):\n def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):\n super().__init__()\n self.eps = eps\n self.fp16_eps = fp16_eps\n self.stable = stable\n self.g = nn.Parameter(torch.ones(dim))\n def forward(self, x):\n eps = self.eps if x.dtype == torch.float32 else self.fp16_eps\n if self.stable:\n x = x / x.amax(dim = -1, keepdim = True).detach()\n var = torch.var(x, dim = -1, unbiased = False, keepdim = True)\n mean = torch.mean(x, dim = -1, keepdim = True)\n return (x - mean) * (var + eps).rsqrt() * self.g\nclass ChanLayerNorm(nn.Module):\n def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):\n super().__init__()\n self.eps = eps\n self.fp16_eps = fp16_eps",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:678-711"
+ },
+ "175": {
+ "file_id": 6,
+ "content": "This function is applying layer normalization to input tensor 'x' and returning the normalized output. The 'LayerNorm' class is a type of layer normalization, while 'ChanLayerNorm' is a channel-wise version. The code includes settings for epsilon, float precision, and stability options.",
+ "type": "comment"
+ },
+ "176": {
+ "file_id": 6,
+ "content": " self.stable = stable\n self.g = nn.Parameter(torch.ones(1, dim, 1, 1))\n def forward(self, x):\n eps = self.eps if x.dtype == torch.float32 else self.fp16_eps\n if self.stable:\n x = x / x.amax(dim = 1, keepdim = True).detach()\n var = torch.var(x, dim = 1, unbiased = False, keepdim = True)\n mean = torch.mean(x, dim = 1, keepdim = True)\n return (x - mean) * (var + eps).rsqrt() * self.g\nclass Residual(nn.Module):\n def __init__(self, fn):\n super().__init__()\n self.fn = fn\n def forward(self, x, **kwargs):\n return self.fn(x, **kwargs) + x\n# mlp\nclass MLP(nn.Module):\n def __init__(\n self,\n dim_in,\n dim_out,\n *,\n expansion_factor = 2.,\n depth = 2,\n norm = False,\n ):\n super().__init__()\n hidden_dim = int(expansion_factor * dim_out)\n norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()\n layers = [nn.Sequential(\n nn.Linear(dim_in, hidden_dim),",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:712-750"
+ },
+ "177": {
+ "file_id": 6,
+ "content": "This code defines a Residual class that wraps a function and adds it to the input. It also contains an MLP (Multi-Layer Perceptron) class with optional normalization and activation functions, followed by a series of fully connected layers. The forward method in DALLE2_PyTorch performs normalization, calculates mean and variance, then applies element-wise transformations before returning the output.",
+ "type": "comment"
+ },
+ "178": {
+ "file_id": 6,
+ "content": " nn.SiLU(),\n norm_fn()\n )]\n for _ in range(depth - 1):\n layers.append(nn.Sequential(\n nn.Linear(hidden_dim, hidden_dim),\n nn.SiLU(),\n norm_fn()\n ))\n layers.append(nn.Linear(hidden_dim, dim_out))\n self.net = nn.Sequential(*layers)\n def forward(self, x):\n return self.net(x.float())\n# relative positional bias for causal transformer\nclass RelPosBias(nn.Module):\n def __init__(\n self,\n heads = 8,\n num_buckets = 32,\n max_distance = 128,\n ):\n super().__init__()\n self.num_buckets = num_buckets\n self.max_distance = max_distance\n self.relative_attention_bias = nn.Embedding(num_buckets, heads)\n @staticmethod\n def _relative_position_bucket(\n relative_position,\n num_buckets = 32,\n max_distance = 128\n ):\n n = -relative_position\n n = torch.max(n, torch.zeros_like(n))\n max_exact = num_buckets // 2\n is_small = n < max_exact",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:751-792"
+ },
+ "179": {
+ "file_id": 6,
+ "content": "This code defines a neural network architecture for the DALL-E 2 model. It includes a sequential layer with multiple linear layers, SiLU activation function, and normalization. The forward method performs inference on input data. Another class is defined for relative positional bias in causal transformer. The RelPosBias class initializes an embedding layer to calculate the relative position between elements for attention mechanism. It uses the concept of buckets, where each bucket represents a range of distances between two elements, and computes the relative position bucket based on input data.",
+ "type": "comment"
+ },
+ "180": {
+ "file_id": 6,
+ "content": " val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()\n val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))\n return torch.where(is_small, n, val_if_large)\n def forward(self, i, j, *, device):\n q_pos = torch.arange(i, dtype = torch.long, device = device)\n k_pos = torch.arange(j, dtype = torch.long, device = device)\n rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')\n rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)\n values = self.relative_attention_bias(rp_bucket)\n return rearrange(values, 'i j h -> h i j')\n# feedforward\nclass SwiGLU(nn.Module):\n \"\"\" used successfully in https://arxiv.org/abs/2204.0231 \"\"\"\n def forward(self, x):\n x, gate = x.chunk(2, dim = -1)\n return x * F.silu(gate)\ndef FeedForward(\n dim,\n mult = 4,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:794-816"
+ },
+ "181": {
+ "file_id": 6,
+ "content": "This code snippet defines a class for DALLE2-pytorch, containing a method to calculate relative position buckets and an attention layer. The attention layer uses the SwiGLU activation function in its FeedForward module. The purpose of this code is to facilitate the calculation and application of positional embeddings in a transformer model.",
+ "type": "comment"
+ },
+ "182": {
+ "file_id": 6,
+ "content": " dropout = 0.,\n post_activation_norm = False\n):\n \"\"\" post-activation norm https://arxiv.org/abs/2110.09456 \"\"\"\n inner_dim = int(mult * dim)\n return nn.Sequential(\n LayerNorm(dim),\n nn.Linear(dim, inner_dim * 2, bias = False),\n SwiGLU(),\n LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),\n nn.Dropout(dropout),\n nn.Linear(inner_dim, dim, bias = False)\n )\n# attention\nclass Attention(nn.Module):\n def __init__(\n self,\n dim,\n *,\n dim_head = 64,\n heads = 8,\n dropout = 0.,\n causal = False,\n rotary_emb = None,\n cosine_sim = True,\n cosine_sim_scale = 16\n ):\n super().__init__()\n self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)\n self.cosine_sim = cosine_sim\n self.heads = heads\n inner_dim = dim_head * heads\n self.causal = causal\n self.norm = LayerNorm(dim)\n self.dropout = nn.Dropout(dropout)\n self.null_kv = nn.Parameter(torch.randn(2, dim_head))",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:817-858"
+ },
+ "183": {
+ "file_id": 6,
+ "content": "The code defines a module that applies post-activation normalization. It also includes a nested Attention class that performs multi-head attention with optional causal masking and rotary embedding. The main components include layer normalization, dropout regularization, and linear transformations for dimensionality adjustments. The cosine similarity calculation is utilized if specified.",
+ "type": "comment"
+ },
+ "184": {
+ "file_id": 6,
+ "content": " self.to_q = nn.Linear(dim, inner_dim, bias = False)\n self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)\n self.rotary_emb = rotary_emb\n self.to_out = nn.Sequential(\n nn.Linear(inner_dim, dim, bias = False),\n LayerNorm(dim)\n )\n def forward(self, x, mask = None, attn_bias = None):\n b, n, device = *x.shape[:2], x.device\n x = self.norm(x)\n q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))\n q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)\n q = q * self.scale\n # rotary embeddings\n if exists(self.rotary_emb):\n q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))\n # add null key / value for classifier free guidance in prior net\n nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))\n k = torch.cat((nk, k), dim = -2)\n v = torch.cat((nv, v), dim = -2)\n # whether to use cosine sim\n if self.cosine_sim:",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:859-891"
+ },
+ "185": {
+ "file_id": 6,
+ "content": "This code defines a self-attention layer for DALL·E 2, initializing linear layers and including the option to use rotary embeddings. It also allows for classifier free guidance by adding null key/value pairs and using cosine similarity if enabled.",
+ "type": "comment"
+ },
+ "186": {
+ "file_id": 6,
+ "content": " q, k = map(l2norm, (q, k))\n q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))\n # calculate query / key similarities\n sim = einsum('b h i d, b j d -> b h i j', q, k)\n # relative positional encoding (T5 style)\n if exists(attn_bias):\n sim = sim + attn_bias\n # masking\n max_neg_value = -torch.finfo(sim.dtype).max\n if exists(mask):\n mask = F.pad(mask, (1, 0), value = True)\n mask = rearrange(mask, 'b j -> b 1 1 j')\n sim = sim.masked_fill(~mask, max_neg_value)\n if self.causal:\n i, j = sim.shape[-2:]\n causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)\n sim = sim.masked_fill(causal_mask, max_neg_value)\n # attention\n attn = sim.softmax(dim = -1, dtype = torch.float32)\n attn = attn.type(sim.dtype)\n attn = self.dropout(attn)\n # aggregate values\n out = einsum('b h i j, b j d -> b h i d', attn, v)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:892-928"
+ },
+ "187": {
+ "file_id": 6,
+ "content": "This code snippet performs multi-head attention by first normalizing the query and key tensors, calculating their similarities, adding relative positional encoding if available, masking irrelevant values based on a given mask, applying causal masking if specified, and finally computing the attention weights and aggregating the corresponding values.",
+ "type": "comment"
+ },
+ "188": {
+ "file_id": 6,
+ "content": " out = rearrange(out, 'b h n d -> b n (h d)')\n return self.to_out(out)\nclass CausalTransformer(nn.Module):\n def __init__(\n self,\n *,\n dim,\n depth,\n dim_head = 64,\n heads = 8,\n ff_mult = 4,\n norm_in = False,\n norm_out = True,\n attn_dropout = 0.,\n ff_dropout = 0.,\n final_proj = True,\n normformer = False,\n rotary_emb = True\n ):\n super().__init__()\n self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM\n self.rel_pos_bias = RelPosBias(heads = heads)\n rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None\n self.layers = nn.ModuleList([])\n for _ in range(depth):\n self.layers.append(nn.ModuleList([\n Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),\n FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:930-961"
+ },
+ "189": {
+ "file_id": 6,
+ "content": "This code defines a `CausalTransformer` class for natural language processing tasks. The class initializes several modules such as LayerNorm, RelPosBias, RotaryEmbedding, and Attention. It also includes a FeedForward layer with configurable parameters like `dim`, `depth`, `dim_head`, `heads`, `ff_mult`, `attn_dropout`, `ff_dropout`, `norm_in`, `norm_out`, `final_proj`, and `rotary_emb`. The code snippet you provided is responsible for rearranging the tensor dimensions and returning it after processing by the `CausalTransformer` model.",
+ "type": "comment"
+ },
+ "190": {
+ "file_id": 6,
+ "content": " ]))\n self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options\n self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()\n def forward(self, x):\n n, device = x.shape[1], x.device\n x = self.init_norm(x)\n attn_bias = self.rel_pos_bias(n, n + 1, device = device)\n for attn, ff in self.layers:\n x = attn(x, attn_bias = attn_bias) + x\n x = ff(x) + x\n out = self.norm(x)\n return self.project_out(out)\nclass DiffusionPriorNetwork(nn.Module):\n def __init__(\n self,\n dim,\n num_timesteps = None,\n num_time_embeds = 1,\n num_image_embeds = 1,\n num_text_embeds = 1,\n max_text_len = 256,\n self_cond = False,\n **kwargs\n ):\n super().__init__()",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:962-993"
+ },
+ "191": {
+ "file_id": 6,
+ "content": "The code initializes a DiffusionPriorNetwork model with multiple layers, including attention and feed-forward modules. It also includes layer normalization and the option to project the output. The network takes in input of varying dimensions and can condition on time, image, and/or text embeddings. The self_cond parameter determines whether or not to use self-conditioning.",
+ "type": "comment"
+ },
+ "192": {
+ "file_id": 6,
+ "content": " self.dim = dim\n self.num_time_embeds = num_time_embeds\n self.num_image_embeds = num_image_embeds\n self.num_text_embeds = num_text_embeds\n self.to_text_embeds = nn.Sequential(\n nn.Linear(dim, dim * num_text_embeds) if num_text_embeds > 1 else nn.Identity(),\n Rearrange('b (n d) -> b n d', n = num_text_embeds)\n )\n self.continuous_embedded_time = not exists(num_timesteps)\n self.to_time_embeds = nn.Sequential(\n nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP\n Rearrange('b (n d) -> b n d', n = num_time_embeds)\n )\n self.to_image_embeds = nn.Sequential(\n nn.Linear(dim, dim * num_image_embeds) if num_image_embeds > 1 else nn.Identity(),\n Rearrange('b (n d) -> b n d', n = num_image_embeds)\n )\n self.learned_query = nn.Parameter(torch.randn(dim))",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:994-1017"
+ },
+ "193": {
+ "file_id": 6,
+ "content": "This code defines a class with parameters for dimensionality, number of time, image, and text embeddings. It initializes layers to transform input into text, time, and image embeddings. The \"learned_query\" is a learned parameter for the model.",
+ "type": "comment"
+ },
+ "194": {
+ "file_id": 6,
+ "content": " self.causal_transformer = CausalTransformer(dim = dim, **kwargs)\n # dalle1 learned padding strategy\n self.max_text_len = max_text_len\n self.null_text_encodings = nn.Parameter(torch.randn(1, max_text_len, dim))\n self.null_text_embeds = nn.Parameter(torch.randn(1, num_text_embeds, dim))\n self.null_image_embed = nn.Parameter(torch.randn(1, dim))\n # whether to use self conditioning, Hinton's group's new ddpm technique\n self.self_cond = self_cond\n def forward_with_cond_scale(\n self,\n *args,\n cond_scale = 1.,\n **kwargs\n ):\n logits = self.forward(*args, **kwargs)\n if cond_scale == 1:\n return logits\n null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs)\n return null_logits + (logits - null_logits) * cond_scale\n def forward(\n self,\n image_embed,\n diffusion_timesteps,\n *,\n text_embed,\n text_encodings = None,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1018-1052"
+ },
+ "195": {
+ "file_id": 6,
+ "content": "The code defines a model with a causal transformer and includes parameters for padding strategy, self-conditioning, and a function to perform forward calculations. The `forward_with_cond_scale` method takes conditional scaling as input and returns the scaled logits by combining original logits with null logits at 100% condition drop probabilities.",
+ "type": "comment"
+ },
+ "196": {
+ "file_id": 6,
+ "content": " self_cond = None,\n text_cond_drop_prob = 0.,\n image_cond_drop_prob = 0.\n ):\n batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype\n num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds\n # setup self conditioning\n if self.self_cond:\n self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))\n self_cond = rearrange(self_cond, 'b d -> b 1 d')\n # in section 2.2, last paragraph\n # \"... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction\"\n text_embed = self.to_text_embeds(text_embed)\n image_embed = self.to_image_embeds(image_embed)\n # classifier free guidance masks\n text_keep_mask = prob_mask_like((batch,), 1 - text_cond_drop_prob, device = device)\n text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1053-1076"
+ },
+ "197": {
+ "file_id": 6,
+ "content": "This code initializes a model's parameters based on the given image_embed. It sets up self-conditioning if necessary, converts text and image embeddings to the appropriate format, and creates classifier free guidance masks for both text and image inputs. The model will use these embeddings and masks for prediction.",
+ "type": "comment"
+ },
+ "198": {
+ "file_id": 6,
+ "content": " image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device)\n image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1')\n # make text encodings optional\n # although the paper seems to suggest it is present <--\n if not exists(text_encodings):\n text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)\n mask = torch.any(text_encodings != 0., dim = -1)\n # replace any padding in the text encodings with learned padding tokens unique across position\n text_encodings = text_encodings[:, :self.max_text_len]\n mask = mask[:, :self.max_text_len]\n text_len = text_encodings.shape[-2]\n remainder = self.max_text_len - text_len\n if remainder > 0:\n text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)\n mask = F.pad(mask, (0, remainder), value = False)\n # mask out text encodings with null encodings\n null_text_encodings = self.null_text_encodings.to(text_encodings.dtype)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1078-1103"
+ },
+ "199": {
+ "file_id": 6,
+ "content": "This code snippet is preparing the input data for a DALL-E 2 model by handling text encodings. It creates an image_keep_mask, makes text encodings optional based on their existence, applies masking to remove padding or null encodings, and ensures that the length of text_encodings matches the expected maximum length.",
+ "type": "comment"
+ }
+}
\ No newline at end of file
diff --git a/docs/data/2.json b/docs/data/2.json
new file mode 100644
index 00000000..fb6a120e
--- /dev/null
+++ b/docs/data/2.json
@@ -0,0 +1,552 @@
+{
+ "200": {
+ "file_id": 6,
+ "content": " text_encodings = torch.where(\n rearrange(mask, 'b n -> b n 1').clone() & text_keep_mask,\n text_encodings,\n null_text_encodings\n )\n # mask out text embeddings with null text embeddings\n null_text_embeds = self.null_text_embeds.to(text_embed.dtype)\n text_embed = torch.where(\n text_keep_mask,\n text_embed,\n null_text_embeds\n )\n # mask out image embeddings with null image embeddings\n null_image_embed = self.null_image_embed.to(image_embed.dtype)\n image_embed = torch.where(\n image_keep_mask,\n image_embed,\n null_image_embed\n )\n # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)\n # but let's just do it right\n if self.continuous_embedded_time:",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1105-1134"
+ },
+ "201": {
+ "file_id": 6,
+ "content": "This code section is applying masking to text, image, and null embeddings based on the `text_keep_mask` and `image_keep_mask`. It uses these masks to decide which embeddings to keep or replace with null embeddings. The embeddings are also being converted to appropriate data types. Additionally, there's a conditional check for continuous embedded time.",
+ "type": "comment"
+ },
+ "202": {
+ "file_id": 6,
+ "content": " diffusion_timesteps = diffusion_timesteps.type(dtype)\n time_embed = self.to_time_embeds(diffusion_timesteps)\n learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)\n if self.self_cond:\n learned_queries = torch.cat((self_cond, learned_queries), dim = -2)\n tokens = torch.cat((\n text_encodings,\n text_embed,\n time_embed,\n image_embed,\n learned_queries\n ), dim = -2)\n # attend\n tokens = self.causal_transformer(tokens)\n # get learned query, which should predict the image embedding (per DDPM timestep)\n pred_image_embed = tokens[..., -1, :]\n return pred_image_embed\nclass DiffusionPrior(nn.Module):\n def __init__(\n self,\n net,\n *,\n clip = None,\n image_embed_dim = None,\n image_size = None,\n image_channels = 3,\n timesteps = 1000,\n sample_timesteps = None,\n cond_drop_prob = 0.,\n text_cond_drop_prob = None,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1135-1174"
+ },
+ "203": {
+ "file_id": 6,
+ "content": "The code defines a DiffusionPrior class that takes in various inputs such as text encodings, timesteps, and image embeddings. It applies causal transformer to learn the learned_query, which predicts the image embedding per DDPM timestep. The text_cond_drop_prob parameter is optional and if provided, will dropout the text conditioning with a specified probability.",
+ "type": "comment"
+ },
+ "204": {
+ "file_id": 6,
+ "content": " image_cond_drop_prob = None,\n loss_type = \"l2\",\n predict_x_start = True,\n predict_v = False,\n beta_schedule = \"cosine\",\n condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training\n sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)\n sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)\n training_clamp_l2norm = False,\n init_image_embed_l2norm = False,\n image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132\n clip_adapter_overrides = dict()\n ):\n super().__init__()",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1175-1188"
+ },
+ "205": {
+ "file_id": 6,
+ "content": "This code snippet initializes a DALLE2 model with various optional parameters for training and sampling. These include loss type, conditioning on text encodings, clamping of image embeddings, scaling the L2-normed image embedding, and adapter overrides for CLIP adapter integration.",
+ "type": "comment"
+ },
+ "206": {
+ "file_id": 6,
+ "content": " self.sample_timesteps = sample_timesteps\n self.noise_scheduler = NoiseScheduler(\n beta_schedule = beta_schedule,\n timesteps = timesteps,\n loss_type = loss_type\n )\n if exists(clip):\n assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'\n if isinstance(clip, CLIP):\n clip = XClipAdapter(clip, **clip_adapter_overrides)\n elif isinstance(clip, CoCa):\n clip = CoCaAdapter(clip, **clip_adapter_overrides)\n assert isinstance(clip, BaseClipAdapter)\n freeze_model_and_make_eval_(clip)\n self.clip = clip\n else:\n assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'\n self.clip = None\n self.net = net\n self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1190-1214"
+ },
+ "207": {
+ "file_id": 6,
+ "content": "The code is initializing an instance of a model. It sets the sample_timesteps, creates a NoiseScheduler object with specified parameters, checks if CLIP is provided and adapts it if necessary, sets the image_embed_dim if not given, and assigns the network architecture (net) to be used.",
+ "type": "comment"
+ },
+ "208": {
+ "file_id": 6,
+ "content": " assert net.dim == self.image_embed_dim, f'your diffusion prior network has a dimension of {net.dim}, but you set your image embedding dimension (keyword image_embed_dim) on DiffusionPrior to {self.image_embed_dim}'\n assert not exists(clip) or clip.dim_latent == self.image_embed_dim, f'you passed in a CLIP to the diffusion prior with latent dimensions of {clip.dim_latent}, but your image embedding dimension (keyword image_embed_dim) for the DiffusionPrior was set to {self.image_embed_dim}'\n self.channels = default(image_channels, lambda: clip.image_channels)\n self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob)\n self.image_cond_drop_prob = default(image_cond_drop_prob, cond_drop_prob)\n self.can_classifier_guidance = self.text_cond_drop_prob > 0. and self.image_cond_drop_prob > 0.\n self.condition_on_text_encodings = condition_on_text_encodings\n # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1216-1227"
+ },
+ "209": {
+ "file_id": 6,
+ "content": "The code asserts that the diffusion prior network dimension and the image embedding dimension are consistent, and checks if a CLIP is passed in with correct latent dimensions. It also sets channels, text conditional drop probability, image conditional drop probability, enables classifier guidance if probabilities are greater than 0, and conditions on text encodings. It offers both options to predict noise or x0 directly for image embedding as per the paper's claim of better results.",
+ "type": "comment"
+ },
+ "210": {
+ "file_id": 6,
+ "content": " self.predict_x_start = predict_x_start\n self.predict_v = predict_v # takes precedence over predict_x_start\n # @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132\n self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)\n # whether to force an l2norm, similar to clipping denoised, when sampling\n self.sampling_clamp_l2norm = sampling_clamp_l2norm\n self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm\n self.training_clamp_l2norm = training_clamp_l2norm\n self.init_image_embed_l2norm = init_image_embed_l2norm\n # device tracker\n self.register_buffer('_dummy', torch.tensor([True]), persistent = False)\n @property\n def device(self):\n return self._dummy.device\n def l2norm_clamp_embed(self, image_embed):\n return l2norm(image_embed) * self.image_embed_scale\n def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1229-1255"
+ },
+ "211": {
+ "file_id": 6,
+ "content": "The code sets various parameters and properties for an object, including predict_x_start, image_embed_scale, sampling_clamp_l2norm, etc. It also defines the l2norm_clamp_embed function and p_mean_variance function. The device property retrieves the device used by the object, and there's a register_buffer for tracking device usage.",
+ "type": "comment"
+ },
+ "212": {
+ "file_id": 6,
+ "content": " assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'\n pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)\n if self.predict_v:\n x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)\n elif self.predict_x_start:\n x_start = pred\n else:\n x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)\n if clip_denoised and not self.predict_x_start:\n x_start.clamp_(-1., 1.)\n if self.predict_x_start and self.sampling_clamp_l2norm:\n x_start = l2norm(x_start) * self.image_embed_scale\n model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)\n return model_mean, posterior_variance, posterior_log_variance, x_start",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1256-1274"
+ },
+ "213": {
+ "file_id": 6,
+ "content": "This code asserts that the model was not trained with conditional dropout, preventing classifier free guidance if cond_scale is anything other than 1. It then calculates and returns the model mean, posterior variance, posterior log variance, and x_start depending on different conditions.",
+ "type": "comment"
+ },
+ "214": {
+ "file_id": 6,
+ "content": " @torch.no_grad()\n def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):\n b, *_, device = *x.shape, x.device\n model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)\n noise = torch.randn_like(x)\n # no noise when t == 0\n nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))\n pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n return pred, x_start\n @torch.no_grad()\n def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):\n batch, device = shape[0], self.device\n image_embed = torch.randn(shape, device = device)\n x_start = None # for self-conditioning\n if self.init_image_embed_l2norm:\n image_embed = l2norm(image_embed) * self.image_embed_scale\n for i in tqdm(reversed(range(",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1276-1296"
+ },
+ "215": {
+ "file_id": 6,
+ "content": "This code defines the `p_sample` and `p_sample_loop_ddpm` functions. `p_sample` takes input, generates a model mean and log variance, applies noise based on whether t is zero or not, and returns the prediction and x_start. `p_sample_loop_ddpm` initializes an image embedding, optionally normalizes it, and iterates through a reversed range to perform some unspecified operation for each iteration. The code uses PyTorch's `@torch.no_grad()` decorator to disable gradient computation during these functions' execution.",
+ "type": "comment"
+ },
+ "216": {
+ "file_id": 6,
+ "content": "0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):\n times = torch.full((batch,), i, device = device, dtype = torch.long)\n self_cond = x_start if self.net.self_cond else None\n image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)\n if self.sampling_final_clamp_l2norm and self.predict_x_start:\n image_embed = self.l2norm_clamp_embed(image_embed)\n return image_embed\n @torch.no_grad()\n def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):\n batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps\n times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1]\n times = list(reversed(times.int().tolist()))\n time_pairs = list(zip(times[:-1], times[1:]))\n image_embed = torch.randn(shape, device = device)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1296-1316"
+ },
+ "217": {
+ "file_id": 6,
+ "content": "The code defines the `p_sample` function which samples images and their corresponding embeddings using a loop over time steps. It also includes an optional L2-norm clamping for final image embedding. The `p_sample_loop_ddim` function is a helper method to define shape, times, and time pairs for the sampling loop in DDIM style.",
+ "type": "comment"
+ },
+ "218": {
+ "file_id": 6,
+ "content": " x_start = None # for self-conditioning\n if self.init_image_embed_l2norm:\n image_embed = l2norm(image_embed) * self.image_embed_scale\n for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):\n alpha = alphas[time]\n alpha_next = alphas[time_next]\n time_cond = torch.full((batch,), time, device = device, dtype = torch.long)\n self_cond = x_start if self.net.self_cond else None\n pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)\n # derive x0\n if self.predict_v:\n x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)\n elif self.predict_x_start:\n x_start = pred\n else:\n x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)\n # clip x0 before maybe predicting noise",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1318-1342"
+ },
+ "219": {
+ "file_id": 6,
+ "content": "The code is iterating through time pairs, calculating alpha values and performing a forward pass in the neural network. It also adjusts x_start based on prediction methods and performs noise scheduling. The purpose seems to be generating an image using conditional sampling with self-conditioning and considering different prediction methods for x_start.",
+ "type": "comment"
+ },
+ "220": {
+ "file_id": 6,
+ "content": " if not self.predict_x_start:\n x_start.clamp_(-1., 1.)\n if self.predict_x_start and self.sampling_clamp_l2norm:\n x_start = self.l2norm_clamp_embed(x_start)\n # predict noise\n pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)\n if time_next < 0:\n image_embed = x_start\n continue\n c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()\n c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()\n noise = torch.randn_like(image_embed) if time_next > 0 else 0.\n image_embed = x_start * alpha_next.sqrt() + \\\n c1 * noise + \\\n c2 * pred_noise\n if self.predict_x_start and self.sampling_final_clamp_l2norm:\n image_embed = self.l2norm_clamp_embed(image_embed)\n return image_embed\n @torch.no_grad()\n def p_sample_loop(self, *args, timesteps = None, **kwargs):",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1344-1372"
+ },
+ "221": {
+ "file_id": 6,
+ "content": "In this code segment, it checks if predicting x_start is enabled and performs L2-norm clamping if necessary. It then predicts noise using the noise scheduler based on image embeddings, time condition, and x_start. If time_next is less than 0, it sets image_embed to x_start. Calculates coefficients c1 and c2 for RNN sampling and generates noise accordingly. Combines these elements to generate the final image_embed which is then optionally L2-norm clamped if enabled.",
+ "type": "comment"
+ },
+ "222": {
+ "file_id": 6,
+ "content": " timesteps = default(timesteps, self.noise_scheduler.num_timesteps)\n assert timesteps <= self.noise_scheduler.num_timesteps\n is_ddim = timesteps < self.noise_scheduler.num_timesteps\n if not is_ddim:\n normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)\n else:\n normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)\n image_embed = normalized_image_embed / self.image_embed_scale\n return image_embed\n def p_losses(self, image_embed, times, text_cond, noise = None):\n noise = default(noise, lambda: torch.randn_like(image_embed))\n image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)\n self_cond = None\n if self.net.self_cond and random.random() < 0.5:\n with torch.no_grad():\n self_cond = self.net(image_embed_noisy, times, **text_cond).detach()\n pred = self.net(\n image_embed_noisy,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1373-1396"
+ },
+ "223": {
+ "file_id": 6,
+ "content": "This code is from the DALLE2-pytorch model. It first determines if the timesteps are less than the number of timesteps in the noise scheduler. If so, it uses the p_sample_loop_ddim function to get the normalized image embeddings, otherwise it uses the p_sample_loop_ddpm function. The code then scales the normalized image embeddings by the image_embed_scale and returns the scaled embeddings. The p_losses function generates a noisy version of the input image embedding using the noise scheduler, and optionally conditions the model with self-conditioning if the condition is met. Finally, it passes the noisy embedding to the network for prediction.",
+ "type": "comment"
+ },
+ "224": {
+ "file_id": 6,
+ "content": " times,\n self_cond = self_cond,\n text_cond_drop_prob = self.text_cond_drop_prob,\n image_cond_drop_prob = self.image_cond_drop_prob,\n **text_cond\n )\n if self.predict_x_start and self.training_clamp_l2norm:\n pred = self.l2norm_clamp_embed(pred)\n if self.predict_v:\n target = self.noise_scheduler.calculate_v(image_embed, times, noise)\n elif self.predict_x_start:\n target = image_embed\n else:\n target = noise\n loss = self.noise_scheduler.loss_fn(pred, target)\n return loss\n @torch.no_grad()\n @eval_decorator\n def sample_batch_size(self, batch_size, text_cond, cond_scale = 1.):\n device = self.betas.device\n shape = (batch_size, self.image_embed_dim)\n img = torch.randn(shape, device = device)\n for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = self.noise_scheduler.num_timesteps):",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1397-1425"
+ },
+ "225": {
+ "file_id": 6,
+ "content": "The code defines a method for predicting and calculating loss. It takes in parameters such as times, self_cond, text_cond_drop_prob, image_cond_drop_prob, and text_cond. If certain conditions are met, it performs l2norm clamping on the prediction, sets the target based on whether to predict x or v, then calculates the loss using the noise scheduler's loss function. The code also includes a sample_batch_size method that samples an image batch and iterates over time steps in reverse order for some processing.",
+ "type": "comment"
+ },
+ "226": {
+ "file_id": 6,
+ "content": " img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)\n return img\n @torch.no_grad()\n @eval_decorator\n def sample(\n self,\n text,\n num_samples_per_batch = 2,\n cond_scale = 1.,\n timesteps = None\n ):\n timesteps = default(timesteps, self.sample_timesteps)\n # in the paper, what they did was\n # sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP\n text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)\n batch_size = text.shape[0]\n image_embed_dim = self.image_embed_dim\n text_embed, text_encodings = self.clip.embed_text(text)\n text_cond = dict(text_embed = text_embed)\n if self.condition_on_text_encodings:\n text_cond = {**text_cond, 'text_encodings': text_encodings}\n image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1426-1454"
+ },
+ "227": {
+ "file_id": 6,
+ "content": "This code is part of a DALL-E 2 model implementation in PyTorch. The sample function generates multiple image embeddings based on provided text, then chooses the most similar one according to CLIP's similarity judgment. The function uses a p_sample_loop method which takes timesteps as input and returns a batch of images with the specified size.",
+ "type": "comment"
+ },
+ "228": {
+ "file_id": 6,
+ "content": " # retrieve original unscaled image embed\n text_embeds = text_cond['text_embed']\n text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)\n image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)\n text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))\n top_sim_indices = text_image_sims.topk(k = 1).indices\n top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)\n top_image_embeds = image_embeds.gather(1, top_sim_indices)\n return rearrange(top_image_embeds, 'b 1 d -> b d')\n def forward(\n self,\n text = None,\n image = None,\n text_embed = None, # allow for training on preprocessed CLIP text and image embeddings\n image_embed = None,\n text_encodings = None, # as well as CLIP text encodings\n *args,\n **kwargs\n ):\n assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1456-1481"
+ },
+ "229": {
+ "file_id": 6,
+ "content": "This function retrieves the original unscaled image embeddings from the input, rearranges them based on the number of samples per batch, calculates text-image similarities using Euclidean distance, gets the top indices and gathers corresponding embeddings. It allows for training on preprocessed CLIP text and image embeddings or CLIP text encodings. If neither text nor text embedding is supplied, an assertion error will be raised.",
+ "type": "comment"
+ },
+ "230": {
+ "file_id": 6,
+ "content": " assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'\n assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'\n if exists(image):\n image_embed, _ = self.clip.embed_image(image)\n # calculate text conditionings, based on what is passed in\n if exists(text):\n text_embed, text_encodings = self.clip.embed_text(text)\n text_cond = dict(text_embed = text_embed)\n if self.condition_on_text_encodings:\n assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'\n text_cond = {**text_cond, 'text_encodings': text_encodings}\n # timestep conditioning from ddpm\n batch, device = image_embed.shape[0], image_embed.device\n times = self.noise_scheduler.sample_random_times(batch)\n # scale image embed (Katherine)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1482-1504"
+ },
+ "231": {
+ "file_id": 6,
+ "content": "The code snippet checks if an image or image embedding is supplied and throws an error if neither exists. It also verifies the presence of text encodings or text based on the specified conditioning during initialization. The code then calculates the text embeddings from the given text using the clip model. If conditioned on text encodings, it includes them in the text_cond dictionary. It samples random times for timestep conditioning from the noise scheduler and scales the image embed (by Katherine).",
+ "type": "comment"
+ },
+ "232": {
+ "file_id": 6,
+ "content": " image_embed *= self.image_embed_scale\n # calculate forward loss\n return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)\n# decoder\ndef NearestUpsample(dim, dim_out = None):\n dim_out = default(dim_out, dim)\n return nn.Sequential(\n nn.Upsample(scale_factor = 2, mode = 'nearest'),\n nn.Conv2d(dim, dim_out, 3, padding = 1)\n )\nclass PixelShuffleUpsample(nn.Module):\n \"\"\"\n code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts\n https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf\n \"\"\"\n def __init__(self, dim, dim_out = None):\n super().__init__()\n dim_out = default(dim_out, dim)\n conv = nn.Conv2d(dim, dim_out * 4, 1)\n self.net = nn.Sequential(\n conv,\n nn.SiLU(),\n nn.PixelShuffle(2)\n )\n self.init_conv_(conv)\n def init_conv_(self, conv):\n o, i, h, w = conv.weight.shape\n conv_weight = torch.empty(o // 4, i, h, w)\n nn.init.kaiming_uniform_(conv_weight)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1506-1543"
+ },
+ "233": {
+ "file_id": 6,
+ "content": "This code contains two classes, `NearestUpsample` and `PixelShuffleUpsample`. `NearestUpsample` performs nearest neighbor upsampling followed by a convolution operation. `PixelShuffleUpsample` applies pixel shuffling after a 1x1 convolution to reduce checkerboard artifacts. Both classes can be used for image upsampling tasks.",
+ "type": "comment"
+ },
+ "234": {
+ "file_id": 6,
+ "content": " conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')\n conv.weight.data.copy_(conv_weight)\n nn.init.zeros_(conv.bias.data)\n def forward(self, x):\n return self.net(x)\ndef Downsample(dim, dim_out = None):\n # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample\n # named SP-conv in the paper, but basically a pixel unshuffle\n dim_out = default(dim_out, dim)\n return nn.Sequential(\n Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),\n nn.Conv2d(dim * 4, dim_out, 1)\n )\nclass WeightStandardizedConv2d(nn.Conv2d):\n \"\"\"\n https://arxiv.org/abs/1903.10520\n weight standardization purportedly works synergistically with group normalization\n \"\"\"\n def forward(self, x):\n eps = 1e-5 if x.dtype == torch.float32 else 1e-3\n weight = self.weight\n flattened_weights = rearrange(weight, 'o ... -> o (...)')\n mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')\n var = torch.var(flattened_weights, dim = -1, unbiased = False)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1544-1574"
+ },
+ "235": {
+ "file_id": 6,
+ "content": "The code defines a class called WeightStandardizedConv2d that extends nn.Conv2d and implements weight standardization for improving synergy with group normalization. It also includes a function named Downsample to downsample the input using pixel unshuffle technique, which is optimal according to a reference paper. The forward method in WeightStandardizedConv2d calculates mean and variance of flattened weights and performs weight standardization before applying convolution operations.",
+ "type": "comment"
+ },
+ "236": {
+ "file_id": 6,
+ "content": " var = rearrange(var, 'o -> o 1 1 1')\n weight = (weight - mean) * (var + eps).rsqrt()\n return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)\nclass SinusoidalPosEmb(nn.Module):\n def __init__(self, dim):\n super().__init__()\n self.dim = dim\n def forward(self, x):\n dtype, device = x.dtype, x.device\n assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'\n half_dim = self.dim // 2\n emb = math.log(10000) / (half_dim - 1)\n emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)\n emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')\n return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)\nclass Block(nn.Module):\n def __init__(\n self,\n dim,\n dim_out,\n groups = 8,\n weight_standardization = False\n ):\n super().__init__()\n conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1575-1605"
+ },
+ "237": {
+ "file_id": 6,
+ "content": "This code snippet contains the definition of three classes: `rearrange`, `SinusoidalPosEmb`, and `Block`. The `rearrange` function is used to reshape tensors, `SinusoidalPosEmb` class computes sinusoidal positional embeddings, and `Block` class defines a convolutional block with an option for weight standardization.",
+ "type": "comment"
+ },
+ "238": {
+ "file_id": 6,
+ "content": " self.project = conv_klass(dim, dim_out, 3, padding = 1)\n self.norm = nn.GroupNorm(groups, dim_out)\n self.act = nn.SiLU()\n def forward(self, x, scale_shift = None):\n x = self.project(x)\n x = self.norm(x)\n if exists(scale_shift):\n scale, shift = scale_shift\n x = x * (scale + 1) + shift\n x = self.act(x)\n return x\nclass ResnetBlock(nn.Module):\n def __init__(\n self,\n dim,\n dim_out,\n *,\n cond_dim = None,\n time_cond_dim = None,\n groups = 8,\n weight_standardization = False,\n cosine_sim_cross_attn = False\n ):\n super().__init__()\n self.time_mlp = None\n if exists(time_cond_dim):\n self.time_mlp = nn.Sequential(\n nn.SiLU(),\n nn.Linear(time_cond_dim, dim_out * 2)\n )\n self.cross_attn = None\n if exists(cond_dim):\n self.cross_attn = CrossAttention(\n dim = dim_out,\n context_dim = cond_dim,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1607-1649"
+ },
+ "239": {
+ "file_id": 6,
+ "content": "This code snippet defines a ResnetBlock class that takes in dimensions and other parameters for its initialization. It includes a project layer, normalization layer, activation function, and optional scale-shift operation. The forward method performs the computation steps involving these layers. Additionally, it checks if time_cond_dim is given to initialize a time MLP and if cond_dim exists to initialize a cross-attention layer.",
+ "type": "comment"
+ },
+ "240": {
+ "file_id": 6,
+ "content": " cosine_sim = cosine_sim_cross_attn\n )\n self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)\n self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)\n self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n def forward(self, x, time_emb = None, cond = None):\n scale_shift = None\n if exists(self.time_mlp) and exists(time_emb):\n time_emb = self.time_mlp(time_emb)\n time_emb = rearrange(time_emb, 'b c -> b c 1 1')\n scale_shift = time_emb.chunk(2, dim = 1)\n h = self.block1(x, scale_shift = scale_shift)\n if exists(self.cross_attn):\n assert exists(cond)\n h = rearrange(h, 'b c ... -> b ... c')\n h, ps = pack([h], 'b * c')\n h = self.cross_attn(h, context = cond) + h\n h, = unpack(h, ps, 'b * c')\n h = rearrange(h, 'b ... c -> b c ...')",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1650-1676"
+ },
+ "241": {
+ "file_id": 6,
+ "content": "This code defines a class for an encoder-decoder architecture with residual connections and cross-attention. It includes blocks, convolutions, time MLP, and optional cross-attention with conditional input. The forward method processes the input, applies blocks, optionally performs time embedding and cross-attention, and returns the output.",
+ "type": "comment"
+ },
+ "242": {
+ "file_id": 6,
+ "content": " h = self.block2(h)\n return h + self.res_conv(x)\nclass CrossAttention(nn.Module):\n def __init__(\n self,\n dim,\n *,\n context_dim = None,\n dim_head = 64,\n heads = 8,\n dropout = 0.,\n norm_context = False,\n cosine_sim = False,\n cosine_sim_scale = 16\n ):\n super().__init__()\n self.cosine_sim = cosine_sim\n self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)\n self.heads = heads\n inner_dim = dim_head * heads\n context_dim = default(context_dim, dim)\n self.norm = LayerNorm(dim)\n self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()\n self.dropout = nn.Dropout(dropout)\n self.null_kv = nn.Parameter(torch.randn(2, dim_head))\n self.to_q = nn.Linear(dim, inner_dim, bias = False)\n self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)\n self.to_out = nn.Sequential(\n nn.Linear(inner_dim, dim, bias = False),",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1678-1711"
+ },
+ "243": {
+ "file_id": 6,
+ "content": "The code defines a CrossAttention class with parameters for dimensionality, context dimension, number of heads, dropout rate, and normalization options. It initializes the necessary layers including linear transformations and layer norms. The cosine similarity scale and null keys are also defined.",
+ "type": "comment"
+ },
+ "244": {
+ "file_id": 6,
+ "content": " LayerNorm(dim)\n )\n def forward(self, x, context, mask = None):\n b, n, device = *x.shape[:2], x.device\n x = self.norm(x)\n context = self.norm_context(context)\n q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))\n q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))\n # add null key / value for classifier free guidance in prior net\n nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))\n k = torch.cat((nk, k), dim = -2)\n v = torch.cat((nv, v), dim = -2)\n if self.cosine_sim:\n q, k = map(l2norm, (q, k))\n q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))\n sim = einsum('b h i d, b h j d -> b h i j', q, k)\n max_neg_value = -torch.finfo(sim.dtype).max\n if exists(mask):\n mask = F.pad(mask, (1, 0), value = True)\n mask = rearrange(mask, 'b j -> b 1 1 j')\n sim = sim.masked_fill(~mask, max_neg_value)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1712-1743"
+ },
+ "245": {
+ "file_id": 6,
+ "content": "This function defines a multi-head attention layer. It normalizes input x and context, splits them into queries (q), keys (k), and values (v). It also includes null key/value pairs for classifier free guidance in the prior net. If cosine_sim is set, it normalizes q and k again. It then computes the attention scores (sim) between q and k, and applies a mask if available, replacing negative values with max_neg_value.",
+ "type": "comment"
+ },
+ "246": {
+ "file_id": 6,
+ "content": " attn = sim.softmax(dim = -1, dtype = torch.float32)\n attn = attn.type(sim.dtype)\n out = einsum('b h i j, b h j d -> b h i d', attn, v)\n out = rearrange(out, 'b h n d -> b n (h d)')\n return self.to_out(out)\nclass LinearAttention(nn.Module):\n def __init__(\n self,\n dim,\n dim_head = 32,\n heads = 8,\n **kwargs\n ):\n super().__init__()\n self.scale = dim_head ** -0.5\n self.heads = heads\n inner_dim = dim_head * heads\n self.norm = ChanLayerNorm(dim)\n self.nonlin = nn.GELU()\n self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)\n self.to_out = nn.Sequential(\n nn.Conv2d(inner_dim, dim, 1, bias = False),\n ChanLayerNorm(dim)\n )\n def forward(self, fmap):\n h, x, y = self.heads, *fmap.shape[-2:]\n seq_len = x * y\n fmap = self.norm(fmap)\n q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)\n q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1745-1780"
+ },
+ "247": {
+ "file_id": 6,
+ "content": "This code defines a LinearAttention module that performs multi-head attention. It normalizes the input, applies convolutions to split input into queries (Q), keys (K), and values (V), then computes attention weights, rearranges output dimensions for efficiency, and finally passes the result through another set of convolutions before returning it.",
+ "type": "comment"
+ },
+ "248": {
+ "file_id": 6,
+ "content": " q = q.softmax(dim = -1)\n k = k.softmax(dim = -2)\n q = q * self.scale\n v = l2norm(v)\n k, v = map(lambda t: t / math.sqrt(seq_len), (k, v))\n context = einsum('b n d, b n e -> b d e', k, v)\n out = einsum('b n d, b d e -> b n e', q, context)\n out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)\n out = self.nonlin(out)\n return self.to_out(out)\nclass CrossEmbedLayer(nn.Module):\n def __init__(\n self,\n dim_in,\n kernel_sizes,\n dim_out = None,\n stride = 2\n ):\n super().__init__()\n assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])\n dim_out = default(dim_out, dim_in)\n kernel_sizes = sorted(kernel_sizes)\n num_scales = len(kernel_sizes)\n # calculate the dimension at each scale\n dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]\n dim_scales = [*dim_scales, dim_out - sum(dim_scales)]\n self.convs = nn.ModuleList([])",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1782-1816"
+ },
+ "249": {
+ "file_id": 6,
+ "content": "The code calculates and applies attention weights to query (q) and key (k) tensors, normalizes them, scales the vectors, and performs element-wise multiplication. It then applies a linear transformation (nonlin) on the result and rearranges the dimensions of the output tensor using the 'rearrange' function. The code also defines a CrossEmbedLayer class that initializes convolutional layers for feature extraction at multiple scales.",
+ "type": "comment"
+ },
+ "250": {
+ "file_id": 6,
+ "content": " for kernel, dim_scale in zip(kernel_sizes, dim_scales):\n self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))\n def forward(self, x):\n fmaps = tuple(map(lambda conv: conv(x), self.convs))\n return torch.cat(fmaps, dim = 1)\nclass UpsampleCombiner(nn.Module):\n def __init__(\n self,\n dim,\n *,\n enabled = False,\n dim_ins = tuple(),\n dim_outs = tuple()\n ):\n super().__init__()\n assert len(dim_ins) == len(dim_outs)\n self.enabled = enabled\n if not self.enabled:\n self.dim_out = dim\n return\n self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])\n self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)\n def forward(self, x, fmaps = None):\n target_size = x.shape[-1]\n fmaps = default(fmaps, tuple())\n if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1817-1849"
+ },
+ "251": {
+ "file_id": 6,
+ "content": "The code defines a convolutional network with adjustable kernel sizes and applies an upsampling combiner to combine feature maps. The enabled flag controls whether the upsampling combiner is active, and it can be customized with different input/output dimensions.",
+ "type": "comment"
+ },
+ "252": {
+ "file_id": 6,
+ "content": " return x\n fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]\n outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]\n return torch.cat((x, *outs), dim = 1)\nclass Unet(nn.Module):\n def __init__(\n self,\n dim,\n *,\n image_embed_dim = None,\n text_embed_dim = None,\n cond_dim = None,\n num_image_tokens = 4,\n num_time_tokens = 2,\n out_dim = None,\n dim_mults=(1, 2, 4, 8),\n channels = 3,\n channels_out = None,\n self_attn = False,\n attn_dim_head = 32,\n attn_heads = 16,\n lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/\n lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen\n self_cond = False, # set this to True to use the self-conditioning technique from - https://arxiv.org/abs/2208.04202\n sparse_attn = False,\n cosine_sim_cross_attn = False,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1850-1877"
+ },
+ "253": {
+ "file_id": 6,
+ "content": "This code defines a Unet model with multiple components including fmaps, convolutions, image and text embeddings, dimensions, conditional parameters, and attention mechanisms. It also includes options for lowres_cond, self_attn, lowres_noise_cond, sparse_attn, and cosine_sim_cross_attn.",
+ "type": "comment"
+ },
+ "254": {
+ "file_id": 6,
+ "content": " cosine_sim_self_attn = False,\n attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)\n cond_on_text_encodings = False,\n max_text_len = 256,\n cond_on_image_embeds = False,\n add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - \"Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding\"\n init_dim = None,\n init_conv_kernel_size = 7,\n resnet_groups = 8,\n resnet_weight_standardization = False,\n num_resnet_blocks = 2,\n init_cross_embed = True,\n init_cross_embed_kernel_sizes = (3, 7, 15),\n cross_embed_downsample = False,\n cross_embed_downsample_kernel_sizes = (2, 4),\n memory_efficient = False,\n scale_skip_connection = False,\n pixel_shuffle_upsample = True,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1878-1895"
+ },
+ "255": {
+ "file_id": 6,
+ "content": "The code defines various settings for the DALLE2 model, including whether to use cosine similarity self-attention, if a layer of attention should be at the bottleneck, and whether to condition on text or image embeddings. It also includes options for initializing embeddings, resnet blocks, cross embeddings, and more. These settings allow for customization and optimization in the DALLE2 model's architecture.",
+ "type": "comment"
+ },
+ "256": {
+ "file_id": 6,
+ "content": " final_conv_kernel_size = 1,\n combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper\n checkpoint_during_training = False,\n **kwargs\n ):\n super().__init__()\n # save locals to take care of some hyperparameters for cascading DDPM\n self._locals = locals()\n del self._locals['self']\n del self._locals['__class__']\n # for eventual cascading diffusion\n self.lowres_cond = lowres_cond\n # whether to do self conditioning\n self.self_cond = self_cond\n # determine dimensions\n self.channels = channels\n self.channels_out = default(channels_out, channels)\n # initial number of channels depends on\n # (1) low resolution conditioning from cascading ddpm paper, conditioned on previous unet output in the cascade\n # (2) self conditioning (bit diffusion paper)\n init_channels = channels * (1 + int(lowres_cond) + int(self_cond))\n init_dim = default(init_dim, dim)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1896-1927"
+ },
+ "257": {
+ "file_id": 6,
+ "content": "The code initializes a DDPM model with specified parameters such as number of channels, output channels, low resolution conditioning and self-conditioning. It determines the dimensions and initial number of channels based on these inputs and saves the hyperparameters for possible cascading DDPM in the future.",
+ "type": "comment"
+ },
+ "258": {
+ "file_id": 6,
+ "content": " self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)\n dims = [init_dim, *map(lambda m: dim * m, dim_mults)]\n in_out = list(zip(dims[:-1], dims[1:]))\n num_stages = len(in_out)\n # time, image embeddings, and optional text encoding\n cond_dim = default(cond_dim, dim)\n time_cond_dim = dim * 4\n self.to_time_hiddens = nn.Sequential(\n SinusoidalPosEmb(dim),\n nn.Linear(dim, time_cond_dim),\n nn.GELU()\n )\n self.to_time_tokens = nn.Sequential(\n nn.Linear(time_cond_dim, cond_dim * num_time_tokens),\n Rearrange('b (r d) -> b r d', r = num_time_tokens)\n )\n self.to_time_cond = nn.Sequential(\n nn.Linear(time_cond_dim, time_cond_dim)\n )\n self.image_to_tokens = nn.Sequential(",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1929-1956"
+ },
+ "259": {
+ "file_id": 6,
+ "content": "This code initializes layers for processing time and image inputs. It creates a CrossEmbedLayer or Conv2d layer for the initial input, sets the dimensions for subsequent stages, defines layers to transform time-based data into conditioning tokens, and initializes an image-to-tokens sequence of layers. These layers will be used in a DALL-E 2 model for processing text, image, and time-based inputs for generating images.",
+ "type": "comment"
+ },
+ "260": {
+ "file_id": 6,
+ "content": " nn.Linear(image_embed_dim, cond_dim * num_image_tokens),\n Rearrange('b (n d) -> b n d', n = num_image_tokens)\n ) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()\n self.to_image_hiddens = nn.Sequential(\n nn.Linear(image_embed_dim, time_cond_dim),\n nn.GELU()\n ) if cond_on_image_embeds and add_image_embeds_to_time else None\n self.norm_cond = nn.LayerNorm(cond_dim)\n self.norm_mid_cond = nn.LayerNorm(cond_dim)\n # text encoding conditioning (optional)\n self.text_to_cond = None\n self.text_embed_dim = None\n if cond_on_text_encodings:\n assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'\n self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)\n self.text_embed_dim = text_embed_dim\n # low resolution noise conditiong, based on Imagen's upsampler training technique\n self.lowres_noise_cond = lowres_noise_cond",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1957-1981"
+ },
+ "261": {
+ "file_id": 6,
+ "content": "The code defines the architecture of a model. It includes linear layers, layer normalization, GELU activation function, and conditioning options for image embeddings, text encodings, and low resolution noise. These components are used to transform inputs and generate conditions based on optional parameters.",
+ "type": "comment"
+ },
+ "262": {
+ "file_id": 6,
+ "content": " self.to_lowres_noise_cond = nn.Sequential(\n SinusoidalPosEmb(dim),\n nn.Linear(dim, time_cond_dim),\n nn.GELU(),\n nn.Linear(time_cond_dim, time_cond_dim)\n ) if lowres_noise_cond else None\n # finer control over whether to condition on image embeddings and text encodings\n # so one can have the latter unets in the cascading DDPMs only focus on super-resoluting\n self.cond_on_text_encodings = cond_on_text_encodings\n self.cond_on_image_embeds = cond_on_image_embeds\n # for classifier free guidance\n self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))\n self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))\n self.max_text_len = max_text_len\n self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))\n # whether to scale skip connection, adopted in Imagen\n self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:1983-2006"
+ },
+ "263": {
+ "file_id": 6,
+ "content": "This code initializes various components of a model. It creates an optional sequential layer for low-res noise conditioning based on a flag, allows fine control over whether to condition on image embeddings and text encodings, and sets up parameters for classifier-free guidance. The skip connection scale is set either to 1 or scaled as per Imagen's approach.",
+ "type": "comment"
+ },
+ "264": {
+ "file_id": 6,
+ "content": " # attention related params\n attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim = cosine_sim_self_attn)\n self_attn = cast_tuple(self_attn, num_stages)\n create_self_attn = lambda dim: RearrangeToSequence(Residual(Attention(dim, **attn_kwargs)))\n # resnet block klass\n resnet_groups = cast_tuple(resnet_groups, num_stages)\n top_level_resnet_group = first(resnet_groups)\n num_resnet_blocks = cast_tuple(num_resnet_blocks, num_stages)\n # downsample klass\n downsample_klass = Downsample\n if cross_embed_downsample:\n downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)\n # upsample klass\n upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample\n # prepare resnet klass\n resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2008-2035"
+ },
+ "265": {
+ "file_id": 6,
+ "content": "This code initializes various parameters and classes for the DALL-E 2 model. It sets up attention, resnet block, downsampling, and upsampling functions based on user inputs. The code uses partial function applications to customize the resnet blocks and other components according to specific settings.",
+ "type": "comment"
+ },
+ "266": {
+ "file_id": 6,
+ "content": " # give memory efficient unet an initial resnet block\n self.init_resnet_block = resnet_block(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None\n # layers\n self.downs = nn.ModuleList([])\n self.ups = nn.ModuleList([])\n num_resolutions = len(in_out)\n skip_connect_dims = [] # keeping track of skip connection dimensions\n upsample_combiner_dims = [] # keeping track of dimensions for final upsample feature map combiner\n for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):\n is_first = ind == 0\n is_last = ind >= (num_resolutions - 1)\n layer_cond_dim = cond_dim if not is_first else None\n dim_layer = dim_out if memory_efficient else dim_in\n skip_connect_dims.append(dim_layer)\n attention = nn.Identity()\n if layer_self_attn:",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2037-2059"
+ },
+ "267": {
+ "file_id": 6,
+ "content": "The code initializes the memory efficient UNet with an initial resnet block, and creates two lists for downsampling and upsampling layers. It also keeps track of skip connection dimensions and dimensions for final upsample feature map combiner. The code iterates over different layer configurations, including whether to use self-attention or not.",
+ "type": "comment"
+ },
+ "268": {
+ "file_id": 6,
+ "content": " attention = create_self_attn(dim_layer)\n elif sparse_attn:\n attention = Residual(LinearAttention(dim_layer, **attn_kwargs))\n self.downs.append(nn.ModuleList([\n downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,\n resnet_block(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),\n nn.ModuleList([resnet_block(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),\n attention,\n downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)\n ]))\n mid_dim = dims[-1]\n self.mid_block1 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])\n self.mid_attn = create_self_attn(mid_dim)\n self.mid_block2 = re",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2060-2076"
+ },
+ "269": {
+ "file_id": 6,
+ "content": "This code initializes a module for a neural network. It adds downsampling modules, resnet blocks, attention layers, and convolutional layers based on the given parameters. The last block of the code initializes two additional blocks and an attention layer for further processing.",
+ "type": "comment"
+ },
+ "270": {
+ "file_id": 6,
+ "content": "snet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])\n for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):\n is_last = ind >= (len(in_out) - 1)\n layer_cond_dim = cond_dim if not is_last else None\n skip_connect_dim = skip_connect_dims.pop()\n attention = nn.Identity()\n if layer_self_attn:\n attention = create_self_attn(dim_out)\n elif sparse_attn:\n attention = Residual(LinearAttention(dim_out, **attn_kwargs))\n upsample_combiner_dims.append(dim_out)\n self.ups.append(nn.ModuleList([\n resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),\n nn.ModuleList([resnet_block(dim_out + skip_connect_dim,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2076-2094"
+ },
+ "271": {
+ "file_id": 6,
+ "content": "The code is defining a ResNet-based architecture with optional self-attention layers. It iterates through the input and output dimensions, groups, number of resnet blocks, and self-attention usage to create a series of resnet blocks, optionally including an identity or linear attention layer after each block.",
+ "type": "comment"
+ },
+ "272": {
+ "file_id": 6,
+ "content": " dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),\n attention,\n upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()\n ]))\n # whether to combine outputs from all upsample blocks for final resnet block\n self.upsample_combiner = UpsampleCombiner(\n dim = dim,\n enabled = combine_upsample_fmaps,\n dim_ins = upsample_combiner_dims,\n dim_outs = (dim,) * len(upsample_combiner_dims)\n )\n # a final resnet block\n self.final_resnet_block = resnet_block(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)\n out_dim_in = dim + (channels if lowres_cond else 0)\n self.to_out = nn.Conv2d(out_dim_in, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)\n zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2094-2116"
+ },
+ "273": {
+ "file_id": 6,
+ "content": "This code defines a DALL·E 2 model architecture. It includes multiple resnet blocks, an upsampling sequence, and a final convolution layer. The number of resnet blocks is determined by the `layer_num_resnet_blocks` parameter. The upsample sequence combines outputs from all upsample blocks if `combine_upsample_fmaps` is set to True. The final resnet block takes in the combined output and the model's channels, with time conditioning (`time_cond_dim`) and top-level resnet grouping (`top_level_resnet_group`). Finally, a convolution layer converts the output to the desired channel size (`channels_out`). The `zero_init_` function initializes the final convolution layer with zero values.",
+ "type": "comment"
+ },
+ "274": {
+ "file_id": 6,
+ "content": " # whether to checkpoint during training\n self.checkpoint_during_training = checkpoint_during_training\n # if the current settings for the unet are not correct\n # for cascading DDPM, then reinit the unet with the right settings\n def cast_model_parameters(\n self,\n *,\n lowres_cond,\n lowres_noise_cond,\n channels,\n channels_out,\n cond_on_image_embeds,\n cond_on_text_encodings,\n ):\n if lowres_cond == self.lowres_cond and \\\n channels == self.channels and \\\n cond_on_image_embeds == self.cond_on_image_embeds and \\\n cond_on_text_encodings == self.cond_on_text_encodings and \\\n lowres_noise_cond == self.lowres_noise_cond and \\\n channels_out == self.channels_out:\n return self\n updated_kwargs = dict(\n lowres_cond = lowres_cond,\n channels = channels,\n channels_out = channels_out,\n cond_on_image_embeds = cond_on_image_embeds,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2118-2146"
+ },
+ "275": {
+ "file_id": 6,
+ "content": "This code function checks if the current unet model parameters are correct for cascading DDPM. If not, it reinitializes the unet with the new settings. The parameters being checked include lowres_cond, channels, cond_on_image_embeds, and cond_on_text_encodings.",
+ "type": "comment"
+ },
+ "276": {
+ "file_id": 6,
+ "content": " cond_on_text_encodings = cond_on_text_encodings,\n lowres_noise_cond = lowres_noise_cond\n )\n return self.__class__(**{**self._locals, **updated_kwargs})\n def forward_with_cond_scale(\n self,\n *args,\n cond_scale = 1.,\n **kwargs\n ):\n logits = self.forward(*args, **kwargs)\n if cond_scale == 1:\n return logits\n null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)\n return null_logits + (logits - null_logits) * cond_scale\n def forward(\n self,\n x,\n time,\n *,\n image_embed,\n lowres_cond_img = None,\n lowres_noise_level = None,\n text_encodings = None,\n image_cond_drop_prob = 0.,\n text_cond_drop_prob = 0.,\n blur_sigma = None,\n blur_kernel_size = None,\n disable_checkpoint = False,\n self_cond = None\n ):\n batch_size, device = x.shape[0], x.device\n # add low resolution conditioning, if present",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2147-2185"
+ },
+ "277": {
+ "file_id": 6,
+ "content": "This code defines a class with forward, forward_with_cond_scale methods that take various parameters and perform image processing operations. The forward method calculates logits based on input images, time, image embeddings, and other optional parameters. The forward_with_cond_scale method applies conditional scaling to the logits calculated by the forward method.",
+ "type": "comment"
+ },
+ "278": {
+ "file_id": 6,
+ "content": " assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'\n # concat self conditioning, if needed\n if self.self_cond:\n self_cond = default(self_cond, lambda: torch.zeros_like(x))\n x = torch.cat((x, self_cond), dim = 1)\n # concat low resolution conditioning\n if exists(lowres_cond_img):\n x = torch.cat((x, lowres_cond_img), dim = 1)\n # initial convolution\n x = self.init_conv(x)\n r = x.clone() # final residual\n # time conditioning\n time = time.type_as(x)\n time_hiddens = self.to_time_hiddens(time)\n time_tokens = self.to_time_tokens(time_hiddens)\n t = self.to_time_cond(time_hiddens)\n # low res noise conditioning (similar to time above)\n if exists(lowres_noise_level):\n assert exists(self.to_lowres_noise_cond), 'lowres_noise_cond must be set to True on instantiation of the unet in order to conditiong on lowres noise'",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2187-2216"
+ },
+ "279": {
+ "file_id": 6,
+ "content": "The code checks if low resolution conditioning image exists and appends it to the input. It then concatenates self-conditioning, initializes a convolution, clones the input for residual calculations, performs time conditioning, and applies low resolution noise conditioning (if enabled).",
+ "type": "comment"
+ },
+ "280": {
+ "file_id": 6,
+ "content": " lowres_noise_level = lowres_noise_level.type_as(x)\n t = t + self.to_lowres_noise_cond(lowres_noise_level)\n # conditional dropout\n image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)\n text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)\n text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')\n # image embedding to be summed to time embedding\n # discovered by @mhh0318 in the paper\n if exists(image_embed) and exists(self.to_image_hiddens):\n image_hiddens = self.to_image_hiddens(image_embed)\n image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')\n null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)\n image_hiddens = torch.where(\n image_keep_mask_hidden,\n image_hiddens,\n null_image_hiddens\n )\n t = t + image_hiddens\n # mask out image embedding depending on condition dropout",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2217-2243"
+ },
+ "281": {
+ "file_id": 6,
+ "content": "This code performs conditional dropout by maintaining image and text masks, checks if an image embedding exists, applies a conditional dropout to the image embedding based on the masks, and adds it to the time embedding.",
+ "type": "comment"
+ },
+ "282": {
+ "file_id": 6,
+ "content": " # for classifier free guidance\n image_tokens = None\n if self.cond_on_image_embeds:\n image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')\n image_tokens = self.image_to_tokens(image_embed)\n null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working\n image_tokens = torch.where(\n image_keep_mask_embed,\n image_tokens,\n null_image_embed\n )\n # take care of text encodings (optional)\n text_tokens = None\n if exists(text_encodings) and self.cond_on_text_encodings:\n assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {text_encodings.shape} - required batch size is {batch_size}'\n assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2244-2265"
+ },
+ "283": {
+ "file_id": 6,
+ "content": "This code chunk is setting up the input for a classifier-free guidance model. It checks if the image and text encodings are provided, and if so, prepares them for the model's input. If both the image embeddings and text encodings are present, it applies conditional guidance by masking the image tokens with the image_keep_mask and nullifying where needed. It asserts that the text encodings match the batch size and the expected embedding dimension of the model.",
+ "type": "comment"
+ },
+ "284": {
+ "file_id": 6,
+ "content": " text_mask = torch.any(text_encodings != 0., dim = -1)\n text_tokens = self.text_to_cond(text_encodings)\n text_tokens = text_tokens[:, :self.max_text_len]\n text_mask = text_mask[:, :self.max_text_len]\n text_tokens_len = text_tokens.shape[1]\n remainder = self.max_text_len - text_tokens_len\n if remainder > 0:\n text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))\n text_mask = F.pad(text_mask, (0, remainder), value = False)\n text_mask = rearrange(text_mask, 'b n -> b n 1')\n assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}. text encoding is of shape {text_encodings.shape}'\n text_keep_mask = text_mask & text_keep_mask\n null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working\n text_tokens = torch.where(",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2267-2288"
+ },
+ "285": {
+ "file_id": 6,
+ "content": "This code snippet is preparing text_tokens for the model by applying padding and ensuring correct shape. It creates a binary mask (text_mask) from the non-zero elements in text_encodings to indicate which tokens are present, then applies this mask to both text_tokens and text_keep_mask. The code also checks if there's remaining space in the max_text_len and pads text_tokens accordingly. Lastly, it asserts that the shapes of text_mask and text_keep_mask match before combining them using a logical AND operation.",
+ "type": "comment"
+ },
+ "286": {
+ "file_id": 6,
+ "content": " text_keep_mask,\n text_tokens,\n null_text_embed\n )\n # main conditioning tokens (c)\n c = time_tokens\n if exists(image_tokens):\n c = torch.cat((c, image_tokens), dim = -2)\n # text and image conditioning tokens (mid_c)\n # to save on compute, only do cross attention based conditioning on the inner most layers of the Unet\n mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)\n # normalize conditioning tokens\n c = self.norm_cond(c)\n mid_c = self.norm_mid_cond(mid_c)\n # gradient checkpointing\n can_checkpoint = self.training and self.checkpoint_during_training and not disable_checkpoint\n apply_checkpoint_fn = make_checkpointable if can_checkpoint else identity\n # make checkpointable modules\n init_resnet_block, mid_block1, mid_attn, mid_block2, final_resnet_block = [maybe(apply_checkpoint_fn)(module) for module in (self.init_resnet_block, self.mid_block1, self.mid_attn, self.mid_block2, self.final_resnet_block)]",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2289-2318"
+ },
+ "287": {
+ "file_id": 6,
+ "content": "This code snippet is part of the DALLE2-pytorch model, responsible for handling conditioning tokens (main and auxiliary) for image and text inputs. The code normalizes these tokens using `self.norm_cond` and `self.norm_mid_cond`, applies gradient checkpointing, and makes certain modules (e.g., `self.init_resnet_block`) checkpointable based on training parameters. This helps to optimize the model's computation during inference and improve its performance.",
+ "type": "comment"
+ },
+ "288": {
+ "file_id": 6,
+ "content": " can_checkpoint_cond = lambda m: isinstance(m, ResnetBlock)\n downs, ups = [maybe(apply_checkpoint_fn)(m, condition = can_checkpoint_cond) for m in (self.downs, self.ups)]\n # initial resnet block\n if exists(init_resnet_block):\n x = init_resnet_block(x, t)\n # go through the layers of the unet, down and up\n down_hiddens = []\n up_hiddens = []\n for pre_downsample, init_block, resnet_blocks, attn, post_downsample in downs:\n if exists(pre_downsample):\n x = pre_downsample(x)\n x = init_block(x, t, c)\n for resnet_block in resnet_blocks:\n x = resnet_block(x, t, c)\n down_hiddens.append(x.contiguous())\n x = attn(x)\n down_hiddens.append(x.contiguous())\n if exists(post_downsample):\n x = post_downsample(x)\n x = mid_block1(x, t, mid_c)\n if exists(mid_attn):\n x = mid_attn(x)\n x = mid_block2(x, t, mid_c)\n ",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2320-2356"
+ },
+ "289": {
+ "file_id": 6,
+ "content": "This code initializes a U-Net model by iterating over its components. It applies pre-downsample, initial block, and resnet blocks to the input x. Then, it adds hidden representations of down and up stages into separate lists. After that, it passes x through an attention module and potentially post-downsample. Finally, it processes x with two more blocks, possibly applies mid-attention, and returns the final result.",
+ "type": "comment"
+ },
+ "290": {
+ "file_id": 6,
+ "content": " connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)\n for init_block, resnet_blocks, attn, upsample in ups:\n x = connect_skip(x)\n x = init_block(x, t, c)\n for resnet_block in resnet_blocks:\n x = connect_skip(x)\n x = resnet_block(x, t, c)\n x = attn(x)\n up_hiddens.append(x.contiguous())\n x = upsample(x)\n x = self.upsample_combiner(x, up_hiddens)\n x = torch.cat((x, r), dim = 1)\n x = final_resnet_block(x, t)\n if exists(lowres_cond_img):\n x = torch.cat((x, lowres_cond_img), dim = 1)\n return self.to_out(x)\nclass LowresConditioner(nn.Module):\n def __init__(\n self,\n downsample_first = True,\n use_blur = True,\n blur_prob = 0.5,\n blur_sigma = 0.6,\n blur_kernel_size = 3,\n use_noise = False,\n input_image_range = None,\n normalize_img_fn = identity,\n unnormalize_img_fn = identity",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2356-2393"
+ },
+ "291": {
+ "file_id": 6,
+ "content": "This code defines a class for processing input images, which consists of an upscaling network and a low-resolution conditioner. The upscaling network takes in a low-resolution image and upscales it using skip connections and residual blocks. The low-resolution conditioner can optionally take a low-resolution version of the input image as additional input. The final output is passed through an activation function before being returned.",
+ "type": "comment"
+ },
+ "292": {
+ "file_id": 6,
+ "content": " ):\n super().__init__()\n self.downsample_first = downsample_first\n self.input_image_range = input_image_range\n self.use_blur = use_blur\n self.blur_prob = blur_prob\n self.blur_sigma = blur_sigma\n self.blur_kernel_size = blur_kernel_size\n self.use_noise = use_noise\n self.normalize_img = normalize_img_fn\n self.unnormalize_img = unnormalize_img_fn\n self.noise_scheduler = NoiseScheduler(beta_schedule = 'linear', timesteps = 1000, loss_type = 'l2') if use_noise else None\n def noise_image(self, cond_fmap, noise_levels = None):\n assert exists(self.noise_scheduler)\n batch = cond_fmap.shape[0]\n cond_fmap = self.normalize_img(cond_fmap)\n random_noise_levels = default(noise_levels, lambda: self.noise_scheduler.sample_random_times(batch))\n cond_fmap = self.noise_scheduler.q_sample(cond_fmap, t = random_noise_levels, noise = torch.randn_like(cond_fmap))\n cond_fmap = self.unnormalize_img(cond_fmap)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2394-2418"
+ },
+ "293": {
+ "file_id": 6,
+ "content": "This code initializes an object with various parameters, including downsampling, image range, and noise-related options. It also includes methods for generating noise images based on the given parameters. The class utilizes normalization and denormalization functions as well as a NoiseScheduler instance to apply noise to the input condition maps.",
+ "type": "comment"
+ },
+ "294": {
+ "file_id": 6,
+ "content": " return cond_fmap, random_noise_levels\n def forward(\n self,\n cond_fmap,\n *,\n target_image_size,\n downsample_image_size = None,\n should_blur = True,\n blur_sigma = None,\n blur_kernel_size = None\n ):\n if self.downsample_first and exists(downsample_image_size):\n cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = True)\n # blur is only applied 50% of the time\n # section 3.1 in https://arxiv.org/abs/2106.15282\n if self.use_blur and should_blur and random.random() < self.blur_prob:\n # when training, blur the low resolution conditional image\n blur_sigma = default(blur_sigma, self.blur_sigma)\n blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)\n # allow for drawing a random sigma between lo and hi float values\n if isinstance(blur_sigma, tuple):\n blur_sigma = tuple(map(float, blur_sigma))",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2419-2447"
+ },
+ "295": {
+ "file_id": 6,
+ "content": "This function takes a conditional feature map and optional parameters to resize, blur, and downsample the image. The code checks if downsampling is needed first, then decides whether to apply blurring based on a probability setting. Blur sigma and kernel size are also set based on default values or user input.",
+ "type": "comment"
+ },
+ "296": {
+ "file_id": 6,
+ "content": " blur_sigma = random.uniform(*blur_sigma)\n # allow for drawing a random kernel size between lo and hi int values\n if isinstance(blur_kernel_size, tuple):\n blur_kernel_size = tuple(map(int, blur_kernel_size))\n kernel_size_lo, kernel_size_hi = blur_kernel_size\n blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)\n cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))\n # resize to target image size\n cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range, nearest = True)\n # noise conditioning, as done in Imagen\n # as a replacement for the BSR noising, and potentially replace blurring for first stage too\n random_noise_levels = None\n if self.use_noise:\n cond_fmap, random_noise_levels = self.noise_image(cond_fmap)\n # return conditioning feature map, as well as the augmentation noise levels",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2448-2471"
+ },
+ "297": {
+ "file_id": 6,
+ "content": "This code performs image conditioning by applying Gaussian blur and noise addition, then resizes the image to a target size. The blurring and noise addition are optional depending on the use_noise flag, and the final result is returned along with any applied random noise levels.",
+ "type": "comment"
+ },
+ "298": {
+ "file_id": 6,
+ "content": " return cond_fmap, random_noise_levels\nclass Decoder(nn.Module):\n def __init__(\n self,\n unet,\n *,\n clip = None,\n image_size = None,\n channels = 3,\n vae = tuple(),\n timesteps = 1000,\n sample_timesteps = None,\n image_cond_drop_prob = 0.1,\n text_cond_drop_prob = 0.5,\n loss_type = 'l2',\n beta_schedule = None,\n predict_x_start = False,\n predict_v = False,\n predict_x_start_for_latent_diffusion = False,\n image_sizes = None, # for cascading ddpm, image size at each stage\n random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)\n use_noise_for_lowres_cond = False, # whether to use Imagen-like noising for low resolution conditioning \n use_blur_for_lowres_cond = True, # whether to use the blur conditioning used in the original cascading ddpm paper, as well as DALL-E2",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2473-2496"
+ },
+ "299": {
+ "file_id": 6,
+ "content": "The code defines a Decoder class that takes various parameters like unet, clip, image_size, channels, vae, timesteps, sample_timesteps, image_cond_drop_prob, text_cond_drop_prob, loss_type, beta_schedule, predict_x_start, predict_v, predict_x_start_for_latent_diffusion, image_sizes, random_crop_sizes, use_noise_for_lowres_cond, and use_blur_for_lowres_cond. It returns cond_fmap and random_noise_levels.",
+ "type": "comment"
+ }
+}
\ No newline at end of file
diff --git a/docs/data/3.json b/docs/data/3.json
new file mode 100644
index 00000000..7e531895
--- /dev/null
+++ b/docs/data/3.json
@@ -0,0 +1,549 @@
+{
+ "300": {
+ "file_id": 6,
+ "content": " lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur\n blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time\n blur_sigma = 0.6, # cascading ddpm - blur sigma\n blur_kernel_size = 3, # cascading ddpm - blur kernel size\n lowres_noise_sample_level = 0.2, # in imagen paper, they use a 0.2 noise level at sample time for low resolution conditioning\n clip_denoised = True,\n clip_x_start = True,\n clip_adapter_overrides = dict(),\n learned_variance = True,\n learned_variance_constrain_frac = False,\n vb_loss_weight = 0.001,\n unconditional = False, # set to True for generating images without conditioning\n auto_normalize_img = True, # whether to take care of normalizing the i",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2497-2509"
+ },
+ "301": {
+ "file_id": 6,
+ "content": "This code snippet is responsible for configuring the settings for a denoising diffusion probabilistic model (DDPM) in the DALLE2-pytorch project. The settings include cascading DDPM parameters, noise level at sample time, clip options, learned variance configuration, and unconditional image generation toggles.",
+ "type": "comment"
+ },
+ "302": {
+ "file_id": 6,
+ "content": "mage from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader\n use_dynamic_thres = False, # from the Imagen paper\n dynamic_thres_percentile = 0.95,\n p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended\n p2_loss_weight_k = 1,\n ddim_sampling_eta = 0. # can be set to 0. for deterministic sampling afaict\n ):\n super().__init__()\n # clip\n self.clip = None\n if exists(clip):\n assert not unconditional, 'clip must not be given if doing unconditional image training'\n assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'\n if isinstance(clip, CLIP):\n clip = XClipAdapter(clip, **clip_adapter_overrides)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2509-2526"
+ },
+ "303": {
+ "file_id": 6,
+ "content": "The code initializes an object with various parameters such as use_dynamic_thres, dynamic_thres_percentile, p2_loss_weight_gamma, p2_loss_weight_k, ddim_sampling_eta, and clip. It also checks if the 'clip' parameter is given and performs necessary assertions. If 'clip' exists and unconditional image training is not being done, it ensures the channels match with CLIP's accepted channels. It also uses XClipAdapter for compatibility with additional overrides.",
+ "type": "comment"
+ },
+ "304": {
+ "file_id": 6,
+ "content": " elif isinstance(clip, CoCa):\n clip = CoCaAdapter(clip, **clip_adapter_overrides)\n freeze_model_and_make_eval_(clip)\n assert isinstance(clip, BaseClipAdapter)\n self.clip = clip\n # determine image size, with image_size and image_sizes taking precedence\n if exists(image_size) or exists(image_sizes):\n assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'\n image_size = default(image_size, lambda: image_sizes[-1])\n elif exists(clip):\n image_size = clip.image_size\n else:\n raise Error('either image_size, image_sizes, or clip must be given to decoder')\n # channels\n self.channels = channels\n # normalize and unnormalize image functions\n self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity\n self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity\n # verify conditioning method",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2527-2555"
+ },
+ "305": {
+ "file_id": 6,
+ "content": "The code checks the input 'clip' type and applies the CoCaAdapter if it's an instance of CoCa. It then freezes the model for evaluation, ensures 'clip' is a BaseClipAdapter instance, and assigns it to self.clip. The image_size is determined from either 'image_size', 'image_sizes', or 'clip'. It sets the 'channels', 'normalize_img', and 'unnormalize_img' based on given parameters.",
+ "type": "comment"
+ },
+ "306": {
+ "file_id": 6,
+ "content": " unets = cast_tuple(unet)\n num_unets = len(unets)\n self.num_unets = num_unets\n self.unconditional = unconditional\n # automatically take care of ensuring that first unet is unconditional\n # while the rest of the unets are conditioned on the low resolution image produced by previous unet\n vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))\n # whether to use learned variance, defaults to True for the first unet in the cascade, as in paper\n learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)\n self.learned_variance = learned_variance\n self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1\n self.vb_loss_weight = vb_loss_weight\n # default and validate conditioning parameters\n use_noise_for_lowres_cond = cast_tuple(use_noise_for_lowres_cond, num_unets - 1, validate = False)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2557-2577"
+ },
+ "307": {
+ "file_id": 6,
+ "content": "This code initializes the U-Nets and VAEs for a DALL-E 2 model. It sets the number of unets, whether they are unconditional or conditioned on previous unets, and their learned variance. It also sets default parameters for conditioning with noise and constrains the output of the network from 0 to 1.",
+ "type": "comment"
+ },
+ "308": {
+ "file_id": 6,
+ "content": " use_blur_for_lowres_cond = cast_tuple(use_blur_for_lowres_cond, num_unets - 1, validate = False)\n if len(use_noise_for_lowres_cond) < num_unets:\n use_noise_for_lowres_cond = (False, *use_noise_for_lowres_cond)\n if len(use_blur_for_lowres_cond) < num_unets:\n use_blur_for_lowres_cond = (False, *use_blur_for_lowres_cond)\n assert not use_noise_for_lowres_cond[0], 'first unet will never need low res noise conditioning'\n assert not use_blur_for_lowres_cond[0], 'first unet will never need low res blur conditioning'\n assert num_unets == 1 or all((use_noise or use_blur) for use_noise, use_blur in zip(use_noise_for_lowres_cond[1:], use_blur_for_lowres_cond[1:]))\n # construct unets and vaes\n self.unets = nn.ModuleList([])\n self.vaes = nn.ModuleList([])\n for ind, (one_unet, one_vae, one_unet_learned_var, lowres_noise_cond) in enumerate(zip(unets, vaes, learned_variance, use_noise_for_lowres_cond)):\n assert isinstance(one_unet, Unet)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2578-2597"
+ },
+ "309": {
+ "file_id": 6,
+ "content": "This code is setting up Unets and Vaes for a model. It ensures that the lists of noise conditions and blur conditions are long enough to correspond to each Unet, adds the Unets and Vaes to module lists, and asserts that at least one Unet will not need low res noise or blur conditioning.",
+ "type": "comment"
+ },
+ "310": {
+ "file_id": 6,
+ "content": " assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))\n is_first = ind == 0\n latent_dim = one_vae.encoded_dim if exists(one_vae) else None\n unet_channels = default(latent_dim, self.channels)\n unet_channels_out = unet_channels * (1 if not one_unet_learned_var else 2)\n one_unet = one_unet.cast_model_parameters(\n lowres_cond = not is_first,\n lowres_noise_cond = lowres_noise_cond,\n cond_on_image_embeds = not unconditional and is_first,\n cond_on_text_encodings = not unconditional and one_unet.cond_on_text_encodings,\n channels = unet_channels,\n channels_out = unet_channels_out\n )\n self.unets.append(one_unet)\n self.vaes.append(one_vae.copy_for_eval())\n # sampling timesteps, defaults to non-ddim with full timesteps sampling\n self.sample_timesteps = cast_tuple(sample_timesteps, num_unets)\n self.ddim_sampling_eta = ddim_sampling_eta",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2598-2621"
+ },
+ "311": {
+ "file_id": 6,
+ "content": "This code block appends a new VAE instance to the list of VAEs and a copied evaluation version of that VAE to the VAEs list. The code also sets the sampling timesteps and ddim_sampling_eta based on the input parameters.",
+ "type": "comment"
+ },
+ "312": {
+ "file_id": 6,
+ "content": " # create noise schedulers per unet\n if not exists(beta_schedule):\n beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1)))\n beta_schedule = cast_tuple(beta_schedule, num_unets)\n p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)\n self.noise_schedulers = nn.ModuleList([])\n for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):\n assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'\n noise_scheduler = NoiseScheduler(\n beta_schedule = unet_beta_schedule,\n timesteps = timesteps,\n loss_type = loss_type,\n p2_loss_weight_gamma = unet_p2_loss_weight_gamma,\n p2_loss_weight_k = p2_loss_weight_k",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2623-2641"
+ },
+ "313": {
+ "file_id": 6,
+ "content": "This code creates noise schedulers for each unet, based on the provided beta schedule and loss weight gamma. It asserts that sampling timesteps must be less than or equal to the number of training timesteps, and initializes a NoiseScheduler object with the specified parameters for each unet.",
+ "type": "comment"
+ },
+ "314": {
+ "file_id": 6,
+ "content": " )\n self.noise_schedulers.append(noise_scheduler)\n # unet image sizes\n image_sizes = default(image_sizes, (image_size,))\n image_sizes = tuple(sorted(set(image_sizes)))\n assert self.num_unets == len(image_sizes), f'you did not supply the correct number of u-nets ({self.num_unets}) for resolutions {image_sizes}'\n self.image_sizes = image_sizes\n self.sample_channels = cast_tuple(self.channels, len(image_sizes))\n # random crop sizes (for super-resoluting unets at the end of cascade?)\n self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes))\n assert not exists(self.random_crop_sizes[0]), 'you would not need to randomly crop the image for the base unet'\n # predict x0 config\n self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))\n # predict v\n self.predict_v = cast_tuple(predict_v, len(unets))",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2642-2666"
+ },
+ "315": {
+ "file_id": 6,
+ "content": "This code is setting up the parameters for a model. It creates noise schedulers, defines image sizes and crop sizes for different resolutions, and configures predicting x0 and v values. These settings will be used to train or use the model. The code also performs assertions to ensure that the correct number of unets and vaes are provided for each resolution.",
+ "type": "comment"
+ },
+ "316": {
+ "file_id": 6,
+ "content": " # input image range\n self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)\n # cascading ddpm related stuff\n lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))\n assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'\n self.lowres_conds = nn.ModuleList([])\n for unet_index, use_noise, use_blur in zip(range(num_unets), use_noise_for_lowres_cond, use_blur_for_lowres_cond):\n if unet_index == 0:\n self.lowres_conds.append(None)\n continue\n lowres_cond = LowresConditioner(\n downsample_first = lowres_downsample_first,\n use_blur = use_blur,\n use_noise = use_noise,\n blur_prob = blur_prob,\n blur_sigma = blur_sigma,\n blur_kernel_size = blur_kernel_size,\n input_image_range = self.input_image_range,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2668-2691"
+ },
+ "317": {
+ "file_id": 6,
+ "content": "The code initializes the input image range and handles lowres_cond for each unet in the model. It ensures that the first unet is unconditioned, while the rest have `lowres_cond` set to True. The `LowresConditioner` class is used with specified parameters for downsampling, blurring, and input image range.",
+ "type": "comment"
+ },
+ "318": {
+ "file_id": 6,
+ "content": " normalize_img_fn = self.normalize_img,\n unnormalize_img_fn = self.unnormalize_img\n )\n self.lowres_conds.append(lowres_cond)\n self.lowres_noise_sample_level = lowres_noise_sample_level\n # classifier free guidance\n self.image_cond_drop_prob = image_cond_drop_prob\n self.text_cond_drop_prob = text_cond_drop_prob\n self.can_classifier_guidance = image_cond_drop_prob > 0. or text_cond_drop_prob > 0.\n # whether to clip when sampling\n self.clip_denoised = clip_denoised\n self.clip_x_start = clip_x_start\n # dynamic thresholding settings, if clipping denoised during sampling\n self.use_dynamic_thres = use_dynamic_thres\n self.dynamic_thres_percentile = dynamic_thres_percentile\n # device tracker\n self.register_buffer('_dummy', torch.Tensor([True]), persistent = False)\n @property\n def device(self):\n return self._dummy.device\n @property\n def condition_on_text_encodings(self):",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2692-2725"
+ },
+ "319": {
+ "file_id": 6,
+ "content": "This code is setting up parameters and functions for an image generation model. It includes normalization and unnormalization functions, lowres noise sample level, classifier free guidance settings, clipping options during sampling, dynamic thresholding settings, and device management. The model can condition on text encodings and uses a device tracker to keep track of device information.",
+ "type": "comment"
+ },
+ "320": {
+ "file_id": 6,
+ "content": " return any([unet.cond_on_text_encodings for unet in self.unets if isinstance(unet, Unet)])\n def get_unet(self, unet_number):\n assert 0 < unet_number <= self.num_unets\n index = unet_number - 1\n return self.unets[index]\n def parse_unet_output(self, learned_variance, output):\n var_interp_frac_unnormalized = None\n if learned_variance:\n output, var_interp_frac_unnormalized = output.chunk(2, dim = 1)\n return UnetOutput(output, var_interp_frac_unnormalized)\n @contextmanager\n def one_unet_in_gpu(self, unet_number = None, unet = None):\n assert exists(unet_number) ^ exists(unet)\n if exists(unet_number):\n unet = self.get_unet(unet_number)\n # devices\n cuda, cpu = torch.device('cuda'), torch.device('cpu')\n self.cuda()\n devices = [module_device(unet) for unet in self.unets]\n self.unets.to(cpu)\n unet.to(cuda)\n yield\n for unet, device in zip(self.unets, devices):\n unet.to(device)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2726-2762"
+ },
+ "321": {
+ "file_id": 6,
+ "content": "This code defines methods for working with a collection of UNET models. The `get_unet` method retrieves a specific UNET based on its number, ensuring it is within the valid range. `parse_unet_output` parses the output of a UNET, interpreting learned variance if present. The `one_unet_in_gpu` context manager allows running inference for one UNET on the GPU while keeping other UNETs on the CPU.",
+ "type": "comment"
+ },
+ "322": {
+ "file_id": 6,
+ "content": " def dynamic_threshold(self, x):\n \"\"\" proposed in https://arxiv.org/abs/2205.11487 as an improved clamping in the setting of classifier free guidance \"\"\"\n # s is the threshold amount\n # static thresholding would just be s = 1\n s = 1.\n if self.use_dynamic_thres:\n s = torch.quantile(\n rearrange(x, 'b ... -> b (...)').abs(),\n self.dynamic_thres_percentile,\n dim = -1\n )\n s.clamp_(min = 1.)\n s = s.view(-1, *((1,) * (x.ndim - 1)))\n # clip by threshold, depending on whether static or dynamic\n x = x.clamp(-s, s) / s\n return x\n def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):\n assert not (cond_scale != 1. and not self.",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2764-2785"
+ },
+ "323": {
+ "file_id": 6,
+ "content": "This code snippet defines a function `dynamic_threshold` and `p_mean_variance`. The `dynamic_threshold` function adjusts the threshold for clamping based on the input's quantile values. It uses static thresholding (s=1) by default, but can be set to dynamic thresholding if `self.use_dynamic_thres` is true. The `p_mean_variance` function performs classifier-free guidance for image generation and includes options for mean/variance prediction, conditioning, noise scheduling, and more.",
+ "type": "comment"
+ },
+ "324": {
+ "file_id": 6,
+ "content": "can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'\n model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))\n pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)\n if predict_v:\n x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)\n elif predict_x_start:\n x_start = pred\n else:\n x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)\n if clip_denoised:\n x_start = self.dynamic_threshold(x_start)\n model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)\n if learned_variance:",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2785-2803"
+ },
+ "325": {
+ "file_id": 6,
+ "content": "This code block is responsible for decoding an input image using a pre-trained unet model. It applies classifier free guidance if enabled, and then calculates the mean and variance of the posterior distribution to perform denoising diffusion probability.",
+ "type": "comment"
+ },
+ "326": {
+ "file_id": 6,
+ "content": " # if learned variance, posterio variance and posterior log variance are predicted by the network\n # by an interpolation of the max and min log beta values\n # eq 15 - https://arxiv.org/abs/2102.09672\n min_log = extract(noise_scheduler.posterior_log_variance_clipped, t, x.shape)\n max_log = extract(torch.log(noise_scheduler.betas), t, x.shape)\n var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)\n if self.learned_variance_constrain_frac:\n var_interp_frac = var_interp_frac.sigmoid()\n posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log\n posterior_variance = posterior_log_variance.exp()\n return model_mean, posterior_variance, posterior_log_variance, x_start\n @torch.no_grad()\n def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2804-2820"
+ },
+ "327": {
+ "file_id": 6,
+ "content": "This code calculates the posterior variance and log variance for a model based on the maximum and minimum log beta values, as described in Equation 15 from arXiv paper. It also applies a learned constraint factor and uses sigmoid activation if required. The function returns the model mean, posterior variance, posterior log variance, and x_start.",
+ "type": "comment"
+ },
+ "328": {
+ "file_id": 6,
+ "content": "start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):\n b, *_, device = *x.shape, x.device\n model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)\n noise = torch.randn_like(x)\n # no noise when t == 0\n nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))\n pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n return pred, x_start\n @torch.no_grad()\n def p_sample_loop_ddpm(\n self,\n unet,\n shape,\n image_embed,\n noise_scheduler,\n predict_x_start = False,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2820-2836"
+ },
+ "329": {
+ "file_id": 6,
+ "content": "This function takes input x and returns the predicted values pred and x_start. It uses a p_mean_variance method from self to calculate model_mean, model_log_variance, and x_start. Noise is added to the input x, except when t == 0. The result is the sum of model_mean and nonzero_mask * (0.5 * model_log_variance).exp() * noise. This is a part of the DDPM (Denoising Diffusion Probabilistic Models) framework for generating images.",
+ "type": "comment"
+ },
+ "330": {
+ "file_id": 6,
+ "content": " predict_v = False,\n learned_variance = False,\n clip_denoised = True,\n lowres_cond_img = None,\n text_encodings = None,\n cond_scale = 1,\n is_latent_diffusion = False,\n lowres_noise_level = None,\n inpaint_image = None,\n inpaint_mask = None,\n inpaint_resample_times = 5\n ):\n device = self.device\n b = shape[0]\n img = torch.randn(shape, device = device)\n x_start = None # for self-conditioning\n is_inpaint = exists(inpaint_image)\n resample_times = inpaint_resample_times if is_inpaint else 1\n if is_inpaint:\n inpaint_image = self.normalize_img(inpaint_image)\n inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)\n inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()\n inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)\n inpaint_mask = inpaint_mask.bool()\n if not is_latent_diffusion:",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2837-2866"
+ },
+ "331": {
+ "file_id": 6,
+ "content": "This function initializes image and related variables. If inpainting is present, it normalizes and resizes the image, mask, and sets their dimensions accordingly. The function also determines if the model is performing latent diffusion by checking for provided parameters. It then proceeds to an if-not condition where it assumes that the model is not performing latent diffusion.",
+ "type": "comment"
+ },
+ "332": {
+ "file_id": 6,
+ "content": " lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)\n for time in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):\n is_last_timestep = time == 0\n for r in reversed(range(0, resample_times)):\n is_last_resample_step = r == 0\n times = torch.full((b,), time, device = device, dtype = torch.long)\n if is_inpaint:\n # following the repaint paper\n # https://arxiv.org/abs/2201.09865\n noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)\n img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)\n self_cond = x_start if unet.self_cond else None\n img, x_start = self.p_sample(\n unet,\n img,\n times,\n image_embed = image_embed,\n text_encodings = text_encodings,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2867-2890"
+ },
+ "333": {
+ "file_id": 6,
+ "content": "This code performs progressive growing of an image using a diffusion model, such as DALLE 2. It iterates over timesteps in reverse order and resamples each timestep to produce a final output image. It also includes the option for inpainting by following the Repaint paper's approach. The self-conditioning and U-Net are utilized within the p_sample function, which takes care of the actual sampling process.",
+ "type": "comment"
+ },
+ "334": {
+ "file_id": 6,
+ "content": " cond_scale = cond_scale,\n self_cond = self_cond,\n lowres_cond_img = lowres_cond_img,\n lowres_noise_level = lowres_noise_level,\n predict_x_start = predict_x_start,\n predict_v = predict_v,\n noise_scheduler = noise_scheduler,\n learned_variance = learned_variance,\n clip_denoised = clip_denoised\n )\n if is_inpaint and not (is_last_timestep or is_last_resample_step):\n # in repaint, you renoise and resample up to 10 times every step\n img = noise_scheduler.q_sample_from_to(img, times - 1, times)\n if is_inpaint:\n img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)\n unnormalize_img = self.unnormalize_img(img)\n return unnormalize_img\n @torch.no_grad()\n def p_sample_loop_ddim(\n self,\n unet,\n shape,\n image_embed,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2891-2917"
+ },
+ "335": {
+ "file_id": 6,
+ "content": "This code is part of a model that performs image denoising using diffusion models. It samples images at different timesteps, applies noise scheduling for resampling, and handles inpainting by combining input mask and image embeddings. The output is then unnormalized for the final result.",
+ "type": "comment"
+ },
+ "336": {
+ "file_id": 6,
+ "content": " noise_scheduler,\n timesteps,\n eta = 1.,\n predict_x_start = False,\n predict_v = False,\n learned_variance = False,\n clip_denoised = True,\n lowres_cond_img = None,\n text_encodings = None,\n cond_scale = 1,\n is_latent_diffusion = False,\n lowres_noise_level = None,\n inpaint_image = None,\n inpaint_mask = None,\n inpaint_resample_times = 5\n ):\n batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod, self.ddim_sampling_eta\n times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]\n times = list(reversed(times.int().tolist()))\n time_pairs = list(zip(times[:-1], times[1:]))\n time_pairs = list(filter(lambda t: t[0] > t[1], time_pairs))\n is_inpaint = exists(inpaint_image)\n resample_times = inpaint_resample_times if is_inpaint else 1\n if is_inpaint:\n inpaint_image = self.normalize_img(inpaint_image)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2918-2946"
+ },
+ "337": {
+ "file_id": 6,
+ "content": "This function takes multiple parameters including noise_scheduler, timesteps, eta, and more. It extracts necessary information like batch size, device, total timesteps, alphas, and other parameters to perform DDIM sampling. It also checks if inpainting is required and resamples times accordingly.",
+ "type": "comment"
+ },
+ "338": {
+ "file_id": 6,
+ "content": " inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)\n inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()\n inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)\n inpaint_mask = inpaint_mask.bool()\n img = torch.randn(shape, device = device)\n x_start = None # for self-conditioning\n if not is_latent_diffusion:\n lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)\n for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):\n is_last_timestep = time_next == 0\n for r in reversed(range(0, resample_times)):\n is_last_resample_step = r == 0\n alpha = alphas[time]\n alpha_next = alphas[time_next]\n time_cond = torch.full((batch,), time, device = device, dtype = torch.long)\n if is_inpaint:\n # following the repaint paper\n # https://arxiv.org/abs/2201.09865",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2947-2972"
+ },
+ "339": {
+ "file_id": 6,
+ "content": "The code is sampling from a diffusion model and applying inpainting. It resizes images, prepares masks for inpainting, sets up variables for time steps, and conditions the model based on inpainting or not. The code follows the process described in the Repaint paper (https://arxiv.org/abs/2201.09865).",
+ "type": "comment"
+ },
+ "340": {
+ "file_id": 6,
+ "content": " noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)\n img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)\n self_cond = x_start if unet.self_cond else None\n unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)\n pred, _ = self.parse_unet_output(learned_variance, unet_output)\n # predict x0\n if predict_v:\n x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)\n elif predict_x_start:\n x_start = pred\n else:\n x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)\n # maybe clip x0\n if clip_denoised:\n x_start = self.dynamic_threshold(x_start)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2973-2994"
+ },
+ "341": {
+ "file_id": 6,
+ "content": "This code is using a conditional image generation model to generate an output image based on the input image, conditioning factors (time_cond, image_embed, text_encodings), and possibly predicting x0 values for further processing or clipping.",
+ "type": "comment"
+ },
+ "342": {
+ "file_id": 6,
+ "content": " # predict noise\n pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)\n c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()\n c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()\n noise = torch.randn_like(img) if not is_last_timestep else 0.\n img = x_start * alpha_next.sqrt() + \\\n c1 * noise + \\\n c2 * pred_noise\n if is_inpaint and not (is_last_timestep or is_last_resample_step):\n # in repaint, you renoise and resample up to 10 times every step\n time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long)\n img = noise_scheduler.q_sample_from_to(img, time_next_cond, time_cond)\n if exists(inpaint_image):\n img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)\n img = self.unnormalize_img(img)\n return img",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:2996-3017"
+ },
+ "343": {
+ "file_id": 6,
+ "content": "Predicts noise based on the current state and time, applies coefficients to noise and image, performs inpainting if necessary, and unnormalizes the image.",
+ "type": "comment"
+ },
+ "344": {
+ "file_id": 6,
+ "content": " @torch.no_grad()\n def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):\n num_timesteps = noise_scheduler.num_timesteps\n timesteps = default(timesteps, num_timesteps)\n assert timesteps <= num_timesteps\n is_ddim = timesteps < num_timesteps\n if not is_ddim:\n return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **kwargs)\n return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)\n def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):\n noise = default(noise, lambda: torch.randn_like(x_start))\n # normalize to [-1, 1]\n if not is_latent_diffusion:\n x_start = self.normalize_img(x_start)\n lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:3019-3039"
+ },
+ "345": {
+ "file_id": 6,
+ "content": "Function `p_sample_loop` takes in arguments, determines if DDPM or DDIM should be used for sampling, and calls respective function.\nIn `p_losses`, noise is defaulted if not provided, and images are normalized before processing if not latent diffusion.",
+ "type": "comment"
+ },
+ "346": {
+ "file_id": 6,
+ "content": " # get x_t\n x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)\n # unet kwargs\n unet_kwargs = dict(\n image_embed = image_embed,\n text_encodings = text_encodings,\n lowres_cond_img = lowres_cond_img,\n lowres_noise_level = lowres_noise_level,\n )\n # self conditioning\n self_cond = None\n if unet.self_cond and random.random() < 0.5:\n with torch.no_grad():\n unet_output = unet(x_noisy, times, **unet_kwargs)\n self_cond, _ = self.parse_unet_output(learned_variance, unet_output)\n self_cond = self_cond.detach()\n # forward to get model prediction\n unet_output = unet(\n x_noisy,\n times,\n **unet_kwargs,\n self_cond = self_cond,\n image_cond_drop_prob = self.image_cond_drop_prob,\n text_cond_drop_prob = self.text_cond_drop_prob,\n )\n pred, _ = self.parse_unet_output(learned_variance, unet_output)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:3041-3075"
+ },
+ "347": {
+ "file_id": 6,
+ "content": "Code snippet is from the DALLE2-pytorch model. It samples noisy images and uses them to conditionally generate unet outputs for self-conditioning and prediction, with optional dropout probabilities for image and text conditions.",
+ "type": "comment"
+ },
+ "348": {
+ "file_id": 6,
+ "content": " if predict_v:\n target = noise_scheduler.calculate_v(x_start, times, noise)\n elif predict_x_start:\n target = x_start\n else:\n target = noise\n loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')\n loss = reduce(loss, 'b ... -> b (...)', 'mean')\n loss = noise_scheduler.p2_reweigh_loss(loss, times)\n loss = loss.mean()\n if not learned_variance:\n # return simple loss if not using learned variance\n return loss\n # most of the code below is transcribed from\n # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py\n # the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 \"simple\" loss\n # it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:3077-3098"
+ },
+ "349": {
+ "file_id": 6,
+ "content": "The code calculates the loss in a specific manner depending on the input parameters. If predict_v is true, it calculates the target value for v. If predict_x_start is true, it uses x_start as the target. Otherwise, it uses noise as the target. Then, it applies the loss function, reduces the loss, reweighs the loss based on times, and finally calculates the mean of the loss. If learned_variance is not used, it returns the simple loss.",
+ "type": "comment"
+ },
+ "350": {
+ "file_id": 6,
+ "content": " # if learning the variance, also include the extra weight kl loss\n true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)\n model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output)\n # kl loss with detached model predicted mean, for stability reasons as in paper\n detached_model_mean = model_mean.detach()\n kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)\n kl = meanflat(kl) * NAT\n decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)\n decoder_nll = meanflat(decoder_nll) * NAT\n # at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:3100-3115"
+ },
+ "351": {
+ "file_id": 6,
+ "content": "This code calculates the KL divergence between true and model predicted posterior distributions, and decoder negative log likelihood. It uses detached model predictions for stability reasons as per the paper. The loss at the first timestep is the decoder NLL, otherwise it's the KL divergence.",
+ "type": "comment"
+ },
+ "352": {
+ "file_id": 6,
+ "content": " vb_losses = torch.where(times == 0, decoder_nll, kl)\n # weight the vb loss smaller, for stability, as in the paper (recommended 0.001)\n vb_loss = vb_losses.mean() * self.vb_loss_weight\n return loss + vb_loss\n @torch.no_grad()\n @eval_decorator\n def sample(\n self,\n image = None,\n image_embed = None,\n text = None,\n text_encodings = None,\n batch_size = 1,\n cond_scale = 1.,\n start_at_unet_number = 1,\n stop_at_unet_number = None,\n distributed = False,\n inpaint_image = None,\n inpaint_mask = None,\n inpaint_resample_times = 5,\n one_unet_in_gpu_at_time = True\n ):\n assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'\n if not self.unconditional:\n batch_size = image_embed.shape[0]\n if exists(text) and not exists(text_encodings) and not self.unconditional:\n assert exists(self.clip)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:3117-3149"
+ },
+ "353": {
+ "file_id": 6,
+ "content": "This function calculates the variational Bayes loss and adds it to the main loss. It then samples from the model given input parameters such as image, text, batch size, etc., with option for conditional or unconditional sampling. The function also performs some assertions on the inputs to ensure proper usage.",
+ "type": "comment"
+ },
+ "354": {
+ "file_id": 6,
+ "content": " _, text_encodings = self.clip.embed_text(text)\n assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'\n assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'\n assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'\n img = None\n if start_at_unet_number > 1:\n # Then we are not generating the first image and one must have been passed in\n assert exists(image), 'image must be passed in if starting at unet number > 1'\n assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)\n prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]\n img = resize_image_to(image, prev_unet_output_size, nearest = True)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:3150-3163"
+ },
+ "355": {
+ "file_id": 6,
+ "content": "This code checks for valid inputs and asserts whether text, text encodings, or inpaint_image and mask are present based on the condition specified. It also ensures that the image input has the correct batch size when starting at a specific unet number. If necessary, it resizes the image using nearest-neighbor interpolation.",
+ "type": "comment"
+ },
+ "356": {
+ "file_id": 6,
+ "content": " is_cuda = next(self.parameters()).is_cuda\n num_unets = self.num_unets\n cond_scale = cast_tuple(cond_scale, num_unets)\n for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):\n if unet_number < start_at_unet_number:\n continue # It's the easiest way to do it\n context = self.one_unet_in_gpu(unet = unet) if is_cuda and one_unet_in_gpu_at_time else null_context()\n with context:\n # prepare low resolution conditioning for upsamplers\n lowres_cond_img = lowres_noise_level = None\n shape = (batch_size, channel, image_size, image_size)\n if unet.lowres_cond:",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:3165-3182"
+ },
+ "357": {
+ "file_id": 6,
+ "content": "This code is iterating through each unet in the model, skipping the first X unets based on a given parameter. It checks if the current unet should be processed based on its position, and then prepares low resolution conditioning for upsamplers if required. The code also handles CUDA processing and uses context managers to ensure efficient resource usage.",
+ "type": "comment"
+ },
+ "358": {
+ "file_id": 6,
+ "content": " lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = True)\n if lowres_cond.use_noise:\n lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device)\n lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level)\n # latent diffusion\n is_latent_diffusion = isinstance(vae, VQGanVAE)\n image_size = vae.get_encoded_fmap_size(image_size)\n shape = (batch_size, vae.encoded_dim, image_size, image_size)\n lowres_cond_img = maybe(vae.encode)(lowres_cond_img)\n # denoising loop for image\n img = self.p_sample_loop(\n unet,\n shape,\n image_embed = image_embed,\n text_encodings = text_encodings,\n cond_scale = unet_cond_scale,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:3183-3204"
+ },
+ "359": {
+ "file_id": 6,
+ "content": "This code is part of a denoising diffusion model. It first resizes the input image to a target size and applies noise if needed. Then, it checks if the VAE (Variational Autoencoder) is used for latent diffusion and adjusts the image size accordingly. Finally, it encodes the low-resolution image using the VAE and enters a denoising loop with a UNet model to generate the final output image.",
+ "type": "comment"
+ },
+ "360": {
+ "file_id": 6,
+ "content": " predict_x_start = predict_x_start,\n predict_v = predict_v,\n learned_variance = learned_variance,\n clip_denoised = not is_latent_diffusion,\n lowres_cond_img = lowres_cond_img,\n lowres_noise_level = lowres_noise_level,\n is_latent_diffusion = is_latent_diffusion,\n noise_scheduler = noise_scheduler,\n timesteps = sample_timesteps,\n inpaint_image = inpaint_image,\n inpaint_mask = inpaint_mask,\n inpaint_resample_times = inpaint_resample_times\n )\n img = vae.decode(img)\n if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:\n break\n return img\n def forward(\n self,\n image,\n text = None,\n image_embed = None,\n text_encodings = None,\n unet_number = None,\n return_lowres_",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:3205-3233"
+ },
+ "361": {
+ "file_id": 6,
+ "content": "The function takes an image and optionally text, generates images at different UNet resolutions based on input parameters, and returns the generated image. It includes options for low-resolution output, inpainting, and stopping at a specific UNet resolution.",
+ "type": "comment"
+ },
+ "362": {
+ "file_id": 6,
+ "content": "cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes\n ):\n assert not (self.num_unets > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {self.num_unets}, if you are training cascading DDPM (multiple unets)'\n unet_number = default(unet_number, 1)\n unet_index = unet_number - 1\n unet = self.get_unet(unet_number)\n vae = self.vaes[unet_index]\n noise_scheduler = self.noise_schedulers[unet_index]\n lowres_conditioner = self.lowres_conds[unet_index]\n target_image_size = self.image_sizes[unet_index]\n predict_x_start = self.predict_x_start[unet_index]\n predict_v = self.predict_v[unet_index]\n random_crop_size = self.random_crop_sizes[unet_index]\n learned_variance = self.learned_variance[unet_index]\n b, c, h, w, device, = *image.shape, image.device\n assert image.shape[1] == self.channels",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:3233-3251"
+ },
+ "363": {
+ "file_id": 6,
+ "content": "This function is initializing variables for a specific U-Net in the model, based on the provided unet_number. It assigns the corresponding U-Net, VAE, noise scheduler, lowres conditioner, target image size, predict x start, predict v, random crop size, and learned variance from predefined lists for that U-Net index. It also ensures the image shape aligns with the expected number of channels.",
+ "type": "comment"
+ },
+ "364": {
+ "file_id": 6,
+ "content": " assert h >= target_image_size and w >= target_image_size\n times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)\n if not exists(image_embed) and not self.unconditional:\n assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'\n image_embed, _ = self.clip.embed_image(image)\n if exists(text) and not exists(text_encodings) and not self.unconditional:\n assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'\n _, text_encodings = self.clip.embed_text(text)\n assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'\n assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'\n ",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:3252-3267"
+ },
+ "365": {
+ "file_id": 6,
+ "content": "The code checks if the image and/or text inputs exist, ensuring that either the CLIP model or the necessary inputs are present. It asserts that if the decoder is supposed to be conditioned on text encodings, then the text encodings must be provided, and vice versa. This helps prevent errors in the input data for generating image embeddings.",
+ "type": "comment"
+ },
+ "366": {
+ "file_id": 6,
+ "content": "lowres_cond_img, lowres_noise_level = lowres_conditioner(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if exists(lowres_conditioner) else (None, None)\n image = resize_image_to(image, target_image_size, nearest = True)\n if exists(random_crop_size):\n aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)\n # make sure low res conditioner and image both get augmented the same way\n # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop\n image = aug(image)\n lowres_cond_img = aug(lowres_cond_img, params = aug._params)\n is_latent_diffusion = not isinstance(vae, NullVQGanVAE)\n vae.eval()\n with torch.no_grad():\n image = vae.encode(image)\n lowres_cond_img = maybe(vae.encode)(lowres_cond_img)\n losses = self.p_losses(unet, image, times, image_embed = image",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:3267-3285"
+ },
+ "367": {
+ "file_id": 6,
+ "content": "This code snippet is conditioning a low-resolution image using the lowres_conditioner and performing data augmentation via Kornia's RandomCrop. It also encodes both the image and the conditioned image using a VAE (Variational Autoencoder) and calculates loss from p_losses for further processing in the U-net model.",
+ "type": "comment"
+ },
+ "368": {
+ "file_id": 6,
+ "content": "_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)\n if not return_lowres_cond_image:\n return losses\n return losses, lowres_cond_img\n# main class\nclass DALLE2(nn.Module):\n def __init__(\n self,\n *,\n prior,\n decoder,\n prior_num_samples = 2\n ):\n super().__init__()\n assert isinstance(prior, DiffusionPrior)\n assert isinstance(decoder, Decoder)\n self.prior = prior\n self.decoder = decoder\n self.prior_num_samples = prior_num_samples\n self.decoder_need_text_cond = self.decoder.condition_on_text_encodings\n self.to_pil = T.ToPILImage()\n @torch.no_grad()\n @eval_decorator\n def forward(\n self,\n text,\n cond_scale = 1.,\n prior_cond_scale = 1.,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:3285-3319"
+ },
+ "369": {
+ "file_id": 6,
+ "content": "This code defines a DALLE2 class with prior and decoder modules. It takes text input, performs diffusion, and returns losses or lowres_cond_img based on the return flag. If not returning the lowres conditional image, it returns only losses.",
+ "type": "comment"
+ },
+ "370": {
+ "file_id": 6,
+ "content": " return_pil_images = False\n ):\n device = module_device(self)\n one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)\n if isinstance(text, str) or is_list_str(text):\n text = [text] if not isinstance(text, (list, tuple)) else text\n text = tokenizer.tokenize(text).to(device)\n image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)\n text_cond = text if self.decoder_need_text_cond else None\n images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)\n if return_pil_images:\n images = list(map(self.to_pil, images.unbind(dim = 0)))\n if one_text:\n return first(images)\n return images",
+ "type": "code",
+ "location": "/dalle2_pytorch/dalle2_pytorch.py:3320-3340"
+ },
+ "371": {
+ "file_id": 6,
+ "content": "This function takes text as input, tokenizes it if necessary, and uses a prior model to generate image embeddings. It then passes these embeddings along with the text (if required) to a decoder model to generate images. Optionally, it converts the images to PIL format and returns them. If only one text is given, it returns the first generated image.",
+ "type": "comment"
+ },
+ "372": {
+ "file_id": 7,
+ "content": "/dalle2_pytorch/dataloaders/README.md",
+ "type": "filepath"
+ },
+ "373": {
+ "file_id": 7,
+ "content": "The code creates a dataloader for image embedding datasets and sets up training, evaluation, and testing splits for three ranks using the provided config TRAIN_ARGS. It uses img2dataset, clip-retrieval, and embedding-dataset-reordering tools to load images and embeddings without resampling.",
+ "type": "summary"
+ },
+ "374": {
+ "file_id": 7,
+ "content": "## Dataloaders\nIn order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.\n### Decoder: Image Embedding Dataset\nWhen training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digit",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/README.md:1-5"
+ },
+ "375": {
+ "file_id": 7,
+ "content": "This code snippet describes the usage of general dataloaders for efficient data loading and training portions of the network, particularly focusing on the decoder. It supports two types of datasets: a webdataset containing .jpg and .npy files in .tar formats or an external source where .npy files correspond to .jpg filenames from the webdataset.",
+ "type": "comment"
+ },
+ "376": {
+ "file_id": 7,
+ "content": "s are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.\nGenerating a dataset of this type:\n1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.\n2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.\n3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.\nUsage:\n```python\nfrom dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader\n# Create a dataloader directly.\ndataloader = create_image_embedding_dataloader(\n tar_url=\"/path/or/url/to/webdataset/{0000..9999}.tar\", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar\n embeddings_url=\"path/or/url/to/embeddings/folder\", # Included if .npy files are not in webdataset. Left out or set to None otherwise",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/README.md:5-19"
+ },
+ "377": {
+ "file_id": 7,
+ "content": "This code demonstrates how to create a dataloader for an image embedding dataset. It utilizes three separate tools: img2dataset, clip-retrieval, and embedding-dataset-reordering. The user must provide the appropriate URLs for the webdataset and embeddings folder in order to generate the dataloader. The code snippet also highlights the usage of create_image_embedding_dataloader function which takes in URL parameters and returns a dataloader object.",
+ "type": "comment"
+ },
+ "378": {
+ "file_id": 7,
+ "content": " num_workers=4,\n batch_size=32,\n shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index\n shuffle_num=200, # Does a shuffle of the data with a buffer size of 200\n shuffle_shards=True, # Shuffle the order the shards are read in\n resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually\n)\nfor img, emb in dataloader:\n print(img.shape) # torch.Size([32, 3, 256, 256])\n print(emb.shape) # torch.Size([32, 512])\n # Train decoder only as shown above\n# Or create a dataset without a loader so you can configure it manually\ndataset = ImageEmbeddingDataset(\n urls=\"/path/or/url/to/webdataset/{0000..9999}.tar\",\n embedding_folder_url=\"path/or/url/to/embeddings/folder\",\n shard_width=4,\n shuffle_shards=True,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/README.md:20-37"
+ },
+ "379": {
+ "file_id": 7,
+ "content": "This code initializes a dataloader with parameters such as number of workers, batch size, shard width, and shuffle settings. It loads images and their corresponding embeddings from webdataset files. The images' shapes are printed for a single epoch. An ImageEmbeddingDataset is also created without a loader for manual configuration.",
+ "type": "comment"
+ },
+ "380": {
+ "file_id": 7,
+ "content": " resample=False\n)\n```\n### Diffusion Prior: Prior Embedding Dataset\nWhen training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.\nTo utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run.\nIf you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters.\nUsage:\n```python\nfrom dalle2_pytorch.dataloaders import get_reader, make_splits\n# grab embeddings from some specified location",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/README.md:38-53"
+ },
+ "381": {
+ "file_id": 7,
+ "content": "The `resample=False` argument is used to disable resampling when processing the embeddings in the Prior Embedding Dataset. This ensures that the embeddings are not recomputed and can be efficiently used for both embedding-only and text-conditioned prior training.",
+ "type": "comment"
+ },
+ "382": {
+ "file_id": 7,
+ "content": "IMG_URL = \"data/img_emb/\"\nMETA_URL = \"data/meta/\"\nreader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)\n# some config for training\nTRAIN_ARGS = {\n \"world_size\": 3,\n \"text_conditioned\": True,\n \"start\": 0,\n \"num_data_points\": 10000,\n \"batch_size\": 2,\n \"train_split\": 0.5,\n \"eval_split\": 0.25,\n \"image_reader\": reader,\n}\n# specifying a rank will handle allocation internally\nrank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)\nrank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)\nrank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)\n```",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/README.md:54-75"
+ },
+ "383": {
+ "file_id": 7,
+ "content": "The code sets up training, evaluation, and testing splits for three different ranks (0, 1, 2) using the provided config TRAIN_ARGS. It uses the get_reader function to load image and metadata from specified URLs, and the make_splits function to divide the data into train, eval, and test sets for distributed training.",
+ "type": "comment"
+ },
+ "384": {
+ "file_id": 8,
+ "content": "/dalle2_pytorch/dataloaders/__init__.py",
+ "type": "filepath"
+ },
+ "385": {
+ "file_id": 8,
+ "content": "This code imports necessary classes for ImageEmbeddingDataset and PriorEmbeddingDataset from their respective modules in the DALLE2-pytorch library. These datasets are used to load data for the model's training and inference.",
+ "type": "summary"
+ },
+ "386": {
+ "file_id": 8,
+ "content": "from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader\nfrom dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/__init__.py:1-2"
+ },
+ "387": {
+ "file_id": 8,
+ "content": "This code imports necessary classes for ImageEmbeddingDataset and PriorEmbeddingDataset from their respective modules in the DALLE2-pytorch library. These datasets are used to load data for the model's training and inference.",
+ "type": "comment"
+ },
+ "388": {
+ "file_id": 9,
+ "content": "/dalle2_pytorch/dataloaders/decoder_loader.py",
+ "type": "filepath"
+ },
+ "389": {
+ "file_id": 9,
+ "content": "The code defines functions for retrieving embeddings, combining image and text embeddings, creating image embedding datasets, and handling exceptions in webdataset tar files. It also includes support for preprocessing, resampling, shuffling, package checks, and dataloaders.",
+ "type": "summary"
+ },
+ "390": {
+ "file_id": 9,
+ "content": "import os\nimport webdataset as wds\nimport torch\nfrom torch.utils.data import DataLoader\nimport numpy as np\nimport fsspec\nimport shutil\ndef get_shard(filename):\n \"\"\"\n Filenames with shards in them have a consistent structure that we can take advantage of\n Standard structure: path/to/file/prefix_string_00001.ext\n \"\"\"\n try:\n return filename.split(\"_\")[-1].split(\".\")[0]\n except ValueError:\n raise RuntimeError(f\"Could not find shard for filename {filename}\")\ndef get_example_file(fs, path, file_format):\n \"\"\"\n Given a file system and a file extension, return the example file\n \"\"\"\n return fs.glob(os.path.join(path, f\"*.{file_format}\"))[0]\ndef embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):\n \"\"\"Given a datum of {\"__key__\": str, \"__url__\": str, ...} adds the cooresponding embedding and yields\"\"\"\n previous_tar_url = None\n current_embeddings = None\n # Get a reference to an abstract file system where the embeddings are stored",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/decoder_loader.py:1-29"
+ },
+ "391": {
+ "file_id": 9,
+ "content": "This code defines three functions: `get_shard`, `get_example_file`, and `embedding_inserter`. The first function extracts the shard number from a filename. The second function returns an example file given a file system and a file format. Lastly, the third function inserts embeddings into a dataset, given samples, embedding URL, index width, sample key, and a handler to handle exceptions.",
+ "type": "comment"
+ },
+ "392": {
+ "file_id": 9,
+ "content": " embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)\n example_embedding_file = get_example_file(embeddings_fs, embeddings_path, \"npy\")\n example_embedding_shard = get_shard(example_embedding_file)\n emb_shard_width = len(example_embedding_shard)\n # Easier to get the basename without the shard once than search through for the correct file every time\n embedding_file_basename = '_'.join(example_embedding_file.split(\"_\")[:-1]) + \"_\"\n def load_corresponding_embeds(tar_url):\n \"\"\"Finds and reads the npy files that contains embeddings for the given webdataset tar\"\"\"\n shard = int(tar_url.split(\"/\")[-1].split(\".\")[0])\n embedding_url = embedding_file_basename + str(shard).zfill(emb_shard_width) + '.npy'\n with embeddings_fs.open(embedding_url) as f:\n data = np.load(f)\n return torch.from_numpy(data)\n for sample in samples:\n try:\n tar_url = sample[\"__url__\"]\n key = sample[\"__key__\"]\n if tar_url != previous_tar_url:",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/decoder_loader.py:30-49"
+ },
+ "393": {
+ "file_id": 9,
+ "content": "This code segment retrieves and loads embeddings from a webdataset tar file using the given URL. It identifies the correct npy file containing the embeddings by extracting the shard number from the URL, then opens and loads the data into a torch tensor.",
+ "type": "comment"
+ },
+ "394": {
+ "file_id": 9,
+ "content": " # If the tar changed, we need to download new embeddings\n # This means if we shuffle before inserting it will load many more files than we expect and be very inefficient.\n previous_tar_url = tar_url\n current_embeddings = load_corresponding_embeds(tar_url)\n embedding_index = int(key[-index_width:])\n embedding = current_embeddings[embedding_index]\n # We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop\n if torch.count_nonzero(embedding) == 0:\n raise RuntimeError(f\"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}\")\n sample[sample_key] = embedding\n yield sample\n except Exception as exn: # From wds implementation\n if handler(exn):\n continue\n else:\n break\ninsert_embedding = wds.filters.pipelinefilter(embedding_inserter)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/decoder_loader.py:50-67"
+ },
+ "395": {
+ "file_id": 9,
+ "content": "The code checks if a tar file changed and loads corresponding embeddings. If the sample has no embedding, it raises an error. The insert_embedding variable is assigned a pipeline filter with the embedding inserter function.",
+ "type": "comment"
+ },
+ "396": {
+ "file_id": 9,
+ "content": "def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):\n \"\"\"Finds if the is a corresponding embedding for the tarfile at { url: [URL] }\"\"\"\n embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)\n embedding_files = embeddings_fs.ls(embeddings_path)\n get_embedding_shard = lambda embedding_file: int(embedding_file.split(\"_\")[-1].split(\".\")[0])\n embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member\n get_tar_shard = lambda tar_file: int(tar_file.split(\"/\")[-1].split(\".\")[0])\n for tarfile in tarfiles:\n try:\n webdataset_shard = get_tar_shard(tarfile[\"url\"])\n # If this shard has an associated embeddings file, we pass it through. Otherwise we iterate until we do have one\n if webdataset_shard in embedding_shards:\n yield tarfile\n except Exception as exn: # From wds implementation\n if handler(exn):",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/decoder_loader.py:69-84"
+ },
+ "397": {
+ "file_id": 9,
+ "content": "This function checks if there are corresponding embeddings for the given tarfiles. It first retrieves a set of embedding shards from the embeddings_url, then iterates through the tarfiles. If a tarfile's shard is in the set of embedding shards, it yields the tarfile. Otherwise, it will continue to iterate until it finds a matching shard. Exceptions are handled using the provided handler function.",
+ "type": "comment"
+ },
+ "398": {
+ "file_id": 9,
+ "content": " continue\n else:\n break\nskip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)\ndef join_embeddings(samples, handler=wds.handlers.reraise_exception):\n \"\"\"\n Takes the img_emb and text_emb keys and turns them into one key \"emb\": { \"text\": text_emb, \"img\": img_emb }\n either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist\n \"\"\"\n for sample in samples:\n try:\n sample['emb'] = {}\n if 'text_emb' in sample:\n sample['emb']['text'] = sample['text_emb']\n if 'img_emb' in sample:\n sample['emb']['img'] = sample['img_emb']\n yield sample\n except Exception as exn: # From wds implementation\n if handler(exn):\n continue\n else:\n break\ndef verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):\n \"\"\"\n Requires that both the image and embedding are present in the sample",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/decoder_loader.py:85-111"
+ },
+ "399": {
+ "file_id": 9,
+ "content": "The code defines two functions: `join_embeddings()` and `verify_keys()`. The first function combines the `img_emb` and `text_emb` keys into a single \"emb\" key in each sample, only including existing embeddings. The second function ensures that both image and embedding are present in each sample. If not, it either continues or breaks depending on the exception handler.",
+ "type": "comment"
+ }
+}
\ No newline at end of file
diff --git a/docs/data/4.json b/docs/data/4.json
new file mode 100644
index 00000000..6c171e52
--- /dev/null
+++ b/docs/data/4.json
@@ -0,0 +1,547 @@
+{
+ "400": {
+ "file_id": 9,
+ "content": " This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.\n \"\"\"\n for sample in samples:\n try:\n for key in required_keys:\n assert key in sample, f\"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}\"\n yield sample\n except Exception as exn: # From wds implementation\n if handler(exn):\n continue\n else:\n break\nkey_verifier = wds.filters.pipelinefilter(verify_keys)\nclass ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):\n \"\"\"\n A fluid interface wrapper for DataPipline that returns image embedding pairs\n Reads embeddings as npy files from the webdataset if they exist. If embedding_folder_url is set, they will be inserted in from the alternate source.\n \"\"\"\n def __init__(\n self,\n urls,\n img_embedding_folder_url=None,\n text_embedding_folder_url=None,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/decoder_loader.py:112-136"
+ },
+ "401": {
+ "file_id": 9,
+ "content": "This code checks if required keys are present in each sample, asserts if missing and yields the sample. It uses a key_verifier filter and a fluid interface for DataPipeline to return image embedding pairs. Embeddings can be read from webdataset or inserted from an alternate source based on embedding_folder_url.",
+ "type": "comment"
+ },
+ "402": {
+ "file_id": 9,
+ "content": " index_width=None,\n img_preproc=None,\n extra_keys=[],\n handler=wds.handlers.reraise_exception,\n resample=False,\n shuffle_shards=True\n ):\n \"\"\"\n Modeled directly off of the WebDataset constructor\n :param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar\n :param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.\n Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.\n :param index_width: The number of digits in the index. This is used to align the embedding index with the image index.\n For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/decoder_loader.py:137-151"
+ },
+ "403": {
+ "file_id": 9,
+ "content": "The code defines a function to load data from webdatasets and embeddings for a model. It takes URLs as input, where each URL points to tar files of the webdataset. If embeddings are not included in the dataset, an embedding_folder_URL is required. The index width specifies the number of digits in the index, used to align image and embedding indices. The handler handles exceptions, while resample can be set for resampling data. The shuffle_shards flag determines whether to shuffle shards during loading.",
+ "type": "comment"
+ },
+ "404": {
+ "file_id": 9,
+ "content": " :param img_preproc: This function is run on the img before it is batched and returned. Useful for data augmentation or converting to torch tensor.\n :param handler: A webdataset handler.\n :param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.\n :param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.\n \"\"\"\n super().__init__()\n keys = [\"jpg\", \"emb\"] + extra_keys\n # if img_embedding_folder_url is not None:\n # keys.append(\"img_emb\")\n # if text_embedding_folder_url is not None:\n # keys.append(\"text_emb\")\n # keys.extend(extra_keys)\n self.key_map = {key: i for i, key in enumerate(keys)}\n self.resampling = resample\n self.img_preproc = img_preproc\n # If s3, check if s3fs is installed and s3cmd is installed and check if the data is piped instead of straight up",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/decoder_loader.py:152-169"
+ },
+ "405": {
+ "file_id": 9,
+ "content": "This function is a webdataset handler that takes parameters for img_preproc, resample, and shuffle_shards. It initializes the keys for data loading and maps them to their respective indices. If img_embedding_folder_url or text_embedding_folder_url is not None, \"img_emb\" and \"text_emb\" will be added as keys. The function also checks if s3fs and s3cmd are installed, and handles data piping.",
+ "type": "comment"
+ },
+ "406": {
+ "file_id": 9,
+ "content": " if (isinstance(urls, str) and \"s3:\" in urls) or (isinstance(urls, list) and any([\"s3:\" in url for url in urls])):\n # Then this has an s3 link for the webdataset and we need extra packages\n if shutil.which(\"s3cmd\") is None:\n raise RuntimeError(\"s3cmd is required for s3 webdataset\")\n if (img_embedding_folder_url is not None and \"s3:\" in img_embedding_folder_url) or (text_embedding_folder_url is not None and \"s3:\" in text_embedding_folder_url):\n # Then the embeddings are being loaded from s3 and fsspec requires s3fs\n try:\n import s3fs\n except ImportError:\n raise RuntimeError(\"s3fs is required to load embeddings from s3\")\n # Add the shardList and randomize or resample if requested\n if resample:\n assert not shuffle_shards, \"Cannot both resample and shuffle\"\n self.append(wds.ResampledShards(urls))\n else:\n self.append(wds.SimpleShardList(urls))",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/decoder_loader.py:170-185"
+ },
+ "407": {
+ "file_id": 9,
+ "content": "Code checks if the URLs provided for webdataset contain \"s3:\" indicating S3 links. If so, it requires 's3cmd' and 's3fs' packages to be installed or raises an error. It also adds shardList and allows resampling or shuffling of shards based on user input.",
+ "type": "comment"
+ },
+ "408": {
+ "file_id": 9,
+ "content": " if shuffle_shards:\n self.append(wds.filters.shuffle(1000))\n if img_embedding_folder_url is not None:\n # There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.\n self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler))\n if text_embedding_folder_url is not None:\n self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler))\n self.append(wds.tarfile_to_samples(handler=handler))\n self.append(wds.decode(\"pilrgb\", handler=handler))\n if img_embedding_folder_url is not None:\n # Then we are loading image embeddings for a remote source\n assert index_width is not None, \"Reading embeddings separately requires index width length to be given\"\n self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler))",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/decoder_loader.py:186-200"
+ },
+ "409": {
+ "file_id": 9,
+ "content": "The code configures a decoder loader for DALLE2-pytorch. It shuffles 1000 filters and skips unassociated shards if necessary, loads embeddings from URLs, converts to samples, and decodes images as PILRGB.",
+ "type": "comment"
+ },
+ "410": {
+ "file_id": 9,
+ "content": " if text_embedding_folder_url is not None:\n # Then we are loading image embeddings for a remote source\n assert index_width is not None, \"Reading embeddings separately requires index width length to be given\"\n self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler))\n self.append(join_embeddings)\n self.append(key_verifier(required_keys=keys, handler=handler))\n # Apply preprocessing\n self.append(wds.map(self.preproc))\n self.append(wds.to_tuple(*keys))\n def preproc(self, sample):\n \"\"\"Applies the preprocessing for images\"\"\"\n if self.img_preproc is not None:\n sample[\"jpg\"] = self.img_preproc(sample[\"jpg\"])\n return sample\ndef create_image_embedding_dataloader(\n tar_url,\n num_workers,\n batch_size,\n img_embeddings_url=None,\n text_embeddings_url=None,\n index_width=None,\n shuffle_num = None,\n shuffle_shards = True,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/decoder_loader.py:201-225"
+ },
+ "411": {
+ "file_id": 9,
+ "content": "This code creates an image embedding dataloader. If a text embedding folder URL is provided, it loads image embeddings for remote sources based on the given index width. It then applies preprocessing and joins the embeddings before returning the tuple of keys. The preproc function applies image preprocessing if available.",
+ "type": "comment"
+ },
+ "412": {
+ "file_id": 9,
+ "content": " resample_shards = False, \n img_preproc=None,\n extra_keys=[],\n handler=wds.handlers.reraise_exception#warn_and_continue\n):\n \"\"\"\n Convenience function to create an image embedding dataseta and dataloader in one line\n :param tar_url: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar\n :param num_workers: The number of workers to use for the dataloader\n :param batch_size: The batch size to use for the dataloader\n :param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.\n Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.\n :param index_width: The number of digits in the index. This is used to align the embedding index with the image index.\n For example, if a file in the webdataset sh",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/decoder_loader.py:226-240"
+ },
+ "413": {
+ "file_id": 9,
+ "content": "This code creates an image embedding dataset and dataloader in one line, accepting parameters such as tar_url, num_workers, batch_size, embeddings_url, and index_width. The function is designed for webdataset format and requires the same number of shards for both the webdataset images and their corresponding embeddings. It also supports handling exceptions using a specified handler.",
+ "type": "comment"
+ },
+ "414": {
+ "file_id": 9,
+ "content": "ard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.\n :param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.\n :param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.\n :param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.\n :param handler: A webdataset handler.\n \"\"\"\n ds = ImageEmbeddingDataset(\n tar_url,\n img_embedding_folder_url=img_embeddings_url,\n text_embedding_folder_url=text_embeddings_url,\n index_width=index_width,\n shuffle_shards=shuffle_shards,\n resample=resample_shards,\n extra_keys=extra_keys,\n img_preproc=img_preproc,\n handler=handler\n )\n if shuffle_num is not None and shuffle_num > 0:\n ds.shuffle(1000)\n return DataLoader(\n ds,\n num_workers=num_workers,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/decoder_loader.py:240-261"
+ },
+ "415": {
+ "file_id": 9,
+ "content": "This code defines a function that takes in parameters like tar_url, img_embedding_folder_url, text_embeddings_url, index_width, extra_keys, img_preproc, and handler. It creates an ImageEmbeddingDataset and optionally shuffles it based on the given shuffle_num. Then, it returns a DataLoader for further processing.",
+ "type": "comment"
+ },
+ "416": {
+ "file_id": 9,
+ "content": " batch_size=batch_size,\n prefetch_factor=2, # This might be good to have high so the next npy file is prefetched\n pin_memory=True,\n shuffle=False\n )",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/decoder_loader.py:262-266"
+ },
+ "417": {
+ "file_id": 9,
+ "content": "This code creates a data loader for the decoder model. It sets batch size, prefetch factor (for efficient loading), pin memory (for faster GPU transfers), and disables shuffling.",
+ "type": "comment"
+ },
+ "418": {
+ "file_id": 10,
+ "content": "/dalle2_pytorch/dataloaders/prior_loader.py",
+ "type": "filepath"
+ },
+ "419": {
+ "file_id": 10,
+ "content": "This code offers efficient data retrieval classes for DALL-E 2, supports text conditioning and MPI distribution. It divides embedding reader objects into training, evaluation, and test sets using PyTorch Dataloaders, without specifying batch sizes.",
+ "type": "summary"
+ },
+ "420": {
+ "file_id": 10,
+ "content": "from math import ceil\nfrom clip import tokenize\nfrom embedding_reader import EmbeddingReader\nfrom torch import from_numpy\nfrom torch.utils.data import IterableDataset, DataLoader\nclass PriorEmbeddingDataset(IterableDataset):\n \"\"\"\n PriorEmbeddingDataset is a wrapper of EmbeddingReader.\n It enables one to simplify the logic necessary to yield samples from\n the different EmbeddingReader configurations available.\n \"\"\"\n def __init__(\n self,\n text_conditioned: bool,\n batch_size: int,\n start: int,\n stop: int,\n image_reader,\n text_reader: EmbeddingReader = None,\n ) -> None:\n super(PriorEmbeddingDataset).__init__()\n self.text_conditioned = text_conditioned\n if not self.text_conditioned:\n self.text_reader = text_reader\n self.image_reader = image_reader\n self.start = start\n self.stop = stop\n self.batch_size = batch_size\n def __len__(self):\n return self.stop - self.start\n def __iter__(self):",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/prior_loader.py:1-40"
+ },
+ "421": {
+ "file_id": 10,
+ "content": "The code defines a class called PriorEmbeddingDataset that wraps the EmbeddingReader class. It allows for simplified sample retrieval from various configurations of EmbeddingReader by enabling batch-based access to prior data, where text_conditioned and batch_size are parameters, along with start and stop indices for the range of data to be loaded.",
+ "type": "comment"
+ },
+ "422": {
+ "file_id": 10,
+ "content": " # D.R.Y loader args\n loader_args = dict(\n batch_size=self.batch_size,\n start=self.start,\n end=self.stop,\n show_progress=False,\n )\n # if the data requested is text conditioned, only load images\n if self.text_conditioned:\n self.loader = self.image_reader(**loader_args)\n # otherwise, include text embeddings and bypass metadata\n else:\n self.loader = zip(\n self.image_reader(**loader_args), self.text_reader(**loader_args)\n )\n # return the data loader in its formatted state\n return self\n def __next__(self):\n try:\n return self.get_sample()\n except StopIteration:\n raise StopIteration\n def __str__(self):\n return f\"\"\n def set_start(self, start):\n \"\"\"\n Adjust the starting point within the reader, useful for resuming an epoch",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/prior_loader.py:41-72"
+ },
+ "423": {
+ "file_id": 10,
+ "content": "The code defines a PriorEmbeddingDataset class for data loading in DALLE2-pytorch. It uses an image_reader and text_reader to load data in a batch, with optional text conditioning. It includes a __next__ method for iterating through the dataset and a set_start method for adjusting the starting point within the reader.",
+ "type": "comment"
+ },
+ "424": {
+ "file_id": 10,
+ "content": " \"\"\"\n self.start = start\n def get_start(self):\n return self.start\n def get_sample(self):\n \"\"\"\n pre-proocess data from either reader into a common format\n \"\"\"\n if self.text_conditioned:\n image_embedding, caption = next(self.loader)\n image_embedding = from_numpy(image_embedding)\n tokenized_caption = tokenize(caption[\"caption\"].to_list(), truncate=True)\n return image_embedding, tokenized_caption\n else:\n (image_embedding, _), (text_embedding, _) = next(self.loader)\n image_embedding = from_numpy(image_embedding)\n text_embedding = from_numpy(text_embedding)\n return image_embedding, text_embedding\n# helper functions\ndef distribute_to_rank(start, stop, rank, world_size):\n \"\"\"\n Distribute data to each rank given the world size.\n Return:\n - New start and stop points for this rank.\n \"\"\"\n num_samples = int(stop - start)\n per_rank = int(ceil((num_samples) / float(world_size)))",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/prior_loader.py:73-112"
+ },
+ "425": {
+ "file_id": 10,
+ "content": "This code defines a class with methods to manage data loading and distribution for the DALL-E 2 model. It supports text-conditioned or unconditioned data, preprocesses input into a common format, and distributes data across multiple ranks using MPI.",
+ "type": "comment"
+ },
+ "426": {
+ "file_id": 10,
+ "content": " assert (\n per_rank > 0\n ), f\"Number of samples per rank must be larger than 0, (found: {per_rank})\"\n rank_start = start + rank * per_rank\n rank_stop = min(rank_start + per_rank, stop)\n new_length = rank_stop - rank_start\n assert (\n new_length > 0\n ), \"Calculated start and stop points result in a length of zero for this rank.\"\n return rank_start, rank_stop\ndef get_reader(\n text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None\n):\n \"\"\"\n Create an EmbeddingReader object from the specified URLs\n get_reader() will always expect a url to image embeddings.\n If text-conditioned, it will also expect a meta_url for the captions.\n Otherwise, it will need txt_url for the matching text embeddings.\n Returns an image_reader object if text-conditioned.\n Otherwise it returns both an image_reader and a text_reader\n \"\"\"\n assert img_url is not None, \"Must supply a image url\"\n if text_conditioned:\n assert meta_url is not None, \"Must supply meta url if text-conditioned\"",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/prior_loader.py:114-149"
+ },
+ "427": {
+ "file_id": 10,
+ "content": "The code is defining functions that calculate the start and stop points for a given rank, and another function to create an EmbeddingReader object based on URLs. It asserts that certain inputs are not None before proceeding, ensuring necessary information is provided.",
+ "type": "comment"
+ },
+ "428": {
+ "file_id": 10,
+ "content": " image_reader = EmbeddingReader(\n embeddings_folder=img_url,\n file_format=\"parquet_npy\",\n # will assume the caption column exists and is the only one requested\n meta_columns=[\"caption\"],\n metadata_folder=meta_url,\n )\n return image_reader\n # otherwise we will require text embeddings as well and return two readers\n assert (\n txt_url is not None\n ), \"Must supply text embedding url if not text-conditioning\"\n image_reader = EmbeddingReader(img_url, file_format=\"npy\")\n text_reader = EmbeddingReader(txt_url, file_format=\"npy\")\n return image_reader, text_reader\ndef make_splits(\n text_conditioned: bool,\n batch_size: int,\n num_data_points: int,\n train_split: float,\n eval_split: float,\n image_reader: EmbeddingReader,\n text_reader: EmbeddingReader = None,\n start=0,\n rank=0,\n world_size=1,\n):\n \"\"\"\n Split an embedding reader object as needed.\n NOTE: make_splits() will infer the test set size from your train and eval.",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/prior_loader.py:151-187"
+ },
+ "429": {
+ "file_id": 10,
+ "content": "This code defines a function to split an embedding reader object into training, evaluation, and optional test sets. It takes in the text conditioned flag, batch size, number of data points, and train/eval splits as input parameters. If text-conditioning is not enabled, it requires text embedding URLs as well and returns two readers.",
+ "type": "comment"
+ },
+ "430": {
+ "file_id": 10,
+ "content": " Input:\n - text_conditioned: whether to prepare text-conditioned training data\n - batch_size: the batch size for a single gpu\n - num_data_points: the total number of data points you wish to train on\n - train_split: the percentage of data you wish to train on\n - eval_split: the percentage of data you wish to validate on\n - image_reader: the image_reader you wish to split\n - text_reader: the text_reader you want to split (if !text_conditioned)\n - start: the starting point within your dataset\n - rank: the rank of your worker\n - world_size: the total world size of your distributed training run\n Returns:\n - PyTorch Dataloaders that yield tuples of (img, txt) data.\n \"\"\"\n assert start < image_reader.count, \"start position cannot exceed reader count.\"\n # verify that the num_data_points does not exceed the max points\n if num_data_points > (image_reader.count - start):\n print(\n \"Specified count is larger than what's available...defaulting to reader's count.\"",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/prior_loader.py:189-210"
+ },
+ "431": {
+ "file_id": 10,
+ "content": "This function takes various inputs like batch size, train and eval splits, readers, and starting point to create PyTorch Dataloaders for image-text pairs. It ensures the start position is within the reader's count, and if the specified data points count exceeds the available ones, it defaults to the remaining count.",
+ "type": "comment"
+ },
+ "432": {
+ "file_id": 10,
+ "content": " )\n num_data_points = image_reader.count\n # compute split points\n train_set_size = int(train_split * num_data_points)\n eval_set_size = int(eval_split * num_data_points)\n eval_start = train_set_size\n eval_stop = int(eval_start + eval_set_size)\n assert (\n train_split + eval_split\n ) < 1.0, \"Specified train and eval split is too large to infer a test split.\"\n # distribute to rank\n rank_train_start, rank_train_stop = distribute_to_rank(\n start, train_set_size, rank, world_size\n )\n rank_eval_start, rank_eval_stop = distribute_to_rank(\n train_set_size, eval_stop, rank, world_size\n )\n rank_test_start, rank_test_stop = distribute_to_rank(\n eval_stop, num_data_points, rank, world_size\n )\n # wrap up splits into a dict\n train_split_args = dict(\n start=rank_train_start, stop=rank_train_stop, batch_size=batch_size\n )\n eval_split_args = dict(\n start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size\n )\n test_split_args = dict(",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/prior_loader.py:211-242"
+ },
+ "433": {
+ "file_id": 10,
+ "content": "Computing split points for training and evaluation data sets based on the specified splits. Distributing the data to ranks according to the world size. Wrapping up the splits into a dictionary with start, stop, and batch_size parameters.",
+ "type": "comment"
+ },
+ "434": {
+ "file_id": 10,
+ "content": " start=rank_test_start, stop=rank_test_stop, batch_size=batch_size\n )\n if text_conditioned:\n # add the text-conditioned args to a unified dict\n reader_args = dict(\n text_conditioned=text_conditioned,\n image_reader=image_reader,\n )\n train_split_args = dict(**reader_args, **train_split_args)\n eval_split_args = dict(**reader_args, **eval_split_args)\n test_split_args = dict(**reader_args, **test_split_args)\n train = PriorEmbeddingDataset(**train_split_args)\n val = PriorEmbeddingDataset(**eval_split_args)\n test = PriorEmbeddingDataset(**test_split_args)\n else:\n # add the non-conditioned args to a unified dict\n reader_args = dict(\n text_conditioned=text_conditioned,\n image_reader=image_reader,\n text_reader=text_reader,\n )\n train_split_args = dict(**reader_args, **train_split_args)\n eval_split_args = dict(**reader_args, **eval_split_args)\n test_split_args = dict(**reader_args, **test_split_args)",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/prior_loader.py:243-271"
+ },
+ "435": {
+ "file_id": 10,
+ "content": "Code is creating a PriorEmbeddingDataset for train, validation, and test datasets based on given arguments. If text_conditioned, it creates separate dictionaries for each dataset and passes them to the PriorEmbeddingDataset class; otherwise, it adds additional non-conditioned arguments for the same process.",
+ "type": "comment"
+ },
+ "436": {
+ "file_id": 10,
+ "content": " train = PriorEmbeddingDataset(**train_split_args)\n val = PriorEmbeddingDataset(**eval_split_args)\n test = PriorEmbeddingDataset(**test_split_args)\n # true batch size is specifed in the PriorEmbeddingDataset\n train_loader = DataLoader(train, batch_size=None)\n eval_loader = DataLoader(val, batch_size=None)\n test_loader = DataLoader(test, batch_size=None)\n return train_loader, eval_loader, test_loader",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/prior_loader.py:273-282"
+ },
+ "437": {
+ "file_id": 10,
+ "content": "This code creates train, val, and test datasets using PriorEmbeddingDataset with specific args. DataLoaders are created without specifying batch sizes, so the true batch size is determined in PriorEmbeddingDataset. The loaders and datasets are returned for further processing.",
+ "type": "comment"
+ },
+ "438": {
+ "file_id": 11,
+ "content": "/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py",
+ "type": "filepath"
+ },
+ "439": {
+ "file_id": 11,
+ "content": "This code defines a Dataset class and get_images_dataloader function for loading image data. The Dataset class initializes with a folder path, image size, and extensions to consider. The get_images_dataloader function returns a DataLoader object for the specified folder with optional parameters like batch size, shuffle, cycle_dl, and pin_memory.",
+ "type": "summary"
+ },
+ "440": {
+ "file_id": 11,
+ "content": "from pathlib import Path\nimport torch\nfrom torch.utils import data\nfrom torchvision import transforms, utils\nfrom PIL import Image\n# helpers functions\ndef cycle(dl):\n while True:\n for data in dl:\n yield data\n# dataset and dataloader\nclass Dataset(data.Dataset):\n def __init__(\n self,\n folder,\n image_size,\n exts = ['jpg', 'jpeg', 'png']\n ):\n super().__init__()\n self.folder = folder\n self.image_size = image_size\n self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]\n self.transform = transforms.Compose([\n transforms.Resize(image_size),\n transforms.RandomHorizontalFlip(),\n transforms.CenterCrop(image_size),\n transforms.ToTensor()\n ])\n def __len__(self):\n return len(self.paths)\n def __getitem__(self, index):\n path = self.paths[index]\n img = Image.open(path)\n return self.transform(img)\ndef get_images_dataloader(\n folder,\n *,",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py:1-47"
+ },
+ "441": {
+ "file_id": 11,
+ "content": "This code defines a Dataset class and get_images_dataloader function for loading image data. The Dataset class initializes with a folder path, image size, and extensions to consider. It uses transforms to apply resizing, horizontal flipping, centercropping, and converting images to tensors. The get_images_dataloader function returns a data loader object for the specified folder.",
+ "type": "comment"
+ },
+ "442": {
+ "file_id": 11,
+ "content": " batch_size,\n image_size,\n shuffle = True,\n cycle_dl = True,\n pin_memory = True\n):\n ds = Dataset(folder, image_size)\n dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)\n if cycle_dl:\n dl = cycle(dl)\n return dl",
+ "type": "code",
+ "location": "/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py:48-59"
+ },
+ "443": {
+ "file_id": 11,
+ "content": "This function takes parameters such as folder, batch size, image size, shuffle, cycle_dl, and pin_memory. It creates a dataset from the provided folder using a given image size. Then, it uses DataLoader to create a data loader with the specified batch size, shuffle, and pin memory settings. If cycle_dl is True, it applies cyclic permutations to the data loader. Finally, it returns the data loader.",
+ "type": "comment"
+ },
+ "444": {
+ "file_id": 12,
+ "content": "/dalle2_pytorch/optimizer.py",
+ "type": "filepath"
+ },
+ "445": {
+ "file_id": 12,
+ "content": "This code defines two functions, `separate_weight_decayable_params` and `get_optimizer`. The `get_optimizer` function takes parameters, learning rate, weight decay, and other options to create an optimizer object. It filters the parameters based on `requires_grad`, separates weight-decayable parameters, and uses either Adam or AdamW optimizer depending on the weight decay value.",
+ "type": "summary"
+ },
+ "446": {
+ "file_id": 12,
+ "content": "from torch.optim import AdamW, Adam\ndef separate_weight_decayable_params(params):\n wd_params, no_wd_params = [], []\n for param in params:\n param_list = no_wd_params if param.ndim < 2 else wd_params\n param_list.append(param)\n return wd_params, no_wd_params\ndef get_optimizer(\n params,\n lr = 1e-4,\n wd = 1e-2,\n betas = (0.9, 0.99),\n eps = 1e-8,\n filter_by_requires_grad = False,\n group_wd_params = True,\n **kwargs\n):\n if filter_by_requires_grad:\n params = list(filter(lambda t: t.requires_grad, params))\n if wd == 0:\n return Adam(params, lr = lr, betas = betas, eps = eps)\n if group_wd_params:\n wd_params, no_wd_params = separate_weight_decayable_params(params)\n params = [\n {'params': wd_params},\n {'params': no_wd_params, 'weight_decay': 0},\n ]\n return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)",
+ "type": "code",
+ "location": "/dalle2_pytorch/optimizer.py:1-34"
+ },
+ "447": {
+ "file_id": 12,
+ "content": "This code defines two functions, `separate_weight_decayable_params` and `get_optimizer`. The `get_optimizer` function takes parameters, learning rate, weight decay, and other options to create an optimizer object. It filters the parameters based on `requires_grad`, separates weight-decayable parameters, and uses either Adam or AdamW optimizer depending on the weight decay value.",
+ "type": "comment"
+ },
+ "448": {
+ "file_id": 13,
+ "content": "/dalle2_pytorch/tokenizer.py",
+ "type": "filepath"
+ },
+ "449": {
+ "file_id": 13,
+ "content": "The code simplifies DALL-E2 text tokenization by offering a PyTorch BPE tokenizer implementation with features for whitespace cleanup, formatting fixes, human-readable conversion, and handling context length limitations.",
+ "type": "summary"
+ },
+ "450": {
+ "file_id": 13,
+ "content": "# take from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py\n# to give users a quick easy start to training DALL-E without doing BPE\nimport torch\nimport html\nimport os\nimport ftfy\nimport regex as re\nfrom functools import lru_cache\nfrom pathlib import Path\nfrom dalle2_pytorch.utils import import_or_print_error\n# OpenAI simple tokenizer\n@lru_cache()\ndef default_bpe():\n return os.path.join(os.path.dirname(os.path.abspath(__file__)), \"data/bpe_simple_vocab_16e6.txt\")\n@lru_cache()\ndef bytes_to_unicode():\n bs = list(range(ord(\"!\"), ord(\"~\") + 1)) + list(range(ord(\"¡\"), ord(\"¬\") + 1)) + list(range(ord(\"®\"), ord(\"ÿ\") + 1))\n cs = bs[:]\n n = 0\n for b in range(2 ** 8):\n if b not in bs:\n bs.append(b)\n cs.append(2 ** 8 + n)\n n += 1\n cs = [chr(n) for n in cs]\n return dict(zip(bs, cs))\ndef get_pairs(word):\n pairs = set()\n prev_char = word[0]\n for char in word[1:]:\n pairs.add((prev_char, char))\n prev_char = char\n return pairs\ndef basic_clean(text):",
+ "type": "code",
+ "location": "/dalle2_pytorch/tokenizer.py:1-42"
+ },
+ "451": {
+ "file_id": 13,
+ "content": "This code imports necessary libraries and defines functions for tokenization, specifically for the DALL-E2 model. It uses OpenAI's simple tokenizer, a byte-to-unicode conversion, and a function to generate character pairs from a given word. The code is meant to provide users with an easy way to start training DALL-E without implementing BPE (Byte Pair Encoding).",
+ "type": "comment"
+ },
+ "452": {
+ "file_id": 13,
+ "content": " text = ftfy.fix_text(text)\n text = html.unescape(html.unescape(text))\n return text.strip()\ndef whitespace_clean(text):\n text = re.sub(r'\\s+', ' ', text)\n text = text.strip()\n return text\nclass SimpleTokenizer(object):\n def __init__(self, bpe_path = default_bpe()):\n self.byte_encoder = bytes_to_unicode()\n self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}\n merges = Path(bpe_path).read_text(encoding='utf8').split('\\n')\n merges = merges[1:49152 - 256 - 2 + 1]\n merges = [tuple(merge.split()) for merge in merges]\n vocab = list(bytes_to_unicode().values())\n vocab = vocab + [v + '' for v in vocab]\n for merge in merges:\n vocab.append(''.join(merge))\n vocab.extend(['<|startoftext|>', '<|endoftext|>'])\n self.vocab_size = 49408\n self.encoder = dict(zip(vocab, range(len(vocab))))\n self.decoder = {v: k for k, v in self.encoder.items()}\n self.bpe_ranks = dict(zip(merges, range(len(merges))))",
+ "type": "code",
+ "location": "/dalle2_pytorch/tokenizer.py:43-69"
+ },
+ "453": {
+ "file_id": 13,
+ "content": "This code is a Python class for a tokenizer that utilizes byte encoding and decoding, along with byte-pair encoding (BPE) to convert text into tokens. The class also includes methods for cleaning whitespace and fixing text formatting issues. The BPE merges are loaded from a specified file path, and the vocabulary is expanded by adding special tokens like \"<|startoftext|>\" and \"<|endoftext|>\".",
+ "type": "comment"
+ },
+ "454": {
+ "file_id": 13,
+ "content": " self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}\n self.pat = re.compile(\n r\"\"\"<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+\"\"\",\n re.IGNORECASE)\n def bpe(self, token):\n if token in self.cache:\n return self.cache[token]\n word = tuple(token[:-1]) + (token[-1] + '',)\n pairs = get_pairs(word)\n if not pairs:\n return token + ''\n while True:\n bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))\n if bigram not in self.bpe_ranks:\n break\n first, second = bigram\n new_word = []\n i = 0\n while i < len(word):\n try:\n j = word.index(first, i)\n new_word.extend(word[i:j])\n i = j\n except:\n new_word.extend(word[i:])\n break",
+ "type": "code",
+ "location": "/dalle2_pytorch/tokenizer.py:70-98"
+ },
+ "455": {
+ "file_id": 13,
+ "content": "The code defines a tokenizer that uses byte-pair encoding (BPE) for text. It compiles a regular expression pattern to match words and special tokens like \"<|startoftext|>\" and \"<|endoftext|>\". The `bpe` method takes a token, checks if it's in the cache, and if not, processes it using BPE by splitting it into smaller parts until no more splits are possible.",
+ "type": "comment"
+ },
+ "456": {
+ "file_id": 13,
+ "content": " if word[i] == first and i < len(word) - 1 and word[i + 1] == second:\n new_word.append(first + second)\n i += 2\n else:\n new_word.append(word[i])\n i += 1\n new_word = tuple(new_word)\n word = new_word\n if len(word) == 1:\n break\n else:\n pairs = get_pairs(word)\n word = ' '.join(word)\n self.cache[token] = word\n return word\n def encode(self, text):\n bpe_tokens = []\n text = whitespace_clean(basic_clean(text)).lower()\n for token in re.findall(self.pat, text):\n token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))\n bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))\n return bpe_tokens\n def decode(self, tokens, remove_start_end = True, pad_tokens = set()):\n if torch.is_tensor(tokens):\n tokens = tokens.tolist()",
+ "type": "code",
+ "location": "/dalle2_pytorch/tokenizer.py:100-126"
+ },
+ "457": {
+ "file_id": 13,
+ "content": "Code snippet is from a byte-pair encoding (BPE) tokenizer implementation in PyTorch. The code encodes input text into BPE tokens, performs wordpiece tokenization, and caches the mapping between tokens and words for decoding. The encode() function processes the input text by applying preprocessing steps, performing BPE, and extending tokens list with BPE tokens. The decode() function allows decoding of encoded tokens back to words using cached mappings.",
+ "type": "comment"
+ },
+ "458": {
+ "file_id": 13,
+ "content": " if remove_start_end:\n tokens = [token for token in tokens if token not in (49406, 40407, 0)]\n text = ''.join([self.decoder[token] for token in tokens if token not in pad_tokens])\n text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=\"replace\").replace('', ' ')\n return text\n def tokenize(self, texts, context_length = 256, truncate_text = False):\n if isinstance(texts, str):\n texts = [texts]\n all_tokens = [self.encode(text) for text in texts]\n result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)\n for i, tokens in enumerate(all_tokens):\n if len(tokens) > context_length:\n if truncate_text:\n tokens = tokens[:context_length]\n else:\n raise RuntimeError(f\"Input {texts[i]} is too long for context length {context_length}\")\n result[i, :len(tokens)] = torch.tensor(tokens)\n return result\ntokenizer = SimpleTokenizer()",
+ "type": "code",
+ "location": "/dalle2_pytorch/tokenizer.py:128-151"
+ },
+ "459": {
+ "file_id": 13,
+ "content": "The code defines a SimpleTokenizer class that tokenizes input texts using an encoding scheme and provides a method to convert encoded tokens into human-readable text. It also includes a tokenize function to process multiple input texts, considering context length limitations and handling truncation. The provided code snippet focuses on the process of converting encoded tokens into text.",
+ "type": "comment"
+ },
+ "460": {
+ "file_id": 13,
+ "content": "# YTTM tokenizer\nclass YttmTokenizer:\n def __init__(self, bpe_path = None):\n bpe_path = Path(bpe_path)\n assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'\n self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')\n tokenizer = self.yttm.BPE(model = str(bpe_path))\n self.tokenizer = tokenizer\n self.vocab_size = tokenizer.vocab_size()\n def decode(self, tokens, pad_tokens = set()):\n if torch.is_tensor(tokens):\n tokens = tokens.tolist()\n return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))\n def encode(self, texts):\n encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)\n return list(map(torch.tensor, encoded))\n def tokenize(self, texts, context_length = 256, truncate_text = False):\n if isinstance(texts, str):\n texts = [texts]\n all_tokens = self.encode(texts)\n result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)",
+ "type": "code",
+ "location": "/dalle2_pytorch/tokenizer.py:153-182"
+ },
+ "461": {
+ "file_id": 13,
+ "content": "This code defines a YTTM tokenizer class in PyTorch. The constructor loads the BPE model from the specified path and initializes the tokenizer instance, which can decode and encode text sequences. The decode function converts tokenized lists to human-readable strings, while the encode function transforms input texts into tokenized lists. The tokenize method takes a list of texts, encodes them, and returns a tensor of shape (number_of_texts, context_length) for further processing.",
+ "type": "comment"
+ },
+ "462": {
+ "file_id": 13,
+ "content": " for i, tokens in enumerate(all_tokens):\n if len(tokens) > context_length:\n if truncate_text:\n tokens = tokens[:context_length]\n else:\n raise RuntimeError(f\"Input {texts[i]} is too long for context length {context_length}\")\n result[i, :len(tokens)] = torch.tensor(tokens)\n return result",
+ "type": "code",
+ "location": "/dalle2_pytorch/tokenizer.py:183-191"
+ },
+ "463": {
+ "file_id": 13,
+ "content": "This code segment iterates through all tokens in a list, truncating any token sequence longer than the specified context length. If truncation is not allowed and an input text is too long, it raises a RuntimeError. The truncated or original tokens are then converted to torch tensors and stored in a result array.",
+ "type": "comment"
+ },
+ "464": {
+ "file_id": 14,
+ "content": "/dalle2_pytorch/trackers.py",
+ "type": "filepath"
+ },
+ "465": {
+ "file_id": 14,
+ "content": "The code initializes trackers and loggers, provides methods for logging data, saving configurations, and metadata. It saves states and models, manages loading/saving checkpoints, and handles errors with a \"recall()\" function.",
+ "type": "summary"
+ },
+ "466": {
+ "file_id": 14,
+ "content": "import urllib.request\nimport os\nimport json\nfrom pathlib import Path\nimport shutil\nfrom itertools import zip_longest\nfrom typing import Any, Optional, List, Union\nfrom pydantic import BaseModel\nimport torch\nfrom dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior\nfrom dalle2_pytorch.utils import import_or_print_error\nfrom dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer\nfrom dalle2_pytorch.version import __version__\nfrom packaging import version\n# constants\nDEFAULT_DATA_PATH = './.tracker-data'\n# helper functions\ndef exists(val):\n return val is not None\nclass BaseLogger:\n \"\"\"\n An abstract class representing an object that can log data.\n Parameters:\n data_path (str): A file path for storing temporary data.\n verbose (bool): Whether of not to always print logs to the console.\n \"\"\"\n def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):\n self.data_path = Path(data_path)\n self.resume = resume",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:1-35"
+ },
+ "467": {
+ "file_id": 14,
+ "content": "This code is from the \"trackers.py\" file in the DALLE2-pytorch library, containing a class for base logger objects that can log data with optional data storage path and verbosity control. The class initializes with specified parameters like data_path, resume, auto_resume, and verbose. It uses Pathlib for path manipulation and supports temporary data storage.",
+ "type": "comment"
+ },
+ "468": {
+ "file_id": 14,
+ "content": " self.auto_resume = auto_resume\n self.verbose = verbose\n def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:\n \"\"\"\n Initializes the logger.\n Errors if the logger is invalid.\n full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.\n \"\"\"\n raise NotImplementedError\n def log(self, log, **kwargs) -> None:\n raise NotImplementedError\n def log_images(self, images, captions=[], image_section=\"images\", **kwargs) -> None:\n raise NotImplementedError\n def log_file(self, file_path, **kwargs) -> None:\n raise NotImplementedError\n def log_error(self, error_string, **kwargs) -> None:\n raise NotImplementedError\n def get_resume_data(self, **kwargs) -> dict:\n \"\"\"\n Sets tracker attributes that along with { \"resume\": True } will be used to resume training.\n It is assumed that after init is called this data will be complete.",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:36-62"
+ },
+ "469": {
+ "file_id": 14,
+ "content": "The code defines a logger class with methods for logging different types of data, and an initialization method to set up the logger. The logger raises a NotImplementedError for each method, which means they need to be implemented in child classes. The get_resume_data method sets tracker attributes used to resume training if needed.",
+ "type": "comment"
+ },
+ "470": {
+ "file_id": 14,
+ "content": " If the logger does not have any resume functionality, it should return an empty dict.\n \"\"\"\n raise NotImplementedError\nclass ConsoleLogger(BaseLogger):\n def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:\n print(\"Logging to console\")\n def log(self, log, **kwargs) -> None:\n print(log)\n def log_images(self, images, captions=[], image_section=\"images\", **kwargs) -> None:\n pass\n def log_file(self, file_path, **kwargs) -> None:\n pass\n def log_error(self, error_string, **kwargs) -> None:\n print(error_string)\n def get_resume_data(self, **kwargs) -> dict:\n return {}\nclass WandbLogger(BaseLogger):\n \"\"\"\n Logs to a wandb run.\n Parameters:\n data_path (str): A file path for storing temporary data.\n wandb_entity (str): The wandb entity to log to.\n wandb_project (str): The wandb project to log to.\n wandb_run_id (str): The wandb run id to resume.\n wandb_run_name (str): The wandb run name to use.",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:63-94"
+ },
+ "471": {
+ "file_id": 14,
+ "content": "This code defines two logger classes, ConsoleLogger and WandbLogger, which inherit from the BaseLogger class. The ConsoleLogger logs to the console while the WandbLogger logs data to a Weights & Biases (WandB) run. Both loggers have methods for logging different types of data such as logs, images, files, and errors. The ConsoleLogger returns an empty dictionary if resuming is not supported, whereas the WandbLogger requires additional parameters like wandb_entity, wandb_project, wandb_run_id, and wandb_run_name for proper functioning.",
+ "type": "comment"
+ },
+ "472": {
+ "file_id": 14,
+ "content": " \"\"\"\n def __init__(self,\n data_path: str,\n wandb_entity: str,\n wandb_project: str,\n wandb_run_id: Optional[str] = None,\n wandb_run_name: Optional[str] = None,\n **kwargs\n ):\n super().__init__(data_path, **kwargs)\n self.entity = wandb_entity\n self.project = wandb_project\n self.run_id = wandb_run_id\n self.run_name = wandb_run_name\n def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:\n assert self.entity is not None, \"wandb_entity must be specified for wandb logger\"\n assert self.project is not None, \"wandb_project must be specified for wandb logger\"\n self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')\n os.environ[\"WANDB_SILENT\"] = \"true\"\n # Initializes the wandb run\n init_object = {\n \"entity\": self.entity,\n \"project\": self.project,\n \"config\": {**full_config.dict(), **extra_config}\n }",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:95-120"
+ },
+ "473": {
+ "file_id": 14,
+ "content": "This code is a Python class for creating and initializing a WandB logger. It requires a data path, WandB entity, and project parameters. The class also supports additional configuration options. If the WandB entity or project are not specified, an error will be raised.",
+ "type": "comment"
+ },
+ "474": {
+ "file_id": 14,
+ "content": " if self.run_name is not None:\n init_object['name'] = self.run_name\n if self.resume:\n assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'\n if self.run_name is not None:\n print(\"You are renaming a run. I hope that is what you intended.\")\n init_object['resume'] = 'must'\n init_object['id'] = self.run_id\n self.wandb.init(**init_object)\n print(f\"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}\")\n def log(self, log, **kwargs) -> None:\n if self.verbose:\n print(log)\n self.wandb.log(log, **kwargs)\n def log_images(self, images, captions=[], image_section=\"images\", **kwargs) -> None:\n \"\"\"\n Takes a tensor of images and a list of captions and logs them to wandb.\n \"\"\"\n wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]\n self.wandb.log({ image_section: wandb_images }, **kwargs)",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:121-143"
+ },
+ "475": {
+ "file_id": 14,
+ "content": "This code initializes a Wandb tracker, allowing for easy logging of data to a specific run. If `run_id` is provided and `wandb_resume` is True, the run is resumed with a warning about renaming. The code then logs various types of data including logs, images with captions, using the Wandb API. Verbose output is also supported for logs.",
+ "type": "comment"
+ },
+ "476": {
+ "file_id": 14,
+ "content": " def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:\n if base_path is None:\n # Then we take the basepath as the parent of the file_path\n base_path = Path(file_path).parent\n self.wandb.save(str(file_path), base_path = str(base_path))\n def log_error(self, error_string, step=None, **kwargs) -> None:\n if self.verbose:\n print(error_string)\n self.wandb.log({\"error\": error_string, **kwargs}, step=step)\n def get_resume_data(self, **kwargs) -> dict:\n # In order to resume, we need wandb_entity, wandb_project, and wandb_run_id\n return {\n \"entity\": self.entity,\n \"project\": self.project,\n \"run_id\": self.wandb.run.id\n }\nlogger_type_map = {\n 'console': ConsoleLogger,\n 'wandb': WandbLogger,\n}\ndef create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:\n if logger_type == 'custom':\n raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:145-170"
+ },
+ "477": {
+ "file_id": 14,
+ "content": "The code defines a class with three methods: `log_file`, `log_error`, and `get_resume_data`. The `log_file` method logs a file path, `log_error` logs an error string, and `get_resume_data` returns a dictionary containing essential resume information. Additionally, there is a function `create_logger` which creates a logger of type 'console' or 'wandb'. For now, custom loggers are not supported.",
+ "type": "comment"
+ },
+ "478": {
+ "file_id": 14,
+ "content": " try:\n logger_class = logger_type_map[logger_type]\n except KeyError:\n raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')\n return logger_class(data_path, **kwargs)\nclass BaseLoader:\n \"\"\"\n An abstract class representing an object that can load a model checkpoint.\n Parameters:\n data_path (str): A file path for storing temporary data.\n \"\"\"\n def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):\n self.data_path = Path(data_path)\n self.only_auto_resume = only_auto_resume\n def init(self, logger: BaseLogger, **kwargs) -> None:\n raise NotImplementedError\n def recall() -> dict:\n raise NotImplementedError\nclass UrlLoader(BaseLoader):\n \"\"\"\n A loader that downloads the file from a url and loads it\n Parameters:\n data_path (str): A file path for storing temporary data.\n url (str): The url to download the file from.\n \"\"\"\n def __init__(self, data_path: str, url: str, **kwargs):",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:171-200"
+ },
+ "479": {
+ "file_id": 14,
+ "content": "Function tries to create an instance of a logger class based on the given type, otherwise it raises a ValueError. BaseLoader is an abstract class that can be used to load model checkpoints with data_path and optionally other parameters. UrlLoader extends BaseLoader by allowing loading files from URLs instead of local file paths.",
+ "type": "comment"
+ },
+ "480": {
+ "file_id": 14,
+ "content": " super().__init__(data_path, **kwargs)\n self.url = url\n def init(self, logger: BaseLogger, **kwargs) -> None:\n # Makes sure the file exists to be downloaded\n pass # TODO: Actually implement that\n def recall(self) -> dict:\n # Download the file\n save_path = self.data_path / 'loaded_checkpoint.pth'\n urllib.request.urlretrieve(self.url, str(save_path))\n # Load the file\n return torch.load(str(save_path), map_location='cpu')\nclass LocalLoader(BaseLoader):\n \"\"\"\n A loader that loads a file from a local path\n Parameters:\n data_path (str): A file path for storing temporary data.\n file_path (str): The path to the file to load.\n \"\"\"\n def __init__(self, data_path: str, file_path: str, **kwargs):\n super().__init__(data_path, **kwargs)\n self.file_path = Path(file_path)\n def init(self, logger: BaseLogger, **kwargs) -> None:\n # Makes sure the file exists to be loaded\n if not self.file_path.exists() and not self.only_auto_resume:",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:201-229"
+ },
+ "481": {
+ "file_id": 14,
+ "content": "The code defines a base class, \"BaseLoader\", which is responsible for loading files from a given data path. It initializes the class by setting the URL and has an init method to check if the file exists. The \"recall\" method downloads the file and loads it into memory. Additionally, there is a subclass called \"LocalLoader\" that loads files from local paths, checking if the file exists before loading it.",
+ "type": "comment"
+ },
+ "482": {
+ "file_id": 14,
+ "content": " raise FileNotFoundError(f'Model not found at {self.file_path}')\n def recall(self) -> dict:\n # Load the file\n return torch.load(str(self.file_path), map_location='cpu')\nclass WandbLoader(BaseLoader):\n \"\"\"\n A loader that loads a model from an existing wandb run\n \"\"\"\n def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):\n super().__init__(data_path, **kwargs)\n self.run_path = wandb_run_path\n self.file_path = wandb_file_path\n def init(self, logger: BaseLogger, **kwargs) -> None:\n self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')\n # Make sure the file can be downloaded\n if self.wandb.run is not None and self.run_path is None:\n self.run_path = self.wandb.run.path\n assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'\n assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:230-251"
+ },
+ "483": {
+ "file_id": 14,
+ "content": "This code defines a class `WandbLoader` that loads a model from an existing W&B (Weights & Biases) run. It requires a data path, a file path within the W&B run, and optionally a W&B run path. The `__init__` method initializes the object, the `init` method ensures the file can be downloaded, and the `recall` method loads the model using `torch.load`. If a W&B run is available but the run path is not specified, it sets the run path to the current run's path. The code also imports the 'wandb' library if it is missing.",
+ "type": "comment"
+ },
+ "484": {
+ "file_id": 14,
+ "content": " assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'\n os.environ[\"WANDB_SILENT\"] = \"true\"\n pass # TODO: Actually implement that\n def recall(self) -> dict:\n file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)\n return torch.load(file_reference.name, map_location='cpu')\nloader_type_map = {\n 'url': UrlLoader,\n 'local': LocalLoader,\n 'wandb': WandbLoader,\n}\ndef create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:\n if loader_type == 'custom':\n raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')\n try:\n loader_class = loader_type_map[loader_type]\n except KeyError:\n raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')\n return loader_class(data_path, **kwargs)\nclass BaseSaver:\n def __init__(self,\n data_path: str,\n save_latest_to: Optional[Union[str, bool]] = None,",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:252-278"
+ },
+ "485": {
+ "file_id": 14,
+ "content": "This code defines a `BaseSaver` class with an optional parameter for saving the latest data to a specified location. It also includes a function `create_loader()` that creates different types of loaders (url, local, wandb) based on the provided loader type and data path. The WandbLoader is used to restore data from a specified file path using Weights & Biases environment.",
+ "type": "comment"
+ },
+ "486": {
+ "file_id": 14,
+ "content": " save_best_to: Optional[Union[str, bool]] = None,\n save_meta_to: Optional[str] = None,\n save_type: str = 'checkpoint',\n **kwargs\n ):\n self.data_path = Path(data_path)\n self.save_latest_to = save_latest_to\n self.saving_latest = save_latest_to is not None and save_latest_to is not False\n self.save_best_to = save_best_to\n self.saving_best = save_best_to is not None and save_best_to is not False\n self.save_meta_to = save_meta_to\n self.saving_meta = save_meta_to is not None\n self.save_type = save_type\n assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'\n assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified'\n def init(self, logger: BaseLogger, **kwargs) -> None:\n raise NotImplementedError\n def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:\n \"\"\"",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:279-299"
+ },
+ "487": {
+ "file_id": 14,
+ "content": "This code defines a tracker class that handles saving of data to specified locations. It allows saving the latest, best, and meta information, with options for file type and paths. The `save_file` method is used to save files with optional flags for best and latest status. An assertion ensures that the save type is either 'checkpoint' or 'model'. A final assertion requires at least one saving option to be specified.",
+ "type": "comment"
+ },
+ "488": {
+ "file_id": 14,
+ "content": " Save a general file under save_meta_to\n \"\"\"\n raise NotImplementedError\nclass LocalSaver(BaseSaver):\n def __init__(self,\n data_path: str,\n **kwargs\n ):\n super().__init__(data_path, **kwargs)\n def init(self, logger: BaseLogger, **kwargs) -> None:\n # Makes sure the directory exists to be saved to\n print(f\"Saving {self.save_type} locally\")\n if not self.data_path.exists():\n self.data_path.mkdir(parents=True)\n def save_file(self, local_path: str, save_path: str, **kwargs) -> None:\n # Copy the file to save_path\n save_path_file_name = Path(save_path).name\n # Make sure parent directory exists\n save_path_parent = Path(save_path).parent\n if not save_path_parent.exists():\n save_path_parent.mkdir(parents=True)\n print(f\"Saving {save_path_file_name} {self.save_type} to local path {save_path}\")\n shutil.copy(local_path, save_path)\nclass WandbSaver(BaseSaver):\n def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:300-328"
+ },
+ "489": {
+ "file_id": 14,
+ "content": "This code defines two classes, LocalSaver and WandbSaver, which inherit from BaseSaver. Both classes are responsible for saving files in different locations. The LocalSaver saves files locally to a specified data_path, ensuring the directory exists beforehand. The WandbSaver is optional and requires a wandb_run_path parameter.",
+ "type": "comment"
+ },
+ "490": {
+ "file_id": 14,
+ "content": " super().__init__(data_path, **kwargs)\n self.run_path = wandb_run_path\n def init(self, logger: BaseLogger, **kwargs) -> None:\n self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')\n os.environ[\"WANDB_SILENT\"] = \"true\"\n # Makes sure that the user can upload tot his run\n if self.run_path is not None:\n entity, project, run_id = self.run_path.split(\"/\")\n self.run = self.wandb.init(entity=entity, project=project, id=run_id)\n else:\n assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'\n self.run = self.wandb.run\n # TODO: Now actually check if upload is possible\n print(f\"Saving to wandb run {self.run.path}-{self.run.name}\")\n def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:\n # In order to log something in the correct place in wandb, we need to have the same file structure here",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:329-346"
+ },
+ "491": {
+ "file_id": 14,
+ "content": "This code initializes a W&B run based on the `wandb_run_path` provided. It imports the W&B library, sets up the environment for uploading to W&B runs, and checks if the user has access to save files in the specified W&B run path.",
+ "type": "comment"
+ },
+ "492": {
+ "file_id": 14,
+ "content": " save_path_file_name = Path(save_path).name\n print(f\"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}\")\n save_path = Path(self.data_path) / save_path\n save_path.parent.mkdir(parents=True, exist_ok=True)\n shutil.copy(local_path, save_path)\n self.run.save(str(save_path), base_path = str(self.data_path), policy='now')\nclass HuggingfaceSaver(BaseSaver):\n def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):\n super().__init__(data_path, **kwargs)\n self.huggingface_repo = huggingface_repo\n self.token_path = token_path\n def init(self, logger: BaseLogger, **kwargs):\n # Makes sure this user can upload to the repo\n self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')\n try:\n identity = self.hub.whoami() # Errors if not logged in\n # Then we are logged in",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:347-365"
+ },
+ "493": {
+ "file_id": 14,
+ "content": "This code defines a `HuggingfaceSaver` class that saves files to a Hugging Face repository. It initializes the instance with a data path, Hugging Face repo, and optional token path. The `init` method checks if the user is logged in to the Hugging Face hub and then saves the file specified by `save_path` using `self.hub.upload`.",
+ "type": "comment"
+ },
+ "494": {
+ "file_id": 14,
+ "content": " except:\n # We are not logged in. Use the token_path to set the token.\n if not os.path.exists(self.token_path):\n raise Exception(\"Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.\")\n with open(self.token_path, \"r\") as f:\n token = f.read().strip()\n self.hub.HfApi.set_access_token(token)\n identity = self.hub.whoami()\n print(f\"Saving to huggingface repo {self.huggingface_repo}\")\n def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:\n # Saving to huggingface is easy, we just need to upload the file with the correct name\n save_path_file_name = Path(save_path).name\n print(f\"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}\")\n self.hub.upload_file(\n path_or_fileobj=str(local_path),\n path_in_repo=str(save_path),",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:366-382"
+ },
+ "495": {
+ "file_id": 14,
+ "content": "This code handles saving a file to the HuggingFace repo. If not logged in, it checks for a token path and uses it if available, or throws an exception. It then prints the saving path, logs in with the token (if provided), and finally uploads the file to the specified HuggingFace repo.",
+ "type": "comment"
+ },
+ "496": {
+ "file_id": 14,
+ "content": " repo_id=self.huggingface_repo\n )\nsaver_type_map = {\n 'local': LocalSaver,\n 'wandb': WandbSaver,\n 'huggingface': HuggingfaceSaver\n}\ndef create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:\n if saver_type == 'custom':\n raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')\n try:\n saver_class = saver_type_map[saver_type]\n except KeyError:\n raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')\n return saver_class(data_path, **kwargs)\nclass Tracker:\n def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):\n self.data_path = Path(data_path)\n if not dummy_mode:\n if not overwrite_data_path:\n assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'\n if not self.data_path.exists():",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:383-407"
+ },
+ "497": {
+ "file_id": 14,
+ "content": "Function create_saver takes a saver type and data path, returns a BaseSaver object. It supports 'local', 'wandb', and 'huggingface' saver types. If the saver type is 'custom', it raises an error since custom savers aren't supported yet. Tracker initializes with optional data_path, overwrite_data_path (to overwrite existing path), and dummy_mode (if running in simulation mode). If not in dummy mode, asserts that the data path doesn't exist unless overwrite_data_path is True.",
+ "type": "comment"
+ },
+ "498": {
+ "file_id": 14,
+ "content": " self.data_path.mkdir(parents=True)\n self.logger: BaseLogger = None\n self.loader: Optional[BaseLoader] = None\n self.savers: List[BaseSaver]= []\n self.dummy_mode = dummy_mode\n def _load_auto_resume(self) -> bool:\n # If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run.\n if not self.auto_resume_path.exists():\n if self.logger.auto_resume:\n print(\"Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.\")\n return False\n # Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time\n if not self.logger.auto_resume:\n print(f'Removing auto_resume.json because auto_resume is not enabled in the config')\n self.auto_resume_path.unlink()\n return False\n # Otherwise we read the json into a dictionary will will override parts of logger.__dict__",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:408-427"
+ },
+ "499": {
+ "file_id": 14,
+ "content": "This code initializes a tracker object, handling the data path creation, base logger and loader setup, saving list initialization, and dummy mode. It also includes a method to load auto-resume configuration if it exists, printing warnings for first run or removing the file if auto-resume is not enabled.",
+ "type": "comment"
+ }
+}
\ No newline at end of file
diff --git a/docs/data/5.json b/docs/data/5.json
new file mode 100644
index 00000000..53d4f05d
--- /dev/null
+++ b/docs/data/5.json
@@ -0,0 +1,550 @@
+{
+ "500": {
+ "file_id": 14,
+ "content": " with open(self.auto_resume_path, 'r') as f:\n auto_resume_dict = json.load(f)\n # Check if the logger is of the same type as the autoresume save\n if auto_resume_dict[\"logger_type\"] != self.logger.__class__.__name__:\n raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict[\"logger_type\"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')\n # Then we are ready to override the logger with the autoresume save\n self.logger.__dict__[\"resume\"] = True\n print(f\"Updating {self.logger.__dict__} with {auto_resume_dict}\")\n self.logger.__dict__.update(auto_resume_dict)\n return True\n def _save_auto_resume(self):\n # Gets the autoresume dict from the logger and adds \"logger_type\" to it then saves it to the auto_resume file\n auto_resume_dict = self.logger.get_resume_data()\n auto_resume_dict['logger_type'] = self.logger.__class__.__name__",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:428-442"
+ },
+ "501": {
+ "file_id": 14,
+ "content": "This code reads a previously saved state from the \"auto_resume_path\" and checks if the logger type matches the current logger. If they don't match, it raises an exception with instructions on how to proceed. Otherwise, it updates the logger with the auto-resume data and returns True.",
+ "type": "comment"
+ },
+ "502": {
+ "file_id": 14,
+ "content": " with open(self.auto_resume_path, 'w') as f:\n json.dump(auto_resume_dict, f)\n def init(self, full_config: BaseModel, extra_config: dict):\n self.auto_resume_path = self.data_path / 'auto_resume.json'\n # Check for resuming the run\n self.did_auto_resume = self._load_auto_resume()\n if self.did_auto_resume:\n print(f'\\n\\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\\n\\n')\n print(f\"New logger config: {self.logger.__dict__}\")\n self.save_metadata = dict(\n version = version.parse(__version__)\n ) # Data that will be saved alongside the checkpoint or model\n self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # These keys would cause us to error if we try to save them as metadata\n assert self.logger is not None, '`logger` must be set before `init` is called'",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:443-459"
+ },
+ "503": {
+ "file_id": 14,
+ "content": "This code is initializing a tracker object. It sets the auto_resume path, checks for resuming the run and prints a warning if it was automatically resumed. The save_metadata dictionary is created with version information and some keys are blacklisted from being saved as metadata to avoid errors during saving. The logger must be set before calling init method.",
+ "type": "comment"
+ },
+ "504": {
+ "file_id": 14,
+ "content": " if self.dummy_mode:\n # The only thing we need is a loader\n if self.loader is not None:\n self.loader.init(self.logger)\n return\n assert len(self.savers) > 0, '`savers` must be set before `init` is called'\n self.logger.init(full_config, extra_config)\n if self.loader is not None:\n self.loader.init(self.logger)\n for saver in self.savers:\n saver.init(self.logger)\n if self.logger.auto_resume:\n # Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved.\n self._save_auto_resume()\n def add_logger(self, logger: BaseLogger):\n self.logger = logger\n def add_loader(self, loader: BaseLoader):\n self.loader = loader\n def add_saver(self, saver: BaseSaver):\n self.savers.append(saver)\n def log(self, *args, **kwargs):\n if self.dummy_mode:\n return\n self.logger.log(*args, **kwargs)",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:460-489"
+ },
+ "505": {
+ "file_id": 14,
+ "content": "This code initializes trackers by first checking if in dummy mode, then initializing loaders and savers. The logger is initialized only if the `savers` list has items, and if `auto_resume` is enabled, it saves an autoresume file. The `add_logger`, `add_loader`, `add_saver`, and `log` methods are provided to interact with trackers' components.",
+ "type": "comment"
+ },
+ "506": {
+ "file_id": 14,
+ "content": " def log_images(self, *args, **kwargs):\n if self.dummy_mode:\n return\n self.logger.log_images(*args, **kwargs)\n def log_file(self, *args, **kwargs):\n if self.dummy_mode:\n return\n self.logger.log_file(*args, **kwargs)\n def save_config(self, current_config_path: str, config_name = 'config.json'):\n if self.dummy_mode:\n return\n # Save the config under config_name in the root folder of data_path\n shutil.copy(current_config_path, self.data_path / config_name)\n for saver in self.savers:\n if saver.saving_meta:\n remote_path = Path(saver.save_meta_to) / config_name\n saver.save_file(current_config_path, str(remote_path))\n def add_save_metadata(self, state_dict_key: str, metadata: Any):\n \"\"\"\n Adds a new piece of metadata that will be saved along with the model or decoder.\n \"\"\"\n self.save_metadata[state_dict_key] = metadata\n def _save_state_dict(self,",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:491-517"
+ },
+ "507": {
+ "file_id": 14,
+ "content": "This code is from the DALLE2-pytorch library and it contains several methods for logging images, files, saving configurations, and adding save metadata. The dummy_mode check prevents unnecessary actions when in a test mode. The save_config method copies the current config file to the root folder of the data_path and saves it remotely if specified by the saver. The add_save_metadata method adds new metadata that will be saved along with the model or decoder.",
+ "type": "comment"
+ },
+ "508": {
+ "file_id": 14,
+ "content": " trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:\n \"\"\"\n Gets the state dict to be saved and writes it to file_path.\n If save_type is 'checkpoint', we save the entire trainer state dict.\n If save_type is 'model', we save only the model state dict.\n \"\"\"\n assert save_type in ['checkpoint', 'model']\n if save_type == 'checkpoint':\n # Create a metadata dict without the blacklisted keys so we do not error when we create the state dict\n metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys}\n trainer.save(file_path, overwrite=True, **kwargs, **metadata)\n elif save_type == 'model':\n if isinstance(trainer, DiffusionPriorTrainer):\n prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior\n prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior)",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:517-531"
+ },
+ "509": {
+ "file_id": 14,
+ "content": "This function saves the trainer's state dict, depending on the 'save_type' parameter. If 'checkpoint', it saves the entire trainer state without blacklisted metadata keys. If 'model', it saves only the model state if the trainer is a DiffusionPriorTrainer.",
+ "type": "comment"
+ },
+ "510": {
+ "file_id": 14,
+ "content": " # Remove CLIP if it is part of the model\n original_clip = prior.clip\n prior.clip = None\n model_state_dict = prior.state_dict()\n prior.clip = original_clip\n elif isinstance(trainer, DecoderTrainer):\n decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)\n # Remove CLIP if it is part of the model\n original_clip = decoder.clip\n decoder.clip = None\n if trainer.use_ema:\n trainable_unets = decoder.unets\n decoder.unets = trainer.unets # Swap EMA unets in\n model_state_dict = decoder.state_dict()\n decoder.unets = trainable_unets # Swap back\n else:\n model_state_dict = decoder.state_dict()\n decoder.clip = original_clip\n else:\n raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:532-551"
+ },
+ "511": {
+ "file_id": 14,
+ "content": "This code checks the type of trainer and removes CLIP from the model if it is part of it. It then saves the state dictionary for the model, and optionally swaps EMA unets in or out depending on the use_ema flag. Finally, it restores the original CLIP state.",
+ "type": "comment"
+ },
+ "512": {
+ "file_id": 14,
+ "content": " state_dict = {\n **self.save_metadata,\n 'model': model_state_dict\n }\n torch.save(state_dict, file_path)\n return Path(file_path)\n def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):\n if self.dummy_mode:\n return\n if not is_best and not is_latest:\n # Nothing to do\n return\n # Save the checkpoint and model to data_path\n checkpoint_path = self.data_path / 'checkpoint.pth'\n self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)\n model_path = self.data_path / 'model.pth'\n self._save_state_dict(trainer, 'model', model_path, **kwargs)\n print(\"Saved cached models\")\n # Call the save methods on the savers\n for saver in self.savers:\n local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path\n if saver.saving_latest and is_latest:\n latest_checkpoint_path = saver.save_latest_to.format(**kwargs)",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:552-575"
+ },
+ "513": {
+ "file_id": 14,
+ "content": "This code saves the model and checkpoint to specified file paths. If not in dummy mode, it checks if the 'is_best' or 'is_latest' flag is set before proceeding with saving the state dictionary for 'checkpoint' and 'model'. It then prints a message confirming the saved cached models. Lastly, it calls save methods on savers, considering the 'saving_latest' flag and appropriate file paths.",
+ "type": "comment"
+ },
+ "514": {
+ "file_id": 14,
+ "content": " try:\n saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)\n except Exception as e:\n self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)\n print(f'Error saving checkpoint: {e}')\n if saver.saving_best and is_best:\n best_checkpoint_path = saver.save_best_to.format(**kwargs)\n try:\n saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)\n except Exception as e:\n self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)\n print(f'Error saving checkpoint: {e}')\n @property\n def can_recall(self):\n # Defines whether a recall can be performed.\n return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)\n def recall(self):\n if self.can_recall:\n return self.loader.recall()\n else:\n",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:576-598"
+ },
+ "515": {
+ "file_id": 14,
+ "content": "This code appears to be part of a class that manages loading and saving checkpoints for a model. It has a property called \"can_recall\" which determines if a recall (loading a previously saved checkpoint) can be performed based on whether the loader is not None and certain conditions about the loader's properties. If a recall is possible, the \"recall()\" function is called to perform the actual recall. Any errors that occur during saving are logged and printed.",
+ "type": "comment"
+ },
+ "516": {
+ "file_id": 14,
+ "content": " raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')",
+ "type": "code",
+ "location": "/dalle2_pytorch/trackers.py:598-598"
+ },
+ "517": {
+ "file_id": 14,
+ "content": "Raises an error when no loader is set and auto-resume was not performed.",
+ "type": "comment"
+ },
+ "518": {
+ "file_id": 15,
+ "content": "/dalle2_pytorch/train_configs.py",
+ "type": "filepath"
+ },
+ "519": {
+ "file_id": 15,
+ "content": "The code sets up DALL-E 2 PyTorch training configurations, provides utility functions and tracker configuration, defines a class for model training/evaluation, and suggests potential efficiency improvements.",
+ "type": "summary"
+ },
+ "520": {
+ "file_id": 15,
+ "content": "import json\nfrom torchvision import transforms as T\nfrom pydantic import BaseModel, validator, model_validator\nfrom typing import List, Optional, Union, Tuple, Dict, Any, TypeVar\nfrom x_clip import CLIP as XCLIP\nfrom open_clip import list_pretrained\nfrom coca_pytorch import CoCa\nfrom dalle2_pytorch.dalle2_pytorch import (\n CoCaAdapter,\n OpenAIClipAdapter,\n OpenClipAdapter,\n Unet,\n Decoder,\n DiffusionPrior,\n DiffusionPriorNetwork,\n XClipAdapter\n)\nfrom dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver\n# helper functions\ndef exists(val):\n return val is not None\ndef default(val, d):\n return val if exists(val) else d\nInnerType = TypeVar('InnerType')\nListOrTuple = Union[List[InnerType], Tuple[InnerType]]\nSingularOrIterable = Union[InnerType, ListOrTuple[InnerType]]\n# general pydantic classes\nclass TrainSplitConfig(BaseModel):\n train: float = 0.75\n val: float = 0.15\n test: float = 0.1\n @model_validator(mode = 'after')\n def validate_all(self, m):\n actual_sum = sum([*dict(self).values()])",
+ "type": "code",
+ "location": "/dalle2_pytorch/train_configs.py:1-43"
+ },
+ "521": {
+ "file_id": 15,
+ "content": "This code is defining various classes and functions for training configurations in a machine learning application, specifically related to the DALL-E 2 PyTorch model. It includes importing necessary modules, setting up pydantic models for train splits, and creating utility functions like `default` and `exists`.",
+ "type": "comment"
+ },
+ "522": {
+ "file_id": 15,
+ "content": " if actual_sum != 1.:\n raise ValueError(f'{dict(self).keys()} must sum to 1.0. Found: {actual_sum}')\n return self\nclass TrackerLogConfig(BaseModel):\n log_type: str = 'console'\n resume: bool = False # For logs that are saved to unique locations, resume a previous run\n auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed\n verbose: bool = False\n class Config:\n # Each individual log type has it's own arguments that will be passed through the config\n extra = \"allow\"\n def create(self, data_path: str):\n kwargs = self.dict()\n return create_logger(self.log_type, data_path, **kwargs)\nclass TrackerLoadConfig(BaseModel):\n load_from: Optional[str] = None\n only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming\n class Config:\n extra = \"allow\"\n def create(self, data_path: str):\n kwargs = self.dict()\n if self.load_from is None:\n return None",
+ "type": "code",
+ "location": "/dalle2_pytorch/train_configs.py:44-73"
+ },
+ "523": {
+ "file_id": 15,
+ "content": "The code defines two classes, `TrackerLogConfig` and `TrackerLoadConfig`, which inherit from `BaseModel`. These classes have various attributes such as `log_type`, `resume`, `auto_resume`, and `verbose`. They also have a method called `create` that takes in a `data_path` parameter and returns a logger object. The classes ensure their attributes sum up to 1.0, and allow additional arguments for each individual log type. The `TrackerLoadConfig` class has an optional attribute `load_from`, which determines if the logger should load from a previous run. If `load_from` is set to `None`, it returns None instead of loading.",
+ "type": "comment"
+ },
+ "524": {
+ "file_id": 15,
+ "content": " return create_loader(self.load_from, data_path, **kwargs)\nclass TrackerSaveConfig(BaseModel):\n save_to: str = 'local'\n save_all: bool = False\n save_latest: bool = True\n save_best: bool = True\n class Config:\n extra = \"allow\"\n def create(self, data_path: str):\n kwargs = self.dict()\n return create_saver(self.save_to, data_path, **kwargs)\nclass TrackerConfig(BaseModel):\n data_path: str = '.tracker_data'\n overwrite_data_path: bool = False\n log: TrackerLogConfig\n load: Optional[TrackerLoadConfig] = None\n save: Union[List[TrackerSaveConfig], TrackerSaveConfig]\n def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:\n tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)\n # Add the logger\n tracker.add_logger(self.log.create(self.data_path))\n # Add the loader\n if self.load is not None:\n tracker.add_loader(self.load.create(self.data_path))",
+ "type": "code",
+ "location": "/dalle2_pytorch/train_configs.py:74-102"
+ },
+ "525": {
+ "file_id": 15,
+ "content": "This code defines classes for tracker configuration and load/save operations. The TrackerConfig class contains information about the data path, overwrite option, logger settings, and optional load configurations. The create method of TrackerConfig initializes a new Tracker object and adds a logger if present in the configuration. If there is a defined load configuration, it also adds a loader to the tracker.",
+ "type": "comment"
+ },
+ "526": {
+ "file_id": 15,
+ "content": " # Add the saver or savers\n if isinstance(self.save, list):\n for save_config in self.save:\n tracker.add_saver(save_config.create(self.data_path))\n else:\n tracker.add_saver(self.save.create(self.data_path))\n # Initialize all the components and verify that all data is valid\n tracker.init(full_config, extra_config)\n return tracker\n# diffusion prior pydantic classes\nclass AdapterConfig(BaseModel):\n make: str = \"openai\"\n model: str = \"ViT-L/14\"\n base_model_kwargs: Optional[Dict[str, Any]] = None\n def create(self):\n if self.make == \"openai\":\n return OpenAIClipAdapter(self.model)\n elif self.make == \"open_clip\":\n pretrained = dict(list_pretrained())\n checkpoint = pretrained[self.model]\n return OpenClipAdapter(name=self.model, pretrained=checkpoint)\n elif self.make == \"x-clip\":\n return XClipAdapter(XCLIP(**self.base_model_kwargs))\n elif self.make == \"coca\":",
+ "type": "code",
+ "location": "/dalle2_pytorch/train_configs.py:103-129"
+ },
+ "527": {
+ "file_id": 15,
+ "content": "This code defines a function that initializes and returns a tracker object, which is responsible for managing savers and components of the model. It also includes classes for different types of adapters used in the model. The tracker object verifies data validity after initialization.",
+ "type": "comment"
+ },
+ "528": {
+ "file_id": 15,
+ "content": " return CoCaAdapter(CoCa(**self.base_model_kwargs))\n else:\n raise AttributeError(\"No adapter with that name is available.\")\nclass DiffusionPriorNetworkConfig(BaseModel):\n dim: int\n depth: int\n max_text_len: Optional[int] = None\n num_timesteps: Optional[int] = None\n num_time_embeds: int = 1\n num_image_embeds: int = 1\n num_text_embeds: int = 1\n dim_head: int = 64\n heads: int = 8\n ff_mult: int = 4\n norm_in: bool = False\n norm_out: bool = True\n attn_dropout: float = 0.\n ff_dropout: float = 0.\n final_proj: bool = True\n normformer: bool = False\n rotary_emb: bool = True\n class Config:\n extra = \"allow\"\n def create(self):\n kwargs = self.dict()\n return DiffusionPriorNetwork(**kwargs)\nclass DiffusionPriorConfig(BaseModel):\n clip: Optional[AdapterConfig] = None\n net: DiffusionPriorNetworkConfig\n image_embed_dim: int\n image_size: int\n image_channels: int = 3\n timesteps: int = 1000\n sample_timesteps: Optional[int] = None",
+ "type": "code",
+ "location": "/dalle2_pytorch/train_configs.py:130-167"
+ },
+ "529": {
+ "file_id": 15,
+ "content": "This code defines configurations for a neural network model. It includes classes for adapters, diffusion prior networks, and diffusion prior models. The adapter class takes in base_model_kwargs and returns an instance of either CoCaAdapter or raises AttributeError if no matching adapter found. DiffusionPriorNetworkConfig defines the architecture specifications like dimensions, depth, and dropout rates. DiffusionPriorConfig handles configurations for clip adapters, diffusion prior networks, image embedding dimensions, image size, and number of timesteps. The create() function returns an instance of the model based on its configuration.",
+ "type": "comment"
+ },
+ "530": {
+ "file_id": 15,
+ "content": " cond_drop_prob: float = 0.\n loss_type: str = 'l2'\n predict_x_start: bool = True\n beta_schedule: str = 'cosine'\n condition_on_text_encodings: bool = True\n class Config:\n extra = \"allow\"\n def create(self):\n kwargs = self.dict()\n has_clip = exists(kwargs.pop('clip'))\n kwargs.pop('net')\n clip = None\n if has_clip:\n clip = self.clip.create()\n diffusion_prior_network = self.net.create()\n return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)\nclass DiffusionPriorTrainConfig(BaseModel):\n epochs: int = 1\n lr: float = 1.1e-4\n wd: float = 6.02e-2\n max_grad_norm: float = 0.5\n use_ema: bool = True\n ema_beta: float = 0.99\n amp: bool = False\n warmup_steps: Optional[int] = None # number of warmup steps\n save_every_seconds: int = 3600 # how often to save\n eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with\n best_validation_loss: float = 1e9 # the current best valudation loss observed",
+ "type": "code",
+ "location": "/dalle2_pytorch/train_configs.py:168-201"
+ },
+ "531": {
+ "file_id": 15,
+ "content": "The code defines a class for training configurations, including epochs, learning rate, weight decay, and other parameters. It also contains functions to create instances of diffusion prior networks and conditioning models. The class is part of the DALLE2-pytorch framework and is used for training the model.",
+ "type": "comment"
+ },
+ "532": {
+ "file_id": 15,
+ "content": " current_epoch: int = 0 # the current epoch\n num_samples_seen: int = 0 # the current number of samples seen\n random_seed: int = 0 # manual seed for torch\nclass DiffusionPriorDataConfig(BaseModel):\n image_url: str # path to embeddings folder\n meta_url: str # path to metadata (captions) for images\n splits: TrainSplitConfig # define train, validation, test splits for your dataset\n batch_size: int # per-gpu batch size used to train the model\n num_data_points: int = 25e7 # total number of datapoints to train on\n eval_every_seconds: int = 3600 # validation statistics will be performed this often\nclass TrainDiffusionPriorConfig(BaseModel):\n prior: DiffusionPriorConfig\n data: DiffusionPriorDataConfig\n train: DiffusionPriorTrainConfig\n tracker: TrackerConfig\n @classmethod\n def from_json_path(cls, json_path):\n with open(json_path) as f:\n config = json.load(f)",
+ "type": "code",
+ "location": "/dalle2_pytorch/train_configs.py:202-223"
+ },
+ "533": {
+ "file_id": 15,
+ "content": "The code defines a configuration class for training the DiffusionPrior model, which contains details such as the data source, batch size, total number of datapoints to train on, and validation frequency. It also has methods to load configurations from JSON files.",
+ "type": "comment"
+ },
+ "534": {
+ "file_id": 15,
+ "content": " return cls(**config)\n# decoder pydantic classes\nclass UnetConfig(BaseModel):\n dim: int\n dim_mults: ListOrTuple[int]\n image_embed_dim: Optional[int] = None\n text_embed_dim: Optional[int] = None\n cond_on_text_encodings: Optional[bool] = None\n cond_dim: Optional[int] = None\n channels: int = 3\n self_attn: SingularOrIterable[bool] = False\n attn_dim_head: int = 32\n attn_heads: int = 16\n init_cross_embed: bool = True\n class Config:\n extra = \"allow\"\nclass DecoderConfig(BaseModel):\n unets: ListOrTuple[UnetConfig]\n image_size: Optional[int] = None\n image_sizes: ListOrTuple[int] = None\n clip: Optional[AdapterConfig] = None # The clip model to use if embeddings are not provided\n channels: int = 3\n timesteps: int = 1000\n sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None\n loss_type: str = 'l2'\n beta_schedule: Optional[ListOrTuple[str]] = None # None means all cosine\n learned_variance: SingularOrIterable[bool] = True\n image_cond_drop_prob: float = 0.1",
+ "type": "code",
+ "location": "/dalle2_pytorch/train_configs.py:224-255"
+ },
+ "535": {
+ "file_id": 15,
+ "content": "The code defines two Pydantic classes, UnetConfig and DecoderConfig, which represent the configurations for the DALL-E 2 model. The UnetConfig class handles the configuration of the UNet transformer in the decoder while the DecoderConfig class includes various settings like the number of UNet blocks, image size, clip model, timesteps, loss type, and more.",
+ "type": "comment"
+ },
+ "536": {
+ "file_id": 15,
+ "content": " text_cond_drop_prob: float = 0.5\n def create(self):\n decoder_kwargs = self.dict()\n unet_configs = decoder_kwargs.pop('unets')\n unets = [Unet(**config) for config in unet_configs]\n has_clip = exists(decoder_kwargs.pop('clip'))\n clip = None\n if has_clip:\n clip = self.clip.create()\n return Decoder(unets, clip=clip, **decoder_kwargs)\n @validator('image_sizes')\n def check_image_sizes(cls, image_sizes, values):\n if exists(values.get('image_size')) ^ exists(image_sizes):\n return image_sizes\n raise ValueError('either image_size or image_sizes is required, but not both')\n class Config:\n extra = \"allow\"\nclass DecoderDataConfig(BaseModel):\n webdataset_base_url: str # path to a webdataset with jpg images\n img_embeddings_url: Optional[str] = None # path to .npy files with embeddings\n text_embeddings_url: Optional[str] = None # path to .npy files with embeddings\n num_workers: int = 4",
+ "type": "code",
+ "location": "/dalle2_pytorch/train_configs.py:256-284"
+ },
+ "537": {
+ "file_id": 15,
+ "content": "This code defines a class \"TrainConfigs\" that creates a decoder for DALL-E 2 training. It uses the Unet architecture, optionally includes CLIP for visual guidance, and allows specifying image sizes through 'image_size' or list of 'image_sizes'. The class also provides configurations for loading data from webdataset with jpg images, embedding files, and setting the number of workers for data loading.",
+ "type": "comment"
+ },
+ "538": {
+ "file_id": 15,
+ "content": " batch_size: int = 64\n start_shard: int = 0\n end_shard: int = 9999999\n shard_width: int = 6\n index_width: int = 4\n splits: TrainSplitConfig\n shuffle_train: bool = True\n resample_train: bool = False\n preprocessing: Dict[str, Any] = {'ToTensor': True}\n @property\n def img_preproc(self):\n def _get_transformation(transformation_name, **kwargs):\n if transformation_name == \"RandomResizedCrop\":\n return T.RandomResizedCrop(**kwargs)\n elif transformation_name == \"RandomHorizontalFlip\":\n return T.RandomHorizontalFlip()\n elif transformation_name == \"ToTensor\":\n return T.ToTensor()\n transforms = []\n for transform_name, transform_kwargs_or_bool in self.preprocessing.items():\n transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool\n transforms.append(_get_transformation(transform_name, **transform_kwargs))\n return T.Compose(transforms)",
+ "type": "code",
+ "location": "/dalle2_pytorch/train_configs.py:285-309"
+ },
+ "539": {
+ "file_id": 15,
+ "content": "This code defines a training configuration with batch size, sharding settings, transformation preprocessing, and boolean flags for shuffling and resampling. It also includes a property method to generate the image preprocessing transforms based on provided names and optional arguments.",
+ "type": "comment"
+ },
+ "540": {
+ "file_id": 15,
+ "content": "class DecoderTrainConfig(BaseModel):\n epochs: int = 20\n lr: SingularOrIterable[float] = 1e-4\n wd: SingularOrIterable[float] = 0.01\n warmup_steps: Optional[SingularOrIterable[int]] = None\n find_unused_parameters: bool = True\n static_graph: bool = True\n max_grad_norm: SingularOrIterable[float] = 0.5\n save_every_n_samples: int = 100000\n n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset\n cond_scale: Union[float, List[float]] = 1.0\n device: str = 'cuda:0'\n epoch_samples: Optional[int] = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.\n validation_samples: Optional[int] = None # Same as above but for validation.\n save_immediately: bool = False\n use_ema: bool = True\n ema_beta: float = 0.999\n amp: bool = False\n unet_training_mask: Optional[ListOrTuple[bool]] = None # If None, use all unets",
+ "type": "code",
+ "location": "/dalle2_pytorch/train_configs.py:311-329"
+ },
+ "541": {
+ "file_id": 15,
+ "content": "This code defines a DecoderTrainConfig class with various configuration options for training the decoder model in DALLE2. The class includes settings for epochs, learning rate, weight decay, warmup steps, finding unused parameters, static graph usage, gradient clipping, saving samples, generating example images, scaling conditions, device selection, sample limits per epoch and validation, saving immediately, using exponential moving average (EMA), EMA beta value, using mixed precision training (AMP), and unet training masks.",
+ "type": "comment"
+ },
+ "542": {
+ "file_id": 15,
+ "content": "class DecoderEvaluateConfig(BaseModel):\n n_evaluation_samples: int = 1000\n FID: Optional[Dict[str, Any]] = None\n IS: Optional[Dict[str, Any]] = None\n KID: Optional[Dict[str, Any]] = None\n LPIPS: Optional[Dict[str, Any]] = None\nclass TrainDecoderConfig(BaseModel):\n decoder: DecoderConfig\n data: DecoderDataConfig\n train: DecoderTrainConfig\n evaluate: DecoderEvaluateConfig\n tracker: TrackerConfig\n seed: int = 0\n @classmethod\n def from_json_path(cls, json_path):\n with open(json_path) as f:\n config = json.load(f)\n print(config)\n return cls(**config)\n @model_validator(mode = 'after')\n def check_has_embeddings(self, m):\n # Makes sure that enough information is provided to get the embeddings specified for training\n values = dict(self)\n data_config, decoder_config = values.get('data'), values.get('decoder')\n if not exists(data_config) or not exists(decoder_config):\n # Then something else errored and we should just pass through",
+ "type": "code",
+ "location": "/dalle2_pytorch/train_configs.py:331-361"
+ },
+ "543": {
+ "file_id": 15,
+ "content": "This code defines two classes, \"DecoderEvaluateConfig\" and \"TrainDecoderConfig\", which inherit from the \"BaseModel\" class. The \"DecoderEvaluateConfig\" class specifies evaluation metrics like FID, IS, KID, and LPIPS, while the \"TrainDecoderConfig\" class combines various configuration elements including a decoder, data, training settings, evaluation settings, tracker, and seed. The \"from_json_path\" method loads configuration from a JSON file, and the \"check_has_embeddings\" validator ensures that enough information is provided to get the embeddings for training.",
+ "type": "comment"
+ },
+ "544": {
+ "file_id": 15,
+ "content": " return values\n using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])\n using_clip = exists(decoder_config.clip)\n img_emb_url = data_config.img_embeddings_url\n text_emb_url = data_config.text_embeddings_url\n if using_text_embeddings:\n # Then we need some way to get the embeddings\n assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'\n if using_clip:\n if using_text_embeddings:\n assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'\n else:\n assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'\n if text_emb_url:\n assert using_te",
+ "type": "code",
+ "location": "/dalle2_pytorch/train_configs.py:362-380"
+ },
+ "545": {
+ "file_id": 15,
+ "content": "This code checks if the text embeddings and/or CLIP model are being used, ensuring that only one of these is provided to avoid redundancy. It asserts that either the CLIP or text embeddings URL must be present if text conditioning is enabled, and if only the CLIP model is loaded, it asserts that neither the text embeddings nor image embeddings URL should be provided.",
+ "type": "comment"
+ },
+ "546": {
+ "file_id": 15,
+ "content": "xt_embeddings, \"Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason.\"\n return m",
+ "type": "code",
+ "location": "/dalle2_pytorch/train_configs.py:380-382"
+ },
+ "547": {
+ "file_id": 15,
+ "content": "This code snippet indicates that text embeddings are being loaded but are not necessary for the task, causing unnecessary slowdown in the dataloader. It is recommended to remove this step for efficiency.",
+ "type": "comment"
+ },
+ "548": {
+ "file_id": 16,
+ "content": "/dalle2_pytorch/trainer.py",
+ "type": "filepath"
+ },
+ "549": {
+ "file_id": 16,
+ "content": "The code initializes DeepSpeed's trainer, sets model parameters, distributes the model, and handles precision. It also initializes optimizers and schedulers, prepares dataloaders, validates compatibility, performs computations, and returns total loss.",
+ "type": "summary"
+ },
+ "550": {
+ "file_id": 16,
+ "content": "import time\nimport copy\nfrom pathlib import Path\nfrom math import ceil\nfrom functools import partial, wraps\nfrom contextlib import nullcontext\nfrom collections.abc import Iterable\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR\nfrom torch.cuda.amp import autocast, GradScaler\nfrom dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior\nfrom dalle2_pytorch.optimizer import get_optimizer\nfrom dalle2_pytorch.version import __version__\nfrom packaging import version\nimport pytorch_warmup as warmup\nfrom ema_pytorch import EMA\nfrom accelerate import Accelerator, DistributedType\nimport numpy as np\n# helper functions\ndef exists(val):\n return val is not None\ndef default(val, d):\n if exists(val):\n return val\n return d() if callable(d) else d\ndef cast_tuple(val, length = 1):\n return val if isinstance(val, tuple) else ((val,) * length)\ndef pick_and_pop(keys, d):\n values = list(map(lambda key: d.pop(key), keys))\n return dict(zip(keys, values))",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:1-43"
+ },
+ "551": {
+ "file_id": 16,
+ "content": "The code imports various libraries and defines several utility functions for working with tensors, learning rates, optimizers, and distributed training. It also includes helper functions to handle default values and handle dictionaries. These utilities are likely used throughout the codebase to train and evaluate models efficiently.",
+ "type": "comment"
+ },
+ "552": {
+ "file_id": 16,
+ "content": "def group_dict_by_key(cond, d):\n return_val = [dict(),dict()]\n for key in d.keys():\n match = bool(cond(key))\n ind = int(not match)\n return_val[ind][key] = d[key]\n return (*return_val,)\ndef string_begins_with(prefix, str):\n return str.startswith(prefix)\ndef group_by_key_prefix(prefix, d):\n return group_dict_by_key(partial(string_begins_with, prefix), d)\ndef groupby_prefix_and_trim(prefix, d):\n kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)\n kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))\n return kwargs_without_prefix, kwargs\ndef num_to_groups(num, divisor):\n groups = num // divisor\n remainder = num % divisor\n arr = [divisor] * groups\n if remainder > 0:\n arr.append(remainder)\n return arr\n# decorators\ndef cast_torch_tensor(fn):\n @wraps(fn)\n def inner(model, *args, **kwargs):\n device = kwargs.pop('_device', next(model.parameters()).device)\n cast_device = kwargs.pop('_cast_device', True)",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:45-78"
+ },
+ "553": {
+ "file_id": 16,
+ "content": "group_dict_by_key: Creates two dictionaries, one for keys that match the condition and another for those that do not, grouping by key.\nstring_begins_with: Returns a boolean value indicating whether a given string starts with a specified prefix.\ngroup_by_key_prefix: Groups dictionary items based on whether their keys start with a certain prefix.\ngroupby_prefix_and_trim: Similar to group_by_key_prefix, but also trims the common prefix from the keys and returns two dictionaries.\nnum_to_groups: Divides a given number into groups based on a specified divisor, appending any remainder to the last group.\ncast_torch_tensor: A decorator that wraps a function to cast its input and output tensors to specific devices.",
+ "type": "comment"
+ },
+ "554": {
+ "file_id": 16,
+ "content": " cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)\n kwargs_keys = kwargs.keys()\n all_args = (*args, *kwargs.values())\n split_kwargs_index = len(all_args) - len(kwargs_keys)\n all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))\n if cast_device:\n all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))\n if cast_deepspeed_precision:\n try:\n accelerator = model.accelerator\n if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:\n cast_type_map = {\n \"fp16\": torch.half,\n \"bf16\": torch.bfloat16,\n \"no\": torch.float\n }\n precision_type = cast_type_map[accelerator.mixed_precision]\n all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:79-99"
+ },
+ "555": {
+ "file_id": 16,
+ "content": "This code handles argument casting and device assignment for a DeepSpeed-accelerated PyTorch model. It first checks if arguments are DeepSpeed precision types, then casts the tensors to the appropriate type if necessary. This ensures that the model's arguments are correctly prepared for training or evaluation within a DeepSpeed framework.",
+ "type": "comment"
+ },
+ "556": {
+ "file_id": 16,
+ "content": " except AttributeError:\n # Then this model doesn't have an accelerator\n pass\n args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]\n kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))\n out = fn(model, *args, **kwargs)\n return out\n return inner\n# gradient accumulation functions\ndef split_iterable(it, split_size):\n accum = []\n for ind in range(ceil(len(it) / split_size)):\n start_index = ind * split_size\n accum.append(it[start_index: (start_index + split_size)])\n return accum\ndef split(t, split_size = None):\n if not exists(split_size):\n return t\n if isinstance(t, torch.Tensor):\n return t.split(split_size, dim = 0)\n if isinstance(t, Iterable):\n return split_iterable(t, split_size)\n return TypeError\ndef find_first(cond, arr):\n for el in arr:\n if cond(el):\n return el\n return None\ndef split_args_and_kwargs(*args, split_size = None, **kwargs):",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:100-138"
+ },
+ "557": {
+ "file_id": 16,
+ "content": "This code defines functions for splitting arguments and keywords, as well as handling gradient accumulation. It includes a function to split an iterable into chunks of specified size (`split_iterable`), a `split` function for tensors and iterables, and a `find_first` function to find the first item in an array that meets a given condition. The last function defined is `split_args_and_kwargs`, which splits arguments and keywords based on a specified size.",
+ "type": "comment"
+ },
+ "558": {
+ "file_id": 16,
+ "content": " all_args = (*args, *kwargs.values())\n len_all_args = len(all_args)\n first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)\n assert exists(first_tensor)\n batch_size = len(first_tensor)\n split_size = default(split_size, batch_size)\n num_chunks = ceil(batch_size / split_size)\n dict_len = len(kwargs)\n dict_keys = kwargs.keys()\n split_kwargs_index = len_all_args - dict_len\n split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]\n chunk_sizes = tuple(map(len, split_all_args[0]))\n for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):\n chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]\n chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))\n chunk_size_frac = chunk_size / batch_size\n yield chunk_size_frac, (chunked_args, chunked_kwargs)",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:139-159"
+ },
+ "559": {
+ "file_id": 16,
+ "content": "This code splits the input arguments and keyword arguments into chunks based on batch size, split size, and dictionary keys. It then yields the chunk size fraction and the split chunked arguments and keyword arguments for further processing.",
+ "type": "comment"
+ },
+ "560": {
+ "file_id": 16,
+ "content": "# diffusion prior trainer\ndef prior_sample_in_chunks(fn):\n @wraps(fn)\n def inner(self, *args, max_batch_size = None, **kwargs):\n if not exists(max_batch_size):\n return fn(self, *args, **kwargs)\n outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]\n return torch.cat(outputs, dim = 0)\n return inner\nclass DiffusionPriorTrainer(nn.Module):\n def __init__(\n self,\n diffusion_prior,\n accelerator = None,\n use_ema = True,\n lr = 3e-4,\n wd = 1e-2,\n eps = 1e-6,\n max_grad_norm = None,\n group_wd_params = True,\n warmup_steps = None,\n cosine_decay_max_steps = None,\n **kwargs\n ):\n super().__init__()\n assert isinstance(diffusion_prior, DiffusionPrior)\n ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)\n accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:161-192"
+ },
+ "561": {
+ "file_id": 16,
+ "content": "This code defines a `DiffusionPriorTrainer` class that takes in a `diffusion_prior`, and allows for training with different batch sizes by splitting arguments and keywords into chunks. It also supports optional accelerator, learning rate, weight decay, epsilon, max gradient norm, grouped weight decay parameters, warmup steps, and cosine decay maximum steps.",
+ "type": "comment"
+ },
+ "562": {
+ "file_id": 16,
+ "content": " if not exists(accelerator):\n accelerator = Accelerator(**accelerator_kwargs)\n # assign some helpful member vars\n self.accelerator = accelerator\n self.text_conditioned = diffusion_prior.condition_on_text_encodings\n # setting the device\n self.device = accelerator.device\n diffusion_prior.to(self.device)\n # save model\n self.diffusion_prior = diffusion_prior\n # mixed precision checks\n if (\n exists(self.accelerator) \n and self.accelerator.distributed_type == DistributedType.DEEPSPEED \n and self.diffusion_prior.clip is not None\n ):\n # Then we need to make sure clip is using the correct precision or else deepspeed will error\n cast_type_map = {\n \"fp16\": torch.half,\n \"bf16\": torch.bfloat16,\n \"no\": torch.float\n }\n precision_type = cast_type_map[accelerator.mixed_precision]\n assert precision",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:194-225"
+ },
+ "563": {
+ "file_id": 16,
+ "content": "Checking if an accelerator is specified, assigning member variables for helpful operations, setting device and transferring model to that device, saving the diffusion prior model, and checking mixed precision settings if applicable.",
+ "type": "comment"
+ },
+ "564": {
+ "file_id": 16,
+ "content": "_type == torch.float, \"DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip\"\n self.diffusion_prior.clip.to(precision_type)\n # optimizer stuff\n self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)\n self.optimizer = get_optimizer(\n self.diffusion_prior.parameters(),\n **self.optim_kwargs,\n **kwargs\n )\n if exists(cosine_decay_max_steps):\n self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)\n else:\n self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)\n self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None\n # distribute the model if using HFA\n self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler)\n # exponential moving average stuff",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:225-249"
+ },
+ "565": {
+ "file_id": 16,
+ "content": "This code initializes the trainer for DeepSpeed, setting precision, optimizer, and scheduler. It checks if on-the-fly embedding generation from CLIP is supported and changes precision accordingly. It also distributes the model using HFA and applies exponential moving average techniques.",
+ "type": "comment"
+ },
+ "566": {
+ "file_id": 16,
+ "content": " self.use_ema = use_ema\n if self.use_ema:\n self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs)\n # gradient clipping if needed\n self.max_grad_norm = max_grad_norm\n # track steps internally\n self.register_buffer('step', torch.tensor([0], device = self.device))\n # utility\n def save(self, path, overwrite = True, **kwargs):\n # only save on the main process\n if self.accelerator.is_main_process:\n print(f\"Saving checkpoint at step: {self.step.item()}\")\n path = Path(path)\n assert not (path.exists() and not overwrite)\n path.parent.mkdir(parents = True, exist_ok = True)\n # FIXME: LambdaLR can't be saved due to pickling issues\n save_obj = dict(\n optimizer = self.optimizer.state_dict(),\n scheduler = self.scheduler.state_dict(),\n warmup_scheduler = self.warmup_scheduler,\n model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:251-280"
+ },
+ "567": {
+ "file_id": 16,
+ "content": "The code snippet initializes a trainer object with an option for exponential moving average (EMA), gradient clipping, and tracks steps internally. It also defines a save method to save the optimizer, scheduler, model state dictionaries, and warmup scheduler on the main process. Note that LambdaLR cannot be saved due to pickling issues.",
+ "type": "comment"
+ },
+ "568": {
+ "file_id": 16,
+ "content": " version = version.parse(__version__),\n step = self.step,\n **kwargs\n )\n if self.use_ema:\n save_obj = {\n **save_obj,\n 'ema': self.ema_diffusion_prior.state_dict(),\n 'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload\n }\n torch.save(save_obj, str(path))\n def load(self, path_or_state, overwrite_lr = True, strict = True):\n \"\"\"\n Load a checkpoint of a diffusion prior trainer.\n Will load the entire trainer, including the optimizer and EMA.\n Params:\n - path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file\n - overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer\n - strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:281-304"
+ },
+ "569": {
+ "file_id": 16,
+ "content": "This code saves and loads a checkpoint for a diffusion prior trainer. It also handles saving the EMA (Exponential Moving Average) model separately for easy ema-only reload, and allows overwriting the learning rate if needed. The `load` method loads an entire trainer, including its optimizer and EMA.",
+ "type": "comment"
+ },
+ "570": {
+ "file_id": 16,
+ "content": " Returns:\n loaded_obj (dict): The loaded checkpoint dictionary\n \"\"\"\n # all processes need to load checkpoint. no restriction here\n if isinstance(path_or_state, str):\n path = Path(path_or_state)\n assert path.exists()\n loaded_obj = torch.load(str(path), map_location=self.device)\n elif isinstance(path_or_state, dict):\n loaded_obj = path_or_state\n if version.parse(__version__) != loaded_obj['version']:\n print(f'loading saved diffusion prior at version {loaded_obj[\"version\"]} but current package version is at {__version__}')\n # unwrap the model when loading from checkpoint\n self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)\n self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))\n self.optimizer.load_state_dict(loaded_obj['optimizer'])\n self.scheduler.load_state_dict(loaded_obj['scheduler'])",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:306-327"
+ },
+ "571": {
+ "file_id": 16,
+ "content": "This function loads a checkpoint from a specified path or dictionary, handling both string paths and existing dictionaries. It checks if the loaded version matches the current package version, then unwraps and loads the model's state dict, sets step values, and loads optimizer and scheduler states as well.",
+ "type": "comment"
+ },
+ "572": {
+ "file_id": 16,
+ "content": " # set warmupstep\n if exists(self.warmup_scheduler):\n self.warmup_scheduler.last_step = self.step.item()\n # ensure new lr is used if different from old one\n if overwrite_lr:\n new_lr = self.optim_kwargs[\"lr\"]\n for group in self.optimizer.param_groups:\n group[\"lr\"] = new_lr if group[\"lr\"] > 0.0 else 0.0\n if self.use_ema:\n assert 'ema' in loaded_obj\n self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)\n # below might not be necessary, but I had a suspicion that this wasn't being loaded correctly\n self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj[\"ema_model\"])\n return loaded_obj\n # model functionality\n def update(self):\n if exists(self.max_grad_norm):\n self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)\n self.optimizer.step()\n self.optimizer.zero_grad()\n # accelerator will ocassionally skip optimizer steps in a \"dynamic loss scaling strategy\"",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:329-358"
+ },
+ "573": {
+ "file_id": 16,
+ "content": "This function handles the warmup step, updating the learning rate if needed, loading EMA diffusion prior state from a checkpoint, and performing model update with optimization.",
+ "type": "comment"
+ },
+ "574": {
+ "file_id": 16,
+ "content": " if not self.accelerator.optimizer_step_was_skipped:\n sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext\n with sched_context():\n self.scheduler.step()\n if self.use_ema:\n self.ema_diffusion_prior.update()\n self.step += 1\n @torch.no_grad()\n @cast_torch_tensor\n @prior_sample_in_chunks\n def p_sample_loop(self, *args, **kwargs):\n model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior\n return model.p_sample_loop(*args, **kwargs)\n @torch.no_grad()\n @cast_torch_tensor\n @prior_sample_in_chunks\n def sample(self, *args, **kwargs):\n model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior\n return model.sample(*args, **kwargs)\n @torch.no_grad()\n def sample_batch_size(self, *args, **kwargs):\n model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior\n return model.sample_batch_size(*args, **kwargs)",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:359-386"
+ },
+ "575": {
+ "file_id": 16,
+ "content": "The code defines several methods for using the diffusion prior model to generate samples. It uses exponential moving average (EMA) for model averaging, if `use_ema` is enabled. The `p_sample_loop`, `sample`, and `sample_batch_size` methods use `torch.no_grad()` for performance optimization, and `cast_torch_tensor` and `prior_sample_in_chunks` decorators are used to process data in chunks.",
+ "type": "comment"
+ },
+ "576": {
+ "file_id": 16,
+ "content": " @torch.no_grad()\n @cast_torch_tensor\n @prior_sample_in_chunks\n def embed_text(self, *args, **kwargs):\n return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)\n @cast_torch_tensor\n def forward(\n self,\n *args,\n max_batch_size = None,\n **kwargs\n ):\n total_loss = 0.\n for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):\n with self.accelerator.autocast():\n loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)\n loss = loss * chunk_size_frac\n total_loss += loss.item()\n if self.training:\n self.accelerator.backward(loss)\n return total_loss\n# decoder trainer\ndef decoder_sample_in_chunks(fn):\n @wraps(fn)\n def inner(self, *args, max_batch_size = None, **kwargs):\n if not exists(max_batch_size):\n return fn(self, *args, **kwargs)\n if self.decoder.unconditional:",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:388-423"
+ },
+ "577": {
+ "file_id": 16,
+ "content": "This code defines a trainer with a function `embed_text` that uses the unwrapped model for embedding text, and a `forward` method that performs forward pass in chunks to handle large batch sizes. The `decoder_sample_in_chunks` decorator enables chunking when sample decoding.",
+ "type": "comment"
+ },
+ "578": {
+ "file_id": 16,
+ "content": " batch_size = kwargs.get('batch_size')\n batch_sizes = num_to_groups(batch_size, max_batch_size)\n outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]\n else:\n outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]\n return torch.cat(outputs, dim = 0)\n return inner\nclass DecoderTrainer(nn.Module):\n def __init__(\n self,\n decoder,\n accelerator = None,\n dataloaders = None,\n use_ema = True,\n lr = 1e-4,\n wd = 1e-2,\n eps = 1e-8,\n warmup_steps = None,\n cosine_decay_max_steps = None,\n max_grad_norm = 0.5,\n amp = False,\n group_wd_params = True,\n **kwargs\n ):\n super().__init__()\n assert isinstance(decoder, Decoder)\n ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)\n self.accelerator = default(accelerator, Accelerator)",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:424-454"
+ },
+ "579": {
+ "file_id": 16,
+ "content": "The function is a trainer that takes a decoder, accelerator, and other parameters. It can handle batching the inputs or splitting arguments and keywords to train the decoder in chunks, depending on the size of input data. The returned inner function is used for training the model using the provided configuration.",
+ "type": "comment"
+ },
+ "580": {
+ "file_id": 16,
+ "content": " self.num_unets = len(decoder.unets)\n self.use_ema = use_ema\n self.ema_unets = nn.ModuleList([])\n self.amp = amp\n # be able to finely customize learning rate, weight decay\n # per unet\n lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))\n assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'\n optimizers = []\n schedulers = []\n warmup_schedulers = []\n for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):\n if isinstance(unet, nn.Identity):\n optimizers.append(None)\n schedulers.append(None)\n warmup_schedulers.append(None)\n else:\n optimizer = get_optimizer(\n unet.parameters(),",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:456-481"
+ },
+ "581": {
+ "file_id": 16,
+ "content": "The code initializes the trainer with specific configurations for each UNET in the decoder. It checks learning rate, weight decay, warmup steps, and cosine decay max steps for each UNET. If a UNET is an identity, it assigns no optimizer or scheduler. Otherwise, it gets an appropriate optimizer for the UNET's parameters.",
+ "type": "comment"
+ },
+ "582": {
+ "file_id": 16,
+ "content": " lr = unet_lr,\n wd = unet_wd,\n eps = unet_eps,\n group_wd_params = group_wd_params,\n **kwargs\n )\n optimizers.append(optimizer)\n if exists(unet_cosine_decay_max_steps):\n scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)\n else:\n scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)\n warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None\n warmup_schedulers.append(warmup_scheduler)\n schedulers.append(scheduler)\n if self.use_ema:\n self.ema_unets.append(EMA(unet, **ema_kwargs))\n # gradient clipping if needed\n self.max_grad_norm = max_grad_norm\n self.register_buffer('steps', torch.tensor([0] * self.num_unets))\n if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:482-510"
+ },
+ "583": {
+ "file_id": 16,
+ "content": "The code initializes optimizers, optionally schedulers for learning rate adjustments, and an exponential moving average (EMA) for the UNETs. It also registers a buffer for tracking steps and handles gradient clipping if needed based on distributed type.",
+ "type": "comment"
+ },
+ "584": {
+ "file_id": 16,
+ "content": " # Then we need to make sure clip is using the correct precision or else deepspeed will error\n cast_type_map = {\n \"fp16\": torch.half,\n \"bf16\": torch.bfloat16,\n \"no\": torch.float\n }\n precision_type = cast_type_map[accelerator.mixed_precision]\n assert precision_type == torch.float, \"DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip\"\n clip = decoder.clip\n clip.to(precision_type)\n decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))\n self.decoder = decoder\n # prepare dataloaders\n train_loader = val_loader = None\n if exists(dataloaders):\n train_loader, val_loader = self.accelerator.prepare(dataloaders[\"train\"], dataloaders[\"val\"])\n self.train_loader = train_loader\n self.val_loader = val_loader\n # store optimizers\n for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:511-537"
+ },
+ "585": {
+ "file_id": 16,
+ "content": "This code ensures that the correct precision is used by DeepSpeed and prepares the decoder, optimizers, and dataloaders for training. It converts the clip to the specified precision type, then prepares them using DeepSpeed's accelerator. The train_loader and val_loader are stored for later use.",
+ "type": "comment"
+ },
+ "586": {
+ "file_id": 16,
+ "content": " setattr(self, f'optim{opt_ind}', optimizer)\n # store schedulers\n for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):\n setattr(self, f'sched{sched_ind}', scheduler)\n # store warmup schedulers\n self.warmup_schedulers = warmup_schedulers\n def validate_and_return_unet_number(self, unet_number = None):\n if self.num_unets == 1:\n unet_number = default(unet_number, 1)\n assert exists(unet_number) and 1 <= unet_number <= self.num_unets\n return unet_number\n def num_steps_taken(self, unet_number = None):\n unet_number = self.validate_and_return_unet_number(unet_number)\n return self.steps[unet_number - 1].item()\n def save(self, path, overwrite = True, **kwargs):\n path = Path(path)\n assert not (path.exists() and not overwrite)\n path.parent.mkdir(parents = True, exist_ok = True)\n save_obj = dict(\n model = self.accelerator.unwrap_model(self.decoder).state_dict(),",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:538-566"
+ },
+ "587": {
+ "file_id": 16,
+ "content": "This code defines a class with optimizers, schedulers, and warmup schedulers. It also validates the unet number and returns the number of steps taken by a specific unet. The save function saves the model's state dict to a specified path.",
+ "type": "comment"
+ },
+ "588": {
+ "file_id": 16,
+ "content": " version = __version__,\n steps = self.steps.cpu(),\n **kwargs\n )\n for ind in range(0, self.num_unets):\n optimizer_key = f'optim{ind}'\n scheduler_key = f'sched{ind}'\n optimizer = getattr(self, optimizer_key)\n scheduler = getattr(self, scheduler_key)\n optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None\n scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None\n save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}\n if self.use_ema:\n save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}\n self.accelerator.save(save_obj, str(path))\n def load_state_dict(self, loaded_obj, only_model = False, strict = True):\n if version.parse(__version__) != version.parse(loaded_obj['version']):\n self.accelerator.print(f'loading saved decoder at version {loaded_obj[\"version\"]}, but current package version is {__version__}')",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:567-591"
+ },
+ "589": {
+ "file_id": 16,
+ "content": "This code snippet saves the model state, optimizer state, and scheduler state if they exist, and an optional Exponential Moving Average (EMA) state. It checks the version compatibility before loading the saved state dictionary.",
+ "type": "comment"
+ },
+ "590": {
+ "file_id": 16,
+ "content": " self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)\n self.steps.copy_(loaded_obj['steps'])\n if only_model:\n return loaded_obj\n for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):\n optimizer_key = f'optim{ind}'\n optimizer = getattr(self, optimizer_key)\n scheduler_key = f'sched{ind}'\n scheduler = getattr(self, scheduler_key)\n warmup_scheduler = self.warmup_schedulers[ind]\n if exists(optimizer):\n optimizer.load_state_dict(loaded_obj[optimizer_key])\n if exists(scheduler):\n scheduler.load_state_dict(loaded_obj[scheduler_key])\n if exists(warmup_scheduler):\n warmup_scheduler.last_step = last_step\n if self.use_ema:\n assert 'ema' in loaded_obj\n self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)\n def load(self, path, only_model = False, strict = True):",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:593-622"
+ },
+ "591": {
+ "file_id": 16,
+ "content": "This code loads a model and its associated optimizers, schedulers, and warmup schedulers from the given path. It also checks if early-stopping (ema) was used and loads that as well. The function returns the loaded state of each component if only_model is True, otherwise it continues with training.",
+ "type": "comment"
+ },
+ "592": {
+ "file_id": 16,
+ "content": " path = Path(path)\n assert path.exists()\n loaded_obj = torch.load(str(path), map_location = 'cpu')\n self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)\n return loaded_obj\n @property\n def unets(self):\n return nn.ModuleList([ema.ema_model for ema in self.ema_unets])\n def increment_step(self, unet_number):\n assert 1 <= unet_number <= self.num_unets\n unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)\n self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))\n def update(self, unet_number = None):\n unet_number = self.validate_and_return_unet_number(unet_number)\n index = unet_number - 1\n optimizer = getattr(self, f'optim{index}')\n scheduler = getattr(self, f'sched{index}')\n if exists(self.max_grad_norm):\n self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients\n optimizer.step()",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:623-652"
+ },
+ "593": {
+ "file_id": 16,
+ "content": "This function loads a saved state and returns it. It also provides access to the unets (U-Nets) in the model and allows incrementing the step of a specific unet. The update method updates the optimizer and scheduler for a specified unet.",
+ "type": "comment"
+ },
+ "594": {
+ "file_id": 16,
+ "content": " optimizer.zero_grad()\n warmup_scheduler = self.warmup_schedulers[index]\n scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext\n with scheduler_context():\n scheduler.step()\n if self.use_ema:\n ema_unet = self.ema_unets[index]\n ema_unet.update()\n self.increment_step(unet_number)\n @torch.no_grad()\n @cast_torch_tensor\n @decoder_sample_in_chunks\n def sample(self, *args, **kwargs):\n distributed = self.accelerator.num_processes > 1\n base_decoder = self.accelerator.unwrap_model(self.decoder)\n was_training = base_decoder.training\n base_decoder.eval()\n if kwargs.pop('use_non_ema', False) or not self.use_ema:\n out = base_decoder.sample(*args, **kwargs, distributed = distributed)\n base_decoder.train(was_training)\n return out\n trainable_unets = self.accelerator.unwrap_model(self.decoder).unets\n base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:653-683"
+ },
+ "595": {
+ "file_id": 16,
+ "content": "This code is responsible for the sampling process in a specific model. It uses gradient descent to optimize the model and updates the exponential moving average (EMA) unets if ema is enabled. The sample function enables evaluation mode, handles non-ema usage or disabled use_ema, and returns the output based on the input arguments. The distributed argument is used for multi-process sampling.",
+ "type": "comment"
+ },
+ "596": {
+ "file_id": 16,
+ "content": " output = base_decoder.sample(*args, **kwargs, distributed = distributed)\n base_decoder.unets = trainable_unets # restore original training unets\n # cast the ema_model unets back to original device\n for ema in self.ema_unets:\n ema.restore_ema_model_device()\n base_decoder.train(was_training)\n return output\n @torch.no_grad()\n @cast_torch_tensor\n @prior_sample_in_chunks\n def embed_text(self, *args, **kwargs):\n return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)\n @torch.no_grad()\n @cast_torch_tensor\n @prior_sample_in_chunks\n def embed_image(self, *args, **kwargs):\n return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)\n @cast_torch_tensor\n def forward(\n self,\n *args,\n unet_number = None,\n max_batch_size = None,\n return_lowres_cond_image=False,\n **kwargs\n ):\n unet_number = self.validate_and_return_unet_number(unet_number)",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:685-717"
+ },
+ "597": {
+ "file_id": 16,
+ "content": "This code defines a function for embedding text and image using the decoder's CLIP module. It also restores the original training unets, casts torch tensors, validates and returns the correct unet number, and allows for conditional lowres image return.",
+ "type": "comment"
+ },
+ "598": {
+ "file_id": 16,
+ "content": " total_loss = 0.\n cond_images = []\n for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):\n with self.accelerator.autocast():\n loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs)\n # loss_obj may be a tuple with loss and cond_image\n if return_lowres_cond_image:\n loss, cond_image = loss_obj\n else:\n loss = loss_obj\n cond_image = None\n loss = loss * chunk_size_frac\n if cond_image is not None:\n cond_images.append(cond_image)\n total_loss += loss.item()\n if self.training:\n self.accelerator.backward(loss)\n if return_lowres_cond_image:\n return total_loss, torch.stack(cond_images)\n else:\n return total_loss",
+ "type": "code",
+ "location": "/dalle2_pytorch/trainer.py:719-742"
+ },
+ "599": {
+ "file_id": 16,
+ "content": "This code chunk splits the input arguments and keywords into multiple smaller chunks, then iterates over them to perform computations with auto-cast enabled. The resulting losses are accumulated, and if conditional images are returned, they are stacked together. Finally, the total loss is returned.",
+ "type": "comment"
+ }
+}
\ No newline at end of file
diff --git a/docs/data/6.json b/docs/data/6.json
new file mode 100644
index 00000000..96e2586b
--- /dev/null
+++ b/docs/data/6.json
@@ -0,0 +1,547 @@
+{
+ "600": {
+ "file_id": 17,
+ "content": "/dalle2_pytorch/utils.py",
+ "type": "filepath"
+ },
+ "601": {
+ "file_id": 17,
+ "content": "This code snippet includes helper functions for time, print, and import operations. It defines a Timer class for measuring elapsed time, a print_ribbon function to format print statements with a banner, and an import_or_print_error function to handle module imports, displaying an error message if necessary and exiting the program.",
+ "type": "summary"
+ },
+ "602": {
+ "file_id": 17,
+ "content": "import time\nimport importlib\n# helper functions\ndef exists(val):\n return val is not None\n# time helpers\nclass Timer:\n def __init__(self):\n self.reset()\n def reset(self):\n self.last_time = time.time()\n def elapsed(self):\n return time.time() - self.last_time\n# print helpers\ndef print_ribbon(s, symbol = '=', repeat = 40):\n flank = symbol * repeat\n return f'{flank} {s} {flank}'\n# import helpers\ndef import_or_print_error(pkg_name, err_str = None):\n try:\n return importlib.import_module(pkg_name)\n except ModuleNotFoundError as e:\n if exists(err_str):\n print(err_str)\n exit()",
+ "type": "code",
+ "location": "/dalle2_pytorch/utils.py:1-35"
+ },
+ "603": {
+ "file_id": 17,
+ "content": "This code snippet includes helper functions for time, print, and import operations. It defines a Timer class for measuring elapsed time, a print_ribbon function to format print statements with a banner, and an import_or_print_error function to handle module imports, displaying an error message if necessary and exiting the program.",
+ "type": "comment"
+ },
+ "604": {
+ "file_id": 18,
+ "content": "/dalle2_pytorch/version.py",
+ "type": "filepath"
+ },
+ "605": {
+ "file_id": 18,
+ "content": "This code defines the version number of the DALLE2-pytorch library, currently set as '1.15.6'.",
+ "type": "summary"
+ },
+ "606": {
+ "file_id": 18,
+ "content": "__version__ = '1.15.6'",
+ "type": "code",
+ "location": "/dalle2_pytorch/version.py:1-1"
+ },
+ "607": {
+ "file_id": 18,
+ "content": "This code defines the version number of the DALLE2-pytorch library, currently set as '1.15.6'.",
+ "type": "comment"
+ },
+ "608": {
+ "file_id": 19,
+ "content": "/dalle2_pytorch/vqgan_vae.py",
+ "type": "filepath"
+ },
+ "609": {
+ "file_id": 19,
+ "content": "Code describes VQGAN-VAE and Vision Transformer architectures for image generation models, including convolutional layers, self-attention mechanisms, layer normalization, initializes model, calculates losses, determines adaptive weight, applies clamp function, calculates combined loss, returns reconstructed feature maps if required.",
+ "type": "summary"
+ },
+ "610": {
+ "file_id": 19,
+ "content": "import copy\nimport math\nfrom math import sqrt\nfrom functools import partial, wraps\nfrom vector_quantize_pytorch import VectorQuantize as VQ\nimport torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom torch.autograd import grad as torch_grad\nimport torchvision\nfrom einops import rearrange, reduce, repeat, pack, unpack\nfrom einops.layers.torch import Rearrange\n# constants\nMList = nn.ModuleList\n# helper functions\ndef exists(val):\n return val is not None\ndef default(val, d):\n return val if exists(val) else d\n# decorators\ndef eval_decorator(fn):\n def inner(model, *args, **kwargs):\n was_training = model.training\n model.eval()\n out = fn(model, *args, **kwargs)\n model.train(was_training)\n return out\n return inner\ndef remove_vgg(fn):\n @wraps(fn)\n def inner(self, *args, **kwargs):\n has_vgg = hasattr(self, 'vgg')\n if has_vgg:\n vgg = self.vgg\n delattr(self, 'vgg')\n out = fn(self, *args, **kwargs)\n if has_vgg:\n self.vgg = vgg",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:1-51"
+ },
+ "611": {
+ "file_id": 19,
+ "content": "This code imports various libraries and defines several constants, helper functions, and decorators for use in a deep learning model. It also sets up a class for a Vector Quantize module using PyTorch, with functionality to evaluate the model and remove the VGG feature if present.",
+ "type": "comment"
+ },
+ "612": {
+ "file_id": 19,
+ "content": " return out\n return inner\n# keyword argument helpers\ndef pick_and_pop(keys, d):\n values = list(map(lambda key: d.pop(key), keys))\n return dict(zip(keys, values))\ndef group_dict_by_key(cond, d):\n return_val = [dict(),dict()]\n for key in d.keys():\n match = bool(cond(key))\n ind = int(not match)\n return_val[ind][key] = d[key]\n return (*return_val,)\ndef string_begins_with(prefix, string_input):\n return string_input.startswith(prefix)\ndef group_by_key_prefix(prefix, d):\n return group_dict_by_key(partial(string_begins_with, prefix), d)\ndef groupby_prefix_and_trim(prefix, d):\n kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)\n kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))\n return kwargs_without_prefix, kwargs\n# tensor helper functions\ndef log(t, eps = 1e-10):\n return torch.log(t + eps)\ndef gradient_penalty(images, output, weight = 10):\n batch_size = images.shape[0]",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:53-87"
+ },
+ "613": {
+ "file_id": 19,
+ "content": "This code contains various utility functions. \"pick_and_pop\" removes and returns keys from a dictionary, \"group_dict_by_key\" groups dictionary items by key condition, \"string_begins_with\" checks if a string begins with a given prefix, \"group_by_key_prefix\" groups dictionary items based on a key prefix, and \"groupby_prefix_and_trim\" trims key prefixes before grouping. Lastly, the \"log\" function calculates the natural logarithm of an input tensor, and the \"gradient_penalty\" function is used to calculate a gradient penalty for image generation tasks.",
+ "type": "comment"
+ },
+ "614": {
+ "file_id": 19,
+ "content": " gradients = torch_grad(outputs = output, inputs = images,\n grad_outputs = torch.ones(output.size(), device = images.device),\n create_graph = True, retain_graph = True, only_inputs = True)[0]\n gradients = rearrange(gradients, 'b ... -> b (...)')\n return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()\ndef l2norm(t):\n return F.normalize(t, dim = -1)\ndef leaky_relu(p = 0.1):\n return nn.LeakyReLU(0.1)\ndef stable_softmax(t, dim = -1, alpha = 32 ** 2):\n t = t / alpha\n t = t - torch.amax(t, dim = dim, keepdim = True).detach()\n return (t * alpha).softmax(dim = dim)\ndef safe_div(numer, denom, eps = 1e-8):\n return numer / (denom + eps)\n# gan losses\ndef hinge_discr_loss(fake, real):\n return (F.relu(1 + fake) + F.relu(1 - real)).mean()\ndef hinge_gen_loss(fake):\n return -fake.mean()\ndef bce_discr_loss(fake, real):\n return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()\ndef bce_gen_loss(fake):\n return -log(torch.sigmoid(fake)).mean()",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:88-121"
+ },
+ "615": {
+ "file_id": 19,
+ "content": "This code contains several utility functions and loss functions used in the VQ-VAE-GAN model. It includes functions for gradient calculations, normalization, activation functions, and various GAN losses. The functions are defined to be reusable throughout the codebase.",
+ "type": "comment"
+ },
+ "616": {
+ "file_id": 19,
+ "content": "def grad_layer_wrt_loss(loss, layer):\n return torch_grad(\n outputs = loss,\n inputs = layer,\n grad_outputs = torch.ones_like(loss),\n retain_graph = True\n )[0].detach()\n# vqgan vae\nclass LayerNormChan(nn.Module):\n def __init__(\n self,\n dim,\n eps = 1e-5\n ):\n super().__init__()\n self.eps = eps\n self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))\n def forward(self, x):\n var = torch.var(x, dim = 1, unbiased = False, keepdim = True)\n mean = torch.mean(x, dim = 1, keepdim = True)\n return (x - mean) / (var + self.eps).sqrt() * self.gamma\n# discriminator\nclass Discriminator(nn.Module):\n def __init__(\n self,\n dims,\n channels = 3,\n groups = 16,\n init_kernel_size = 5\n ):\n super().__init__()\n dim_pairs = zip(dims[:-1], dims[1:])\n self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])\n for dim_in, dim_out in dim_pairs:",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:123-163"
+ },
+ "617": {
+ "file_id": 19,
+ "content": "The code defines a function to compute gradients of a layer wrt the loss, and introduces two custom modules: LayerNormChan for layer normalization and Discriminator for a convolutional network. The discriminator consists of multiple layers with decreasing kernel sizes, each followed by a leaky ReLU activation function. These components are part of the VQGAN-VAE architecture in DALLE2-pytorch.",
+ "type": "comment"
+ },
+ "618": {
+ "file_id": 19,
+ "content": " self.layers.append(nn.Sequential(\n nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),\n nn.GroupNorm(groups, dim_out),\n leaky_relu()\n ))\n dim = dims[-1]\n self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training\n nn.Conv2d(dim, dim, 1),\n leaky_relu(),\n nn.Conv2d(dim, 1, 4)\n )\n def forward(self, x):\n for net in self.layers:\n x = net(x)\n return self.to_logits(x)\n# positional encoding\nclass ContinuousPositionBias(nn.Module):\n \"\"\" from https://arxiv.org/abs/2111.09883 \"\"\"\n def __init__(self, *, dim, heads, layers = 2):\n super().__init__()\n self.net = MList([])\n self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))\n for _ in range(layers - 1):\n self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))\n self.net.append(nn.Linear(dim, heads))\n self.register_buffer('rel_pos', None, persistent = False)",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:164-197"
+ },
+ "619": {
+ "file_id": 19,
+ "content": "The code defines a VQGAN-VAE model. It uses convolutional layers and group normalization for downsampling the input image, followed by linear layers and leaky ReLU activation functions in a sequential manner to generate logits. The `ContinuousPositionBias` class is used for positional encoding in the model.",
+ "type": "comment"
+ },
+ "620": {
+ "file_id": 19,
+ "content": " def forward(self, x):\n n, device = x.shape[-1], x.device\n fmap_size = int(sqrt(n))\n if not exists(self.rel_pos):\n pos = torch.arange(fmap_size, device = device)\n grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))\n grid = rearrange(grid, 'c i j -> (i j) c')\n rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')\n rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)\n self.register_buffer('rel_pos', rel_pos, persistent = False)\n rel_pos = self.rel_pos.float()\n for layer in self.net:\n rel_pos = layer(rel_pos)\n bias = rearrange(rel_pos, 'i j h -> h i j')\n return x + bias\n# resnet encoder / decoder\nclass ResnetEncDec(nn.Module):\n def __init__(\n self,\n dim,\n *,\n channels = 3,\n layers = 4,\n layer_mults = None,\n num_resnet_blocks = 1,\n resnet_groups = 16,\n first_conv_kernel_size = 5,\n use_attn = True,",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:199-232"
+ },
+ "621": {
+ "file_id": 19,
+ "content": "The code defines a VQ-VAE implementation with a resnet encoder/decoder for image generation. The function calculates relative positional embeddings and applies them to the input, then passes the result through a resnet encoder/decoder network before returning the transformed input. The ResnetEncDec class creates an instance of the resnet encoder/decoder with optional parameters such as dimensions, channels, layers, layer_mults, num_resnet_blocks, resnet_groups, first_conv_kernel_size, and use_attn.",
+ "type": "comment"
+ },
+ "622": {
+ "file_id": 19,
+ "content": " attn_dim_head = 64,\n attn_heads = 8,\n attn_dropout = 0.,\n ):\n super().__init__()\n assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'\n self.layers = layers\n self.encoders = MList([])\n self.decoders = MList([])\n layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))\n assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'\n layer_dims = [dim * mult for mult in layer_mults]\n dims = (dim, *layer_dims)\n self.encoded_dim = dims[-1]\n dim_pairs = zip(dims[:-1], dims[1:])\n append = lambda arr, t: arr.append(t)\n prepend = lambda arr, t: arr.insert(0, t)\n if not isinstance(num_resnet_blocks, tuple):\n num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)\n if not isinstance(use_attn, tuple):\n use_attn = (*((False,) * (layers - 1)), use_attn)",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:233-262"
+ },
+ "623": {
+ "file_id": 19,
+ "content": "This code defines a class with specified parameters for layers, encoders, and decoders. It ensures the dimension is divisible by resnet_groups. The layer multipliers are stored in a list and used to determine the dimensions of each layer. num_resnet_blocks and use_attn are checked to make sure they match the designated number of layers.",
+ "type": "comment"
+ },
+ "624": {
+ "file_id": 19,
+ "content": " assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'\n assert len(use_attn) == layers\n for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):\n append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))\n prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))\n if layer_use_attn:\n prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))\n for _ in range(layer_num_resnet_blocks):\n append(self.encoders, ResBlock(dim_out, groups = resnet_groups))\n prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))\n if layer_use_attn:\n append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:264-279"
+ },
+ "625": {
+ "file_id": 19,
+ "content": "This code creates encoder and decoder blocks for a VQ-VAE model. It asserts that the number of resnet blocks and use_attn match the layers, then iterates over each layer creating convolutional layers, LeakyReLU activation functions, optionally adding attention modules, and repeating a specific number of residual blocks in both encoders and decoders.",
+ "type": "comment"
+ },
+ "626": {
+ "file_id": 19,
+ "content": " prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))\n append(self.decoders, nn.Conv2d(dim, channels, 1))\n def get_encoded_fmap_size(self, image_size):\n return image_size // (2 ** self.layers)\n @property\n def last_dec_layer(self):\n return self.decoders[-1].weight\n def encode(self, x):\n for enc in self.encoders:\n x = enc(x)\n return x\n def decode(self, x):\n for dec in self.decoders:\n x = dec(x)\n return x\nclass GLUResBlock(nn.Module):\n def __init__(self, chan, groups = 16):\n super().__init__()\n self.net = nn.Sequential(\n nn.Conv2d(chan, chan * 2, 3, padding = 1),\n nn.GLU(dim = 1),\n nn.GroupNorm(groups, chan),\n nn.Conv2d(chan, chan * 2, 3, padding = 1),\n nn.GLU(dim = 1),\n nn.GroupNorm(groups, chan),\n nn.Conv2d(chan, chan, 1)\n )\n def forward(self, x):\n return self.net(x) + x",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:281-315"
+ },
+ "627": {
+ "file_id": 19,
+ "content": "The code defines a class for a VQGAN-VAE model. It consists of encoder and decoder blocks, along with a GLUResBlock for the residual connections in the decoder. The encoder and decoder are composed of convolutional layers that reduce and increase image size respectively. The encoded image size is defined as the original image size divided by 2 to the power of the number of layers. The model can encode and decode images using the encoder and decoder blocks, and the last decoder layer's weights can be accessed separately.",
+ "type": "comment"
+ },
+ "628": {
+ "file_id": 19,
+ "content": "class ResBlock(nn.Module):\n def __init__(self, chan, groups = 16):\n super().__init__()\n self.net = nn.Sequential(\n nn.Conv2d(chan, chan, 3, padding = 1),\n nn.GroupNorm(groups, chan),\n leaky_relu(),\n nn.Conv2d(chan, chan, 3, padding = 1),\n nn.GroupNorm(groups, chan),\n leaky_relu(),\n nn.Conv2d(chan, chan, 1)\n )\n def forward(self, x):\n return self.net(x) + x\n# vqgan attention layer\nclass VQGanAttention(nn.Module):\n def __init__(\n self,\n *,\n dim,\n dim_head = 64,\n heads = 8,\n dropout = 0.\n ):\n super().__init__()\n self.heads = heads\n self.scale = dim_head ** -0.5\n inner_dim = heads * dim_head\n self.dropout = nn.Dropout(dropout)\n self.pre_norm = LayerNormChan(dim)\n self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)\n self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)\n self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:317-354"
+ },
+ "629": {
+ "file_id": 19,
+ "content": "This code defines a residual block and a VQGAN attention layer for image processing. The ResBlock consists of multiple 2D convolutions and GroupNorm layers, followed by leaky ReLU activation functions. The VQGANAttention class is responsible for self-attention in the VQGAN model, using continuous position bias and multi-head attention with dropout regularization.",
+ "type": "comment"
+ },
+ "630": {
+ "file_id": 19,
+ "content": " def forward(self, x):\n h = self.heads\n height, width, residual = *x.shape[-2:], x.clone()\n x = self.pre_norm(x)\n q, k, v = self.to_qkv(x).chunk(3, dim = 1)\n q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))\n sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale\n sim = self.cpb(sim)\n attn = stable_softmax(sim, dim = -1)\n attn = self.dropout(attn)\n out = einsum('b h i j, b h c j -> b h c i', attn, v)\n out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)\n out = self.to_out(out)\n return out + residual\n# ViT encoder / decoder\nclass RearrangeImage(nn.Module):\n def forward(self, x):\n n = x.shape[1]\n w = h = int(sqrt(n))\n return rearrange(x, 'b (h w) ... -> b h w ...', h = h, w = w)\nclass Attention(nn.Module):\n def __init__(\n self,\n dim,\n *,\n heads = 8,\n dim_head = 32\n ):\n super().__init__()\n self.norm = nn.LayerNorm(dim)",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:356-396"
+ },
+ "631": {
+ "file_id": 19,
+ "content": "This code defines a class for the Attention module in a ViT (Vision Transformer) model. It performs multi-head attention using key, query, and value tensors, followed by a softmax function to compute attention weights. The output is then passed through a linear layer and layer normalization before being added back to the input with residual connection.",
+ "type": "comment"
+ },
+ "632": {
+ "file_id": 19,
+ "content": " self.heads = heads\n self.scale = dim_head ** -0.5\n inner_dim = dim_head * heads\n self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)\n self.to_out = nn.Linear(inner_dim, dim)\n def forward(self, x):\n h = self.heads\n x = self.norm(x)\n q, k, v = self.to_qkv(x).chunk(3, dim = -1)\n q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))\n q = q * self.scale\n sim = einsum('b h i d, b h j d -> b h i j', q, k)\n sim = sim - sim.amax(dim = -1, keepdim = True).detach()\n attn = sim.softmax(dim = -1)\n out = einsum('b h i j, b h j d -> b h i d', attn, v)\n out = rearrange(out, 'b h n d -> b n (h d)')\n return self.to_out(out)\ndef FeedForward(dim, mult = 4):\n return nn.Sequential(\n nn.LayerNorm(dim),\n nn.Linear(dim, dim * mult, bias = False),\n nn.GELU(),\n nn.Linear(dim * mult, dim, bias = False)\n )\nclass Transformer(nn.Module):\n def __init__(\n self,",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:397-433"
+ },
+ "633": {
+ "file_id": 19,
+ "content": "This code defines a MultiHeadAttention module for a transformer model. It initializes the attention head count and scale, calculates inner dimension based on head count and input dimension. The forward function performs multi-head attention by splitting input into query, key, value tensors, scaling query tensor, computing similarity between query and key, subtracting maximum similarity to avoid zero gradients, performing softmax on attention scores, and finally producing output tensor through weighted sum of value tensors. The FeedForward function defines a feedforward network for the transformer model, consisting of layer normalization, linear layers with GELU activation function.",
+ "type": "comment"
+ },
+ "634": {
+ "file_id": 19,
+ "content": " dim,\n *,\n layers,\n dim_head = 32,\n heads = 8,\n ff_mult = 4\n ):\n super().__init__()\n self.layers = nn.ModuleList([])\n for _ in range(layers):\n self.layers.append(nn.ModuleList([\n Attention(dim = dim, dim_head = dim_head, heads = heads),\n FeedForward(dim = dim, mult = ff_mult)\n ]))\n self.norm = nn.LayerNorm(dim)\n def forward(self, x):\n for attn, ff in self.layers:\n x = attn(x) + x\n x = ff(x) + x\n return self.norm(x)\nclass ViTEncDec(nn.Module):\n def __init__(\n self,\n dim,\n channels = 3,\n layers = 4,\n patch_size = 8,\n dim_head = 32,\n heads = 8,\n ff_mult = 4\n ):\n super().__init__()\n self.encoded_dim = dim\n self.patch_size = patch_size\n input_dim = channels * (patch_size ** 2)\n self.encoder = nn.Sequential(\n Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:434-476"
+ },
+ "635": {
+ "file_id": 19,
+ "content": "The code defines a class for an encoder-decoder architecture, which is part of the Vision Transformer (ViT) model. It utilizes attention and feedforward layers, and includes layer normalization in its forward pass. The encoder section takes input images, reshapes them into patches, and passes them through multiple attention and feedforward layers.",
+ "type": "comment"
+ },
+ "636": {
+ "file_id": 19,
+ "content": " nn.Linear(input_dim, dim),\n Transformer(\n dim = dim,\n dim_head = dim_head,\n heads = heads,\n ff_mult = ff_mult,\n layers = layers\n ),\n RearrangeImage(),\n Rearrange('b h w c -> b c h w')\n )\n self.decoder = nn.Sequential(\n Rearrange('b c h w -> b (h w) c'),\n Transformer(\n dim = dim,\n dim_head = dim_head,\n heads = heads,\n ff_mult = ff_mult,\n layers = layers\n ),\n nn.Sequential(\n nn.Linear(dim, dim * 4, bias = False),\n nn.Tanh(),\n nn.Linear(dim * 4, input_dim, bias = False),\n ),\n RearrangeImage(),\n Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)\n )\n def get_encoded_fmap_size(self, image_size):\n return image_size // self.patch_size\n @property",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:477-510"
+ },
+ "637": {
+ "file_id": 19,
+ "content": "The code defines a VQ-VAE model for image generation, consisting of an encoder and decoder. The encoder processes the input image and outputs a compressed codebook index followed by a positional embedding. The decoder then reconstructs the original image from these inputs using a series of transformers and linear layers. The get_encoded_fmap_size function calculates the encoded feature map size based on the input image size.",
+ "type": "comment"
+ },
+ "638": {
+ "file_id": 19,
+ "content": " def last_dec_layer(self):\n return self.decoder[-3][-1].weight\n def encode(self, x):\n return self.encoder(x)\n def decode(self, x):\n return self.decoder(x)\n# main vqgan-vae classes\nclass NullVQGanVAE(nn.Module):\n def __init__(\n self,\n *,\n channels\n ):\n super().__init__()\n self.encoded_dim = channels\n self.layers = 0\n def get_encoded_fmap_size(self, size):\n return size\n def copy_for_eval(self):\n return self\n def encode(self, x):\n return x\n def decode(self, x):\n return x\nclass VQGanVAE(nn.Module):\n def __init__(\n self,\n *,\n dim,\n image_size,\n channels = 3,\n layers = 4,\n l2_recon_loss = False,\n use_hinge_loss = True,\n vgg = None,\n vq_codebook_dim = 256,\n vq_codebook_size = 512,\n vq_decay = 0.8,\n vq_commitment_weight = 1.,\n vq_kmeans_init = True,\n vq_use_cosine_sim = True,\n use_vgg_and_gan = True,\n vae_type = 'resnet',",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:511-562"
+ },
+ "639": {
+ "file_id": 19,
+ "content": "This code defines two classes: NullVQGanVAE and VQGanVAE. The NullVQGanVAE is a placeholder class without any specific layers or functionality, while the VQGanVAE class represents a variant of the VAE model with optional features like VGG loss, GAN integration, and customizable parameters for codebook dimensions and layers.",
+ "type": "comment"
+ },
+ "640": {
+ "file_id": 19,
+ "content": " discr_layers = 4,\n **kwargs\n ):\n super().__init__()\n vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)\n encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)\n self.image_size = image_size\n self.channels = channels\n self.codebook_size = vq_codebook_size\n if vae_type == 'resnet':\n enc_dec_klass = ResnetEncDec\n elif vae_type == 'vit':\n enc_dec_klass = ViTEncDec\n else:\n raise ValueError(f'{vae_type} not valid')\n self.enc_dec = enc_dec_klass(\n dim = dim,\n channels = channels,\n layers = layers,\n **encdec_kwargs\n )\n self.vq = VQ(\n dim = self.enc_dec.encoded_dim,\n codebook_dim = vq_codebook_dim,\n codebook_size = vq_codebook_size,\n decay = vq_decay,\n commitment_weight = vq_commitment_weight,\n accept_image_fmap = True,\n kmeans_init = vq_kmeans_init,\n use_cosine_sim = vq_use_cosine_sim,",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:563-596"
+ },
+ "641": {
+ "file_id": 19,
+ "content": "This code initializes a VQ-VAE model with given parameters. It uses a specified encoder-decoder network (ResNet or ViT), codebook size, and other VQ-specific options. The VQ module is initialized based on the dimensionality of the encoder-decoder's encoded output, and the codebook size and related options. If an invalid VAE type is given, a ValueError is raised.",
+ "type": "comment"
+ },
+ "642": {
+ "file_id": 19,
+ "content": " **vq_kwargs\n )\n # reconstruction loss\n self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss\n # turn off GAN and perceptual loss if grayscale\n self.vgg = None\n self.discr = None\n self.use_vgg_and_gan = use_vgg_and_gan\n if not use_vgg_and_gan:\n return\n # preceptual loss\n if exists(vgg):\n self.vgg = vgg\n else:\n self.vgg = torchvision.models.vgg16(pretrained = True)\n self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])\n # gan related losses\n layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))\n layer_dims = [dim * mult for mult in layer_mults]\n dims = (dim, *layer_dims)\n self.discr = Discriminator(dims = dims, channels = channels)\n self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss\n self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss\n @property\n def encoded_dim(self):",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:597-633"
+ },
+ "643": {
+ "file_id": 19,
+ "content": "This code defines a VQGAN-VAE model with optional GAN and perceptual loss components. It initializes the VGG model, Discriminator, and sets the reconstruction and generator losses based on provided arguments. The encoded_dim property returns the dimension of the encoded images.",
+ "type": "comment"
+ },
+ "644": {
+ "file_id": 19,
+ "content": " return self.enc_dec.encoded_dim\n def get_encoded_fmap_size(self, image_size):\n return self.enc_dec.get_encoded_fmap_size(image_size)\n def copy_for_eval(self):\n device = next(self.parameters()).device\n vae_copy = copy.deepcopy(self.cpu())\n if vae_copy.use_vgg_and_gan:\n del vae_copy.discr\n del vae_copy.vgg\n vae_copy.eval()\n return vae_copy.to(device)\n @remove_vgg\n def state_dict(self, *args, **kwargs):\n return super().state_dict(*args, **kwargs)\n @remove_vgg\n def load_state_dict(self, *args, **kwargs):\n return super().load_state_dict(*args, **kwargs)\n @property\n def codebook(self):\n return self.vq.codebook\n def encode(self, fmap):\n fmap = self.enc_dec.encode(fmap)\n return fmap\n def decode(self, fmap, return_indices_and_loss = False):\n fmap, indices, commit_loss = self.vq(fmap)\n fmap = self.enc_dec.decode(fmap)\n if not return_indices_and_loss:\n return fmap",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:634-672"
+ },
+ "645": {
+ "file_id": 19,
+ "content": "This code defines a class with methods to get encoded dimensions, calculate encoded frame map size, copy the model for evaluation, save and load state dictionary while removing VGG, encode input frames, and decode encoded frames.",
+ "type": "comment"
+ },
+ "646": {
+ "file_id": 19,
+ "content": " return fmap, indices, commit_loss\n def forward(\n self,\n img,\n return_loss = False,\n return_discr_loss = False,\n return_recons = False,\n add_gradient_penalty = True\n ):\n batch, channels, height, width, device = *img.shape, img.device\n assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'\n assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'\n fmap = self.encode(img)\n fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)\n if not return_loss and not return_discr_loss:\n return fmap\n assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'\n # whether to return discriminator loss\n if return_discr_loss:\n assert exists(self.discr), 'discriminator must exist to train it'",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:674-700"
+ },
+ "647": {
+ "file_id": 19,
+ "content": "This function encodes an input image, decodes it, and can optionally return autoencoder or discriminator losses. It expects the image to have the specified dimensions and number of channels. The code asserts that the image's height, width, and number of channels match the expected values, and that only one type of loss is returned at a time.",
+ "type": "comment"
+ },
+ "648": {
+ "file_id": 19,
+ "content": " fmap.detach_()\n img.requires_grad_()\n fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))\n discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)\n if add_gradient_penalty:\n gp = gradient_penalty(img, img_discr_logits)\n loss = discr_loss + gp\n if return_recons:\n return loss, fmap\n return loss\n # reconstruction loss\n recon_loss = self.recon_loss_fn(fmap, img)\n # early return if training on grayscale\n if not self.use_vgg_and_gan:\n if return_recons:\n return recon_loss, fmap\n return recon_loss\n # perceptual loss\n img_vgg_input = img\n fmap_vgg_input = fmap\n if img.shape[1] == 1:\n # handle grayscale for vgg\n img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))\n img_vgg_feats = self.vgg(img_vgg_input)",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:702-739"
+ },
+ "649": {
+ "file_id": 19,
+ "content": "The code is calculating the reconstruction and perceptual loss for an image generation model. It also includes gradient penalty for the discriminator loss, and optionally returns the reconstructed feature map.",
+ "type": "comment"
+ },
+ "650": {
+ "file_id": 19,
+ "content": " recon_vgg_feats = self.vgg(fmap_vgg_input)\n perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)\n # generator loss\n gen_loss = self.gen_loss(self.discr(fmap))\n # calculate adaptive weight\n last_dec_layer = self.enc_dec.last_dec_layer\n norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)\n norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)\n adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)\n adaptive_weight.clamp_(max = 1e4)\n # combine losses\n loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss\n if return_recons:\n return loss, fmap\n return loss",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae.py:740-764"
+ },
+ "651": {
+ "file_id": 19,
+ "content": "This code calculates a combination of losses, including reconstruction, perceptual, and commitment. The adaptive weight is determined based on the gradients of these losses. A clamp function limits the adaptive weight to prevent extreme values. Finally, the combined loss is calculated and returned. If return_recons is True, fmap is also returned.",
+ "type": "comment"
+ },
+ "652": {
+ "file_id": 20,
+ "content": "/dalle2_pytorch/vqgan_vae_trainer.py",
+ "type": "filepath"
+ },
+ "653": {
+ "file_id": 20,
+ "content": "This code defines ImageDataset and VQGanVAETrainer classes for loading image data and training a VAE model, setting parameters, optimizers, and creating loaders. It trains the model, logs losses, saves models, and tracks progress in a results folder.",
+ "type": "summary"
+ },
+ "654": {
+ "file_id": 20,
+ "content": "from math import sqrt\nimport copy\nfrom random import choice\nfrom pathlib import Path\nfrom shutil import rmtree\nfrom PIL import Image\nimport torch\nfrom torch import nn\nfrom torch.cuda.amp import autocast, GradScaler\nfrom torch.utils.data import Dataset, DataLoader, random_split\nimport torchvision.transforms as T\nfrom torchvision.datasets import ImageFolder\nfrom torchvision.utils import make_grid, save_image\nfrom einops import rearrange\nfrom dalle2_pytorch.vqgan_vae import VQGanVAE\nfrom dalle2_pytorch.optimizer import get_optimizer\nfrom ema_pytorch import EMA\n# helpers\ndef exists(val):\n return val is not None\ndef noop(*args, **kwargs):\n pass\ndef cycle(dl):\n while True:\n for data in dl:\n yield data\ndef cast_tuple(t):\n return t if isinstance(t, (tuple, list)) else (t,)\ndef yes_or_no(question):\n answer = input(f'{question} (y/n) ')\n return answer.lower() in ('yes', 'y')\ndef accum_log(log, new_logs):\n for key, new_value in new_logs.items():\n old_value = log.get(key, 0.)\n log[key] = old_value + new_value",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae_trainer.py:1-47"
+ },
+ "655": {
+ "file_id": 20,
+ "content": "This code contains several utility functions and helper methods. It includes import statements for various libraries, classes for data handling and model training, as well as custom functions for logging, looping, and user input.",
+ "type": "comment"
+ },
+ "656": {
+ "file_id": 20,
+ "content": " return log\n# classes\nclass ImageDataset(Dataset):\n def __init__(\n self,\n folder,\n image_size,\n exts = ['jpg', 'jpeg', 'png']\n ):\n super().__init__()\n self.folder = folder\n self.image_size = image_size\n self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]\n print(f'{len(self.paths)} training samples found at {folder}')\n self.transform = T.Compose([\n T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),\n T.Resize(image_size),\n T.RandomHorizontalFlip(),\n T.CenterCrop(image_size),\n T.ToTensor()\n ])\n def __len__(self):\n return len(self.paths)\n def __getitem__(self, index):\n path = self.paths[index]\n img = Image.open(path)\n return self.transform(img)\n# main trainer class\nclass VQGanVAETrainer(nn.Module):\n def __init__(\n self,\n vae,\n *,\n num_train_steps,\n lr,\n batch_size,",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae_trainer.py:48-91"
+ },
+ "657": {
+ "file_id": 20,
+ "content": "The code defines a class \"ImageDataset\" for loading and transforming image data, and a main trainer class \"VQGanVAETrainer\" for training a VAE model. The \"ImageDataset\" class initializes with a folder path, image size, and extension types to filter the images, then applies image transformations like converting to RGB mode, resizing, horizontal flipping, cropping, and tensor conversion. The \"VQGanVAETrainer\" class initializes with parameters like the VAE model, number of training steps, learning rate, and batch size for the training process.",
+ "type": "comment"
+ },
+ "658": {
+ "file_id": 20,
+ "content": " folder,\n grad_accum_every,\n wd = 0.,\n save_results_every = 100,\n save_model_every = 1000,\n results_folder = './results',\n valid_frac = 0.05,\n random_split_seed = 42,\n ema_beta = 0.995,\n ema_update_after_step = 500,\n ema_update_every = 10,\n apply_grad_penalty_every = 4,\n amp = False\n ):\n super().__init__()\n assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'\n image_size = vae.image_size\n self.vae = vae\n self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)\n self.register_buffer('steps', torch.Tensor([0]))\n self.num_train_steps = num_train_steps\n self.batch_size = batch_size\n self.grad_accum_every = grad_accum_every\n all_parameters = set(vae.parameters())\n discr_parameters = set(vae.discr.parameters())\n vae_parameters = all_parameters - discr_parameters\n self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae_trainer.py:92-123"
+ },
+ "659": {
+ "file_id": 20,
+ "content": "The code initializes an instance of a VQGanVAE and sets up various parameters for training. It checks if the provided vae is of type VQGanVAE, then assigns image size, creates an EMA model with specified update steps and intervals, registers a buffer for tracking steps, sets number of train steps, batch size, grad accumulation every, and initializes optimizer with specified learning rate and weight decay.",
+ "type": "comment"
+ },
+ "660": {
+ "file_id": 20,
+ "content": " self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)\n self.amp = amp\n self.scaler = GradScaler(enabled = amp)\n self.discr_scaler = GradScaler(enabled = amp)\n # create dataset\n self.ds = ImageDataset(folder, image_size = image_size)\n # split for validation\n if valid_frac > 0:\n train_size = int((1 - valid_frac) * len(self.ds))\n valid_size = len(self.ds) - train_size\n self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))\n print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')\n else:\n self.valid_ds = self.ds\n print(f'training with shared training and valid dataset of {len(self.ds)} samples')\n # dataloader\n self.dl = cycle(DataLoader(\n self.ds,\n batch_size = batch_size,\n shuffle = True",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae_trainer.py:124-150"
+ },
+ "661": {
+ "file_id": 20,
+ "content": "This code initializes a Discriminator optimizer, Amplitude Signed-Precision (AMP) for mixed precision training, GradScaler for handling gradients, creates an ImageDataset from the given folder and image size, splits the dataset into training and validation if valid_frac is greater than 0, creates DataLoader for the dataset with specified batch_size and shuffle set to True.",
+ "type": "comment"
+ },
+ "662": {
+ "file_id": 20,
+ "content": " ))\n self.valid_dl = cycle(DataLoader(\n self.valid_ds,\n batch_size = batch_size,\n shuffle = True\n ))\n self.save_model_every = save_model_every\n self.save_results_every = save_results_every\n self.apply_grad_penalty_every = apply_grad_penalty_every\n self.results_folder = Path(results_folder)\n if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):\n rmtree(str(self.results_folder))\n self.results_folder.mkdir(parents = True, exist_ok = True)\n def train_step(self):\n device = next(self.vae.parameters()).device\n steps = int(self.steps.item())\n apply_grad_penalty = not (steps % self.apply_grad_penalty_every)\n self.vae.train()\n # logs\n logs = {}\n # update vae (generator)\n for _ in range(self.grad_accum_every):\n img = next(self.dl)\n img = img.to(device)\n with autocast(enabled = self.amp):",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae_trainer.py:151-188"
+ },
+ "663": {
+ "file_id": 20,
+ "content": "The code initializes the valid data loader and sets parameters for saving models, results, and applying gradient penalty. It checks if previous experiment checkpoints and results should be cleared, creates the results folder if needed, and defines the train_step function for training the VAE (generator).",
+ "type": "comment"
+ },
+ "664": {
+ "file_id": 20,
+ "content": " loss = self.vae(\n img,\n return_loss = True,\n apply_grad_penalty = apply_grad_penalty\n )\n self.scaler.scale(loss / self.grad_accum_every).backward()\n accum_log(logs, {'loss': loss.item() / self.grad_accum_every})\n self.scaler.step(self.optim)\n self.scaler.update()\n self.optim.zero_grad()\n # update discriminator\n if exists(self.vae.discr):\n discr_loss = 0\n for _ in range(self.grad_accum_every):\n img = next(self.dl)\n img = img.to(device)\n with autocast(enabled = self.amp):\n loss = self.vae(img, return_discr_loss = True)\n self.discr_scaler.scale(loss / self.grad_accum_every).backward()\n accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})\n self.discr_scaler.step(self.discr_optim)\n self.discr_scaler.update()\n self.discr_optim.zero_grad()",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae_trainer.py:189-221"
+ },
+ "665": {
+ "file_id": 20,
+ "content": "This code trains a VAE model and updates the discriminator. It uses scaling, accumulation, and gradients for efficient backpropagation. The loss is calculated and logged for both VAE and discriminator, then optimizers are updated.",
+ "type": "comment"
+ },
+ "666": {
+ "file_id": 20,
+ "content": " # log\n print(f\"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}\")\n # update exponential moving averaged generator\n self.ema_vae.update()\n # sample results every so often\n if not (steps % self.save_results_every):\n for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):\n model.eval()\n imgs = next(self.dl)\n imgs = imgs.to(device)\n recons = model(imgs)\n nrows = int(sqrt(self.batch_size))\n imgs_and_recons = torch.stack((imgs, recons), dim = 0)\n imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')\n imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)\n grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))\n logs['reconstructions'] = grid\n save_image(grid, str(self.results_folder / f'{filename}.png'))",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae_trainer.py:223-251"
+ },
+ "667": {
+ "file_id": 20,
+ "content": "This code snippet logs the VAE and discriminator losses, updates the exponential moving average (EMA) generator model, saves models every save_results_every steps, and generates and saves reconstruction images for training.",
+ "type": "comment"
+ },
+ "668": {
+ "file_id": 20,
+ "content": " print(f'{steps}: saving to {str(self.results_folder)}')\n # save model every so often\n if not (steps % self.save_model_every):\n state_dict = self.vae.state_dict()\n model_path = str(self.results_folder / f'vae.{steps}.pt')\n torch.save(state_dict, model_path)\n ema_state_dict = self.ema_vae.state_dict()\n model_path = str(self.results_folder / f'vae.{steps}.ema.pt')\n torch.save(ema_state_dict, model_path)\n print(f'{steps}: saving model to {str(self.results_folder)}')\n self.steps += 1\n return logs\n def train(self, log_fn = noop):\n device = next(self.vae.parameters()).device\n while self.steps < self.num_train_steps:\n logs = self.train_step()\n log_fn(logs)\n print('training complete')",
+ "type": "code",
+ "location": "/dalle2_pytorch/vqgan_vae_trainer.py:253-278"
+ },
+ "669": {
+ "file_id": 20,
+ "content": "Saves the VAE model and EMA-VAE model periodically during training, tracking progress in specified results folder.",
+ "type": "comment"
+ },
+ "670": {
+ "file_id": 21,
+ "content": "/prior.md",
+ "type": "filepath"
+ },
+ "671": {
+ "file_id": 21,
+ "content": "This code uses diffusion prior and CLIP to generate images from text prompts, implements pre-trained decoders, compares EMA models, checks image embeddings in DALLE2-pytorch, and discusses overfitting and running diffusion model training scripts.",
+ "type": "summary"
+ },
+ "672": {
+ "file_id": 21,
+ "content": "# Diffusion Prior\nThis readme serves as an introduction to the diffusion prior.\n## Intro\nA properly trained prior will allow you to translate between two embedding spaces. If you know *a priori* that two embeddings are connected some way—then ability the translate between them could extremely helpful.\n### Motivation\nBefore we dive into the model, let’s look at a quick example of where the model may be helpful.\nFor demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder.\n> [CLIP](https://openai.com/blog/clip/) is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are ***close*** the image and text embeddings occupy two disjoint sets.\n```python\n# Load Models\nclip_model = clip.load(\"ViT-L/14\")\ndecoder = Decoder(checkpoint=\"best.pth\") # A decoder trained on CLIP Image embeddings\n# Retrieve prompt from user and encode with CLIP",
+ "type": "code",
+ "location": "/prior.md:1-21"
+ },
+ "673": {
+ "file_id": 21,
+ "content": "This code introduces the concept of a diffusion prior, which is a trained model that allows translation between two embedding spaces. It motivates the use case of generating images from text using CLIP and a Decoder when embeddings are not guaranteed to be in the same space. The code loads CLIP and a pre-trained decoder, then retrieves a prompt from the user and encodes it with CLIP for further processing.",
+ "type": "comment"
+ },
+ "674": {
+ "file_id": 21,
+ "content": "prompt = \"A corgi wearing sunglasses\"\ntokenized_text = tokenize(prompt)\ntext_embedding = clip_model.encode_text(tokenized_text)\n# Now, pass the text embedding to the decoder\npredicted_image = decoder.sample(text_embedding)\n```\n> **Question**: *Can you spot the issue here?*\n>\n> **Answer**: *We’re trying to generate an image from a text embedding!*\nUnfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution\n```python\n# Load Models\nprior= Prior(checkpoint=\"prior.pth\") # A decoder trained to go from: text-> clip text emb -> clip img emb\ndecoder = Decoder(checkpoint=\"decoder.pth\") # A decoder trained on CLIP Image embeddings\n# Retrieve prompt from user and encode with a prior\nprompt = \"A corgi wearing sunglasses\"\ntokenized_text = tokenize(prompt)\ntext_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images!\n# Now, pass the predicted image embedding to the decoder\npredicted_image = decoder.sample(text_embedding)",
+ "type": "code",
+ "location": "/prior.md:22-47"
+ },
+ "675": {
+ "file_id": 21,
+ "content": "This code snippet demonstrates the process of generating an image from a text prompt using deep learning models. The decoder model is trained to convert text into embeddings that are in the same space as CLIP image embeddings. First, we load two models: Prior and Decoder. Then, we retrieve a user-inputted prompt, tokenize it, and use the Prior model to sample a text embedding in the same space as images. Finally, we pass this text embedding into the Decoder model to generate an image.",
+ "type": "comment"
+ },
+ "676": {
+ "file_id": 21,
+ "content": "```\nWith the prior we are able to successfully generate embeddings *within* CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data.\n> **You may be asking yourself the following question:**\n>\n> *\"Why don't you just train the decoder on clip text embeddings instead of image embeddings?\"*\n>\n> OpenAI covers this topic in their [DALLE-2 paper](https://arxiv.org/abs/2204.06125). The TL;DR is *\"it doesn't work as well as decoders trained on image embeddings\"*...also...its just an example :smile:\n## Usage\nTo utilize a pre-trained prior, it’s quite simple.\n### Loading Checkpoints\n```python\nimport torch\nfrom dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter\nfrom dalle2_pytorch.trainer import DiffusionPriorTrainer\ndef load_diffusion_model(dprior_path):\n prior_network = DiffusionPriorNetwork(\n dim=768,\n depth=24,\n dim_head=64,\n heads=32,\n normformer=True,\n attn_dropout=5e-2,",
+ "type": "code",
+ "location": "/prior.md:48-76"
+ },
+ "677": {
+ "file_id": 21,
+ "content": "The code demonstrates how to load a pre-trained prior model for use in generating embeddings within CLIP's image space, enhancing the performance of the decoder. The usage section outlines the necessary steps to load a checkpoint from a specific path using `load_diffusion_model()`.",
+ "type": "comment"
+ },
+ "678": {
+ "file_id": 21,
+ "content": " ff_dropout=5e-2,\n num_time_embeds=1,\n num_image_embeds=1,\n num_text_embeds=1,\n num_timesteps=1000,\n ff_mult=4\n )\n diffusion_prior = DiffusionPrior(\n net=prior_network,\n clip=OpenAIClipAdapter(\"ViT-L/14\"),\n image_embed_dim=768,\n timesteps=1000,\n cond_drop_prob=0.1,\n loss_type=\"l2\",\n condition_on_text_encodings=True,\n )\n trainer = DiffusionPriorTrainer(\n diffusion_prior=diffusion_prior,\n lr=1.1e-4,\n wd=6.02e-2,\n max_grad_norm=0.5,\n amp=False,\n group_wd_params=True,\n use_ema=True,\n device=device,\n accelerator=None,\n )\n trainer.load(dprior_path)\n return trainer\n```\n Here we instantiate a model matches the configuration it was trained with, and then load the weights (*just like any other PyTorch model!*)\n### Sampling\nOnce we have a pre-trained model, generating embeddings is quite simple!\n```python\n# tokenize the text\ntokenized_text = clip.tokenize(\"\")",
+ "type": "code",
+ "location": "/prior.md:77-119"
+ },
+ "679": {
+ "file_id": 21,
+ "content": "Here, a pre-trained model is instantiated and its weights are loaded. This can be done just like any other PyTorch model. To generate embeddings from text, first tokenize the input text using `clip.tokenize()`.",
+ "type": "comment"
+ },
+ "680": {
+ "file_id": 21,
+ "content": "# predict an embedding\npredicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)\n```\nThe resulting tensor returned from `.sample()` is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on `ViT-L/14` embeddings will predict an embedding of shape (1, 768).\n> For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text().\n**Some things to note:**\n* It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is `n=2`). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt.\n* You may specify a higher conditioning scale than the default (`1.0`). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than `1.0` but *ymmv*.",
+ "type": "code",
+ "location": "/prior.md:120-130"
+ },
+ "681": {
+ "file_id": 21,
+ "content": "The code snippet is predicting an embedding using the prior's sample function, which returns a tensor of the same shape as the training data. The number of embeddings to sample can be specified and conditioning scale can be adjusted for better results. It serves as a replacement for clip.encode_text() in CLIP priors.",
+ "type": "comment"
+ },
+ "682": {
+ "file_id": 21,
+ "content": "---\n## Training\n### Overview\nTraining the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration\n## Dataset\nTo train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage [img2datset](https://github.com/rom1504/img2dataset) to pull images from a list of URLs and [clip_retrieval](https://github.com/rom1504/clip-retrieval#clip-inference) for generating the actual embeddings that can be used in the prior's dataloader.\n## Configuration\nThe configuration file allows for you to easily track and reproduce experiments. It is a simple JSON file that wil",
+ "type": "code",
+ "location": "/prior.md:132-146"
+ },
+ "683": {
+ "file_id": 21,
+ "content": "Training the prior involves preparing a dataset in the format expected by EmbeddingReader. Precomputed embeddings for images significantly increase training efficiency and are beneficial for other tasks as well. To obtain precomputed embeddings, you can use img2dataset and clip_retrieval. The configuration file enables tracking and reproducing experiments.",
+ "type": "comment"
+ },
+ "684": {
+ "file_id": 21,
+ "content": "l specify the architecture, dataset, and training parameters. For more information and specifics please see the configuration README.\n## Distributed Training\nIf you would like to train in a distributed manner we have opted to leverage huggingface’ new Accelerate library. HFA makes it extremely simple to distribute work across multiple GPU’s and nodes. All that is required of you is to follow the simple CLI configuration tool [more information here](https://huggingface.co/docs/accelerate/accelerator).\n## Evaluation\nThere are a variety of metrics available to you when training the prior. You can read a brief description of each in the table below:\n| Metric | Description | Comments ",
+ "type": "code",
+ "location": "/prior.md:146-155"
+ },
+ "685": {
+ "file_id": 21,
+ "content": "This code describes the architecture, dataset, and training parameters for a specific task. It also mentions distributed training using HuggingFace's Accelerate library and various evaluation metrics available during training.",
+ "type": "comment"
+ },
+ "686": {
+ "file_id": 21,
+ "content": " |\n| ----------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| Online Model Validation | The validation loss associated ",
+ "type": "code",
+ "location": "/prior.md:155-157"
+ },
+ "687": {
+ "file_id": 21,
+ "content": "This code is for calculating the validation loss associated with the online model validation process. The calculated validation loss will be used to evaluate the performance of the trained model during inference.",
+ "type": "comment"
+ },
+ "688": {
+ "file_id": 21,
+ "content": "with your online model. | Ideally validation loss will be as low as possible. Using L2 loss, values as low as `0.1` and lower are possible after around 1 Billion samples seen. |\n| EMA Validation | This metric measures the validation loss associated with your EMA model. | This will likely lag behind your \"online\" model's validation loss, but should outperform in the long-term. ",
+ "type": "code",
+ "location": "/prior.md:157-158"
+ },
+ "689": {
+ "file_id": 21,
+ "content": "This code is discussing the usage of an Exponential Moving Average (EMA) model in a machine learning context. The EMA model's performance is compared to the online model, specifically focusing on validation loss as a metric. The lower the validation loss, the better the model's performance, with values around 0.1 achievable after billions of samples. However, the EMA validation loss might lag behind but should outperform in the long term.",
+ "type": "comment"
+ },
+ "690": {
+ "file_id": 21,
+ "content": " |\n| Baseline Similarity | Baseline similarity refers to the similarity between your dataset's prompts and associated image embeddings. This will serve as a guide for your prior's performance in cosine similarity. | Generally `0.3` is considered a good cosine similarity for caption similarity. |\n| Similarity With Original Image | This metric will measure the cosine similarity between your prior's predicted image embedding and the actual image",
+ "type": "code",
+ "location": "/prior.md:158-160"
+ },
+ "691": {
+ "file_id": 21,
+ "content": "This code snippet is explaining the concept of baseline similarity in the context of DALLE2-pytorch, where it refers to the similarity between dataset prompts and image embeddings. It also mentions that generally, a cosine similarity value of 0.3 is considered good for caption similarity. Additionally, there's information about another metric - similarity with original image, which measures cosine similarity between the prior's predicted image embedding and the actual image.",
+ "type": "comment"
+ },
+ "692": {
+ "file_id": 21,
+ "content": " that the caption was associated with. This is useful for determining wether your prior is generating images with the right contents. | Values around `0.75`+ are obtainable. This metric should improve rapidly in the early stages of training and plateau with diminishing increases over time. If it takes hundreds of millions of samples to reach above `0.5`/`0.6` similarity--then you likely are suffering from some kind of training error or inefficiency (i.e. not using EMA) |\n| Difference From Baseline Similarity | Sometimes its useful to visualize a metric in another light. This metric will show you how your prior's predicted image embeddings match up with the baseline similarity measured in your dataset. | This value should float around `0.0` with some room for variation. After a billion samples seen, values are within `0.01`+/- of `0.0`. If this climbs to high, (~>`0.02`) then this may be a sign that your model is overfitting ",
+ "type": "code",
+ "location": "/prior.md:160-161"
+ },
+ "693": {
+ "file_id": 21,
+ "content": "The code provides information about the similarity metric between generated images and captions, as well as the difference from baseline similarity. The values should improve rapidly in early stages of training and plateau over time, while staying around 0 for the difference metric. Values above 0.5/0.6 or climbing to high values may indicate issues with training efficiency or overfitting, respectively.",
+ "type": "comment"
+ },
+ "694": {
+ "file_id": 21,
+ "content": "somehow. |\n| Similarity With Text | This metric is your bread and butter cosine similarity between the predicted image embedding and the original caption given to the prior. Monitoring this metric will be on of your main focuses and is probably the second most important behind your loss. | As mentioned, this value should be close to baseline similarity. We have observed early rapid increase with diminishing returns as the prior learns to generate valid image embeddings. If this value increases too far beyond the baseline similarity--it could be an indication that your model is overfitting. |\n| Similarity With Unrelated Caption | This metric will attempt to exposed an overfit prior by feeding it arbitrary prompts (from your dataset) and then measure the similarity of this predicted embedding with some other image. ",
+ "type": "code",
+ "location": "/prior.md:161-163"
+ },
+ "695": {
+ "file_id": 21,
+ "content": "The code measures the cosine similarity between predicted image embeddings and original captions, as well as with unrelated captions to detect overfitting. Monitoring these metrics is crucial for model performance, as they indicate how well the model is learning from captions and generating valid image embeddings.",
+ "type": "comment"
+ },
+ "696": {
+ "file_id": 21,
+ "content": " | Early on we found that a poorly trained/modeled prior could effectively fool CLIP into believing that the cosine similarity between two images were high (when in fact the caption and image were completely unrelated). With this in mind--a low value is ideal, anything below `0.1` is probably safe. |\n## Launching the script\nNow that you’ve done all the prep it’s time for the easy part! 🚀\nTo actually launch the script, you will either use `accelerate launch train_diffusion_prior.py --config_path ` to launch with distributed training & huggingface accelerate or `python train_diffusion_prior.py` if you would like to train on your gpu/cpu without huggingface accelerate.\n## Checkpointing\nCheckpoints will be saved to the directory specified in your configuration file.\nAdditionally, a final checkpoint is saved before running the test split. This file will be saved to the same directory and",
+ "type": "code",
+ "location": "/prior.md:163-175"
+ },
+ "697": {
+ "file_id": 21,
+ "content": "The code provides instructions on how to launch the training script for a diffusion model using either distributed training with HuggingFace Accelerate or without it. It also mentions that checkpoints will be saved in the directory specified in the configuration file, and an additional final checkpoint will be saved before running the test split. The prior value should ideally be kept low to avoid fooling CLIP into believing unrelated captions and images have high cosine similarity.",
+ "type": "comment"
+ },
+ "698": {
+ "file_id": 21,
+ "content": " titled “latest.pth”. This is to avoid problems where your `save_every` configuration does not overlap with the number of steps required to do a complete pass through the data.\n## Things To Keep In Mind\nThe prior has not been trained for tasks other than the traditional CLIP embedding translation…at least yet.\nAs we finalize the replication of unCLIP, there will almost assuredly be experiments attempting to apply the prior network to other tasks.\nWith that in mind, you are more or less a pioneer in embedding-translation if you are reading this and attempting something you don’t see documentation for!",
+ "type": "code",
+ "location": "/prior.md:175-183"
+ },
+ "699": {
+ "file_id": 21,
+ "content": "This code snippet is providing information about the \"latest.pth\" file and its purpose to avoid potential problems with `save_every` configuration not overlapping with data requirements. It also mentions that the prior network has not been trained for tasks other than traditional CLIP embedding translation, hinting at future experiments applying the prior network to other tasks.",
+ "type": "comment"
+ }
+}
\ No newline at end of file
diff --git a/docs/data/7.json b/docs/data/7.json
new file mode 100644
index 00000000..7c5887dc
--- /dev/null
+++ b/docs/data/7.json
@@ -0,0 +1,549 @@
+{
+ "700": {
+ "file_id": 22,
+ "content": "/setup.py",
+ "type": "filepath"
+ },
+ "701": {
+ "file_id": 22,
+ "content": "This code is a setup script for the dalle2-pytorch package using setuptools, defining project details and dependencies like PyTorch, Torchvision, and more. It's a Python project with beta development status, targeting developers in AI field, licensed under MIT, requires Python 3.6.",
+ "type": "summary"
+ },
+ "702": {
+ "file_id": 22,
+ "content": "from setuptools import setup, find_packages\nexec(open('dalle2_pytorch/version.py').read())\nsetup(\n name = 'dalle2-pytorch',\n packages = find_packages(exclude=[]),\n include_package_data = True,\n entry_points={\n 'console_scripts': [\n 'dalle2_pytorch = dalle2_pytorch.cli:main',\n 'dream = dalle2_pytorch.cli:dream'\n ],\n },\n version = __version__,\n license='MIT',\n description = 'DALL-E 2',\n author = 'Phil Wang',\n author_email = 'lucidrains@gmail.com',\n long_description_content_type = 'text/markdown',\n url = 'https://github.com/lucidrains/dalle2-pytorch',\n keywords = [\n 'artificial intelligence',\n 'deep learning',\n 'text to image'\n ],\n install_requires=[\n 'accelerate',\n 'click',\n 'open-clip-torch>=2.0.0,<3.0.0',\n 'clip-anytorch>=2.5.2',\n 'coca-pytorch>=0.0.5',\n 'ema-pytorch>=0.0.7',\n 'einops>=0.7.0',\n 'embedding-reader',\n 'kornia>=0.5.4',\n 'numpy',\n 'packaging',\n 'pillow',\n 'pydantic>=2',\n 'pytorch-warmup',\n 'resize-right>=0.0.2',\n 'rotary-embedding-torch',",
+ "type": "code",
+ "location": "/setup.py:1-42"
+ },
+ "703": {
+ "file_id": 22,
+ "content": "This code is a setup script for the dalle2-pytorch package using setuptools. It defines the name, packages, entry points, version, license, description, author, URL, keywords, and install_requires. The script imports necessary modules and sets up dependencies for installation.",
+ "type": "comment"
+ },
+ "704": {
+ "file_id": 22,
+ "content": " 'torch>=1.10',\n 'torchvision',\n 'tqdm',\n 'vector-quantize-pytorch',\n 'x-clip>=0.4.4',\n 'webdataset>=0.2.5',\n 'fsspec>=2022.1.0',\n 'torchmetrics[image]>=0.8.0'\n ],\n classifiers=[\n 'Development Status :: 4 - Beta',\n 'Intended Audience :: Developers',\n 'Topic :: Scientific/Engineering :: Artificial Intelligence',\n 'License :: OSI Approved :: MIT License',\n 'Programming Language :: Python :: 3.6',\n ],\n)",
+ "type": "code",
+ "location": "/setup.py:43-59"
+ },
+ "705": {
+ "file_id": 22,
+ "content": "This is a Python project setup file, using setuptools. It depends on PyTorch >= 1.10, Torchvision, Tqdm, VectorQuantizePytorch, X-Clip >= 0.4.4, Webdataset >= 0.2.5, FSSpec >= 2022.1.0, and TorchMetrics[image] >= 0.8.0. The project has a beta development status, is intended for developers, relates to artificial intelligence, is licensed under MIT, and requires Python 3.6.",
+ "type": "comment"
+ },
+ "706": {
+ "file_id": 23,
+ "content": "/train_decoder.py",
+ "type": "filepath"
+ },
+ "707": {
+ "file_id": 23,
+ "content": "This code divides shards, initializes training, and trains UNet models for DALL-E 2 using PyTorch. It also supports distributed training and executes as a standalone program.",
+ "type": "summary"
+ },
+ "708": {
+ "file_id": 23,
+ "content": "from pathlib import Path\nfrom typing import List\nfrom datetime import timedelta\nfrom dalle2_pytorch.trainer import DecoderTrainer\nfrom dalle2_pytorch.dataloaders import create_image_embedding_dataloader\nfrom dalle2_pytorch.trackers import Tracker\nfrom dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig\nfrom dalle2_pytorch.utils import Timer, print_ribbon\nfrom dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to\nfrom clip import tokenize\nimport torchvision\nimport torch\nfrom torch import nn\nfrom torchmetrics.image.fid import FrechetInceptionDistance\nfrom torchmetrics.image.inception import InceptionScore\nfrom torchmetrics.image.kid import KernelInceptionDistance\nfrom torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity\nfrom accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs\nfrom accelerate.utils import dataclasses as accelerate_dataclasses\nimport webdataset as wds\nimport click\n# constants\nTRAIN_CALC_LOSS_EVERY_ITERS = 10\nVALID_CALC_LOSS_EVERY_ITERS = 10",
+ "type": "code",
+ "location": "/train_decoder.py:1-28"
+ },
+ "709": {
+ "file_id": 23,
+ "content": "This code imports various modules and defines constants for training a decoder model in the DALLE2-pytorch framework. It uses DecoderTrainer, dataloaders, trackers, train configs, utilities, and models from the dalle2_pytorch package. It also includes metrics such as FrechetInceptionDistance, InceptionScore, KernelInceptionDistance, and LearnedPerceptualImagePatchSimilarity for evaluation. Accelerate is used for accelerated training, and webdataset is used for data loading.",
+ "type": "comment"
+ },
+ "710": {
+ "file_id": 23,
+ "content": "# helpers functions\ndef exists(val):\n return val is not None\n# main functions\ndef create_dataloaders(\n available_shards,\n webdataset_base_url,\n img_embeddings_url=None,\n text_embeddings_url=None,\n shard_width=6,\n num_workers=4,\n batch_size=32,\n n_sample_images=6,\n shuffle_train=True,\n resample_train=False,\n img_preproc = None,\n index_width=4,\n train_prop = 0.75,\n val_prop = 0.15,\n test_prop = 0.10,\n seed = 0,\n **kwargs\n):\n \"\"\"\n Randomly splits the available shards into train, val, and test sets and returns a dataloader for each\n \"\"\"\n assert train_prop + test_prop + val_prop == 1\n num_train = round(train_prop*len(available_shards))\n num_test = round(test_prop*len(available_shards))\n num_val = len(available_shards) - num_train - num_test\n assert num_train + num_test + num_val == len(available_shards), f\"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}\"\n train_split, test_split, val_split =",
+ "type": "code",
+ "location": "/train_decoder.py:30-64"
+ },
+ "711": {
+ "file_id": 23,
+ "content": "This function takes available shards, URLs for embeddings, and other parameters to randomly split them into train, validation, and test sets, then returns dataloaders for each. It asserts that the proportions of splits sum up to 1, calculates the actual number of samples in each split based on the proportion, and checks if the sum of splits matches the total number of available shards.",
+ "type": "comment"
+ },
+ "712": {
+ "file_id": 23,
+ "content": " torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(seed))\n # The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.\n train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]\n test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]\n val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]\n create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(\n tar_url=tar_urls,\n num_workers=num_workers,\n batch_size=batch_size if not for_sampling else n_sample_images,\n img_embeddings_url=img_embeddings_url,\n text_embeddings_url=text_embeddings_url,\n index_width=index_width,\n shuffle_num = None,\n extra_keys= [\"txt\"],\n shuffle_shards = shuffle,",
+ "type": "code",
+ "location": "/train_decoder.py:64-80"
+ },
+ "713": {
+ "file_id": 23,
+ "content": "This code randomly splits available shards into training, testing, and validation sets. It then generates corresponding URLs for each set by zero-padding the shard numbers to match the filename format. A lambda function is created to handle creating a dataloader for image embeddings using these URLs, considering various parameters like batch size and number of workers.",
+ "type": "comment"
+ },
+ "714": {
+ "file_id": 23,
+ "content": " resample_shards = resample, \n img_preproc=img_preproc,\n handler=wds.handlers.warn_and_continue\n )\n train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)\n train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)\n val_dataloader = create_dataloader(val_urls, shuffle=False)\n test_dataloader = create_dataloader(test_urls, shuffle=False)\n test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)\n return {\n \"train\": train_dataloader,\n \"train_sampling\": train_sampling_dataloader,\n \"val\": val_dataloader,\n \"test\": test_dataloader,\n \"test_sampling\": test_sampling_dataloader\n }\ndef get_dataset_keys(dataloader):\n \"\"\"\n It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.\n \"\"\"\n # If the dataloader is actually a WebLoader, we need to extract the real dataloader",
+ "type": "code",
+ "location": "/train_decoder.py:81-103"
+ },
+ "715": {
+ "file_id": 23,
+ "content": "The code creates multiple data loaders for training, validation, and testing datasets. It returns a dictionary with each dataset's corresponding dataloader. The `get_dataset_keys` function extracts the real dataloader if the input is a WebLoader.",
+ "type": "comment"
+ },
+ "716": {
+ "file_id": 23,
+ "content": " if isinstance(dataloader, wds.WebLoader):\n dataloader = dataloader.pipeline[0]\n return dataloader.dataset.key_map\ndef get_example_data(dataloader, device, n=5):\n \"\"\"\n Samples the dataloader and returns a zipped list of examples\n \"\"\"\n images = []\n img_embeddings = []\n text_embeddings = []\n captions = []\n for img, emb, txt in dataloader:\n img_emb, text_emb = emb.get('img'), emb.get('text')\n if img_emb is not None:\n img_emb = img_emb.to(device=device, dtype=torch.float)\n img_embeddings.extend(list(img_emb))\n else:\n # Then we add None img.shape[0] times\n img_embeddings.extend([None]*img.shape[0])\n if text_emb is not None:\n text_emb = text_emb.to(device=device, dtype=torch.float)\n text_embeddings.extend(list(text_emb))\n else:\n # Then we add None img.shape[0] times\n text_embeddings.extend([None]*img.shape[0])\n img = img.to(device=device, dtype=torch.float)",
+ "type": "code",
+ "location": "/train_decoder.py:104-130"
+ },
+ "717": {
+ "file_id": 23,
+ "content": "The code samples the dataloader and returns a zipped list of examples. It iterates through each image, extracts its embedding, converts it to the device's format, extends the respective lists for images and text embeddings, and finally returns them.",
+ "type": "comment"
+ },
+ "718": {
+ "file_id": 23,
+ "content": " images.extend(list(img))\n captions.extend(list(txt))\n if len(images) >= n:\n break\n return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))\ndef generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=\"\", match_image_size=True):\n \"\"\"\n Takes example data and generates images from the embeddings\n Returns three lists: real images, generated images, and captions\n \"\"\"\n real_images, img_embeddings, text_embeddings, txts = zip(*example_data)\n sample_params = {}\n if img_embeddings[0] is None:\n # Generate image embeddings from clip\n imgs_tensor = torch.stack(real_images)\n assert clip is not None, \"clip is None, but img_embeddings is None\"\n imgs_tensor.to(device=device)\n img_embeddings, img_encoding = clip.embed_image(imgs_tensor)\n sample_params[\"image_embed\"] = img_embeddings\n else:\n # Then we are using precomputed image embeddings",
+ "type": "code",
+ "location": "/train_decoder.py:131-152"
+ },
+ "719": {
+ "file_id": 23,
+ "content": "This function generates samples by taking example data and creating real images, generated images, and captions. If image embeddings are None, it generates them using the clip model. It returns three lists: real images, generated images, and captions.",
+ "type": "comment"
+ },
+ "720": {
+ "file_id": 23,
+ "content": " img_embeddings = torch.stack(img_embeddings)\n sample_params[\"image_embed\"] = img_embeddings\n if condition_on_text_encodings:\n if text_embeddings[0] is None:\n # Generate text embeddings from text\n assert clip is not None, \"clip is None, but text_embeddings is None\"\n tokenized_texts = tokenize(txts, truncate=True).to(device=device)\n text_embed, text_encodings = clip.embed_text(tokenized_texts)\n sample_params[\"text_encodings\"] = text_encodings\n else:\n # Then we are using precomputed text embeddings\n text_embeddings = torch.stack(text_embeddings)\n sample_params[\"text_encodings\"] = text_embeddings\n sample_params[\"start_at_unet_number\"] = start_unet\n sample_params[\"stop_at_unet_number\"] = end_unet\n if start_unet > 1:\n # If we are only training upsamplers\n sample_params[\"image\"] = torch.stack(real_images)\n if device is not None:\n sample_params[\"_device\"] = device",
+ "type": "code",
+ "location": "/train_decoder.py:153-172"
+ },
+ "721": {
+ "file_id": 23,
+ "content": "This code is responsible for preparing training samples by stacking image and text embeddings, setting parameters for start and stop U-net layers, and handling the case where real images are provided. If real images exist, it stacks them as part of the sample. The code also considers whether to generate text embeddings or use precomputed ones and ensures everything is on the specified device.",
+ "type": "comment"
+ },
+ "722": {
+ "file_id": 23,
+ "content": " samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # At sampling time we don't want to cast to FP16\n generated_images = list(samples)\n captions = [text_prepend + txt for txt in txts]\n if match_image_size:\n generated_image_size = generated_images[0].shape[-1]\n real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]\n return real_images, generated_images, captions\ndef generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=\"\"):\n \"\"\"\n Generates samples and uses torchvision to put them in a side by side grid for easy viewing\n \"\"\"\n real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)\n grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]",
+ "type": "code",
+ "location": "/train_decoder.py:173-186"
+ },
+ "723": {
+ "file_id": 23,
+ "content": "This function generates samples, combines them with real images in a grid format for easy viewing. It first calls `generate_samples` to get the real and generated images along with their corresponding captions. Then it uses `torchvision.utils.make_grid` to create grids of original and generated images.",
+ "type": "comment"
+ },
+ "724": {
+ "file_id": 23,
+ "content": " return grid_images, captions\ndef evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):\n \"\"\"\n Computes evaluation metrics for the decoder\n \"\"\"\n metrics = {}\n # Prepare the data\n examples = get_example_data(dataloader, device, n_evaluation_samples)\n if len(examples) == 0:\n print(\"No data to evaluate. Check that your dataloader has shards.\")\n return metrics\n real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)\n real_images = torch.stack(real_images).to(device=device, dtype=torch.float)\n generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)\n # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8\n int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)",
+ "type": "code",
+ "location": "/train_decoder.py:187-203"
+ },
+ "725": {
+ "file_id": 23,
+ "content": "This function computes evaluation metrics for a decoder. It prepares data, generates samples using the trainer and start/end unets, converts images from [0, 1] to [0, 255], and types them as uint8. The generated and real images are then stored in variables for further evaluation metrics calculations.",
+ "type": "comment"
+ },
+ "726": {
+ "file_id": 23,
+ "content": " int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)\n def null_sync(t, *args, **kwargs):\n return [t]\n if exists(FID):\n fid = FrechetInceptionDistance(**FID, dist_sync_fn=null_sync)\n fid.to(device=device)\n fid.update(int_real_images, real=True)\n fid.update(int_generated_images, real=False)\n metrics[\"FID\"] = fid.compute().item()\n if exists(IS):\n inception = InceptionScore(**IS, dist_sync_fn=null_sync)\n inception.to(device=device)\n inception.update(int_real_images)\n is_mean, is_std = inception.compute()\n metrics[\"IS_mean\"] = is_mean.item()\n metrics[\"IS_std\"] = is_std.item()\n if exists(KID):\n kernel_inception = KernelInceptionDistance(**KID, dist_sync_fn=null_sync)\n kernel_inception.to(device=device)\n kernel_inception.update(int_real_images, real=True)\n kernel_inception.update(int_generated_images, real=False)\n kid_mean, kid_std = kernel_inception.compute()",
+ "type": "code",
+ "location": "/train_decoder.py:204-227"
+ },
+ "727": {
+ "file_id": 23,
+ "content": "This code calculates and stores metrics for the quality of generated images, including Frechet Inception Distance (FID), Inception Score (IS), and Kernel Inception Distance (KID). It first scales the generated images, then checks if specific configuration files exist for each metric. If they do, it creates an instance of the corresponding metric class, sets it up on the device, updates with real and generated images, and computes the metric values. The computed metrics are stored in the \"metrics\" dictionary.",
+ "type": "comment"
+ },
+ "728": {
+ "file_id": 23,
+ "content": " metrics[\"KID_mean\"] = kid_mean.item()\n metrics[\"KID_std\"] = kid_std.item()\n if exists(LPIPS):\n # Convert from [0, 1] to [-1, 1]\n renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)\n renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)\n lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)\n lpips.to(device=device)\n lpips.update(renorm_real_images, renorm_generated_images)\n metrics[\"LPIPS\"] = lpips.compute().item()\n if trainer.accelerator.num_processes > 1:\n # Then we should sync the metrics\n metrics_order = sorted(metrics.keys())\n metrics_tensor = torch.zeros(1, len(metrics), device=device, dtype=torch.float)\n for i, metric_name in enumerate(metrics_order):\n metrics_tensor[0, i] = metrics[metric_name]\n metrics_tensor = trainer.accelerator.gather(metrics_tensor)\n metrics_tensor = metrics_tensor.mean(dim=0)\n for i, metric_name in enumerate(metrics_order):",
+ "type": "code",
+ "location": "/train_decoder.py:228-247"
+ },
+ "729": {
+ "file_id": 23,
+ "content": "This code calculates metrics such as KID and LPIPS for a model's performance. It stores the values in a dictionary, normalizes the images if LPIPS is present, applies the LearnedPerceptualImagePatchSimilarity function, and syncs the calculated metrics across processes using accelerator functions.",
+ "type": "comment"
+ },
+ "730": {
+ "file_id": 23,
+ "content": " metrics[metric_name] = metrics_tensor[i].item()\n return metrics\ndef save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):\n \"\"\"\n Logs the model with an appropriate method depending on the tracker\n \"\"\"\n tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)\ndef recall_trainer(tracker: Tracker, trainer: DecoderTrainer):\n \"\"\"\n Loads the model with an appropriate method depending on the tracker\n \"\"\"\n trainer.accelerator.print(print_ribbon(f\"Loading model from {type(tracker.loader).__name__}\"))\n state_dict = tracker.recall()\n trainer.load_state_dict(state_dict, only_model=False, strict=True)\n return state_dict.get(\"epoch\", 0), state_dict.get(\"validation_losses\", []), state_dict.get(\"next_task\", \"train\"), state_dict.get(\"sample\", 0), state_dict.get(\"samples_seen\", 0)",
+ "type": "code",
+ "location": "/train_decoder.py:248-264"
+ },
+ "731": {
+ "file_id": 23,
+ "content": "This code contains three functions: 1) `train_decoder`, which updates metrics based on the current metric; 2) `save_trainer`, which logs the model using an appropriate method according to the tracker; and 3) `recall_trainer`, which loads the model using the tracker. The code is part of a larger system that likely involves training a machine learning model, tracking its progress, and recalling it for further use or evaluation.",
+ "type": "comment"
+ },
+ "732": {
+ "file_id": 23,
+ "content": "def train(\n dataloaders,\n decoder: Decoder,\n accelerator: Accelerator,\n tracker: Tracker,\n inference_device,\n clip=None,\n evaluate_config=None,\n epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch\n validation_samples = None,\n save_immediately=False,\n epochs = 20,\n n_sample_images = 5,\n save_every_n_samples = 100000,\n unet_training_mask=None,\n condition_on_text_encodings=False,\n cond_scale=1.0,\n **kwargs\n):\n \"\"\"\n Trains a decoder on a dataset.\n \"\"\"\n is_master = accelerator.process_index == 0\n if not exists(unet_training_mask):\n # Then the unet mask should be true for all unets in the decoder\n unet_training_mask = [True] * len(decoder.unets)\n assert len(unet_training_mask) == len(decoder.unets), f\"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}\"\n trainable_unet_numbers = [i+1 for i, trainable in enumerate(unet_training_mask) if trainable]",
+ "type": "code",
+ "location": "/train_decoder.py:266-294"
+ },
+ "733": {
+ "file_id": 23,
+ "content": "The function trains a decoder on a dataset, using the specified dataloaders, Decoder instance, and Accelerator. It also has optional arguments for clip, evaluate_config, epoch_samples, validation_samples, save_immediately, epochs, n_sample_images, save_every_n_samples, unet_training_mask, condition_on_text_encodings, and cond_scale. The function checks if the unet_training_mask exists and asserts that its length matches the number of unets in the decoder. It also assigns trainable unet numbers to a list.",
+ "type": "comment"
+ },
+ "734": {
+ "file_id": 23,
+ "content": " first_trainable_unet = trainable_unet_numbers[0]\n last_trainable_unet = trainable_unet_numbers[-1]\n def move_unets(unet_training_mask):\n for i in range(len(decoder.unets)):\n if not unet_training_mask[i]:\n # Replace the unet from the module list with a nn.Identity(). This training script never uses unets that aren't being trained so this is fine.\n decoder.unets[i] = nn.Identity().to(inference_device)\n # Remove non-trainable unets\n move_unets(unet_training_mask)\n trainer = DecoderTrainer(\n decoder=decoder,\n accelerator=accelerator,\n dataloaders=dataloaders,\n **kwargs\n )\n # Set up starting model and parameters based on a recalled state dict\n start_epoch = 0\n validation_losses = []\n next_task = 'train'\n sample = 0\n samples_seen = 0\n val_sample = 0\n step = lambda: int(trainer.num_steps_taken(unet_number=first_trainable_unet))\n if tracker.can_recall:\n start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)",
+ "type": "code",
+ "location": "/train_decoder.py:295-322"
+ },
+ "735": {
+ "file_id": 23,
+ "content": "The code is removing non-trainable UNet modules and setting up a trainer for the given task. It also checks if the state can be recalled from a previous training session and updates relevant variables accordingly.",
+ "type": "comment"
+ },
+ "736": {
+ "file_id": 23,
+ "content": " if next_task == 'train':\n sample = recalled_sample\n if next_task == 'val':\n val_sample = recalled_sample\n accelerator.print(f\"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}\")\n accelerator.print(f\"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}\")\n trainer.to(device=inference_device)\n accelerator.print(print_ribbon(\"Generating Example Data\", repeat=40))\n accelerator.print(\"This can take a while to load the shard lists...\")\n if is_master:\n train_example_data = get_example_data(dataloaders[\"train_sampling\"], inference_device, n_sample_images)\n accelerator.print(\"Generated training examples\")\n test_example_data = get_example_data(dataloaders[\"test_sampling\"], inference_device, n_sample_images)\n accelerator.print(\"Generated testing examples\")",
+ "type": "code",
+ "location": "/train_decoder.py:323-337"
+ },
+ "737": {
+ "file_id": 23,
+ "content": "The code loads a model and starts training from the specified task, either 'train' or 'val'. It prints the details of the loaded model, including epoch, samples seen, and minimum validation loss. The trainer is moved to the inference device. Example data for both training and testing is generated using get_example_data function with the specified number of sample images.",
+ "type": "comment"
+ },
+ "738": {
+ "file_id": 23,
+ "content": " send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]\n sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)\n unet_losses_tensor = torch.zeros(TRAIN_CALC_LOSS_EVERY_ITERS, trainer.num_unets, dtype=torch.float, device=inference_device)\n for epoch in range(start_epoch, epochs):\n accelerator.print(print_ribbon(f\"Starting epoch {epoch}\", repeat=40))\n timer = Timer()\n last_sample = sample\n last_snapshot = sample\n if next_task == 'train':\n for i, (img, emb, txt) in enumerate(dataloaders[\"train\"]):\n # We want to count the total number of samples across all processes\n sample_length_tensor[0] = len(img)\n all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.\n total_samples = all_samples.sum().item()\n sample += total_samples\n samples_seen += total_samples",
+ "type": "code",
+ "location": "/train_decoder.py:339-357"
+ },
+ "739": {
+ "file_id": 23,
+ "content": "Iterating over epochs in training mode, counting the total number of samples across all processes. Gathering sample length tensors using accelerator's gather function and summing them up to get the total samples seen. Updating sample and samples_seen variables accordingly.",
+ "type": "comment"
+ },
+ "740": {
+ "file_id": 23,
+ "content": " img_emb = emb.get('img')\n has_img_embedding = img_emb is not None\n if has_img_embedding:\n img_emb, = send_to_device((img_emb,))\n text_emb = emb.get('text')\n has_text_embedding = text_emb is not None\n if has_text_embedding:\n text_emb, = send_to_device((text_emb,))\n img, = send_to_device((img,))\n trainer.train()\n for unet in range(1, trainer.num_unets+1):\n # Check if this is a unet we are training\n if not unet_training_mask[unet-1]: # Unet index is the unet number - 1\n continue\n forward_params = {}\n if has_img_embedding:\n forward_params['image_embed'] = img_emb\n else:\n # Forward pass automatically generates embedding\n assert clip is not None\n img_embed, img_encoding = clip.embed_image(img)",
+ "type": "code",
+ "location": "/train_decoder.py:358-380"
+ },
+ "741": {
+ "file_id": 23,
+ "content": "This code checks if there are image or text embeddings available, sends them to the device, and then trains a model. It also performs a forward pass for image embedding generation if necessary.",
+ "type": "comment"
+ },
+ "742": {
+ "file_id": 23,
+ "content": " forward_params['image_embed'] = img_embed\n if condition_on_text_encodings:\n if has_text_embedding:\n forward_params['text_encodings'] = text_emb\n else:\n # Then we need to pass the text instead\n assert clip is not None\n tokenized_texts = tokenize(txt, truncate=True).to(inference_device)\n assert tokenized_texts.shape[0] == len(img), f\"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})\"\n text_embed, text_encodings = clip.embed_text(tokenized_texts)\n forward_params['text_encodings'] = text_encodings\n loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)\n trainer.update(unet_number=unet)\n unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss",
+ "type": "code",
+ "location": "/train_decoder.py:381-394"
+ },
+ "743": {
+ "file_id": 23,
+ "content": "This code chunk is for training the DALL-E 2 model's decoder. It first checks if image and text embeddings are provided, and if not, it tokenizes the text and generates text embeddings using the CLIP model. Then, it passes the required parameters to the trainer and updates the model, storing the loss for each unit in the unet_losses_tensor array.",
+ "type": "comment"
+ },
+ "744": {
+ "file_id": 23,
+ "content": " samples_per_sec = (sample - last_sample) / timer.elapsed()\n timer.reset()\n last_sample = sample\n if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:\n # We want to average losses across all processes\n unet_all_losses = accelerator.gather(unet_losses_tensor)\n mask = unet_all_losses != 0\n unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)\n loss_map = { f\"Unet {index} Training Loss\": loss.item() for index, loss in enumerate(unet_average_loss) if unet_training_mask[index] }\n # gather decay rate on each UNet\n ema_decay_list = {f\"Unet {index} EMA Decay\": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets) if unet_training_mask[index]}\n log_data = {\n \"Epoch\": epoch,\n \"Sample\": sample,\n \"Step\": i,",
+ "type": "code",
+ "location": "/train_decoder.py:396-413"
+ },
+ "745": {
+ "file_id": 23,
+ "content": "This code is calculating the samples per second and resetting timers, then averaging the losses across all processes for a UNet model. It gathers the decay rate on each UNet, logs epoch, sample, and step information.",
+ "type": "comment"
+ },
+ "746": {
+ "file_id": 23,
+ "content": " \"Samples per second\": samples_per_sec,\n \"Samples Seen\": samples_seen,\n **ema_decay_list,\n **loss_map\n }\n if is_master:\n tracker.log(log_data, step=step())\n if is_master and (last_snapshot + save_every_n_samples < sample or (save_immediately and i == 0)): # This will miss by some amount every time, but it's not a big deal... I hope\n # It is difficult to gather this kind of info on the accelerator, so we have to do it on the master\n print(\"Saving snapshot\")\n last_snapshot = sample\n # We need to know where the model should be saved\n save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)\n if exists(n_sample_images) and n_sample_images > 0:\n trainer.eval()\n ",
+ "type": "code",
+ "location": "/train_decoder.py:414-431"
+ },
+ "747": {
+ "file_id": 23,
+ "content": "This code snippet is logging data and saving a snapshot of the model at specific intervals. It logs samples per second, samples seen, EMA decay parameters, and loss metrics. The snapshot is saved if the current sample meets certain conditions or every time an immediate save command is issued. The code prints \"Saving snapshot\" when a snapshot is taken.",
+ "type": "comment"
+ },
+ "748": {
+ "file_id": 23,
+ "content": " train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, \"Train: \")\n tracker.log_images(train_images, captions=train_captions, image_section=\"Train Samples\", step=step())\n if epoch_samples is not None and sample >= epoch_samples:\n break\n next_task = 'val'\n sample = 0\n all_average_val_losses = None\n if next_task == 'val':\n trainer.eval()\n accelerator.print(print_ribbon(f\"Starting Validation {epoch}\", repeat=40))\n last_val_sample = val_sample\n val_sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)\n average_val_loss_tensor = torch.zeros(1, trainer.num_unets, dtype=torch.float, device=inference_device)\n timer = Timer()\n accelerator.wait_for_everyone()\n i = 0",
+ "type": "code",
+ "location": "/train_decoder.py:431-448"
+ },
+ "749": {
+ "file_id": 23,
+ "content": "This code is used for training a model and validating it. It generates samples from the training dataset, logs them, checks if it should stop based on sample count, switches to validation mode, and initializes variables for validation.",
+ "type": "comment"
+ },
+ "750": {
+ "file_id": 23,
+ "content": " for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader\n val_sample_length_tensor[0] = len(img)\n all_samples = accelerator.gather(val_sample_length_tensor)\n total_samples = all_samples.sum().item()\n val_sample += total_samples\n img_emb = emb.get('img')\n has_img_embedding = img_emb is not None\n if has_img_embedding:\n img_emb, = send_to_device((img_emb,))\n text_emb = emb.get('text')\n has_text_embedding = text_emb is not None\n if has_text_embedding:\n text_emb, = send_to_device((text_emb,))\n img, = send_to_device((img,))\n for unet in range(1, len(decoder.unets)+1):\n if not unet_training_mask[unet-1]: # Unet index is the unet number - 1\n # No need to evaluate an unchanging unet\n continue",
+ "type": "code",
+ "location": "/train_decoder.py:449-467"
+ },
+ "751": {
+ "file_id": 23,
+ "content": "This code is part of the DALLE2-pytorch training process. It iterates over the validation dataloader, gathers sample lengths, calculates total samples, and checks for image and text embeddings. If available, it sends these embeddings along with images to the device for further processing. This code ensures that all necessary data is properly prepared and sent to the device for evaluation.",
+ "type": "comment"
+ },
+ "752": {
+ "file_id": 23,
+ "content": " forward_params = {}\n if has_img_embedding:\n forward_params['image_embed'] = img_emb.float()\n else:\n # Forward pass automatically generates embedding\n assert clip is not None\n img_embed, img_encoding = clip.embed_image(img)\n forward_params['image_embed'] = img_embed\n if condition_on_text_encodings:\n if has_text_embedding:\n forward_params['text_encodings'] = text_emb.float()\n else:\n # Then we need to pass the text instead\n assert clip is not None\n tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)\n assert tokenized_texts.shape[0] == len(img), f\"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})\"",
+ "type": "code",
+ "location": "/train_decoder.py:469-484"
+ },
+ "753": {
+ "file_id": 23,
+ "content": "This code segment checks if image and text embeddings are provided. If not, it automatically generates image embedding or passes the text instead based on the condition. It also asserts the number of texts should be equal to the number of images for consistency.",
+ "type": "comment"
+ },
+ "754": {
+ "file_id": 23,
+ "content": " text_embed, text_encodings = clip.embed_text(tokenized_texts)\n forward_params['text_encodings'] = text_encodings\n loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)\n average_val_loss_tensor[0, unet-1] += loss\n if i % VALID_CALC_LOSS_EVERY_ITERS == 0:\n samples_per_sec = (val_sample - last_val_sample) / timer.elapsed()\n timer.reset()\n last_val_sample = val_sample\n accelerator.print(f\"Epoch {epoch}/{epochs} Val Step {i} - Sample {val_sample} - {samples_per_sec:.2f} samples/sec\")\n accelerator.print(f\"Loss: {(average_val_loss_tensor / (i+1))}\")\n accelerator.print(\"\")\n if validation_samples is not None and val_sample >= validation_samples:\n break\n print(f\"Rank {accelerator.state.process_index} finished validation after {i} steps\")",
+ "type": "code",
+ "location": "/train_decoder.py:485-500"
+ },
+ "755": {
+ "file_id": 23,
+ "content": "This code snippet is part of a larger model training process. It calculates the loss based on input images and text, updates the average validation loss, prints validation progress including samples per second and loss, and eventually breaks the loop when the specified number of validation samples have been processed. The code uses the PyTorch framework and the DALLE2 library for embedding text.",
+ "type": "comment"
+ },
+ "756": {
+ "file_id": 23,
+ "content": " accelerator.wait_for_everyone()\n average_val_loss_tensor /= i+1\n # Gather all the average loss tensors\n all_average_val_losses = accelerator.gather(average_val_loss_tensor)\n if is_master:\n unet_average_val_loss = all_average_val_losses.mean(dim=0)\n val_loss_map = { f\"Unet {index} Validation Loss\": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }\n tracker.log(val_loss_map, step=step())\n next_task = 'eval'\n if next_task == 'eval':\n if exists(evaluate_config):\n accelerator.print(print_ribbon(f\"Starting Evaluation {epoch}\", repeat=40))\n evaluation = evaluate_trainer(trainer, dataloaders[\"val\"], inference_device, first_trainable_unet, last_trainable_unet, clip=clip, inference_device=inference_device, **evaluate_config.model_dump(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)\n if is_master:",
+ "type": "code",
+ "location": "/train_decoder.py:501-515"
+ },
+ "757": {
+ "file_id": 23,
+ "content": "This code is used for averaging the validation losses and logging them during training. It also starts the evaluation process if it's time to do so, printing a message to indicate this. The average_val_loss_tensor is gathered by the accelerator, and then the mean of all the average loss tensors is calculated if the current task is 'eval'. If there are no zeros in the unet_average_val_loss, the validation losses are logged.",
+ "type": "comment"
+ },
+ "758": {
+ "file_id": 23,
+ "content": " tracker.log(evaluation, step=step())\n next_task = 'sample'\n val_sample = 0\n if next_task == 'sample':\n if is_master:\n # Generate examples and save the model if we are the master\n # Generate sample images\n print(print_ribbon(f\"Sampling Set {epoch}\", repeat=40))\n test_images, test_captions = generate_grid_samples(trainer, test_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, \"Test: \")\n train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, \"Train: \")\n tracker.log_images(test_images, captions=test_captions, image_section=\"Test Samples\", step=step())\n tracker.log_images(train_images, captions=train_captions, image_section=\"Train Samples\", step=step())",
+ "type": "code",
+ "location": "/train_decoder.py:516-528"
+ },
+ "759": {
+ "file_id": 23,
+ "content": "The code is generating sample images and saving the model if it is the master process. It prints a ribbon and then generates grid samples from both train and test example data, conditioning on text encodings. Finally, it logs the generated images using the tracker, with labels indicating whether they are test or train samples.",
+ "type": "comment"
+ },
+ "760": {
+ "file_id": 23,
+ "content": " print(print_ribbon(f\"Starting Saving {epoch}\", repeat=40))\n is_best = False\n if all_average_val_losses is not None:\n average_loss = all_average_val_losses.mean(dim=0).sum() / sum(unet_training_mask)\n if len(validation_losses) == 0 or average_loss < min(validation_losses):\n is_best = True\n validation_losses.append(average_loss)\n save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best)\n next_task = 'train'\ndef create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker:\n tracker_config = config.tracker\n accelerator_config = {\n \"Distributed\": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,\n \"DistributedType\": accelerator.distributed_type,\n \"NumProcesses\": accelerator.num_processes,\n \"MixedPrecision\": accelerator.mixed_precision",
+ "type": "code",
+ "location": "/train_decoder.py:530-546"
+ },
+ "761": {
+ "file_id": 23,
+ "content": "The code checks if the average validation loss is lower than previous min, and saves the trainer if it's a new minimum. It's part of a function called create_tracker that creates a tracker object with accelerator, config, and dummy parameters.",
+ "type": "comment"
+ },
+ "762": {
+ "file_id": 23,
+ "content": " }\n accelerator.wait_for_everyone() # If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors\n tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)\n tracker.save_config(config_path, config_name='decoder_config.json')\n tracker.add_save_metadata(state_dict_key='config', metadata=config.model_dump())\n return tracker\ndef initialize_training(config: TrainDecoderConfig, config_path):\n # Make sure if we are not loading, distributed models are initialized to the same values\n torch.manual_seed(config.seed)\n # Set up accelerator for configurable distributed training\n ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)\n init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))\n accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])\n if accelerator.num_processes > 1:",
+ "type": "code",
+ "location": "/train_decoder.py:547-563"
+ },
+ "763": {
+ "file_id": 23,
+ "content": "This code initializes distributed training for DALLE2, sets manual seed, and creates an accelerator for parallel processing with optional arguments. The function returns a tracker object to save configuration.",
+ "type": "comment"
+ },
+ "764": {
+ "file_id": 23,
+ "content": " # We are using distributed training and want to immediately ensure all can connect\n accelerator.print(\"Waiting for all processes to connect...\")\n accelerator.wait_for_everyone()\n accelerator.print(\"All processes online and connected\")\n # If we are in deepspeed fp16 mode, we must ensure learned variance is off\n if accelerator.mixed_precision == \"fp16\" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:\n raise ValueError(\"DeepSpeed fp16 mode does not support learned variance\")\n # Set up data\n all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))\n world_size = accelerator.num_processes\n rank = accelerator.process_index\n shards_per_process = len(all_shards) // world_size\n assert shards_per_process > 0, \"Not enough shards to split evenly\"\n my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]\n dataloaders = create_dataloaders (",
+ "type": "code",
+ "location": "/train_decoder.py:564-581"
+ },
+ "765": {
+ "file_id": 23,
+ "content": "This code snippet is part of a distributed training process where it checks the accelerator settings, data sharding, and creates dataloaders for training. It ensures all processes are connected, handles DeepSpeed mixed precision mode without learned variance, splits data shards evenly across processes, and finally creates the necessary dataloaders for the training process.",
+ "type": "comment"
+ },
+ "766": {
+ "file_id": 23,
+ "content": " available_shards=my_shards,\n img_preproc = config.data.img_preproc,\n train_prop = config.data.splits.train,\n val_prop = config.data.splits.val,\n test_prop = config.data.splits.test,\n n_sample_images=config.train.n_sample_images,\n **config.data.model_dump(),\n rank = rank,\n seed = config.seed,\n )\n # If clip is in the model, we need to remove it for compatibility with deepspeed\n clip = None\n if config.decoder.clip is not None:\n clip = config.decoder.clip.create() # Of course we keep it to use it during training, just not in the decoder as that causes issues\n config.decoder.clip = None\n # Create the decoder model and print basic info\n decoder = config.decoder.create()\n get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))\n # Create and initialize the tracker if we are the master\n tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)",
+ "type": "code",
+ "location": "/train_decoder.py:582-603"
+ },
+ "767": {
+ "file_id": 23,
+ "content": "The code initializes the decoder model with specified parameters, removes clip if present for compatibility, and creates a tracker if the current rank is not the master. It also calculates the number of parameters in the model and prepares it for training.",
+ "type": "comment"
+ },
+ "768": {
+ "file_id": 23,
+ "content": " has_img_embeddings = config.data.img_embeddings_url is not None\n has_text_embeddings = config.data.text_embeddings_url is not None\n conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])\n has_clip_model = clip is not None\n data_source_string = \"\"\n if has_img_embeddings:\n data_source_string += \"precomputed image embeddings\"\n elif has_clip_model:\n data_source_string += \"clip image embeddings generation\"\n else:\n raise ValueError(\"No image embeddings source specified\")\n if conditioning_on_text:\n if has_text_embeddings:\n data_source_string += \" and precomputed text embeddings\"\n elif has_clip_model:\n data_source_string += \" and clip text encoding generation\"\n else:\n raise ValueError(\"No text embeddings source specified\")\n accelerator.print(print_ribbon(\"Loaded Config\", repeat=40))\n accelerator.print(f\"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training\")",
+ "type": "code",
+ "location": "/train_decoder.py:605-627"
+ },
+ "769": {
+ "file_id": 23,
+ "content": "This code checks if image and/or text embeddings are available, either precomputed or generated using CLIP model. It then prints a message indicating the source of embeddings used for training.",
+ "type": "comment"
+ },
+ "770": {
+ "file_id": 23,
+ "content": " accelerator.print(f\"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}\")\n accelerator.print(f\"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training\")\n for i, unet in enumerate(decoder.unets):\n accelerator.print(f\"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training\")\n train(dataloaders, decoder, accelerator,\n clip=clip,\n tracker=tracker,\n inference_device=accelerator.device,\n evaluate_config=config.evaluate,\n condition_on_text_encodings=conditioning_on_text,\n **config.train.model_dump(),\n )\n# Create a simple click command line interface to load the config and start the training\n@click.command()\n@click.option(\"--config_file\", default=\"./train_decoder_config.json\", help=\"Path to config file\")\ndef main(config_file):\n config_file_path = Path(config_file)\n config = TrainDecoderConfig.from_json_path(str(config_file_path))",
+ "type": "code",
+ "location": "/train_decoder.py:628-647"
+ },
+ "771": {
+ "file_id": 23,
+ "content": "Training of the decoder is being executed using the specified data source, with or without conditioning on text. The number of parameters in total and for training are displayed, along with similar information for each Unet. The train function is called with dataloaders, decoder, accelerator, clip, tracker, inference_device, evaluate_config, and condition_on_text_encodings as arguments. A simple click command line interface is created to load the config and start training, using a default configuration file path and allowing for an alternative path to be specified with the --config_file option.",
+ "type": "comment"
+ },
+ "772": {
+ "file_id": 23,
+ "content": " initialize_training(config, config_path=config_file_path)\nif __name__ == \"__main__\":\n main()",
+ "type": "code",
+ "location": "/train_decoder.py:648-651"
+ },
+ "773": {
+ "file_id": 23,
+ "content": "This code snippet initializes training and then calls the main function if the script is run directly. It ensures proper execution when running the script as a standalone program.",
+ "type": "comment"
+ },
+ "774": {
+ "file_id": 24,
+ "content": "/train_diffusion_prior.py",
+ "type": "filepath"
+ },
+ "775": {
+ "file_id": 24,
+ "content": "This code trains a Diffusion Prior model using PyTorch and DALLE2-pytorch library, with functions for creating the model, training, data loading, acceleration, evaluation, text-image similarity comparison, backpropagation, logging, saving best models, measuring speed, resetting validation timers, handling errors, saving models, and initializing training with data loaders and HFA setup.",
+ "type": "summary"
+ },
+ "776": {
+ "file_id": 24,
+ "content": "import click\nimport torch\nfrom torch import nn\nfrom typing import List\nfrom accelerate import Accelerator\nfrom accelerate.utils import set_seed\nfrom torch.utils.data import DataLoader\nfrom embedding_reader import EmbeddingReader\nfrom accelerate.utils import dataclasses as accelerate_dataclasses\nfrom dalle2_pytorch.utils import Timer\nfrom dalle2_pytorch.trackers import Tracker\nfrom dalle2_pytorch import DiffusionPriorTrainer\nfrom dalle2_pytorch.dataloaders import get_reader, make_splits\nfrom dalle2_pytorch.train_configs import (\n DiffusionPriorConfig,\n DiffusionPriorTrainConfig,\n TrainDiffusionPriorConfig,\n)\n# helpers\ncos = nn.CosineSimilarity(dim=1, eps=1e-6)\ndef exists(val):\n return val is not None\ndef all_between(values: list, lower_bound, upper_bound):\n for value in values:\n if value < lower_bound or value > upper_bound:\n return False\n return True\ndef make_model(\n prior_config: DiffusionPriorConfig,\n train_config: DiffusionPriorTrainConfig,\n device: str = None,\n accelerator: Accelerator = None,",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:1-45"
+ },
+ "777": {
+ "file_id": 24,
+ "content": "This code is for training a Diffusion Prior model using PyTorch and the DALLE2-pytorch library. It defines functions to create the model, configure the training process, and load data. The cosine similarity function is used for comparison, and there are helper functions to check if values exist and if they fall within specified bounds. The code also uses accelerate for efficient training and allows for device specification (CPU or GPU) and an optional accelerator instance for further optimization.",
+ "type": "comment"
+ },
+ "778": {
+ "file_id": 24,
+ "content": "):\n # create model from config\n diffusion_prior = prior_config.create()\n # instantiate the trainer\n trainer = DiffusionPriorTrainer(\n diffusion_prior=diffusion_prior,\n lr=train_config.lr,\n wd=train_config.wd,\n max_grad_norm=train_config.max_grad_norm,\n amp=train_config.amp,\n use_ema=train_config.use_ema,\n device=device,\n accelerator=accelerator,\n warmup_steps=train_config.warmup_steps,\n )\n return trainer\ndef create_tracker(\n accelerator: Accelerator,\n config: TrainDiffusionPriorConfig,\n config_path: str,\n dummy: bool = False,\n) -> Tracker:\n tracker_config = config.tracker\n accelerator_config = {\n \"Distributed\": accelerator.distributed_type\n != accelerate_dataclasses.DistributedType.NO,\n \"DistributedType\": accelerator.distributed_type,\n \"NumProcesses\": accelerator.num_processes,\n \"MixedPrecision\": accelerator.mixed_precision,\n }\n tracker: Tracker = tracker_config.create(\n config, accelerator_config, dummy_mode=dummy",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:46-83"
+ },
+ "779": {
+ "file_id": 24,
+ "content": "This code defines a function `create_trainer` that takes in a `prior_config`, and creates a `DiffusionPriorTrainer` object with specified parameters. It also defines the `create_tracker` function, which creates a `Tracker` object based on the provided configuration. The functions return the created objects.",
+ "type": "comment"
+ },
+ "780": {
+ "file_id": 24,
+ "content": " )\n tracker.save_config(config_path, config_name=\"prior_config.json\")\n return tracker\ndef pad_gather_reduce(trainer: DiffusionPriorTrainer, x, method=\"mean\"):\n \"\"\"\n pad a value or tensor across all processes and gather\n params:\n - trainer: a trainer that carries an accelerator object\n - x: a number or torch tensor to reduce\n - method: \"mean\", \"sum\", \"max\", \"min\"\n return:\n - the average tensor after maskin out 0's\n - None if the gather resulted in an empty tensor\n \"\"\"\n assert method in [\n \"mean\",\n \"sum\",\n \"max\",\n \"min\",\n ], \"This function has limited capabilities [sum, mean, max, min]\"\n assert type(x) is not None, \"Cannot reduce a None type object\"\n # wait for everyone to arrive here before gathering\n if type(x) is not torch.Tensor:\n x = torch.tensor([x])\n # verify that the tensor is on the proper device\n x = x.to(trainer.device)\n # pad across processes\n padded_x = trainer.accelerator.pad_across_processes(x, dim=0)",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:84-122"
+ },
+ "781": {
+ "file_id": 24,
+ "content": "This function pads a value or tensor across all processes, gathers them and reduces them to a single average. It works with tensors of type \"mean\", \"sum\", \"max\", and \"min\". If the resulting tensor is empty, it returns None. It first waits for everyone to arrive before gathering, converts the input to a tensor if it's not already, and ensures that the tensor is on the proper device.",
+ "type": "comment"
+ },
+ "782": {
+ "file_id": 24,
+ "content": " # gather across all procesess\n gathered_x = trainer.accelerator.gather(padded_x)\n # mask out zeros\n masked_x = gathered_x[gathered_x != 0]\n # if the tensor is empty, warn and return None\n if len(masked_x) == 0:\n click.secho(\n f\"The call to this method resulted in an empty tensor after masking out zeros. The gathered tensor was this: {gathered_x} and the original value passed was: {x}.\",\n fg=\"red\",\n )\n return None\n if method == \"mean\":\n return torch.mean(masked_x)\n elif method == \"sum\":\n return torch.sum(masked_x)\n elif method == \"max\":\n return torch.max(masked_x)\n elif method == \"min\":\n return torch.min(masked_x)\ndef save_trainer(\n tracker: Tracker,\n trainer: DiffusionPriorTrainer,\n is_latest: bool,\n is_best: bool,\n epoch: int,\n samples_seen: int,\n best_validation_loss: float,\n):\n \"\"\"\n Logs the model with an appropriate method depending on the tracker\n \"\"\"\n trainer.accelerator.wait_for_everyone()",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:124-160"
+ },
+ "783": {
+ "file_id": 24,
+ "content": "The code gathers tensor data across all processes, masks out zeros, and handles empty tensors. It then calculates the mean, sum, maximum, or minimum of the masked tensor depending on the method specified. The save_trainer function logs the model with an appropriate method based on the tracker.",
+ "type": "comment"
+ },
+ "784": {
+ "file_id": 24,
+ "content": " if trainer.accelerator.is_main_process:\n click.secho(\n f\"RANK:{trainer.accelerator.process_index} | Saving Model | Best={is_best} | Latest={is_latest}\",\n fg=\"magenta\",\n )\n tracker.save(\n trainer=trainer,\n is_best=is_best,\n is_latest=is_latest,\n epoch=int(epoch),\n samples_seen=int(samples_seen),\n best_validation_loss=best_validation_loss,\n )\ndef recall_trainer(tracker: Tracker, trainer: DiffusionPriorTrainer):\n \"\"\"\n Loads the model with an appropriate method depending on the tracker\n \"\"\"\n if trainer.accelerator.is_main_process:\n click.secho(f\"Loading model from {type(tracker.loader).__name__}\", fg=\"yellow\")\n state_dict = tracker.recall()\n trainer.load(state_dict, strict=True)\n return (\n int(state_dict.get(\"epoch\", 0)),\n state_dict.get(\"best_validation_loss\", 0),\n int(state_dict.get(\"samples_seen\", 0)),\n )\n# eval functions\ndef report_validation_loss(\n trainer: DiffusionPriorTrainer,",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:162-201"
+ },
+ "785": {
+ "file_id": 24,
+ "content": "This code is part of a model training process. It saves the model at certain intervals and loads it later depending on the tracker type. The save function reports whether the saved model is best or latest, and the recall_trainer function loads the model with an appropriate method based on the tracker's loader type. Additionally, there are functions for evaluating validation loss.",
+ "type": "comment"
+ },
+ "786": {
+ "file_id": 24,
+ "content": " dataloader: DataLoader,\n text_conditioned: bool,\n use_ema: bool,\n tracker: Tracker,\n split: str,\n tracker_folder: str,\n loss_type: str,\n):\n \"\"\"\n Compute the validation loss on a given subset of data.\n \"\"\"\n if trainer.accelerator.is_main_process:\n click.secho(\n f\"Measuring performance on {use_ema}-{split} split\",\n fg=\"green\",\n blink=True,\n )\n total_loss = torch.zeros(1, dtype=torch.float, device=trainer.device)\n for image_embeddings, text_data in dataloader:\n image_embeddings = image_embeddings.to(trainer.device)\n text_data = text_data.to(trainer.device)\n input_args = dict(image_embed=image_embeddings)\n if text_conditioned:\n input_args = dict(**input_args, text=text_data)\n else:\n input_args = dict(**input_args, text_embed=text_data)\n if use_ema:\n loss = trainer.ema_diffusion_prior(**input_args)\n else:\n loss = trainer(**input_args)\n total_loss += loss",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:202-239"
+ },
+ "787": {
+ "file_id": 24,
+ "content": "This code measures validation loss on a given data subset, using an optional EMA model and text conditioning. It iterates through a dataloader, computes losses for each batch, accumulates them in total_loss, and finally returns the average loss. The progress is echoed if the process is the main one.",
+ "type": "comment"
+ },
+ "788": {
+ "file_id": 24,
+ "content": " # compute the average loss across all processes\n avg_loss = pad_gather_reduce(trainer, total_loss, method=\"mean\")\n stats = {f\"{tracker_folder}/{loss_type}-loss\": avg_loss}\n # print and log results on main process\n tracker.log(stats, step=trainer.step.item() + 1)\n return avg_loss\ndef report_cosine_sims(\n trainer: DiffusionPriorTrainer,\n dataloader: DataLoader,\n text_conditioned: bool,\n tracker: Tracker,\n split: str,\n timesteps: int,\n tracker_folder: str,\n):\n trainer.eval()\n if trainer.accelerator.is_main_process:\n click.secho(\n f\"Measuring Cosine-Similarity on {split} split with {timesteps} timesteps\",\n fg=\"green\",\n blink=True,\n )\n for test_image_embeddings, text_data in dataloader:\n test_image_embeddings = test_image_embeddings.to(trainer.device)\n text_data = text_data.to(trainer.device)\n # we are text conditioned, we produce an embedding from the tokenized text\n if text_conditioned:\n text_embedding, text_encodings = trainer.embed_text(text_data)",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:241-275"
+ },
+ "789": {
+ "file_id": 24,
+ "content": "This code measures the cosine similarity on a given split with specified timesteps. It first sets the trainer to evaluation mode and then iterates through each batch of data from the dataloader. Within this loop, it moves both test image embeddings and text data to the device used by the trainer. If the model is text-conditioned, it generates an embedding from the tokenized text using the `embed_text` function provided by the trainer. This information can be useful for understanding how this code measures cosine similarity in a given context.",
+ "type": "comment"
+ },
+ "790": {
+ "file_id": 24,
+ "content": " text_cond = dict(text_embed=text_embedding, text_encodings=text_encodings)\n else:\n text_embedding = text_data\n text_cond = dict(text_embed=text_embedding)\n # make a copy of the text embeddings for shuffling\n text_embed_shuffled = text_embedding.clone()\n # roll the text to simulate \"unrelated\" captions\n rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)\n text_embed_shuffled = text_embed_shuffled[rolled_idx]\n text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm(\n dim=1, keepdim=True\n )\n if text_conditioned:\n text_encodings_shuffled = text_encodings[rolled_idx]\n else:\n text_encodings_shuffled = None\n text_cond_shuffled = dict(\n text_embed=text_embed_shuffled, text_encodings=text_encodings_shuffled\n )\n # prepare the text embedding\n text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)\n # prepare image embeddings",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:276-303"
+ },
+ "791": {
+ "file_id": 24,
+ "content": "This code shuffles text embeddings and encodings to simulate \"unrelated\" captions for training the diffusion model. If text-conditioned, it also shuffles the text condition. It prepares both text and image embeddings.",
+ "type": "comment"
+ },
+ "792": {
+ "file_id": 24,
+ "content": " test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(\n dim=1, keepdim=True\n )\n # predict on the unshuffled text embeddings\n predicted_image_embeddings = trainer.p_sample_loop(\n test_image_embeddings.shape,\n text_cond,\n timesteps=timesteps,\n )\n predicted_image_embeddings = (\n predicted_image_embeddings\n / predicted_image_embeddings.norm(dim=1, keepdim=True)\n )\n # predict on the shuffled embeddings\n predicted_unrelated_embeddings = trainer.p_sample_loop(\n test_image_embeddings.shape,\n text_cond_shuffled,\n timesteps=timesteps,\n )\n predicted_unrelated_embeddings = (\n predicted_unrelated_embeddings\n / predicted_unrelated_embeddings.norm(dim=1, keepdim=True)\n )\n # calculate similarities\n orig_sim = pad_gather_reduce(\n trainer, cos(text_embed, test_image_embeddings), method=\"mean\"",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:304-334"
+ },
+ "793": {
+ "file_id": 24,
+ "content": "This code calculates the similarity between text embeddings and image embeddings, then shuffles the text embeddings to create unrelated pairs. It uses diffusion models for prediction and normalizes the embeddings. The final step is calculating the similarities using cosine similarity and mean reduction method.",
+ "type": "comment"
+ },
+ "794": {
+ "file_id": 24,
+ "content": " )\n pred_sim = pad_gather_reduce(\n trainer, cos(text_embed, predicted_image_embeddings), method=\"mean\"\n )\n unrel_sim = pad_gather_reduce(\n trainer, cos(text_embed, predicted_unrelated_embeddings), method=\"mean\"\n )\n pred_img_sim = pad_gather_reduce(\n trainer,\n cos(test_image_embeddings, predicted_image_embeddings),\n method=\"mean\",\n )\n stats = {\n f\"{tracker_folder}/baseline similarity [steps={timesteps}]\": orig_sim,\n f\"{tracker_folder}/similarity with text [steps={timesteps}]\": pred_sim,\n f\"{tracker_folder}/similarity with original image [steps={timesteps}]\": pred_img_sim,\n f\"{tracker_folder}/similarity with unrelated caption [steps={timesteps}]\": unrel_sim,\n f\"{tracker_folder}/difference from baseline similarity [steps={timesteps}]\": pred_sim\n - orig_sim,\n }\n tracker.log(stats, step=trainer.step.item() + 1)\ndef eval_model(",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:335-360"
+ },
+ "795": {
+ "file_id": 24,
+ "content": "This code calculates similarity scores between embeddings of text, predicted images, and original images. It then logs these scores for various steps in the training process to track progress.",
+ "type": "comment"
+ },
+ "796": {
+ "file_id": 24,
+ "content": " trainer: DiffusionPriorTrainer,\n dataloader: DataLoader,\n text_conditioned: bool,\n split: str,\n tracker: Tracker,\n use_ema: bool,\n report_cosine: bool,\n report_loss: bool,\n timesteps: List[int],\n loss_type: str = None,\n):\n \"\"\"\n Run evaluation on a model and track metrics\n returns: loss if requested\n \"\"\"\n trainer.eval()\n use_ema = \"ema\" if use_ema else \"online\"\n tracker_folder = f\"metrics/{use_ema}-{split}\"\n # detemine if valid timesteps are passed\n min_timesteps = trainer.accelerator.unwrap_model(\n trainer.diffusion_prior\n ).sample_timesteps\n max_timesteps = trainer.accelerator.unwrap_model(\n trainer.diffusion_prior\n ).noise_scheduler.num_timesteps\n assert all_between(\n timesteps, lower_bound=min_timesteps, upper_bound=max_timesteps\n ), f\"all timesteps values must be between {min_timesteps} and {max_timesteps}: got {timesteps}\"\n # measure cosine metrics across various eta and timesteps\n if report_cosine:\n for timestep in timesteps:",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:361-398"
+ },
+ "797": {
+ "file_id": 24,
+ "content": "This function runs evaluation on a model, tracks metrics, and returns the loss if requested. It uses DiffusionPriorTrainer and DataLoader. The use_ema parameter is used to differentiate between an Exponential Moving Average (EMA) model and an online (current) model. It checks whether the timesteps are valid for the model's noise scheduler. It also measures cosine metrics across various eta and timesteps if report_cosine is set to True.",
+ "type": "comment"
+ },
+ "798": {
+ "file_id": 24,
+ "content": " report_cosine_sims(\n trainer,\n dataloader=dataloader,\n text_conditioned=text_conditioned,\n tracker=tracker,\n split=split,\n timesteps=timestep,\n tracker_folder=tracker_folder,\n )\n # measure loss on a seperate split of data\n if report_loss:\n loss = report_validation_loss(\n trainer=trainer,\n dataloader=dataloader,\n text_conditioned=text_conditioned,\n use_ema=use_ema,\n tracker=tracker,\n split=split,\n tracker_folder=tracker_folder,\n loss_type=loss_type,\n )\n return loss\n# training script\ndef train(\n trainer: DiffusionPriorTrainer,\n tracker: Tracker,\n train_loader: DataLoader,\n eval_loader: DataLoader,\n test_loader: DataLoader,\n config: DiffusionPriorTrainConfig,\n):\n # init timers\n save_timer = Timer() # when to save\n samples_timer = Timer() # samples/sec\n validation_profiler = Timer() # how long is validation taking",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:399-440"
+ },
+ "799": {
+ "file_id": 24,
+ "content": "This code measures cosine similarity on a separate dataset and reports the loss on another split of data in a training script. It also initializes timers for saving, measuring samples per second, and tracking validation time.",
+ "type": "comment"
+ }
+}
\ No newline at end of file
diff --git a/docs/data/8.json b/docs/data/8.json
new file mode 100644
index 00000000..c034f729
--- /dev/null
+++ b/docs/data/8.json
@@ -0,0 +1,112 @@
+{
+ "800": {
+ "file_id": 24,
+ "content": " validation_countdown = Timer() # when to perform evalutation\n # keep track of best validation loss\n best_validation_loss = config.train.best_validation_loss\n samples_seen = config.train.num_samples_seen\n # do training\n start_epoch = config.train.current_epoch\n for epoch in range(start_epoch, config.train.epochs):\n # if we finished out an old epoch, reset the distribution to be a full epoch\n tracker.log({\"tracking/epoch\": epoch}, step=trainer.step.item())\n if train_loader.dataset.get_start() > 0 and epoch == start_epoch+1:\n if trainer.accelerator.is_main_process:\n click.secho(f\"Finished resumed epoch...resetting dataloader.\")\n train_loader.dataset.set_start(0)\n for img, txt in train_loader:\n # setup things every step\n trainer.train()\n current_step = trainer.step.item()\n samples_timer.reset()\n # place data on device\n img = img.to(trainer.device)\n txt = txt.to(trainer.device)",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:441-471"
+ },
+ "801": {
+ "file_id": 24,
+ "content": "The code sets up a training loop that iterates over epochs and resets the dataloader if it was paused mid-epoch. It places data on the device, tracks the best validation loss, and keeps track of samples seen.",
+ "type": "comment"
+ },
+ "802": {
+ "file_id": 24,
+ "content": " # pass to model\n loss = trainer(text=txt, image_embed=img)\n # perform backprop & apply EMA updates\n trainer.update()\n # gather info about training step\n all_loss = pad_gather_reduce(trainer, loss, method=\"mean\")\n num_samples = pad_gather_reduce(trainer, len(txt), method=\"sum\")\n samples_per_sec = num_samples / samples_timer.elapsed()\n samples_seen += num_samples\n ema_decay = trainer.ema_diffusion_prior.get_current_decay()\n # log\n tracker.log(\n {\n \"tracking/samples-sec\": samples_per_sec,\n \"tracking/samples-seen\": samples_seen,\n \"tracking/ema-decay\": ema_decay,\n f\"tracking/training-{config.prior.loss_type}\": all_loss,\n },\n step=current_step,\n )\n # Metric Tracking @ Timed Intervals\n eval_delta = pad_gather_reduce(\n trainer, validation_countdown.elapsed(), method=\"min\"",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:473-504"
+ },
+ "803": {
+ "file_id": 24,
+ "content": "This code is performing backpropagation, updating the exponential moving average (EMA), logging training metrics, and tracking evaluation intervals. It calculates the loss from text and image embeddings using the trainer model and updates the EMA diffusion prior. Metrics like samples per second, number of samples seen, EMA decay, and a specific loss type are logged at each step, while evaluating the validation countdown time interval for metrics tracking.",
+ "type": "comment"
+ },
+ "804": {
+ "file_id": 24,
+ "content": " )\n if eval_delta != None and eval_delta > config.data.eval_every_seconds:\n # begin timing how long this takes\n validation_profiler.reset()\n # package kwargs for evaluation\n eval_kwargs = {\n \"trainer\": trainer,\n \"tracker\": tracker,\n \"text_conditioned\": config.prior.condition_on_text_encodings,\n \"timesteps\": config.train.eval_timesteps,\n }\n # ONLINE MODEL : COSINE : LOSS : VALIDATION SPLIT\n eval_model(\n dataloader=eval_loader,\n loss_type=config.prior.loss_type,\n split=\"validation\",\n use_ema=False,\n report_cosine=False,\n report_loss=True,\n **eval_kwargs,\n )\n # EMA MODEL : COSINE : LOSS : VALIDATION DATA\n ema_val_loss = eval_model(\n dataloader=eval_loader,",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:505-536"
+ },
+ "805": {
+ "file_id": 24,
+ "content": "This code is evaluating the model on validation data with specified options. It checks if it's time to evaluate, resets the profiler for timing, packages evaluation kwargs, and calls eval_model function with dataloader, loss type, split (validation), use_ema, report_cosine, report_loss, and eval_kwargs. It also evaluates the ema model separately.",
+ "type": "comment"
+ },
+ "806": {
+ "file_id": 24,
+ "content": " loss_type=config.prior.loss_type,\n split=\"validation\",\n use_ema=True,\n report_cosine=True,\n report_loss=True,\n **eval_kwargs,\n )\n tracker.log(\n {\n \"tracking/validation length (minutes)\": validation_profiler.elapsed()\n / 60\n }\n )\n # check if the ema validation is the lowest seen yet\n if ema_val_loss < best_validation_loss:\n best_validation_loss = ema_val_loss\n # go save the model as best\n save_trainer(\n trainer=trainer,\n tracker=tracker,\n is_best=True,\n is_latest=False,\n samples_seen=samples_seen,\n epoch=epoch,\n best_validation_loss=best_validation_loss,",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:537-566"
+ },
+ "807": {
+ "file_id": 24,
+ "content": "In this code, a validation process is executed using ema (exponential moving average) to calculate the loss. The lowest ema validation loss seen so far is stored in `best_validation_loss` and if the current validation loss is lower than the previous best, the model is saved as the 'best' model. This code also logs the time taken for the validation process.",
+ "type": "comment"
+ },
+ "808": {
+ "file_id": 24,
+ "content": " )\n # reset timer for validaiton\n validation_countdown.reset()\n elif eval_delta is None:\n click.secho(\n f\"Error occured reading the eval time on rank: {trainer.device}\",\n fg=\"yellow\",\n )\n # save as latest model on schedule\n save_delta = pad_gather_reduce(trainer, save_timer.elapsed(), method=\"min\")\n if save_delta != None and save_delta >= config.train.save_every_seconds:\n save_trainer(\n trainer=trainer,\n tracker=tracker,\n is_best=False,\n is_latest=True,\n samples_seen=samples_seen,\n epoch=epoch,\n best_validation_loss=best_validation_loss,\n )\n save_timer.reset()\n elif save_delta is None:\n click.secho(\n f\"Error occured reading the save time on rank: {trainer.device}\",",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:567-598"
+ },
+ "809": {
+ "file_id": 24,
+ "content": "This code segment resets the validation timer and handles errors in reading eval and save times. It saves the latest model if the elapsed time meets a certain condition, and resets the save timer. This helps keep track of the training progress and ensures timely saving of models for later use.",
+ "type": "comment"
+ },
+ "810": {
+ "file_id": 24,
+ "content": " fg=\"yellow\",\n )\n # evaluate on test data\n if trainer.accelerator.is_main_process:\n click.secho(f\"Starting Test\", fg=\"red\")\n # save one last time as latest before beginning validation\n save_trainer(\n tracker=tracker,\n trainer=trainer,\n is_best=False,\n is_latest=True,\n samples_seen=samples_seen,\n epoch=epoch,\n best_validation_loss=best_validation_loss,\n )\n test_loss = eval_model(\n trainer=trainer,\n dataloader=test_loader,\n text_conditioned=config.prior.condition_on_text_encodings,\n split=\"test\",\n tracker=tracker,\n use_ema=True,\n report_cosine=False,\n report_loss=True,\n timesteps=config.train.eval_timesteps,\n loss_type=config.prior.loss_type,\n )\n if test_loss < best_validation_loss:\n best_validation_loss = test_loss\n # go save the model as best\n save_trainer(\n trainer=trainer,\n tracker=tracker,\n is_best=True,",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:599-640"
+ },
+ "811": {
+ "file_id": 24,
+ "content": "Starting test phase and saving the last model as latest before validation. If test loss is lower than previous best validation loss, it will be saved as the new best model.",
+ "type": "comment"
+ },
+ "812": {
+ "file_id": 24,
+ "content": " is_latest=False,\n samples_seen=samples_seen,\n epoch=epoch,\n best_validation_loss=test_loss,\n )\ndef initialize_training(config_file, accelerator):\n \"\"\"\n Parse the configuration file, and prepare everything necessary for training\n \"\"\"\n # load the configuration file\n if accelerator.is_main_process:\n click.secho(f\"Loading configuration from {config_file}\", fg=\"green\")\n config = TrainDiffusionPriorConfig.from_json_path(config_file)\n # seed\n set_seed(config.train.random_seed)\n # get a device\n device = accelerator.device\n # make the trainer (will automatically distribute if possible & configured)\n trainer: DiffusionPriorTrainer = make_model(\n config.prior, config.train, device, accelerator\n ).to(device)\n # create a tracker\n tracker = create_tracker(\n accelerator, config, config_file, dummy=accelerator.process_index != 0\n )\n # reload from chcekpoint\n if tracker.can_recall:\n current_epoch, best_validation_loss, samples_seen = recall_trainer(",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:641-681"
+ },
+ "813": {
+ "file_id": 24,
+ "content": "The function initialize_training is responsible for loading the configuration file, setting the seed, getting a device, making the trainer, and creating a tracker. The trainer is automatically distributed if possible and configured. Additionally, the function checks whether it can recall from a checkpoint.",
+ "type": "comment"
+ },
+ "814": {
+ "file_id": 24,
+ "content": " tracker=tracker, trainer=trainer\n )\n # display best values\n if trainer.accelerator.is_main_process:\n click.secho(f\"Current Epoch: {current_epoch} | Best Val Loss: {best_validation_loss} | Samples Seen: {samples_seen}\", fg=\"yellow\")\n # update config to reflect recalled values\n config.train.num_samples_seen = samples_seen\n config.train.current_epoch = current_epoch\n config.train.best_validation_loss = best_validation_loss\n # fetch and prepare data\n if trainer.accelerator.is_main_process:\n click.secho(\"Grabbing data...\", fg=\"blue\", blink=True)\n trainer.accelerator.wait_for_everyone()\n img_reader = get_reader(\n text_conditioned=trainer.text_conditioned,\n img_url=config.data.image_url,\n meta_url=config.data.meta_url,\n )\n # calculate start point within epoch\n trainer.accelerator.wait_for_everyone()\n train_loader, eval_loader, test_loader = make_splits(\n text_conditioned=trainer.text_conditioned,",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:682-711"
+ },
+ "815": {
+ "file_id": 24,
+ "content": "This code block displays the current epoch, best validation loss, and samples seen, updates configuration with recalled values, fetches and prepares data for training by creating a loader, and calculates the start point within the epoch.",
+ "type": "comment"
+ },
+ "816": {
+ "file_id": 24,
+ "content": " batch_size=config.data.batch_size,\n num_data_points=config.data.num_data_points,\n train_split=config.data.splits.train,\n eval_split=config.data.splits.val,\n image_reader=img_reader,\n rank=accelerator.state.process_index,\n world_size=accelerator.state.num_processes,\n start=0,\n )\n # update the start point to finish out the epoch on a resumed run\n if tracker.can_recall:\n samples_seen = config.train.num_samples_seen\n length = (\n config.data.num_data_points\n if samples_seen <= img_reader.count\n else img_reader.count\n )\n scaled_samples = length * config.train.current_epoch\n start_point = (\n scaled_samples - samples_seen if scaled_samples > samples_seen else samples_seen\n )\n if trainer.accelerator.is_main_process:\n click.secho(f\"Resuming at sample: {start_point}\", fg=\"yellow\")\n train_loader.dataset.set_start(start_point)\n # start training\n if trainer.accelerator.is_main_process:",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:712-743"
+ },
+ "817": {
+ "file_id": 24,
+ "content": "This code initializes a data loader and sets the start point for resuming training if necessary. It ensures that the training continues from where it left off in a previous run by adjusting the number of samples seen based on the total number of data points and the current epoch. The main process prints a message indicating the resumption sample count.",
+ "type": "comment"
+ },
+ "818": {
+ "file_id": 24,
+ "content": " click.secho(\n f\"Beginning Prior Training : Distributed={accelerator.state.distributed_type != accelerate_dataclasses.DistributedType.NO}\",\n fg=\"yellow\",\n )\n train(\n trainer=trainer,\n tracker=tracker,\n train_loader=train_loader,\n eval_loader=eval_loader,\n test_loader=test_loader,\n config=config,\n )\n@click.command()\n@click.option(\"--config_file\", default=\"configs/train_prior_config.example.json\")\ndef main(config_file):\n # start HFA\n accelerator = Accelerator()\n # setup training\n initialize_training(config_file, accelerator)\nif __name__ == \"__main__\":\n main()",
+ "type": "code",
+ "location": "/train_diffusion_prior.py:744-770"
+ },
+ "819": {
+ "file_id": 24,
+ "content": "Beginning Prior Training message with distributed status. Then, initiates training process using provided configurations and loaders for trainer, tracker, train_loader, eval_loader, and test_loader. Finally, executes main function with the specified config file to start Heterogeneous Fusion Acceleration (HFA) and set up the training environment.",
+ "type": "comment"
+ }
+}
\ No newline at end of file
diff --git a/docs/data/titles/0.json b/docs/data/titles/0.json
new file mode 100644
index 00000000..47b0cd74
--- /dev/null
+++ b/docs/data/titles/0.json
@@ -0,0 +1,302 @@
+{
+ "/MANIFEST.in": "Including Dalle2 PyTorch .txt Files",
+ "/Makefile": "CUDA Test Installation",
+ "/README.md": "Enhancing DALL-E 2 with Layers and Inpainting",
+ "/README.md:1-13": "DALL-E 2 PyTorch Implementation: SOTA Text-to-Image Synthesis",
+ "/README.md:1011-1046": "Decoder Training and Image Generation",
+ "/README.md:1046-1050": "WebDataset Embedding Dataset Format",
+ "/README.md:1051-1066": "Create Image Embedding Dataloader with URL Path and Optional Folder",
+ "/README.md:1067-1093": "Load and Verify ImageEmbeddingDataset",
+ "/README.md:1094-1102": "DDPM Model: Building and Customizing",
+ "/README.md:1103-1113": "Comprehensive DALLE2-pytorch Tasklist",
+ "/README.md:1113-1121": "Diffusion Prior Model for Image Gen: Hyperparams, Cross-Scale Embedding",
+ "/README.md:1122-1130": "DALLE2-pytorch: Implemented Tasks and Features",
+ "/README.md:1131-1165": "Task List for DALLE2-pytorch Project",
+ "/README.md:1166-1198": "BibTeX Entry Template\nThe title: BibTeX Entry Template",
+ "/README.md:1199-1225": "CrossFormer: Vision Transformer with Cross-Scale Attention",
+ "/README.md:1226-1261": "BibTeX Entries for Recent Diffusion Model Research",
+ "/README.md:1262-1293": "BibTeX Entry Structure",
+ "/README.md:1294-1311": "BibTeX Citation Format for Two Papers",
+ "/README.md:13-25": "Simplified DALL-E 2 Model with Imagen Decoder",
+ "/README.md:138-181": "U-Net Model Creation and Decoder for Image Generation",
+ "/README.md:182-223": "Diffusion Prior Network for Image-Text Generation",
+ "/README.md:223-269": "CLIP Model with Cascaded UNETs Initialization",
+ "/README.md:269-298": "U-Net Decoder with Diffusion Prior for DALL-E2",
+ "/README.md:27-36": "Diffusion Models on Steroids",
+ "/README.md:299-353": "DALLE2 Image Generation and Training Code",
+ "/README.md:354-400": "Training DALLE-like Image Gen Model with PyTorch Unets",
+ "/README.md:38-48": "Acknowledging Contributors",
+ "/README.md:401-431": "DALLE2 Model Training and Image Generation",
+ "/README.md:432-474": "Diffusion Prior Setup for CLIP Model",
+ "/README.md:475-510": "Diffusion Prior with AR Transformer and CLIP Embeddings",
+ "/README.md:49-70": "Installation and Usage Guide",
+ "/README.md:511-544": "Diffusion Model Training with OpenAI CLIP",
+ "/README.md:544-589": "Integrating CLIP with DALLE2 for Training",
+ "/README.md:590-625": "Training and Generating with DALLE2",
+ "/README.md:626-663": "Open Clip Image Processing with Code Examples",
+ "/README.md:664-700": "Initialize DALL-E 2 Model Components",
+ "/README.md:702-731": "DALL-E2 Inpainting with Latent Diffusion",
+ "/README.md:71-91": "CLIP Model Setup and Configurations",
+ "/README.md:731-762": "Enhancing Autoencoder with VQGAN-VAE",
+ "/README.md:763-807": "Setting Up DALLE2 Components: VQGanVAE, Unets, and Decoder",
+ "/README.md:808-846": "Training Multi-UNET Decoder with Cascading DDPM",
+ "/README.md:846-888": "Text-to-Image Generation with CLIP and Unet",
+ "/README.md:889-926": "Training Unets for Image Generation",
+ "/README.md:92-136": "Training CLIP Decoder Using Unet",
+ "/README.md:926-971": "Diffusion Prior for CLIP Model Training",
+ "/README.md:972-1009": "Train Diffusion Prior with EMA and Sampling",
+ "/configs/README.md": "Customizable DALLE2 Training Configs",
+ "/configs/README.md:1-24": "DALLE2 Training Configuration Guide",
+ "/configs/README.md:117-141": "Configuring DALLE2-pytorch: Logging and Loading Options",
+ "/configs/README.md:142-164": "Checkpoint File Loading and Saving Options",
+ "/configs/README.md:165-182": "Saving Options for Models and Metadata",
+ "/configs/README.md:183-185": "Configuring Wandb Run Paths",
+ "/configs/README.md:25-40": "U-Net Model Configuration Options",
+ "/configs/README.md:41-51": "Flexible Dataloader Configuration",
+ "/configs/README.md:51-58": "Shard Allocation and Dataset Configurations",
+ "/configs/README.md:60-74": "Training Hyperparameters Control",
+ "/configs/README.md:74-87": "DALLE2 Model Training Configurations",
+ "/configs/README.md:87-98": "Tracking and Metrics for DALLE2-pytorch",
+ "/configs/README.md:99-116": "DALLE2-PyTorch Configuration Settings",
+ "/dalle2_pytorch/__init__.py": "DALLE2-PyTorch Library Imports",
+ "/dalle2_pytorch/cli.py": "DALL-E2 Image Generation CLI",
+ "/dalle2_pytorch/cli.py:1-33": "DALL-E 2 Command-Line Interface",
+ "/dalle2_pytorch/cli.py:34-52": "DALL-E2 Image Generation",
+ "/dalle2_pytorch/dalle2_pytorch.py": "Dalle2-Python: AI Image Generation",
+ "/dalle2_pytorch/dalle2_pytorch.py:1-49": "Image Processing and Modeling Library",
+ "/dalle2_pytorch/dalle2_pytorch.py:1018-1052": "Scaled Causal Transformer Model",
+ "/dalle2_pytorch/dalle2_pytorch.py:1053-1076": "Model Initialization for DALLE-2",
+ "/dalle2_pytorch/dalle2_pytorch.py:1078-1103": "Prepping DALL-E 2 Input Data",
+ "/dalle2_pytorch/dalle2_pytorch.py:1105-1134": "Masking and Type Conversion in DALL-E 2 PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:1135-1174": "DiffusionPrior: Causal Transformers for DDPM",
+ "/dalle2_pytorch/dalle2_pytorch.py:1175-1188": "DALLE2 Model Initialization",
+ "/dalle2_pytorch/dalle2_pytorch.py:1190-1214": "Model Initialization",
+ "/dalle2_pytorch/dalle2_pytorch.py:1216-1227": "Dalle2 PyTorch: Dimension Consistency and CLIP Check",
+ "/dalle2_pytorch/dalle2_pytorch.py:1229-1255": "Dalle2 PyTorch Parameters and Properties",
+ "/dalle2_pytorch/dalle2_pytorch.py:1256-1274": "Conditional Dropout Assertion",
+ "/dalle2_pytorch/dalle2_pytorch.py:1276-1296": "DDPMLearner: PyTorch Image Sampling Functions",
+ "/dalle2_pytorch/dalle2_pytorch.py:1296-1316": "Dalle2 PyTorch: Image Sampling and Embedding Functions",
+ "/dalle2_pytorch/dalle2_pytorch.py:1318-1342": "Neural Network Image Generation",
+ "/dalle2_pytorch/dalle2_pytorch.py:1344-1372": "L2-Norm Clamping for Dalle2 PyTorch Noise Prediction",
+ "/dalle2_pytorch/dalle2_pytorch.py:135-179": "Gradient Control and Model Utilities",
+ "/dalle2_pytorch/dalle2_pytorch.py:1373-1396": "DALLE2-pytorch Embedding Scaling",
+ "/dalle2_pytorch/dalle2_pytorch.py:1397-1425": "Loss Calculation Method in Dalle2_PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:1426-1454": "DALL-E 2 Image Embedding Sampling",
+ "/dalle2_pytorch/dalle2_pytorch.py:1456-1481": "Retrieve Image Embeddings and Compute Similarity",
+ "/dalle2_pytorch/dalle2_pytorch.py:1482-1504": "Image and Text Encoding Verification",
+ "/dalle2_pytorch/dalle2_pytorch.py:1506-1543": "Upsampling Techniques in Dalle2 PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:1544-1574": "WeightStandardizedConv2d: Fusion of Group Norm & Conv2D",
+ "/dalle2_pytorch/dalle2_pytorch.py:1575-1605": "3 Classes: Reshape, Embedding, ConvBlock",
+ "/dalle2_pytorch/dalle2_pytorch.py:1607-1649": "ResnetBlock: Compute, Project, Normalize",
+ "/dalle2_pytorch/dalle2_pytorch.py:1650-1676": "Encoder-Decoder Architecture with Residual Connections and Cross-Attention",
+ "/dalle2_pytorch/dalle2_pytorch.py:1678-1711": "CrossAttention Class Initialization",
+ "/dalle2_pytorch/dalle2_pytorch.py:1712-1743": "Multi-Head Attention Layer Definition",
+ "/dalle2_pytorch/dalle2_pytorch.py:1745-1780": "Multi-Head Linear Attention Module",
+ "/dalle2_pytorch/dalle2_pytorch.py:1782-1816": "Attention and Transformation Layer",
+ "/dalle2_pytorch/dalle2_pytorch.py:180-214": "Image Normalization and CLIP Model Adapter in DALLE2-PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:1817-1849": "Convolutional Network with Adjustable Upsampling Combiner",
+ "/dalle2_pytorch/dalle2_pytorch.py:1850-1877": "Unet Model: Dalle2 PyTorch Implementation",
+ "/dalle2_pytorch/dalle2_pytorch.py:1878-1895": "DALLE2 Model Customization and Optimization",
+ "/dalle2_pytorch/dalle2_pytorch.py:1896-1927": "DDPM Model Initialization: Parameters and Hyperparameters",
+ "/dalle2_pytorch/dalle2_pytorch.py:1929-1956": "DALL-E 2 Input Processing Layers",
+ "/dalle2_pytorch/dalle2_pytorch.py:1957-1981": "DALL-E 2 Model Architecture: Linear, Layer Normalization, GELU",
+ "/dalle2_pytorch/dalle2_pytorch.py:1983-2006": "Initializing DALL-E 2 Model Components",
+ "/dalle2_pytorch/dalle2_pytorch.py:2008-2035": "DALL-E 2 Model Initialization Code",
+ "/dalle2_pytorch/dalle2_pytorch.py:2037-2059": "Efficient UNet Initialization",
+ "/dalle2_pytorch/dalle2_pytorch.py:2060-2076": "Neural Network Module Initializer",
+ "/dalle2_pytorch/dalle2_pytorch.py:2076-2094": "ResNet Blocks with Attention Layers",
+ "/dalle2_pytorch/dalle2_pytorch.py:2094-2116": "DALL·E 2 Model Architecture in PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:2118-2146": "Uneting DDPM Parameters Checker",
+ "/dalle2_pytorch/dalle2_pytorch.py:2147-2185": "Conditional Scaling in Dalle2 PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:215-257": "Clip Model Embedding Adapter",
+ "/dalle2_pytorch/dalle2_pytorch.py:2187-2216": "Low Resolution Conditioning in DALL-E 2",
+ "/dalle2_pytorch/dalle2_pytorch.py:2217-2243": "Conditional Dropout in DALL-E 2 PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:2244-2265": "Preparing Input for Classifier-Free Guidance Model",
+ "/dalle2_pytorch/dalle2_pytorch.py:2267-2288": "Text Padding and Masking in DALLE-PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:2289-2318": "Conditioning Token Handling in DALLE2-pytorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:2320-2356": "U-Net Model Initialization in Dalle2PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:2356-2393": "Upscaling Network and Conditioner",
+ "/dalle2_pytorch/dalle2_pytorch.py:2394-2418": "Generate Noise Image Object with Parameters",
+ "/dalle2_pytorch/dalle2_pytorch.py:2419-2447": "Conditional Resize, Blur, Downsample Function",
+ "/dalle2_pytorch/dalle2_pytorch.py:2448-2471": "Gaussian Blur and Noise Image Conditioning",
+ "/dalle2_pytorch/dalle2_pytorch.py:2473-2496": "Decoder Class Parameters",
+ "/dalle2_pytorch/dalle2_pytorch.py:2497-2509": "DDPM Configuration in DALLE2-PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:2509-2526": "Initializing Dalle2 Object with Parameters",
+ "/dalle2_pytorch/dalle2_pytorch.py:2527-2555": "CoCa Adapter and Model Freezing in Dalle2 PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:2557-2577": "Initializing DALL-E 2 Networks",
+ "/dalle2_pytorch/dalle2_pytorch.py:2578-2597": "Setting Up Unets and Vaes for Dalle2 Model",
+ "/dalle2_pytorch/dalle2_pytorch.py:258-289": "CoCaAdapter: DALL-E 2 Base Adapter",
+ "/dalle2_pytorch/dalle2_pytorch.py:2598-2621": "VAE Instance Appending and Sampling",
+ "/dalle2_pytorch/dalle2_pytorch.py:2623-2641": "Noise Scheduler Creation for UNETs",
+ "/dalle2_pytorch/dalle2_pytorch.py:2642-2666": "Model Parameters Configuration",
+ "/dalle2_pytorch/dalle2_pytorch.py:2668-2691": "Lowres Unet Initialization",
+ "/dalle2_pytorch/dalle2_pytorch.py:2692-2725": "Dynamic Image Generation Model Setup",
+ "/dalle2_pytorch/dalle2_pytorch.py:2726-2762": "UNET Collection Management and Inference",
+ "/dalle2_pytorch/dalle2_pytorch.py:2764-2785": "Dynamic Threshold and Classifier-Free Guidance",
+ "/dalle2_pytorch/dalle2_pytorch.py:2785-2803": "Image Decoding with Pre-trained UNET",
+ "/dalle2_pytorch/dalle2_pytorch.py:2804-2820": "Posterior Variance Calculator",
+ "/dalle2_pytorch/dalle2_pytorch.py:2820-2836": "DDPM: Predicting Values in Denoising Diffusion Models",
+ "/dalle2_pytorch/dalle2_pytorch.py:2837-2866": "Initialize Image and Variables\"\nor \n\"Image and Variable Initialization",
+ "/dalle2_pytorch/dalle2_pytorch.py:2867-2890": "Inpainting DALLE 2 Image Sequences",
+ "/dalle2_pytorch/dalle2_pytorch.py:2891-2917": "Diffusion-Based Image Denoising with DALLE2",
+ "/dalle2_pytorch/dalle2_pytorch.py:290-319": "CLIP-Based Text-to-Image Model",
+ "/dalle2_pytorch/dalle2_pytorch.py:2918-2946": "DDIM Sampling Function",
+ "/dalle2_pytorch/dalle2_pytorch.py:2947-2972": "Diffusion Model Inpainting with DALLE-2 PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:2973-2994": "Conditional Image Generation Model",
+ "/dalle2_pytorch/dalle2_pytorch.py:2996-3017": "Predictive Noise Processing in DALL-E 2",
+ "/dalle2_pytorch/dalle2_pytorch.py:3019-3039": "DDPM or DDIM Sampling in p_sample_loop",
+ "/dalle2_pytorch/dalle2_pytorch.py:3041-3075": "DALLE2-pytorch Self-Conditioning Sampling",
+ "/dalle2_pytorch/dalle2_pytorch.py:3077-3098": "Calculating Loss with Variable Targets",
+ "/dalle2_pytorch/dalle2_pytorch.py:3100-3115": "KL Divergence and Decoder Loss Calculation",
+ "/dalle2_pytorch/dalle2_pytorch.py:3117-3149": "Variational Bayes Loss Sampling Function",
+ "/dalle2_pytorch/dalle2_pytorch.py:3150-3163": "Valid Input Checker",
+ "/dalle2_pytorch/dalle2_pytorch.py:3165-3182": "CUDA-Powered Unet Processing",
+ "/dalle2_pytorch/dalle2_pytorch.py:3183-3204": "Denoising Diffusion Model Image Processing",
+ "/dalle2_pytorch/dalle2_pytorch.py:320-360": "Neural Text-to-Image Generation with PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:3205-3233": "UNet Image Generation",
+ "/dalle2_pytorch/dalle2_pytorch.py:3233-3251": "U-Net Initialization in Dalle2 PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:3252-3267": "Input Validation for CLIP Model and Text Encodings",
+ "/dalle2_pytorch/dalle2_pytorch.py:3267-3285": "Conditional VAE Image Processing",
+ "/dalle2_pytorch/dalle2_pytorch.py:3285-3319": "DALLE2 Diffusion Model: Losses and Lowres Conditional Images",
+ "/dalle2_pytorch/dalle2_pytorch.py:3320-3340": "Text-to-Image Model: DALL-E 2 PyTorch Implementation",
+ "/dalle2_pytorch/dalle2_pytorch.py:361-390": "CLIP Embedding for Text",
+ "/dalle2_pytorch/dalle2_pytorch.py:391-433": "Dalle2 PyTorch: Text Class and Model Components",
+ "/dalle2_pytorch/dalle2_pytorch.py:434-459": "Embedded Text Generation with CLIP",
+ "/dalle2_pytorch/dalle2_pytorch.py:460-488": "DALLE2-PyTorch Helper Functions",
+ "/dalle2_pytorch/dalle2_pytorch.py:489-519": "Adaptive Quantile Regression with Cosine or Linear Schedule",
+ "/dalle2_pytorch/dalle2_pytorch.py:50-96": "Utility Functions",
+ "/dalle2_pytorch/dalle2_pytorch.py:520-548": "Noise Scheduler Beta Functions",
+ "/dalle2_pytorch/dalle2_pytorch.py:549-580": "Beta Schedule and Loss Function Registration",
+ "/dalle2_pytorch/dalle2_pytorch.py:581-601": "Diffusion Buffers Calculation and Clipping",
+ "/dalle2_pytorch/dalle2_pytorch.py:602-620": "Loss Reweighting and Posterior Calculation",
+ "/dalle2_pytorch/dalle2_pytorch.py:622-645": "Neural Image Generation Functions",
+ "/dalle2_pytorch/dalle2_pytorch.py:647-677": "Reweighting Loss and Predicting Values",
+ "/dalle2_pytorch/dalle2_pytorch.py:678-711": "Layer Normalization Implementation",
+ "/dalle2_pytorch/dalle2_pytorch.py:712-750": "Residual MLP with Normalization and Activation in DALLE2 PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:751-792": "DALL-E 2 Neural Architecture PyTorch",
+ "/dalle2_pytorch/dalle2_pytorch.py:794-816": "DALLE2-pytorch Attention Layer",
+ "/dalle2_pytorch/dalle2_pytorch.py:817-858": "Post-Activation Normalization with Nested Attention",
+ "/dalle2_pytorch/dalle2_pytorch.py:859-891": "Implementing DALL·E 2 Self-Attention Layer",
+ "/dalle2_pytorch/dalle2_pytorch.py:892-928": "Multi-Head Attention Calculation",
+ "/dalle2_pytorch/dalle2_pytorch.py:930-961": "CausalTransformer Dimensional Manipulation",
+ "/dalle2_pytorch/dalle2_pytorch.py:962-993": "Diffusion Prior Network Model\n\nThe title captures the essence of the content, highlighting the DiffusionPriorNetwork model and its various layers and components. It is concise and informative, fitting within the 3 to 7 word limit provided",
+ "/dalle2_pytorch/dalle2_pytorch.py:97-133": "Helper Functions for Python Lists and Functions",
+ "/dalle2_pytorch/dalle2_pytorch.py:994-1017": "Dalle2 PyTorch Model Architecture",
+ "/dalle2_pytorch/dataloaders/README.md": "DALLE2 PyTorch Datasets & Dataloaders",
+ "/dalle2_pytorch/dataloaders/README.md:1-5": "Efficient Dataloaders for Decoder Training",
+ "/dalle2_pytorch/dataloaders/README.md:20-37": "Dataloader Configuration and Dataset Creation",
+ "/dalle2_pytorch/dataloaders/README.md:38-53": "Disable Resampling in Prior Embedding Dataset",
+ "/dalle2_pytorch/dataloaders/README.md:5-19": "Create Image Embedding Dataloader",
+ "/dalle2_pytorch/dataloaders/README.md:54-75": "Data Splitter for Distributed DALLE2 Training",
+ "/dalle2_pytorch/dataloaders/__init__.py": "Loading Datasets for DALLE2-pytorch",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py": "Dalle2 Dataset Loaders: Exception Handling",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py:1-29": "Decoder Loader: Functions for DALLE-PyTorch",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py:112-136": "Decoder Loader Data Pipeline",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py:137-151": "WebDataset Loader Function",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py:152-169": "Decoder Loader: WebDataset Handler",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py:170-185": "S3 Link Check and Shuffling for Dataset",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py:186-200": "DALLE2 Decoder Loader Implementation",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py:201-225": "Image Embedding Dataloader Creation",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py:226-240": "One-Line Image Embedding Dataset Dataloader",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py:240-261": "Decoder Data Loader Function",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py:262-266": "Decoder Data Loader Configuration",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py:30-49": "Load Embeddings from WebDataset Tar File",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py:50-67": "Embedding Loader for Dalle2 PyTorch",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py:69-84": "Checks Tarfile Embeddings Matching",
+ "/dalle2_pytorch/dataloaders/decoder_loader.py:85-111": "Combine Image and Text Embeddings",
+ "/dalle2_pytorch/dataloaders/prior_loader.py": "Efficient DALL-E 2 Data Loaders",
+ "/dalle2_pytorch/dataloaders/prior_loader.py:1-40": "Simplified Embedding Dataset Loader",
+ "/dalle2_pytorch/dataloaders/prior_loader.py:114-149": "Defining EmbeddingReader Functions",
+ "/dalle2_pytorch/dataloaders/prior_loader.py:151-187": "Split Embedding Reader",
+ "/dalle2_pytorch/dataloaders/prior_loader.py:189-210": "Creating PyTorch Dataloaders for Image-Text Pairs",
+ "/dalle2_pytorch/dataloaders/prior_loader.py:211-242": "Split and Wrap Data Loader",
+ "/dalle2_pytorch/dataloaders/prior_loader.py:243-271": "Creating PriorEmbeddingDataset for Datasets",
+ "/dalle2_pytorch/dataloaders/prior_loader.py:273-282": "Prior Embedding Datasets Creation",
+ "/dalle2_pytorch/dataloaders/prior_loader.py:41-72": "Custom DALLE2 Dataset Class",
+ "/dalle2_pytorch/dataloaders/prior_loader.py:73-112": "DALL-E 2 Data Loader",
+ "/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py": "Simple Image Dataloader",
+ "/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py:1-47": "Simple Image Dataset Loader",
+ "/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py:48-59": "Simple Image Dataset DataLoader",
+ "/dalle2_pytorch/optimizer.py": "Optimizer Creation and Filtering",
+ "/dalle2_pytorch/tokenizer.py": "Streamlined DALL-E2 Tokenization",
+ "/dalle2_pytorch/tokenizer.py:1-42": "DALL-E2 Tokenizer: Easy BPE-Free Training",
+ "/dalle2_pytorch/tokenizer.py:100-126": "BPE Tokenizer Implementation",
+ "/dalle2_pytorch/tokenizer.py:128-151": "SimpleTokenizer: Converting Encoded Tokens to Text",
+ "/dalle2_pytorch/tokenizer.py:153-182": "YTTM Tokenizer in PyTorch",
+ "/dalle2_pytorch/tokenizer.py:183-191": "Token Truncation and Conversion",
+ "/dalle2_pytorch/tokenizer.py:43-69": "Python BPE Tokenizer Class",
+ "/dalle2_pytorch/tokenizer.py:70-98": "Byte-Pair Encoding Tokenizer",
+ "/dalle2_pytorch/trackers.py": "Trackers and Loggers Management",
+ "/dalle2_pytorch/trackers.py:1-35": "Base Logger Class in DALLE2-pytorch",
+ "/dalle2_pytorch/trackers.py:121-143": "Wandb Tracker: Resume and Log Data",
+ "/dalle2_pytorch/trackers.py:145-170": "Logging and Resume Data Tracking Functionality",
+ "/dalle2_pytorch/trackers.py:171-200": "Logger Function for Model Checkpoints",
+ "/dalle2_pytorch/trackers.py:201-229": "Base Class File Loader",
+ "/dalle2_pytorch/trackers.py:230-251": "WandbLoader: Loading Models from W&B Runs",
+ "/dalle2_pytorch/trackers.py:252-278": "Wandb Data Loader Integration",
+ "/dalle2_pytorch/trackers.py:279-299": "PyTorch Tracker Class: Versatile Data Saving",
+ "/dalle2_pytorch/trackers.py:300-328": "Local and Wandb File Savers",
+ "/dalle2_pytorch/trackers.py:329-346": "W&B Run Initialization",
+ "/dalle2_pytorch/trackers.py:347-365": "HuggingFace Saver Class",
+ "/dalle2_pytorch/trackers.py:36-62": "Logger Class with Methods for Data Logging",
+ "/dalle2_pytorch/trackers.py:366-382": "HuggingFace Repo File Saver",
+ "/dalle2_pytorch/trackers.py:383-407": "Creating and Initializing Savers and Trackers",
+ "/dalle2_pytorch/trackers.py:408-427": "Auto-Resume Tracker Setup",
+ "/dalle2_pytorch/trackers.py:428-442": "Auto-Resume Data Validator",
+ "/dalle2_pytorch/trackers.py:443-459": "Auto Resume Tracker Initialization",
+ "/dalle2_pytorch/trackers.py:460-489": "Initializing Trackers in Dalle2PyTorch",
+ "/dalle2_pytorch/trackers.py:491-517": "Tracking Dalle2-PyTorch Data: Saving, Logging, and Metadata",
+ "/dalle2_pytorch/trackers.py:517-531": "Saving Trainer State: Checkpoint vs Model",
+ "/dalle2_pytorch/trackers.py:532-551": "Remove CLIP from Model State",
+ "/dalle2_pytorch/trackers.py:552-575": "Saving and Checkpointing Model Files",
+ "/dalle2_pytorch/trackers.py:576-598": "Checkpoint Recall Manager",
+ "/dalle2_pytorch/trackers.py:598-598": "No Loader, No Resume Error",
+ "/dalle2_pytorch/trackers.py:63-94": "Logger Classes for PyTorch Tracking",
+ "/dalle2_pytorch/trackers.py:95-120": "WandB Logger Python Class",
+ "/dalle2_pytorch/train_configs.py": "DALL-E 2 PyTorch Training Utilities",
+ "/dalle2_pytorch/train_configs.py:1-43": "Training Configs for DALL-E 2 PyTorch",
+ "/dalle2_pytorch/train_configs.py:103-129": "Tracker Function: Saver and Component Management",
+ "/dalle2_pytorch/train_configs.py:130-167": "Neural Network Configurations for Diffusion Models",
+ "/dalle2_pytorch/train_configs.py:168-201": "Training Config Class",
+ "/dalle2_pytorch/train_configs.py:202-223": "Train Configs for Diffusion Prior Model",
+ "/dalle2_pytorch/train_configs.py:224-255": "Dalle-2 Model Config Classes",
+ "/dalle2_pytorch/train_configs.py:256-284": "Train DALL-E 2 with Unet and CLIP",
+ "/dalle2_pytorch/train_configs.py:285-309": "Training Config Generator",
+ "/dalle2_pytorch/train_configs.py:311-329": "DecoderTrainConfig Class Options",
+ "/dalle2_pytorch/train_configs.py:331-361": "Decoding Training Configs",
+ "/dalle2_pytorch/train_configs.py:362-380": "Ensuring Non-Redundant Model Usage",
+ "/dalle2_pytorch/train_configs.py:380-382": "Unnecessary Text Embedding Slowdown",
+ "/dalle2_pytorch/train_configs.py:44-73": "Tracker Log Config Classes",
+ "/dalle2_pytorch/train_configs.py:74-102": "Tracker Config and Loader",
+ "/dalle2_pytorch/trainer.py": "DeepSpeed Trainer Initialization and Computations",
+ "/dalle2_pytorch/trainer.py:1-43": "Trainer Utilities",
+ "/dalle2_pytorch/trainer.py:100-138": "Gradient Accumulation and Splitting Functions",
+ "/dalle2_pytorch/trainer.py:139-159": "Chunked Arguments Yielding",
+ "/dalle2_pytorch/trainer.py:161-192": "Diffusion Prior Trainer Class",
+ "/dalle2_pytorch/trainer.py:194-225": "Accelerator Check and Model Transfer",
+ "/dalle2_pytorch/trainer.py:225-249": "DeepSpeed Trainer Initialization",
+ "/dalle2_pytorch/trainer.py:251-280": "Saveable Trainer with LambdaLR",
+ "/dalle2_pytorch/trainer.py:281-304": "Efficient Checkpoint Saving for Diffusion Trainers",
+ "/dalle2_pytorch/trainer.py:306-327": "Checkpoint Loading Function",
+ "/dalle2_pytorch/trainer.py:329-358": "Warmup, Learning Rate Update, and Optimization",
+ "/dalle2_pytorch/trainer.py:359-386": "Diffusion Prior Model Sampling",
+ "/dalle2_pytorch/trainer.py:388-423": "Embedding Trainer with Chunking",
+ "/dalle2_pytorch/trainer.py:424-454": "Batch-Based Decoder Trainer",
+ "/dalle2_pytorch/trainer.py:45-78": "Grouping and Casting Operations",
+ "/dalle2_pytorch/trainer.py:456-481": "Configuring UNETs in Trainer Initialization",
+ "/dalle2_pytorch/trainer.py:482-510": "Optimizer and EMA Initialization",
+ "/dalle2_pytorch/trainer.py:511-537": "Precision Conversion and Decoder Preparation",
+ "/dalle2_pytorch/trainer.py:538-566": "Optimizers, Schedulers, and Warmup: UNET Validation",
+ "/dalle2_pytorch/trainer.py:567-591": "Save Model States",
+ "/dalle2_pytorch/trainer.py:593-622": "Load and Train Model State from Path",
+ "/dalle2_pytorch/trainer.py:623-652": "Load and Update Saved State: Unet Trainer",
+ "/dalle2_pytorch/trainer.py:653-683": "Sampling Process for Model Optimization",
+ "/dalle2_pytorch/trainer.py:685-717": "Embedding Text and Images with CLIP Unets",
+ "/dalle2_pytorch/trainer.py:719-742": "Accumulating Losses in DALLE2 Trainer",
+ "/dalle2_pytorch/trainer.py:79-99": "DeepSpeed PyTorch Model Argument Casting",
+ "/dalle2_pytorch/utils.py": "Utility Functions for Time, Print, and Import",
+ "/dalle2_pytorch/version.py": "DALLE2-PyTorch Version: 1.15.6",
+ "/dalle2_pytorch/vqgan_vae.py": "VQGAN-VAE Image Generation: Architectures and Algorithms",
+ "/dalle2_pytorch/vqgan_vae.py:1-51": "Vector Quantize Module Setup",
+ "/dalle2_pytorch/vqgan_vae.py:123-163": "Discriminator Design in VQGAN-VAE Architecture"
+}
\ No newline at end of file
diff --git a/docs/data/titles/1.json b/docs/data/titles/1.json
new file mode 100644
index 00000000..ad7a48b1
--- /dev/null
+++ b/docs/data/titles/1.json
@@ -0,0 +1,105 @@
+{
+ "/dalle2_pytorch/vqgan_vae.py:164-197": "VQGAN-VAE Model: Convolutional and Group Normalization",
+ "/dalle2_pytorch/vqgan_vae.py:199-232": "Resnet Encoder/Decoder VQ-VAE for Image Gen",
+ "/dalle2_pytorch/vqgan_vae.py:233-262": "VQGAN-VAE Class Definition",
+ "/dalle2_pytorch/vqgan_vae.py:264-279": "VQ-VAE Model Encoder and Decoder Blocks",
+ "/dalle2_pytorch/vqgan_vae.py:281-315": "VQGAN-VAE Model: Encoder, Decoder, GLUResBlock",
+ "/dalle2_pytorch/vqgan_vae.py:317-354": "VQGAN-ResBlock and Attention Layer: Image Processing",
+ "/dalle2_pytorch/vqgan_vae.py:356-396": "Multi-Head Attention Layer in ViT",
+ "/dalle2_pytorch/vqgan_vae.py:397-433": "Multi-Head Attention for Transformers",
+ "/dalle2_pytorch/vqgan_vae.py:434-476": "Encoder-Decoder ViT Architecture",
+ "/dalle2_pytorch/vqgan_vae.py:477-510": "VQ-VAE Model for Image Generation",
+ "/dalle2_pytorch/vqgan_vae.py:511-562": "VQGAN-VAE: GAN Integrated Variant of VAE",
+ "/dalle2_pytorch/vqgan_vae.py:53-87": "Utility Functions and Tensor Operations",
+ "/dalle2_pytorch/vqgan_vae.py:563-596": "Initialize VQ-VAE Model with Specified Parameters",
+ "/dalle2_pytorch/vqgan_vae.py:597-633": "VQGAN-VAE Model with GAN and Losses",
+ "/dalle2_pytorch/vqgan_vae.py:634-672": "VQGAN-VAE Encoding and Decoding",
+ "/dalle2_pytorch/vqgan_vae.py:674-700": "VQGAN-VAE: Image Encoding and Decoding with Losses",
+ "/dalle2_pytorch/vqgan_vae.py:702-739": "VQ-VAE: Reconstruction and Perceptual Loss",
+ "/dalle2_pytorch/vqgan_vae.py:740-764": "Combined Loss Calculation",
+ "/dalle2_pytorch/vqgan_vae.py:88-121": "Utility Functions for VQ-VAE-GAN Model",
+ "/dalle2_pytorch/vqgan_vae_trainer.py": "VAE Model Training with PyTorch",
+ "/dalle2_pytorch/vqgan_vae_trainer.py:1-47": "Utility Functions and Helper Methods",
+ "/dalle2_pytorch/vqgan_vae_trainer.py:124-150": "Mixed Precision Pytorch Dataset Initializer",
+ "/dalle2_pytorch/vqgan_vae_trainer.py:151-188": "VAE Trainer: Initializing Data Loader and Parameters",
+ "/dalle2_pytorch/vqgan_vae_trainer.py:189-221": "Efficient VAE Training with Discriminator Update",
+ "/dalle2_pytorch/vqgan_vae_trainer.py:223-251": "VAE Discriminator Loss Tracking and Saving",
+ "/dalle2_pytorch/vqgan_vae_trainer.py:253-278": "Periodic VAE Model Saving",
+ "/dalle2_pytorch/vqgan_vae_trainer.py:48-91": "VQGAN-VAE Training Classes",
+ "/dalle2_pytorch/vqgan_vae_trainer.py:92-123": "VQGAN VAE Training Setup",
+ "/prior.md": "CLIP and Diffusion Prior Image Generation",
+ "/prior.md:1-21": "Diffusion Prior for Cross-Space Image Generation",
+ "/prior.md:120-130": "Efficient Embedding Prediction in CLIP Priors",
+ "/prior.md:132-146": "Preparing Embeddings for Training Efficiency",
+ "/prior.md:146-155": "Distributed Training with Accelerate and Metrics",
+ "/prior.md:155-157": "Validation Loss Calculation",
+ "/prior.md:157-158": "EMA vs Online Model: Validation Loss Comparison",
+ "/prior.md:158-160": "Baseline Similarity in DALLE2-PyTorch",
+ "/prior.md:160-161": "Training Efficiency and Overfitting Metrics",
+ "/prior.md:161-163": "Monitoring Cosine Similarity for Overfitting Prevention",
+ "/prior.md:163-175": "Training Diffusion Model: Launch and Save Checkpoints",
+ "/prior.md:175-183": "Latest.pth: Avoiding Save_every Overlaps",
+ "/prior.md:22-47": "Generate Image from Text with Deep Learning Models",
+ "/prior.md:48-76": "Load Pre-trained Prior Model for Enhanced CLIP Embeddings",
+ "/prior.md:77-119": "Instantiating Pre-Trained Model and Tokenization",
+ "/setup.py": "Dalle2-PyTorch Setup Script",
+ "/setup.py:1-42": "Dalle2-PyTorch Setup Script",
+ "/setup.py:43-59": "Python AI Project Setup File",
+ "/train_decoder.py": "DALL-E 2 UNet Model Training Script",
+ "/train_decoder.py:1-28": "Train Decoder Model in DALLE2-pytorch Framework",
+ "/train_decoder.py:104-130": "Zipped List Dataset Extractor",
+ "/train_decoder.py:131-152": "Generate Sample Images and Captions",
+ "/train_decoder.py:153-172": "Preparing Training Samples with Image-Text Embeddings",
+ "/train_decoder.py:173-186": "Grid Image Samples Generator",
+ "/train_decoder.py:187-203": "Decoder Evaluation Metrics",
+ "/train_decoder.py:204-227": "Metrics Calculation and Storage",
+ "/train_decoder.py:228-247": "Metrics Calculation and Normalization for Model Performance",
+ "/train_decoder.py:248-264": "Trainer Functions: Train, Save, Recall",
+ "/train_decoder.py:266-294": "Train Decoder Function",
+ "/train_decoder.py:295-322": "Training Decoder Module",
+ "/train_decoder.py:30-64": "Splitting Datasets for Training, Validation & Testing",
+ "/train_decoder.py:323-337": "Model Loading and Training Progress",
+ "/train_decoder.py:339-357": "Epoch Sample Counting in Train Decoder",
+ "/train_decoder.py:358-380": "Train Decoder Model with Image and Text Embeddings",
+ "/train_decoder.py:381-394": "Training DALL-E 2 Decoder with CLIP Embeddings",
+ "/train_decoder.py:396-413": "Averaging Losses: Unet Decoder Training",
+ "/train_decoder.py:414-431": "Model Snapshotting and Logging",
+ "/train_decoder.py:431-448": "Training and Validation Code Generation",
+ "/train_decoder.py:449-467": "Prepare Data for Evaluation",
+ "/train_decoder.py:469-484": "Auto-generating Image Embeddings in Train Decoder",
+ "/train_decoder.py:485-500": "Validation Loss Calculator",
+ "/train_decoder.py:501-515": "Averaging and Logging Validation Losses",
+ "/train_decoder.py:516-528": "Train Decoder: Sample Generation and Model Saving",
+ "/train_decoder.py:530-546": "Saving Trainer with New Minimum Loss",
+ "/train_decoder.py:547-563": "DALLE2 Distributed Training Initialization",
+ "/train_decoder.py:564-581": "Distributed Training Utilities",
+ "/train_decoder.py:582-603": "Initialize Decoder Model and Tracker",
+ "/train_decoder.py:605-627": "CLIP Embedding Checker",
+ "/train_decoder.py:628-647": "Train Decoder: Click CLI with Config Options",
+ "/train_decoder.py:64-80": "Random Shard Dataset Splitter",
+ "/train_decoder.py:648-651": "Init and Call Main",
+ "/train_decoder.py:81-103": "Multi-Dataset Dataloaders",
+ "/train_diffusion_prior.py": "Train Diffusion Prior with PyTorch",
+ "/train_diffusion_prior.py:1-45": "Training Diffusion Prior Model with PyTorch",
+ "/train_diffusion_prior.py:124-160": "Tensor Data Processing and Model Logging",
+ "/train_diffusion_prior.py:162-201": "Model Training Tracker and Evaluation",
+ "/train_diffusion_prior.py:202-239": "Validation Loss Calculator",
+ "/train_diffusion_prior.py:241-275": "Cosine Similarity Measurement in Diffusion Model",
+ "/train_diffusion_prior.py:276-303": "Shuffled Embeddings for Diffusion Training",
+ "/train_diffusion_prior.py:304-334": "Text-Image Similarity with Diffusion Models",
+ "/train_diffusion_prior.py:335-360": "Embedding Similarity Tracker in Diffusion Model",
+ "/train_diffusion_prior.py:361-398": "Evaluate Diffusion Prior Model",
+ "/train_diffusion_prior.py:399-440": "Cosine Similarity Measurement and Training Script",
+ "/train_diffusion_prior.py:441-471": "Epoch-wise Dataset Reset and Tracking",
+ "/train_diffusion_prior.py:46-83": "Create Trainer and Tracker for Diffusion Prior",
+ "/train_diffusion_prior.py:473-504": "Backpropagation and EMA Updating",
+ "/train_diffusion_prior.py:505-536": "Evaluating Model on Validation Data",
+ "/train_diffusion_prior.py:537-566": "Best Model Validation and Time Logging",
+ "/train_diffusion_prior.py:567-598": "Reset Validation Timer and Handle Errors",
+ "/train_diffusion_prior.py:599-640": "Saving Best Model with Lower Test Loss",
+ "/train_diffusion_prior.py:641-681": "Trainer Initialization and Distribution",
+ "/train_diffusion_prior.py:682-711": "Epoch Tracker Code Snippet",
+ "/train_diffusion_prior.py:712-743": "Resume Training Data Loader",
+ "/train_diffusion_prior.py:744-770": "Initiating Heterogeneous Fusion Training",
+ "/train_diffusion_prior.py:84-122": "Process-Aware Aggregation Function"
+}
\ No newline at end of file
diff --git a/docs/doc/17c1111a-5b83-460f-9a69-fe3153f8f0a4.json b/docs/doc/17c1111a-5b83-460f-9a69-fe3153f8f0a4.json
new file mode 100644
index 00000000..2b2f7e37
--- /dev/null
+++ b/docs/doc/17c1111a-5b83-460f-9a69-fe3153f8f0a4.json
@@ -0,0 +1,10 @@
+{
+ "summary": "This code is specifying the recursive inclusion of all .txt files in the dalle2_pytorch directory for the MANIFEST.in file.",
+ "details": [
+ {
+ "comment": "This code is specifying the recursive inclusion of all .txt files in the dalle2_pytorch directory for the MANIFEST.in file.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/MANIFEST.in\":0-0",
+ "content": "recursive-include dalle2_pytorch *.txt"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/1e021c3d-2b35-4049-a4dd-f79ad7b30708.json b/docs/doc/1e021c3d-2b35-4049-a4dd-f79ad7b30708.json
new file mode 100644
index 00000000..5a81179d
--- /dev/null
+++ b/docs/doc/1e021c3d-2b35-4049-a4dd-f79ad7b30708.json
@@ -0,0 +1,30 @@
+{
+ "summary": "The code creates a dataloader for image embedding datasets and sets up training, evaluation, and testing splits for three ranks using the provided config TRAIN_ARGS. It uses img2dataset, clip-retrieval, and embedding-dataset-reordering tools to load images and embeddings without resampling.",
+ "details": [
+ {
+ "comment": "This code snippet describes the usage of general dataloaders for efficient data loading and training portions of the network, particularly focusing on the decoder. It supports two types of datasets: a webdataset containing .jpg and .npy files in .tar formats or an external source where .npy files correspond to .jpg filenames from the webdataset.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/README.md\":0-4",
+ "content": "## Dataloaders\nIn order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.\n### Decoder: Image Embedding Dataset\nWhen training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digit"
+ },
+ {
+ "comment": "This code demonstrates how to create a dataloader for an image embedding dataset. It utilizes three separate tools: img2dataset, clip-retrieval, and embedding-dataset-reordering. The user must provide the appropriate URLs for the webdataset and embeddings folder in order to generate the dataloader. The code snippet also highlights the usage of create_image_embedding_dataloader function which takes in URL parameters and returns a dataloader object.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/README.md\":4-18",
+ "content": "s are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.\nGenerating a dataset of this type:\n1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.\n2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.\n3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.\nUsage:\n```python\nfrom dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader\n# Create a dataloader directly.\ndataloader = create_image_embedding_dataloader(\n tar_url=\"/path/or/url/to/webdataset/{0000..9999}.tar\", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar\n embeddings_url=\"path/or/url/to/embeddings/folder\", # Included if .npy files are not in webdataset. Left out or set to None otherwise"
+ },
+ {
+ "comment": "This code initializes a dataloader with parameters such as number of workers, batch size, shard width, and shuffle settings. It loads images and their corresponding embeddings from webdataset files. The images' shapes are printed for a single epoch. An ImageEmbeddingDataset is also created without a loader for manual configuration.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/README.md\":19-36",
+ "content": " num_workers=4,\n batch_size=32,\n shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index\n shuffle_num=200, # Does a shuffle of the data with a buffer size of 200\n shuffle_shards=True, # Shuffle the order the shards are read in\n resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually\n)\nfor img, emb in dataloader:\n print(img.shape) # torch.Size([32, 3, 256, 256])\n print(emb.shape) # torch.Size([32, 512])\n # Train decoder only as shown above\n# Or create a dataset without a loader so you can configure it manually\ndataset = ImageEmbeddingDataset(\n urls=\"/path/or/url/to/webdataset/{0000..9999}.tar\",\n embedding_folder_url=\"path/or/url/to/embeddings/folder\",\n shard_width=4,\n shuffle_shards=True,"
+ },
+ {
+ "comment": "The `resample=False` argument is used to disable resampling when processing the embeddings in the Prior Embedding Dataset. This ensures that the embeddings are not recomputed and can be efficiently used for both embedding-only and text-conditioned prior training.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/README.md\":37-52",
+ "content": " resample=False\n)\n```\n### Diffusion Prior: Prior Embedding Dataset\nWhen training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.\nTo utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run.\nIf you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters.\nUsage:\n```python\nfrom dalle2_pytorch.dataloaders import get_reader, make_splits\n# grab embeddings from some specified location"
+ },
+ {
+ "comment": "The code sets up training, evaluation, and testing splits for three different ranks (0, 1, 2) using the provided config TRAIN_ARGS. It uses the get_reader function to load image and metadata from specified URLs, and the make_splits function to divide the data into train, eval, and test sets for distributed training.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/README.md\":53-74",
+ "content": "IMG_URL = \"data/img_emb/\"\nMETA_URL = \"data/meta/\"\nreader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)\n# some config for training\nTRAIN_ARGS = {\n \"world_size\": 3,\n \"text_conditioned\": True,\n \"start\": 0,\n \"num_data_points\": 10000,\n \"batch_size\": 2,\n \"train_split\": 0.5,\n \"eval_split\": 0.25,\n \"image_reader\": reader,\n}\n# specifying a rank will handle allocation internally\nrank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)\nrank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)\nrank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)\n```"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/27ca1601-0cbd-4541-8b0a-809d085f31b2.json b/docs/doc/27ca1601-0cbd-4541-8b0a-809d085f31b2.json
new file mode 100644
index 00000000..955079d5
--- /dev/null
+++ b/docs/doc/27ca1601-0cbd-4541-8b0a-809d085f31b2.json
@@ -0,0 +1,135 @@
+{
+ "summary": "The code initializes trackers and loggers, provides methods for logging data, saving configurations, and metadata. It saves states and models, manages loading/saving checkpoints, and handles errors with a \"recall()\" function.",
+ "details": [
+ {
+ "comment": "This code is from the \"trackers.py\" file in the DALLE2-pytorch library, containing a class for base logger objects that can log data with optional data storage path and verbosity control. The class initializes with specified parameters like data_path, resume, auto_resume, and verbose. It uses Pathlib for path manipulation and supports temporary data storage.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":0-34",
+ "content": "import urllib.request\nimport os\nimport json\nfrom pathlib import Path\nimport shutil\nfrom itertools import zip_longest\nfrom typing import Any, Optional, List, Union\nfrom pydantic import BaseModel\nimport torch\nfrom dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior\nfrom dalle2_pytorch.utils import import_or_print_error\nfrom dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer\nfrom dalle2_pytorch.version import __version__\nfrom packaging import version\n# constants\nDEFAULT_DATA_PATH = './.tracker-data'\n# helper functions\ndef exists(val):\n return val is not None\nclass BaseLogger:\n \"\"\"\n An abstract class representing an object that can log data.\n Parameters:\n data_path (str): A file path for storing temporary data.\n verbose (bool): Whether of not to always print logs to the console.\n \"\"\"\n def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):\n self.data_path = Path(data_path)\n self.resume = resume"
+ },
+ {
+ "comment": "The code defines a logger class with methods for logging different types of data, and an initialization method to set up the logger. The logger raises a NotImplementedError for each method, which means they need to be implemented in child classes. The get_resume_data method sets tracker attributes used to resume training if needed.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":35-61",
+ "content": " self.auto_resume = auto_resume\n self.verbose = verbose\n def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:\n \"\"\"\n Initializes the logger.\n Errors if the logger is invalid.\n full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.\n \"\"\"\n raise NotImplementedError\n def log(self, log, **kwargs) -> None:\n raise NotImplementedError\n def log_images(self, images, captions=[], image_section=\"images\", **kwargs) -> None:\n raise NotImplementedError\n def log_file(self, file_path, **kwargs) -> None:\n raise NotImplementedError\n def log_error(self, error_string, **kwargs) -> None:\n raise NotImplementedError\n def get_resume_data(self, **kwargs) -> dict:\n \"\"\"\n Sets tracker attributes that along with { \"resume\": True } will be used to resume training.\n It is assumed that after init is called this data will be complete."
+ },
+ {
+ "comment": "This code defines two logger classes, ConsoleLogger and WandbLogger, which inherit from the BaseLogger class. The ConsoleLogger logs to the console while the WandbLogger logs data to a Weights & Biases (WandB) run. Both loggers have methods for logging different types of data such as logs, images, files, and errors. The ConsoleLogger returns an empty dictionary if resuming is not supported, whereas the WandbLogger requires additional parameters like wandb_entity, wandb_project, wandb_run_id, and wandb_run_name for proper functioning.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":62-93",
+ "content": " If the logger does not have any resume functionality, it should return an empty dict.\n \"\"\"\n raise NotImplementedError\nclass ConsoleLogger(BaseLogger):\n def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:\n print(\"Logging to console\")\n def log(self, log, **kwargs) -> None:\n print(log)\n def log_images(self, images, captions=[], image_section=\"images\", **kwargs) -> None:\n pass\n def log_file(self, file_path, **kwargs) -> None:\n pass\n def log_error(self, error_string, **kwargs) -> None:\n print(error_string)\n def get_resume_data(self, **kwargs) -> dict:\n return {}\nclass WandbLogger(BaseLogger):\n \"\"\"\n Logs to a wandb run.\n Parameters:\n data_path (str): A file path for storing temporary data.\n wandb_entity (str): The wandb entity to log to.\n wandb_project (str): The wandb project to log to.\n wandb_run_id (str): The wandb run id to resume.\n wandb_run_name (str): The wandb run name to use."
+ },
+ {
+ "comment": "This code is a Python class for creating and initializing a WandB logger. It requires a data path, WandB entity, and project parameters. The class also supports additional configuration options. If the WandB entity or project are not specified, an error will be raised.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":94-119",
+ "content": " \"\"\"\n def __init__(self,\n data_path: str,\n wandb_entity: str,\n wandb_project: str,\n wandb_run_id: Optional[str] = None,\n wandb_run_name: Optional[str] = None,\n **kwargs\n ):\n super().__init__(data_path, **kwargs)\n self.entity = wandb_entity\n self.project = wandb_project\n self.run_id = wandb_run_id\n self.run_name = wandb_run_name\n def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:\n assert self.entity is not None, \"wandb_entity must be specified for wandb logger\"\n assert self.project is not None, \"wandb_project must be specified for wandb logger\"\n self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')\n os.environ[\"WANDB_SILENT\"] = \"true\"\n # Initializes the wandb run\n init_object = {\n \"entity\": self.entity,\n \"project\": self.project,\n \"config\": {**full_config.dict(), **extra_config}\n }"
+ },
+ {
+ "comment": "This code initializes a Wandb tracker, allowing for easy logging of data to a specific run. If `run_id` is provided and `wandb_resume` is True, the run is resumed with a warning about renaming. The code then logs various types of data including logs, images with captions, using the Wandb API. Verbose output is also supported for logs.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":120-142",
+ "content": " if self.run_name is not None:\n init_object['name'] = self.run_name\n if self.resume:\n assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'\n if self.run_name is not None:\n print(\"You are renaming a run. I hope that is what you intended.\")\n init_object['resume'] = 'must'\n init_object['id'] = self.run_id\n self.wandb.init(**init_object)\n print(f\"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}\")\n def log(self, log, **kwargs) -> None:\n if self.verbose:\n print(log)\n self.wandb.log(log, **kwargs)\n def log_images(self, images, captions=[], image_section=\"images\", **kwargs) -> None:\n \"\"\"\n Takes a tensor of images and a list of captions and logs them to wandb.\n \"\"\"\n wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]\n self.wandb.log({ image_section: wandb_images }, **kwargs)"
+ },
+ {
+ "comment": "The code defines a class with three methods: `log_file`, `log_error`, and `get_resume_data`. The `log_file` method logs a file path, `log_error` logs an error string, and `get_resume_data` returns a dictionary containing essential resume information. Additionally, there is a function `create_logger` which creates a logger of type 'console' or 'wandb'. For now, custom loggers are not supported.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":144-169",
+ "content": " def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:\n if base_path is None:\n # Then we take the basepath as the parent of the file_path\n base_path = Path(file_path).parent\n self.wandb.save(str(file_path), base_path = str(base_path))\n def log_error(self, error_string, step=None, **kwargs) -> None:\n if self.verbose:\n print(error_string)\n self.wandb.log({\"error\": error_string, **kwargs}, step=step)\n def get_resume_data(self, **kwargs) -> dict:\n # In order to resume, we need wandb_entity, wandb_project, and wandb_run_id\n return {\n \"entity\": self.entity,\n \"project\": self.project,\n \"run_id\": self.wandb.run.id\n }\nlogger_type_map = {\n 'console': ConsoleLogger,\n 'wandb': WandbLogger,\n}\ndef create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:\n if logger_type == 'custom':\n raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')"
+ },
+ {
+ "comment": "Function tries to create an instance of a logger class based on the given type, otherwise it raises a ValueError. BaseLoader is an abstract class that can be used to load model checkpoints with data_path and optionally other parameters. UrlLoader extends BaseLoader by allowing loading files from URLs instead of local file paths.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":170-199",
+ "content": " try:\n logger_class = logger_type_map[logger_type]\n except KeyError:\n raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')\n return logger_class(data_path, **kwargs)\nclass BaseLoader:\n \"\"\"\n An abstract class representing an object that can load a model checkpoint.\n Parameters:\n data_path (str): A file path for storing temporary data.\n \"\"\"\n def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):\n self.data_path = Path(data_path)\n self.only_auto_resume = only_auto_resume\n def init(self, logger: BaseLogger, **kwargs) -> None:\n raise NotImplementedError\n def recall() -> dict:\n raise NotImplementedError\nclass UrlLoader(BaseLoader):\n \"\"\"\n A loader that downloads the file from a url and loads it\n Parameters:\n data_path (str): A file path for storing temporary data.\n url (str): The url to download the file from.\n \"\"\"\n def __init__(self, data_path: str, url: str, **kwargs):"
+ },
+ {
+ "comment": "The code defines a base class, \"BaseLoader\", which is responsible for loading files from a given data path. It initializes the class by setting the URL and has an init method to check if the file exists. The \"recall\" method downloads the file and loads it into memory. Additionally, there is a subclass called \"LocalLoader\" that loads files from local paths, checking if the file exists before loading it.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":200-228",
+ "content": " super().__init__(data_path, **kwargs)\n self.url = url\n def init(self, logger: BaseLogger, **kwargs) -> None:\n # Makes sure the file exists to be downloaded\n pass # TODO: Actually implement that\n def recall(self) -> dict:\n # Download the file\n save_path = self.data_path / 'loaded_checkpoint.pth'\n urllib.request.urlretrieve(self.url, str(save_path))\n # Load the file\n return torch.load(str(save_path), map_location='cpu')\nclass LocalLoader(BaseLoader):\n \"\"\"\n A loader that loads a file from a local path\n Parameters:\n data_path (str): A file path for storing temporary data.\n file_path (str): The path to the file to load.\n \"\"\"\n def __init__(self, data_path: str, file_path: str, **kwargs):\n super().__init__(data_path, **kwargs)\n self.file_path = Path(file_path)\n def init(self, logger: BaseLogger, **kwargs) -> None:\n # Makes sure the file exists to be loaded\n if not self.file_path.exists() and not self.only_auto_resume:"
+ },
+ {
+ "comment": "This code defines a class `WandbLoader` that loads a model from an existing W&B (Weights & Biases) run. It requires a data path, a file path within the W&B run, and optionally a W&B run path. The `__init__` method initializes the object, the `init` method ensures the file can be downloaded, and the `recall` method loads the model using `torch.load`. If a W&B run is available but the run path is not specified, it sets the run path to the current run's path. The code also imports the 'wandb' library if it is missing.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":229-250",
+ "content": " raise FileNotFoundError(f'Model not found at {self.file_path}')\n def recall(self) -> dict:\n # Load the file\n return torch.load(str(self.file_path), map_location='cpu')\nclass WandbLoader(BaseLoader):\n \"\"\"\n A loader that loads a model from an existing wandb run\n \"\"\"\n def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):\n super().__init__(data_path, **kwargs)\n self.run_path = wandb_run_path\n self.file_path = wandb_file_path\n def init(self, logger: BaseLogger, **kwargs) -> None:\n self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')\n # Make sure the file can be downloaded\n if self.wandb.run is not None and self.run_path is None:\n self.run_path = self.wandb.run.path\n assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'\n assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'"
+ },
+ {
+ "comment": "This code defines a `BaseSaver` class with an optional parameter for saving the latest data to a specified location. It also includes a function `create_loader()` that creates different types of loaders (url, local, wandb) based on the provided loader type and data path. The WandbLoader is used to restore data from a specified file path using Weights & Biases environment.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":251-277",
+ "content": " assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'\n os.environ[\"WANDB_SILENT\"] = \"true\"\n pass # TODO: Actually implement that\n def recall(self) -> dict:\n file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)\n return torch.load(file_reference.name, map_location='cpu')\nloader_type_map = {\n 'url': UrlLoader,\n 'local': LocalLoader,\n 'wandb': WandbLoader,\n}\ndef create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:\n if loader_type == 'custom':\n raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')\n try:\n loader_class = loader_type_map[loader_type]\n except KeyError:\n raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')\n return loader_class(data_path, **kwargs)\nclass BaseSaver:\n def __init__(self,\n data_path: str,\n save_latest_to: Optional[Union[str, bool]] = None,"
+ },
+ {
+ "comment": "This code defines a tracker class that handles saving of data to specified locations. It allows saving the latest, best, and meta information, with options for file type and paths. The `save_file` method is used to save files with optional flags for best and latest status. An assertion ensures that the save type is either 'checkpoint' or 'model'. A final assertion requires at least one saving option to be specified.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":278-298",
+ "content": " save_best_to: Optional[Union[str, bool]] = None,\n save_meta_to: Optional[str] = None,\n save_type: str = 'checkpoint',\n **kwargs\n ):\n self.data_path = Path(data_path)\n self.save_latest_to = save_latest_to\n self.saving_latest = save_latest_to is not None and save_latest_to is not False\n self.save_best_to = save_best_to\n self.saving_best = save_best_to is not None and save_best_to is not False\n self.save_meta_to = save_meta_to\n self.saving_meta = save_meta_to is not None\n self.save_type = save_type\n assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'\n assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified'\n def init(self, logger: BaseLogger, **kwargs) -> None:\n raise NotImplementedError\n def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:\n \"\"\""
+ },
+ {
+ "comment": "This code defines two classes, LocalSaver and WandbSaver, which inherit from BaseSaver. Both classes are responsible for saving files in different locations. The LocalSaver saves files locally to a specified data_path, ensuring the directory exists beforehand. The WandbSaver is optional and requires a wandb_run_path parameter.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":299-327",
+ "content": " Save a general file under save_meta_to\n \"\"\"\n raise NotImplementedError\nclass LocalSaver(BaseSaver):\n def __init__(self,\n data_path: str,\n **kwargs\n ):\n super().__init__(data_path, **kwargs)\n def init(self, logger: BaseLogger, **kwargs) -> None:\n # Makes sure the directory exists to be saved to\n print(f\"Saving {self.save_type} locally\")\n if not self.data_path.exists():\n self.data_path.mkdir(parents=True)\n def save_file(self, local_path: str, save_path: str, **kwargs) -> None:\n # Copy the file to save_path\n save_path_file_name = Path(save_path).name\n # Make sure parent directory exists\n save_path_parent = Path(save_path).parent\n if not save_path_parent.exists():\n save_path_parent.mkdir(parents=True)\n print(f\"Saving {save_path_file_name} {self.save_type} to local path {save_path}\")\n shutil.copy(local_path, save_path)\nclass WandbSaver(BaseSaver):\n def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):"
+ },
+ {
+ "comment": "This code initializes a W&B run based on the `wandb_run_path` provided. It imports the W&B library, sets up the environment for uploading to W&B runs, and checks if the user has access to save files in the specified W&B run path.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":328-345",
+ "content": " super().__init__(data_path, **kwargs)\n self.run_path = wandb_run_path\n def init(self, logger: BaseLogger, **kwargs) -> None:\n self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')\n os.environ[\"WANDB_SILENT\"] = \"true\"\n # Makes sure that the user can upload tot his run\n if self.run_path is not None:\n entity, project, run_id = self.run_path.split(\"/\")\n self.run = self.wandb.init(entity=entity, project=project, id=run_id)\n else:\n assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'\n self.run = self.wandb.run\n # TODO: Now actually check if upload is possible\n print(f\"Saving to wandb run {self.run.path}-{self.run.name}\")\n def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:\n # In order to log something in the correct place in wandb, we need to have the same file structure here"
+ },
+ {
+ "comment": "This code defines a `HuggingfaceSaver` class that saves files to a Hugging Face repository. It initializes the instance with a data path, Hugging Face repo, and optional token path. The `init` method checks if the user is logged in to the Hugging Face hub and then saves the file specified by `save_path` using `self.hub.upload`.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":346-364",
+ "content": " save_path_file_name = Path(save_path).name\n print(f\"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}\")\n save_path = Path(self.data_path) / save_path\n save_path.parent.mkdir(parents=True, exist_ok=True)\n shutil.copy(local_path, save_path)\n self.run.save(str(save_path), base_path = str(self.data_path), policy='now')\nclass HuggingfaceSaver(BaseSaver):\n def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):\n super().__init__(data_path, **kwargs)\n self.huggingface_repo = huggingface_repo\n self.token_path = token_path\n def init(self, logger: BaseLogger, **kwargs):\n # Makes sure this user can upload to the repo\n self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')\n try:\n identity = self.hub.whoami() # Errors if not logged in\n # Then we are logged in"
+ },
+ {
+ "comment": "This code handles saving a file to the HuggingFace repo. If not logged in, it checks for a token path and uses it if available, or throws an exception. It then prints the saving path, logs in with the token (if provided), and finally uploads the file to the specified HuggingFace repo.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":365-381",
+ "content": " except:\n # We are not logged in. Use the token_path to set the token.\n if not os.path.exists(self.token_path):\n raise Exception(\"Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.\")\n with open(self.token_path, \"r\") as f:\n token = f.read().strip()\n self.hub.HfApi.set_access_token(token)\n identity = self.hub.whoami()\n print(f\"Saving to huggingface repo {self.huggingface_repo}\")\n def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:\n # Saving to huggingface is easy, we just need to upload the file with the correct name\n save_path_file_name = Path(save_path).name\n print(f\"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}\")\n self.hub.upload_file(\n path_or_fileobj=str(local_path),\n path_in_repo=str(save_path),"
+ },
+ {
+ "comment": "Function create_saver takes a saver type and data path, returns a BaseSaver object. It supports 'local', 'wandb', and 'huggingface' saver types. If the saver type is 'custom', it raises an error since custom savers aren't supported yet. Tracker initializes with optional data_path, overwrite_data_path (to overwrite existing path), and dummy_mode (if running in simulation mode). If not in dummy mode, asserts that the data path doesn't exist unless overwrite_data_path is True.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":382-406",
+ "content": " repo_id=self.huggingface_repo\n )\nsaver_type_map = {\n 'local': LocalSaver,\n 'wandb': WandbSaver,\n 'huggingface': HuggingfaceSaver\n}\ndef create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:\n if saver_type == 'custom':\n raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')\n try:\n saver_class = saver_type_map[saver_type]\n except KeyError:\n raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')\n return saver_class(data_path, **kwargs)\nclass Tracker:\n def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):\n self.data_path = Path(data_path)\n if not dummy_mode:\n if not overwrite_data_path:\n assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'\n if not self.data_path.exists():"
+ },
+ {
+ "comment": "This code initializes a tracker object, handling the data path creation, base logger and loader setup, saving list initialization, and dummy mode. It also includes a method to load auto-resume configuration if it exists, printing warnings for first run or removing the file if auto-resume is not enabled.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":407-426",
+ "content": " self.data_path.mkdir(parents=True)\n self.logger: BaseLogger = None\n self.loader: Optional[BaseLoader] = None\n self.savers: List[BaseSaver]= []\n self.dummy_mode = dummy_mode\n def _load_auto_resume(self) -> bool:\n # If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run.\n if not self.auto_resume_path.exists():\n if self.logger.auto_resume:\n print(\"Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.\")\n return False\n # Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time\n if not self.logger.auto_resume:\n print(f'Removing auto_resume.json because auto_resume is not enabled in the config')\n self.auto_resume_path.unlink()\n return False\n # Otherwise we read the json into a dictionary will will override parts of logger.__dict__"
+ },
+ {
+ "comment": "This code reads a previously saved state from the \"auto_resume_path\" and checks if the logger type matches the current logger. If they don't match, it raises an exception with instructions on how to proceed. Otherwise, it updates the logger with the auto-resume data and returns True.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":427-441",
+ "content": " with open(self.auto_resume_path, 'r') as f:\n auto_resume_dict = json.load(f)\n # Check if the logger is of the same type as the autoresume save\n if auto_resume_dict[\"logger_type\"] != self.logger.__class__.__name__:\n raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict[\"logger_type\"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')\n # Then we are ready to override the logger with the autoresume save\n self.logger.__dict__[\"resume\"] = True\n print(f\"Updating {self.logger.__dict__} with {auto_resume_dict}\")\n self.logger.__dict__.update(auto_resume_dict)\n return True\n def _save_auto_resume(self):\n # Gets the autoresume dict from the logger and adds \"logger_type\" to it then saves it to the auto_resume file\n auto_resume_dict = self.logger.get_resume_data()\n auto_resume_dict['logger_type'] = self.logger.__class__.__name__"
+ },
+ {
+ "comment": "This code is initializing a tracker object. It sets the auto_resume path, checks for resuming the run and prints a warning if it was automatically resumed. The save_metadata dictionary is created with version information and some keys are blacklisted from being saved as metadata to avoid errors during saving. The logger must be set before calling init method.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":442-458",
+ "content": " with open(self.auto_resume_path, 'w') as f:\n json.dump(auto_resume_dict, f)\n def init(self, full_config: BaseModel, extra_config: dict):\n self.auto_resume_path = self.data_path / 'auto_resume.json'\n # Check for resuming the run\n self.did_auto_resume = self._load_auto_resume()\n if self.did_auto_resume:\n print(f'\\n\\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\\n\\n')\n print(f\"New logger config: {self.logger.__dict__}\")\n self.save_metadata = dict(\n version = version.parse(__version__)\n ) # Data that will be saved alongside the checkpoint or model\n self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # These keys would cause us to error if we try to save them as metadata\n assert self.logger is not None, '`logger` must be set before `init` is called'"
+ },
+ {
+ "comment": "This code initializes trackers by first checking if in dummy mode, then initializing loaders and savers. The logger is initialized only if the `savers` list has items, and if `auto_resume` is enabled, it saves an autoresume file. The `add_logger`, `add_loader`, `add_saver`, and `log` methods are provided to interact with trackers' components.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":459-488",
+ "content": " if self.dummy_mode:\n # The only thing we need is a loader\n if self.loader is not None:\n self.loader.init(self.logger)\n return\n assert len(self.savers) > 0, '`savers` must be set before `init` is called'\n self.logger.init(full_config, extra_config)\n if self.loader is not None:\n self.loader.init(self.logger)\n for saver in self.savers:\n saver.init(self.logger)\n if self.logger.auto_resume:\n # Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved.\n self._save_auto_resume()\n def add_logger(self, logger: BaseLogger):\n self.logger = logger\n def add_loader(self, loader: BaseLoader):\n self.loader = loader\n def add_saver(self, saver: BaseSaver):\n self.savers.append(saver)\n def log(self, *args, **kwargs):\n if self.dummy_mode:\n return\n self.logger.log(*args, **kwargs)"
+ },
+ {
+ "comment": "This code is from the DALLE2-pytorch library and it contains several methods for logging images, files, saving configurations, and adding save metadata. The dummy_mode check prevents unnecessary actions when in a test mode. The save_config method copies the current config file to the root folder of the data_path and saves it remotely if specified by the saver. The add_save_metadata method adds new metadata that will be saved along with the model or decoder.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":490-516",
+ "content": " def log_images(self, *args, **kwargs):\n if self.dummy_mode:\n return\n self.logger.log_images(*args, **kwargs)\n def log_file(self, *args, **kwargs):\n if self.dummy_mode:\n return\n self.logger.log_file(*args, **kwargs)\n def save_config(self, current_config_path: str, config_name = 'config.json'):\n if self.dummy_mode:\n return\n # Save the config under config_name in the root folder of data_path\n shutil.copy(current_config_path, self.data_path / config_name)\n for saver in self.savers:\n if saver.saving_meta:\n remote_path = Path(saver.save_meta_to) / config_name\n saver.save_file(current_config_path, str(remote_path))\n def add_save_metadata(self, state_dict_key: str, metadata: Any):\n \"\"\"\n Adds a new piece of metadata that will be saved along with the model or decoder.\n \"\"\"\n self.save_metadata[state_dict_key] = metadata\n def _save_state_dict(self,"
+ },
+ {
+ "comment": "This function saves the trainer's state dict, depending on the 'save_type' parameter. If 'checkpoint', it saves the entire trainer state without blacklisted metadata keys. If 'model', it saves only the model state if the trainer is a DiffusionPriorTrainer.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":516-530",
+ "content": " trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:\n \"\"\"\n Gets the state dict to be saved and writes it to file_path.\n If save_type is 'checkpoint', we save the entire trainer state dict.\n If save_type is 'model', we save only the model state dict.\n \"\"\"\n assert save_type in ['checkpoint', 'model']\n if save_type == 'checkpoint':\n # Create a metadata dict without the blacklisted keys so we do not error when we create the state dict\n metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys}\n trainer.save(file_path, overwrite=True, **kwargs, **metadata)\n elif save_type == 'model':\n if isinstance(trainer, DiffusionPriorTrainer):\n prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior\n prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior)"
+ },
+ {
+ "comment": "This code checks the type of trainer and removes CLIP from the model if it is part of it. It then saves the state dictionary for the model, and optionally swaps EMA unets in or out depending on the use_ema flag. Finally, it restores the original CLIP state.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":531-550",
+ "content": " # Remove CLIP if it is part of the model\n original_clip = prior.clip\n prior.clip = None\n model_state_dict = prior.state_dict()\n prior.clip = original_clip\n elif isinstance(trainer, DecoderTrainer):\n decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)\n # Remove CLIP if it is part of the model\n original_clip = decoder.clip\n decoder.clip = None\n if trainer.use_ema:\n trainable_unets = decoder.unets\n decoder.unets = trainer.unets # Swap EMA unets in\n model_state_dict = decoder.state_dict()\n decoder.unets = trainable_unets # Swap back\n else:\n model_state_dict = decoder.state_dict()\n decoder.clip = original_clip\n else:\n raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')"
+ },
+ {
+ "comment": "This code saves the model and checkpoint to specified file paths. If not in dummy mode, it checks if the 'is_best' or 'is_latest' flag is set before proceeding with saving the state dictionary for 'checkpoint' and 'model'. It then prints a message confirming the saved cached models. Lastly, it calls save methods on savers, considering the 'saving_latest' flag and appropriate file paths.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":551-574",
+ "content": " state_dict = {\n **self.save_metadata,\n 'model': model_state_dict\n }\n torch.save(state_dict, file_path)\n return Path(file_path)\n def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):\n if self.dummy_mode:\n return\n if not is_best and not is_latest:\n # Nothing to do\n return\n # Save the checkpoint and model to data_path\n checkpoint_path = self.data_path / 'checkpoint.pth'\n self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)\n model_path = self.data_path / 'model.pth'\n self._save_state_dict(trainer, 'model', model_path, **kwargs)\n print(\"Saved cached models\")\n # Call the save methods on the savers\n for saver in self.savers:\n local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path\n if saver.saving_latest and is_latest:\n latest_checkpoint_path = saver.save_latest_to.format(**kwargs)"
+ },
+ {
+ "comment": "This code appears to be part of a class that manages loading and saving checkpoints for a model. It has a property called \"can_recall\" which determines if a recall (loading a previously saved checkpoint) can be performed based on whether the loader is not None and certain conditions about the loader's properties. If a recall is possible, the \"recall()\" function is called to perform the actual recall. Any errors that occur during saving are logged and printed.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":575-597",
+ "content": " try:\n saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)\n except Exception as e:\n self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)\n print(f'Error saving checkpoint: {e}')\n if saver.saving_best and is_best:\n best_checkpoint_path = saver.save_best_to.format(**kwargs)\n try:\n saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)\n except Exception as e:\n self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)\n print(f'Error saving checkpoint: {e}')\n @property\n def can_recall(self):\n # Defines whether a recall can be performed.\n return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)\n def recall(self):\n if self.can_recall:\n return self.loader.recall()\n else:\n"
+ },
+ {
+ "comment": "Raises an error when no loader is set and auto-resume was not performed.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trackers.py\":597-597",
+ "content": " raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/3c38450c-ae87-469a-b2ab-13ef3445459b.json b/docs/doc/3c38450c-ae87-469a-b2ab-13ef3445459b.json
new file mode 100644
index 00000000..01aec2d8
--- /dev/null
+++ b/docs/doc/3c38450c-ae87-469a-b2ab-13ef3445459b.json
@@ -0,0 +1,65 @@
+{
+ "summary": "This code configures DALLE2 model training and logging options in PyTorch, with customizable settings for Unet and Decoder, dataloader, preprocessing, hyperparameters, image metrics, and experiment tracking. It supports various configurations based on selected logger and storage types.",
+ "details": [
+ {
+ "comment": "This code provides details on configuring training for DALLE2, a complex model that requires various settings. It includes sections for Unet and Decoder configurations with optional parameters. An example configuration file is also mentioned for easier understanding.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/configs/README.md\":0-23",
+ "content": "## DALLE2 Training Configurations\nFor more complex configuration, we provide the option of using a configuration file instead of command line arguments.\n### Decoder Trainer\nThe decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).\n**Unet:**\nThis is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `dim` | Yes | N/A | The starting channels of the unet. |\n| `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. |\n| `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. |\nAny parameter from the `Unet` constructor can also be given here.\n**Decoder:**\nDefines the configuration options for the decoder model. The unets defined above will automatically be inserted.\n| Option | Required | Default | Description |"
+ },
+ {
+ "comment": "This code appears to be defining the configuration for a machine learning model, specifically one using U-Nets. The configuration includes options for the number of unets, image resolution, timesteps, loss function type, noise schedule, and learned variance. Additionally, there are settings for creating dataloaders for the model's data. The code also notes that any parameter from the `Decoder` constructor can be included in this configuration.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/configs/README.md\":24-39",
+ "content": "| ------ | -------- | ------- | ----------- |\n| `unets` | Yes | N/A | A list of unets, using the configuration above |\n| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |\n| `image_size` | Yes | N/A | Not used. Can be any number. |\n| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |\n| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |\n| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |\n| `learned_variance` | No | `True` | Whether to learn the variance. |\n| `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. |\nAny parameter from the `Decoder` constructor can also be given here.\n**Data:**\nSettings for creation of the dataloaders.\n| Option | Required | Default | Description |"
+ },
+ {
+ "comment": "This code defines various configuration options for a dataloader, including webdataset and embeddings urls, worker numbers, batch size, shard range, and file indexing. The config allows flexibility in handling different types of datasets, with optional embeddings or use of the webdataset library.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/configs/README.md\":40-50",
+ "content": "| ------ | -------- | ------- | ----------- |\n| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |\n| `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. |\n| `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. |\n| `num_workers` | No | `4` | The number of workers used in the dataloader. |\n| `batch_size` | No | `64` | The batch size. |\n| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |\n| `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. |\n| `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. |\n| `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. |\n| `splits` | No | `{ \"tra"
+ },
+ {
+ "comment": "This code defines the proportion of shards allocated to training, validation, and testing datasets as well as whether to shuffle training dataset, preprocessing applied to images from datasets, and details for downloading shard files. It also provides information on how to use protocols like `s3` and calculating the shard length based on filename.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/configs/README.md\":50-57",
+ "content": "in\": 0.75, \"val\": 0.15, \"test\": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. |\n| `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. |\n| `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. |\n| `preprocessing` | No | `{ \"ToTensor\": True }` | Defines preprocessing applied to images from the datasets. |\n[^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`.\n[^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5."
+ },
+ {
+ "comment": "The code provides settings for controlling training hyperparameters, such as the number of epochs, learning rate, weight decay, and grad norm clipping. It also allows saving checkpoints at specific intervals and specifying the device to train on. The conditioning scale can be customized for each unet if desired.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/configs/README.md\":59-73",
+ "content": "[^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4.\n**Train:**\nSettings for controlling the training hyperparameters.\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `epochs` | No | `20` | The number of epochs in the training run. |\n| `lr` | No | `1e-4` | The learning rate. |\n| `wd` | No | `0.01` | The weight decay. |\n| `max_grad_norm`| No | `0.5` | The grad norm clipping. |\n| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |\n| `cond_scale` | No | `1.0` | Conditioning scale to use for sampling. Can also be an array of values, one for each unet. |\n| `device` | No | `cuda:0` | The device to train on. |\n| `epoch_samples` | No | `None` | Limits the num"
+ },
+ {
+ "comment": "The code snippet defines configurations for training a DALLE2 model in PyTorch. It includes settings such as the number of samples iterated through in each epoch, number of validation samples, whether to use exponential moving average models for sampling, and the ema coefficient. Additionally, it allows defining which evaluation metrics will be used to test the model by setting their configurations using torchmetrics constructors. The number of samples generated to test the model is also specified.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/configs/README.md\":73-86",
+ "content": "ber of samples iterated through in each epoch. This must be set if resampling. None means no limit. |\n| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |\n| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |\n| `ema_beta` | No | `0.99` | The ema coefficient. |\n| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |\n**Evaluate:**\nDefines which evaluation metrics will be used to test the model.\nEach metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |\n| `FID` | No | `None` | Setting to"
+ },
+ {
+ "comment": "This code snippet is from the configs/README.md file of the DALLE2-pytorch project. It describes how to enable different image metrics and set up experiment tracking. The available metrics are Frechet Inception Distance, Inception Score, Kernel Inception Distance, and Learned Perceptual Image Patch Similarity. The tracker can be configured with data_path and overwrite_data_path options for storing temporary tracking data.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/configs/README.md\":86-97",
+ "content": " an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric. \n| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.\n| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |\n| `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. |\n**Tracker:**\nSelects how the experiment will be tracked.\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `data_path` | No | `./.tracker-data` | The path to the folder where temporary tracker data will be saved. |\n| `overwrite_data_path` | No | `False` | If true, the data path will be overwritten. Otherwise, you need to delete it yourself. |"
+ },
+ {
+ "comment": "The code defines configuration settings for logging, loading checkpoints, and saving checkpoints in a DALLE2-pytorch application. The logging section allows specifying where to save run metadata and image output (options: console or wandb). Loading can be from local, URL, or Wandb sources. Saving can be done locally, on HuggingFace, or via Wandb. Loggers have options for resume and auto-resume functions. If using console logging, only the log_type needs to be set as console.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/configs/README.md\":98-115",
+ "content": "| `log` | Yes | N/A | Logging configuration. |\n| `load` | No | `None` | Checkpoint loading configuration. |\n| `save` | Yes | N/A | Checkpoint/Model saving configuration. |\nTracking is split up into three sections:\n* Log: Where to save run metadata and image output. Options are `console` or `wandb`.\n* Load: Where to load a checkpoint from. Options are `local`, `url`, or `wandb`.\n* Save: Where to save a checkpoint to. Options are `local`, `huggingface`, or `wandb`.\n**Logging:**\nAll loggers have the following keys:\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `log_type` | Yes | N/A | The type of logger class to use. |\n| `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. |\n| `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. |\nIf using `console` there is no further configuration than setting `log_type` to `console`."
+ },
+ {
+ "comment": "This code is defining the configuration options for logging and loading in a DALLE2-pytorch application. The user has to specify the log type (console or wandb) along with other required and optional parameters depending on the selected logger. The loaders have options to specify the loader class type (e.g., local) and whether to only auto resume if the run is being resumed.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/configs/README.md\":116-140",
+ "content": "| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `log_type` | Yes | N/A | Must be `console`. |\nIf using `wandb`\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `log_type` | Yes | N/A | Must be `wandb`. |\n| `wandb_entity` | Yes | N/A | The wandb entity to log to. |\n| `wandb_project` | Yes | N/A | The wandb project save the run to. |\n| `wandb_run_name` | No | `None` | The wandb run name. |\n| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |\n**Loading:**\nAll loaders have the following keys:\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `load_from` | Yes | N/A | The type of loader class to use. |\n| `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. |\nIf using `local`\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `load_from` | Yes | N/A | Must be `local`. |"
+ },
+ {
+ "comment": "The code defines the options for loading and saving checkpoint files. It supports loading from a file path, URL or WandB run, with each option having specific required configurations. Saving to different locations is also supported through options like local, huggingface, or wandb, with additional configuration possibilities.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/configs/README.md\":141-163",
+ "content": "| `file_path` | Yes | N/A | The path to the checkpoint file. |\nIf using `url`\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `load_from` | Yes | N/A | Must be `url`. |\n| `url` | Yes | N/A | The url of the checkpoint file. |\nIf using `wandb`\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `load_from` | Yes | N/A | Must be `wandb`. |\n| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the run that is being resumed. |\n| `wandb_file_path` | Yes | N/A | The path to the checkpoint file in the W&B file system. |\n**Saving:**\nUnlike `log` and `load`, `save` may be an array of options so that you can save to different locations in a run.\nAll save locations have these configuration options\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |\n| `save_latest_to` | No | `None` | Sets the relative path to save the latest model to. |"
+ },
+ {
+ "comment": "This code sets options for saving models and metadata during training. It allows saving to local, huggingface or wandb storage with specific requirements for each option. The save type can be checkpoint or model, and there are additional options like saving best models, token file path, and repository paths.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/configs/README.md\":164-181",
+ "content": "| `save_best_to` | No | `None` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |\n| `save_meta_to` | No | `None` | The path to save metadata files in. This includes the config files used to start the training. |\n| `save_type` | No | `checkpoint` | The type of save. `checkpoint` saves a checkpoint, `model` saves a model without any fluff (Saves with ema if ema is enabled). |\nIf using `local`\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `save_to` | Yes | N/A | Must be `local`. |\nIf using `huggingface`\n| Option | Required | Default | Description |\n| ------ | -------- | ------- | ----------- |\n| `save_to` | Yes | N/A | Must be `huggingface`. |\n| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |\n| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |\nIf using `wandb`\n| Option | Required | Default | Description |"
+ },
+ {
+ "comment": "The code defines configuration options for saving and interacting with the Weights & Biases (Wandb) run path. If `save_to` is set to `wandb`, the `wandb_run_path` should be `None`. Otherwise, it defaults to the current run if `wandb_run_path` is set to `None`.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/configs/README.md\":182-184",
+ "content": "| ------ | -------- | ------- | ----------- |\n| `save_to` | Yes | N/A | Must be `wandb`. |\n| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the current run. You will almost always want this to be `None`. |"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/3d87071b-e428-45c7-a06d-930a410fe347.json b/docs/doc/3d87071b-e428-45c7-a06d-930a410fe347.json
new file mode 100644
index 00000000..d58ea865
--- /dev/null
+++ b/docs/doc/3d87071b-e428-45c7-a06d-930a410fe347.json
@@ -0,0 +1,10 @@
+{
+ "summary": "This code defines the version number of the DALLE2-pytorch library, currently set as '1.15.6'.",
+ "details": [
+ {
+ "comment": "This code defines the version number of the DALLE2-pytorch library, currently set as '1.15.6'.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/version.py\":0-0",
+ "content": "__version__ = '1.15.6'"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/47dfeaa2-468a-4d0f-af79-99dac1cfa7ab.json b/docs/doc/47dfeaa2-468a-4d0f-af79-99dac1cfa7ab.json
new file mode 100644
index 00000000..5d086b22
--- /dev/null
+++ b/docs/doc/47dfeaa2-468a-4d0f-af79-99dac1cfa7ab.json
@@ -0,0 +1,75 @@
+{
+ "summary": "The code defines functions for retrieving embeddings, combining image and text embeddings, creating image embedding datasets, and handling exceptions in webdataset tar files. It also includes support for preprocessing, resampling, shuffling, package checks, and dataloaders.",
+ "details": [
+ {
+ "comment": "This code defines three functions: `get_shard`, `get_example_file`, and `embedding_inserter`. The first function extracts the shard number from a filename. The second function returns an example file given a file system and a file format. Lastly, the third function inserts embeddings into a dataset, given samples, embedding URL, index width, sample key, and a handler to handle exceptions.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py\":0-28",
+ "content": "import os\nimport webdataset as wds\nimport torch\nfrom torch.utils.data import DataLoader\nimport numpy as np\nimport fsspec\nimport shutil\ndef get_shard(filename):\n \"\"\"\n Filenames with shards in them have a consistent structure that we can take advantage of\n Standard structure: path/to/file/prefix_string_00001.ext\n \"\"\"\n try:\n return filename.split(\"_\")[-1].split(\".\")[0]\n except ValueError:\n raise RuntimeError(f\"Could not find shard for filename {filename}\")\ndef get_example_file(fs, path, file_format):\n \"\"\"\n Given a file system and a file extension, return the example file\n \"\"\"\n return fs.glob(os.path.join(path, f\"*.{file_format}\"))[0]\ndef embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):\n \"\"\"Given a datum of {\"__key__\": str, \"__url__\": str, ...} adds the cooresponding embedding and yields\"\"\"\n previous_tar_url = None\n current_embeddings = None\n # Get a reference to an abstract file system where the embeddings are stored"
+ },
+ {
+ "comment": "This code segment retrieves and loads embeddings from a webdataset tar file using the given URL. It identifies the correct npy file containing the embeddings by extracting the shard number from the URL, then opens and loads the data into a torch tensor.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py\":29-48",
+ "content": " embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)\n example_embedding_file = get_example_file(embeddings_fs, embeddings_path, \"npy\")\n example_embedding_shard = get_shard(example_embedding_file)\n emb_shard_width = len(example_embedding_shard)\n # Easier to get the basename without the shard once than search through for the correct file every time\n embedding_file_basename = '_'.join(example_embedding_file.split(\"_\")[:-1]) + \"_\"\n def load_corresponding_embeds(tar_url):\n \"\"\"Finds and reads the npy files that contains embeddings for the given webdataset tar\"\"\"\n shard = int(tar_url.split(\"/\")[-1].split(\".\")[0])\n embedding_url = embedding_file_basename + str(shard).zfill(emb_shard_width) + '.npy'\n with embeddings_fs.open(embedding_url) as f:\n data = np.load(f)\n return torch.from_numpy(data)\n for sample in samples:\n try:\n tar_url = sample[\"__url__\"]\n key = sample[\"__key__\"]\n if tar_url != previous_tar_url:"
+ },
+ {
+ "comment": "The code checks if a tar file changed and loads corresponding embeddings. If the sample has no embedding, it raises an error. The insert_embedding variable is assigned a pipeline filter with the embedding inserter function.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py\":49-66",
+ "content": " # If the tar changed, we need to download new embeddings\n # This means if we shuffle before inserting it will load many more files than we expect and be very inefficient.\n previous_tar_url = tar_url\n current_embeddings = load_corresponding_embeds(tar_url)\n embedding_index = int(key[-index_width:])\n embedding = current_embeddings[embedding_index]\n # We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop\n if torch.count_nonzero(embedding) == 0:\n raise RuntimeError(f\"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}\")\n sample[sample_key] = embedding\n yield sample\n except Exception as exn: # From wds implementation\n if handler(exn):\n continue\n else:\n break\ninsert_embedding = wds.filters.pipelinefilter(embedding_inserter)"
+ },
+ {
+ "comment": "This function checks if there are corresponding embeddings for the given tarfiles. It first retrieves a set of embedding shards from the embeddings_url, then iterates through the tarfiles. If a tarfile's shard is in the set of embedding shards, it yields the tarfile. Otherwise, it will continue to iterate until it finds a matching shard. Exceptions are handled using the provided handler function.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py\":68-83",
+ "content": "def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):\n \"\"\"Finds if the is a corresponding embedding for the tarfile at { url: [URL] }\"\"\"\n embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)\n embedding_files = embeddings_fs.ls(embeddings_path)\n get_embedding_shard = lambda embedding_file: int(embedding_file.split(\"_\")[-1].split(\".\")[0])\n embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member\n get_tar_shard = lambda tar_file: int(tar_file.split(\"/\")[-1].split(\".\")[0])\n for tarfile in tarfiles:\n try:\n webdataset_shard = get_tar_shard(tarfile[\"url\"])\n # If this shard has an associated embeddings file, we pass it through. Otherwise we iterate until we do have one\n if webdataset_shard in embedding_shards:\n yield tarfile\n except Exception as exn: # From wds implementation\n if handler(exn):"
+ },
+ {
+ "comment": "The code defines two functions: `join_embeddings()` and `verify_keys()`. The first function combines the `img_emb` and `text_emb` keys into a single \"emb\" key in each sample, only including existing embeddings. The second function ensures that both image and embedding are present in each sample. If not, it either continues or breaks depending on the exception handler.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py\":84-110",
+ "content": " continue\n else:\n break\nskip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)\ndef join_embeddings(samples, handler=wds.handlers.reraise_exception):\n \"\"\"\n Takes the img_emb and text_emb keys and turns them into one key \"emb\": { \"text\": text_emb, \"img\": img_emb }\n either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist\n \"\"\"\n for sample in samples:\n try:\n sample['emb'] = {}\n if 'text_emb' in sample:\n sample['emb']['text'] = sample['text_emb']\n if 'img_emb' in sample:\n sample['emb']['img'] = sample['img_emb']\n yield sample\n except Exception as exn: # From wds implementation\n if handler(exn):\n continue\n else:\n break\ndef verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):\n \"\"\"\n Requires that both the image and embedding are present in the sample"
+ },
+ {
+ "comment": "This code checks if required keys are present in each sample, asserts if missing and yields the sample. It uses a key_verifier filter and a fluid interface for DataPipeline to return image embedding pairs. Embeddings can be read from webdataset or inserted from an alternate source based on embedding_folder_url.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py\":111-135",
+ "content": " This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.\n \"\"\"\n for sample in samples:\n try:\n for key in required_keys:\n assert key in sample, f\"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}\"\n yield sample\n except Exception as exn: # From wds implementation\n if handler(exn):\n continue\n else:\n break\nkey_verifier = wds.filters.pipelinefilter(verify_keys)\nclass ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):\n \"\"\"\n A fluid interface wrapper for DataPipline that returns image embedding pairs\n Reads embeddings as npy files from the webdataset if they exist. If embedding_folder_url is set, they will be inserted in from the alternate source.\n \"\"\"\n def __init__(\n self,\n urls,\n img_embedding_folder_url=None,\n text_embedding_folder_url=None,"
+ },
+ {
+ "comment": "The code defines a function to load data from webdatasets and embeddings for a model. It takes URLs as input, where each URL points to tar files of the webdataset. If embeddings are not included in the dataset, an embedding_folder_URL is required. The index width specifies the number of digits in the index, used to align image and embedding indices. The handler handles exceptions, while resample can be set for resampling data. The shuffle_shards flag determines whether to shuffle shards during loading.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py\":136-150",
+ "content": " index_width=None,\n img_preproc=None,\n extra_keys=[],\n handler=wds.handlers.reraise_exception,\n resample=False,\n shuffle_shards=True\n ):\n \"\"\"\n Modeled directly off of the WebDataset constructor\n :param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar\n :param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.\n Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.\n :param index_width: The number of digits in the index. This is used to align the embedding index with the image index.\n For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width."
+ },
+ {
+ "comment": "This function is a webdataset handler that takes parameters for img_preproc, resample, and shuffle_shards. It initializes the keys for data loading and maps them to their respective indices. If img_embedding_folder_url or text_embedding_folder_url is not None, \"img_emb\" and \"text_emb\" will be added as keys. The function also checks if s3fs and s3cmd are installed, and handles data piping.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py\":151-168",
+ "content": " :param img_preproc: This function is run on the img before it is batched and returned. Useful for data augmentation or converting to torch tensor.\n :param handler: A webdataset handler.\n :param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.\n :param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.\n \"\"\"\n super().__init__()\n keys = [\"jpg\", \"emb\"] + extra_keys\n # if img_embedding_folder_url is not None:\n # keys.append(\"img_emb\")\n # if text_embedding_folder_url is not None:\n # keys.append(\"text_emb\")\n # keys.extend(extra_keys)\n self.key_map = {key: i for i, key in enumerate(keys)}\n self.resampling = resample\n self.img_preproc = img_preproc\n # If s3, check if s3fs is installed and s3cmd is installed and check if the data is piped instead of straight up"
+ },
+ {
+ "comment": "Code checks if the URLs provided for webdataset contain \"s3:\" indicating S3 links. If so, it requires 's3cmd' and 's3fs' packages to be installed or raises an error. It also adds shardList and allows resampling or shuffling of shards based on user input.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py\":169-184",
+ "content": " if (isinstance(urls, str) and \"s3:\" in urls) or (isinstance(urls, list) and any([\"s3:\" in url for url in urls])):\n # Then this has an s3 link for the webdataset and we need extra packages\n if shutil.which(\"s3cmd\") is None:\n raise RuntimeError(\"s3cmd is required for s3 webdataset\")\n if (img_embedding_folder_url is not None and \"s3:\" in img_embedding_folder_url) or (text_embedding_folder_url is not None and \"s3:\" in text_embedding_folder_url):\n # Then the embeddings are being loaded from s3 and fsspec requires s3fs\n try:\n import s3fs\n except ImportError:\n raise RuntimeError(\"s3fs is required to load embeddings from s3\")\n # Add the shardList and randomize or resample if requested\n if resample:\n assert not shuffle_shards, \"Cannot both resample and shuffle\"\n self.append(wds.ResampledShards(urls))\n else:\n self.append(wds.SimpleShardList(urls))"
+ },
+ {
+ "comment": "The code configures a decoder loader for DALLE2-pytorch. It shuffles 1000 filters and skips unassociated shards if necessary, loads embeddings from URLs, converts to samples, and decodes images as PILRGB.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py\":185-199",
+ "content": " if shuffle_shards:\n self.append(wds.filters.shuffle(1000))\n if img_embedding_folder_url is not None:\n # There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.\n self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler))\n if text_embedding_folder_url is not None:\n self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler))\n self.append(wds.tarfile_to_samples(handler=handler))\n self.append(wds.decode(\"pilrgb\", handler=handler))\n if img_embedding_folder_url is not None:\n # Then we are loading image embeddings for a remote source\n assert index_width is not None, \"Reading embeddings separately requires index width length to be given\"\n self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler))"
+ },
+ {
+ "comment": "This code creates an image embedding dataloader. If a text embedding folder URL is provided, it loads image embeddings for remote sources based on the given index width. It then applies preprocessing and joins the embeddings before returning the tuple of keys. The preproc function applies image preprocessing if available.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py\":200-224",
+ "content": " if text_embedding_folder_url is not None:\n # Then we are loading image embeddings for a remote source\n assert index_width is not None, \"Reading embeddings separately requires index width length to be given\"\n self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler))\n self.append(join_embeddings)\n self.append(key_verifier(required_keys=keys, handler=handler))\n # Apply preprocessing\n self.append(wds.map(self.preproc))\n self.append(wds.to_tuple(*keys))\n def preproc(self, sample):\n \"\"\"Applies the preprocessing for images\"\"\"\n if self.img_preproc is not None:\n sample[\"jpg\"] = self.img_preproc(sample[\"jpg\"])\n return sample\ndef create_image_embedding_dataloader(\n tar_url,\n num_workers,\n batch_size,\n img_embeddings_url=None,\n text_embeddings_url=None,\n index_width=None,\n shuffle_num = None,\n shuffle_shards = True,"
+ },
+ {
+ "comment": "This code creates an image embedding dataset and dataloader in one line, accepting parameters such as tar_url, num_workers, batch_size, embeddings_url, and index_width. The function is designed for webdataset format and requires the same number of shards for both the webdataset images and their corresponding embeddings. It also supports handling exceptions using a specified handler.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py\":225-239",
+ "content": " resample_shards = False, \n img_preproc=None,\n extra_keys=[],\n handler=wds.handlers.reraise_exception#warn_and_continue\n):\n \"\"\"\n Convenience function to create an image embedding dataseta and dataloader in one line\n :param tar_url: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar\n :param num_workers: The number of workers to use for the dataloader\n :param batch_size: The batch size to use for the dataloader\n :param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.\n Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.\n :param index_width: The number of digits in the index. This is used to align the embedding index with the image index.\n For example, if a file in the webdataset sh"
+ },
+ {
+ "comment": "This code defines a function that takes in parameters like tar_url, img_embedding_folder_url, text_embeddings_url, index_width, extra_keys, img_preproc, and handler. It creates an ImageEmbeddingDataset and optionally shuffles it based on the given shuffle_num. Then, it returns a DataLoader for further processing.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py\":239-260",
+ "content": "ard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.\n :param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.\n :param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.\n :param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.\n :param handler: A webdataset handler.\n \"\"\"\n ds = ImageEmbeddingDataset(\n tar_url,\n img_embedding_folder_url=img_embeddings_url,\n text_embedding_folder_url=text_embeddings_url,\n index_width=index_width,\n shuffle_shards=shuffle_shards,\n resample=resample_shards,\n extra_keys=extra_keys,\n img_preproc=img_preproc,\n handler=handler\n )\n if shuffle_num is not None and shuffle_num > 0:\n ds.shuffle(1000)\n return DataLoader(\n ds,\n num_workers=num_workers,"
+ },
+ {
+ "comment": "This code creates a data loader for the decoder model. It sets batch size, prefetch factor (for efficient loading), pin memory (for faster GPU transfers), and disables shuffling.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py\":261-265",
+ "content": " batch_size=batch_size,\n prefetch_factor=2, # This might be good to have high so the next npy file is prefetched\n pin_memory=True,\n shuffle=False\n )"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/488b3fbd-91ea-4582-bbd9-d564446bc9d2.json b/docs/doc/488b3fbd-91ea-4582-bbd9-d564446bc9d2.json
new file mode 100644
index 00000000..a9b8b3d3
--- /dev/null
+++ b/docs/doc/488b3fbd-91ea-4582-bbd9-d564446bc9d2.json
@@ -0,0 +1,15 @@
+{
+ "summary": "This code is a setup script for the dalle2-pytorch package using setuptools, defining project details and dependencies like PyTorch, Torchvision, and more. It's a Python project with beta development status, targeting developers in AI field, licensed under MIT, requires Python 3.6.",
+ "details": [
+ {
+ "comment": "This code is a setup script for the dalle2-pytorch package using setuptools. It defines the name, packages, entry points, version, license, description, author, URL, keywords, and install_requires. The script imports necessary modules and sets up dependencies for installation.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/setup.py\":0-41",
+ "content": "from setuptools import setup, find_packages\nexec(open('dalle2_pytorch/version.py').read())\nsetup(\n name = 'dalle2-pytorch',\n packages = find_packages(exclude=[]),\n include_package_data = True,\n entry_points={\n 'console_scripts': [\n 'dalle2_pytorch = dalle2_pytorch.cli:main',\n 'dream = dalle2_pytorch.cli:dream'\n ],\n },\n version = __version__,\n license='MIT',\n description = 'DALL-E 2',\n author = 'Phil Wang',\n author_email = 'lucidrains@gmail.com',\n long_description_content_type = 'text/markdown',\n url = 'https://github.com/lucidrains/dalle2-pytorch',\n keywords = [\n 'artificial intelligence',\n 'deep learning',\n 'text to image'\n ],\n install_requires=[\n 'accelerate',\n 'click',\n 'open-clip-torch>=2.0.0,<3.0.0',\n 'clip-anytorch>=2.5.2',\n 'coca-pytorch>=0.0.5',\n 'ema-pytorch>=0.0.7',\n 'einops>=0.7.0',\n 'embedding-reader',\n 'kornia>=0.5.4',\n 'numpy',\n 'packaging',\n 'pillow',\n 'pydantic>=2',\n 'pytorch-warmup',\n 'resize-right>=0.0.2',\n 'rotary-embedding-torch',"
+ },
+ {
+ "comment": "This is a Python project setup file, using setuptools. It depends on PyTorch >= 1.10, Torchvision, Tqdm, VectorQuantizePytorch, X-Clip >= 0.4.4, Webdataset >= 0.2.5, FSSpec >= 2022.1.0, and TorchMetrics[image] >= 0.8.0. The project has a beta development status, is intended for developers, relates to artificial intelligence, is licensed under MIT, and requires Python 3.6.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/setup.py\":42-58",
+ "content": " 'torch>=1.10',\n 'torchvision',\n 'tqdm',\n 'vector-quantize-pytorch',\n 'x-clip>=0.4.4',\n 'webdataset>=0.2.5',\n 'fsspec>=2022.1.0',\n 'torchmetrics[image]>=0.8.0'\n ],\n classifiers=[\n 'Development Status :: 4 - Beta',\n 'Intended Audience :: Developers',\n 'Topic :: Scientific/Engineering :: Artificial Intelligence',\n 'License :: OSI Approved :: MIT License',\n 'Programming Language :: Python :: 3.6',\n ],\n)"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/49fa1d18-27c6-4ae1-90d0-36ce67ebcb9d.json b/docs/doc/49fa1d18-27c6-4ae1-90d0-36ce67ebcb9d.json
new file mode 100644
index 00000000..aebec431
--- /dev/null
+++ b/docs/doc/49fa1d18-27c6-4ae1-90d0-36ce67ebcb9d.json
@@ -0,0 +1,75 @@
+{
+ "summary": "The code sets up DALL-E 2 PyTorch training configurations, provides utility functions and tracker configuration, defines a class for model training/evaluation, and suggests potential efficiency improvements.",
+ "details": [
+ {
+ "comment": "This code is defining various classes and functions for training configurations in a machine learning application, specifically related to the DALL-E 2 PyTorch model. It includes importing necessary modules, setting up pydantic models for train splits, and creating utility functions like `default` and `exists`.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/train_configs.py\":0-42",
+ "content": "import json\nfrom torchvision import transforms as T\nfrom pydantic import BaseModel, validator, model_validator\nfrom typing import List, Optional, Union, Tuple, Dict, Any, TypeVar\nfrom x_clip import CLIP as XCLIP\nfrom open_clip import list_pretrained\nfrom coca_pytorch import CoCa\nfrom dalle2_pytorch.dalle2_pytorch import (\n CoCaAdapter,\n OpenAIClipAdapter,\n OpenClipAdapter,\n Unet,\n Decoder,\n DiffusionPrior,\n DiffusionPriorNetwork,\n XClipAdapter\n)\nfrom dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver\n# helper functions\ndef exists(val):\n return val is not None\ndef default(val, d):\n return val if exists(val) else d\nInnerType = TypeVar('InnerType')\nListOrTuple = Union[List[InnerType], Tuple[InnerType]]\nSingularOrIterable = Union[InnerType, ListOrTuple[InnerType]]\n# general pydantic classes\nclass TrainSplitConfig(BaseModel):\n train: float = 0.75\n val: float = 0.15\n test: float = 0.1\n @model_validator(mode = 'after')\n def validate_all(self, m):\n actual_sum = sum([*dict(self).values()])"
+ },
+ {
+ "comment": "The code defines two classes, `TrackerLogConfig` and `TrackerLoadConfig`, which inherit from `BaseModel`. These classes have various attributes such as `log_type`, `resume`, `auto_resume`, and `verbose`. They also have a method called `create` that takes in a `data_path` parameter and returns a logger object. The classes ensure their attributes sum up to 1.0, and allow additional arguments for each individual log type. The `TrackerLoadConfig` class has an optional attribute `load_from`, which determines if the logger should load from a previous run. If `load_from` is set to `None`, it returns None instead of loading.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/train_configs.py\":43-72",
+ "content": " if actual_sum != 1.:\n raise ValueError(f'{dict(self).keys()} must sum to 1.0. Found: {actual_sum}')\n return self\nclass TrackerLogConfig(BaseModel):\n log_type: str = 'console'\n resume: bool = False # For logs that are saved to unique locations, resume a previous run\n auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed\n verbose: bool = False\n class Config:\n # Each individual log type has it's own arguments that will be passed through the config\n extra = \"allow\"\n def create(self, data_path: str):\n kwargs = self.dict()\n return create_logger(self.log_type, data_path, **kwargs)\nclass TrackerLoadConfig(BaseModel):\n load_from: Optional[str] = None\n only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming\n class Config:\n extra = \"allow\"\n def create(self, data_path: str):\n kwargs = self.dict()\n if self.load_from is None:\n return None"
+ },
+ {
+ "comment": "This code defines classes for tracker configuration and load/save operations. The TrackerConfig class contains information about the data path, overwrite option, logger settings, and optional load configurations. The create method of TrackerConfig initializes a new Tracker object and adds a logger if present in the configuration. If there is a defined load configuration, it also adds a loader to the tracker.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/train_configs.py\":73-101",
+ "content": " return create_loader(self.load_from, data_path, **kwargs)\nclass TrackerSaveConfig(BaseModel):\n save_to: str = 'local'\n save_all: bool = False\n save_latest: bool = True\n save_best: bool = True\n class Config:\n extra = \"allow\"\n def create(self, data_path: str):\n kwargs = self.dict()\n return create_saver(self.save_to, data_path, **kwargs)\nclass TrackerConfig(BaseModel):\n data_path: str = '.tracker_data'\n overwrite_data_path: bool = False\n log: TrackerLogConfig\n load: Optional[TrackerLoadConfig] = None\n save: Union[List[TrackerSaveConfig], TrackerSaveConfig]\n def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:\n tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)\n # Add the logger\n tracker.add_logger(self.log.create(self.data_path))\n # Add the loader\n if self.load is not None:\n tracker.add_loader(self.load.create(self.data_path))"
+ },
+ {
+ "comment": "This code defines a function that initializes and returns a tracker object, which is responsible for managing savers and components of the model. It also includes classes for different types of adapters used in the model. The tracker object verifies data validity after initialization.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/train_configs.py\":102-128",
+ "content": " # Add the saver or savers\n if isinstance(self.save, list):\n for save_config in self.save:\n tracker.add_saver(save_config.create(self.data_path))\n else:\n tracker.add_saver(self.save.create(self.data_path))\n # Initialize all the components and verify that all data is valid\n tracker.init(full_config, extra_config)\n return tracker\n# diffusion prior pydantic classes\nclass AdapterConfig(BaseModel):\n make: str = \"openai\"\n model: str = \"ViT-L/14\"\n base_model_kwargs: Optional[Dict[str, Any]] = None\n def create(self):\n if self.make == \"openai\":\n return OpenAIClipAdapter(self.model)\n elif self.make == \"open_clip\":\n pretrained = dict(list_pretrained())\n checkpoint = pretrained[self.model]\n return OpenClipAdapter(name=self.model, pretrained=checkpoint)\n elif self.make == \"x-clip\":\n return XClipAdapter(XCLIP(**self.base_model_kwargs))\n elif self.make == \"coca\":"
+ },
+ {
+ "comment": "This code defines configurations for a neural network model. It includes classes for adapters, diffusion prior networks, and diffusion prior models. The adapter class takes in base_model_kwargs and returns an instance of either CoCaAdapter or raises AttributeError if no matching adapter found. DiffusionPriorNetworkConfig defines the architecture specifications like dimensions, depth, and dropout rates. DiffusionPriorConfig handles configurations for clip adapters, diffusion prior networks, image embedding dimensions, image size, and number of timesteps. The create() function returns an instance of the model based on its configuration.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/train_configs.py\":129-166",
+ "content": " return CoCaAdapter(CoCa(**self.base_model_kwargs))\n else:\n raise AttributeError(\"No adapter with that name is available.\")\nclass DiffusionPriorNetworkConfig(BaseModel):\n dim: int\n depth: int\n max_text_len: Optional[int] = None\n num_timesteps: Optional[int] = None\n num_time_embeds: int = 1\n num_image_embeds: int = 1\n num_text_embeds: int = 1\n dim_head: int = 64\n heads: int = 8\n ff_mult: int = 4\n norm_in: bool = False\n norm_out: bool = True\n attn_dropout: float = 0.\n ff_dropout: float = 0.\n final_proj: bool = True\n normformer: bool = False\n rotary_emb: bool = True\n class Config:\n extra = \"allow\"\n def create(self):\n kwargs = self.dict()\n return DiffusionPriorNetwork(**kwargs)\nclass DiffusionPriorConfig(BaseModel):\n clip: Optional[AdapterConfig] = None\n net: DiffusionPriorNetworkConfig\n image_embed_dim: int\n image_size: int\n image_channels: int = 3\n timesteps: int = 1000\n sample_timesteps: Optional[int] = None"
+ },
+ {
+ "comment": "The code defines a class for training configurations, including epochs, learning rate, weight decay, and other parameters. It also contains functions to create instances of diffusion prior networks and conditioning models. The class is part of the DALLE2-pytorch framework and is used for training the model.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/train_configs.py\":167-200",
+ "content": " cond_drop_prob: float = 0.\n loss_type: str = 'l2'\n predict_x_start: bool = True\n beta_schedule: str = 'cosine'\n condition_on_text_encodings: bool = True\n class Config:\n extra = \"allow\"\n def create(self):\n kwargs = self.dict()\n has_clip = exists(kwargs.pop('clip'))\n kwargs.pop('net')\n clip = None\n if has_clip:\n clip = self.clip.create()\n diffusion_prior_network = self.net.create()\n return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)\nclass DiffusionPriorTrainConfig(BaseModel):\n epochs: int = 1\n lr: float = 1.1e-4\n wd: float = 6.02e-2\n max_grad_norm: float = 0.5\n use_ema: bool = True\n ema_beta: float = 0.99\n amp: bool = False\n warmup_steps: Optional[int] = None # number of warmup steps\n save_every_seconds: int = 3600 # how often to save\n eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with\n best_validation_loss: float = 1e9 # the current best valudation loss observed"
+ },
+ {
+ "comment": "The code defines a configuration class for training the DiffusionPrior model, which contains details such as the data source, batch size, total number of datapoints to train on, and validation frequency. It also has methods to load configurations from JSON files.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/train_configs.py\":201-222",
+ "content": " current_epoch: int = 0 # the current epoch\n num_samples_seen: int = 0 # the current number of samples seen\n random_seed: int = 0 # manual seed for torch\nclass DiffusionPriorDataConfig(BaseModel):\n image_url: str # path to embeddings folder\n meta_url: str # path to metadata (captions) for images\n splits: TrainSplitConfig # define train, validation, test splits for your dataset\n batch_size: int # per-gpu batch size used to train the model\n num_data_points: int = 25e7 # total number of datapoints to train on\n eval_every_seconds: int = 3600 # validation statistics will be performed this often\nclass TrainDiffusionPriorConfig(BaseModel):\n prior: DiffusionPriorConfig\n data: DiffusionPriorDataConfig\n train: DiffusionPriorTrainConfig\n tracker: TrackerConfig\n @classmethod\n def from_json_path(cls, json_path):\n with open(json_path) as f:\n config = json.load(f)"
+ },
+ {
+ "comment": "The code defines two Pydantic classes, UnetConfig and DecoderConfig, which represent the configurations for the DALL-E 2 model. The UnetConfig class handles the configuration of the UNet transformer in the decoder while the DecoderConfig class includes various settings like the number of UNet blocks, image size, clip model, timesteps, loss type, and more.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/train_configs.py\":223-254",
+ "content": " return cls(**config)\n# decoder pydantic classes\nclass UnetConfig(BaseModel):\n dim: int\n dim_mults: ListOrTuple[int]\n image_embed_dim: Optional[int] = None\n text_embed_dim: Optional[int] = None\n cond_on_text_encodings: Optional[bool] = None\n cond_dim: Optional[int] = None\n channels: int = 3\n self_attn: SingularOrIterable[bool] = False\n attn_dim_head: int = 32\n attn_heads: int = 16\n init_cross_embed: bool = True\n class Config:\n extra = \"allow\"\nclass DecoderConfig(BaseModel):\n unets: ListOrTuple[UnetConfig]\n image_size: Optional[int] = None\n image_sizes: ListOrTuple[int] = None\n clip: Optional[AdapterConfig] = None # The clip model to use if embeddings are not provided\n channels: int = 3\n timesteps: int = 1000\n sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None\n loss_type: str = 'l2'\n beta_schedule: Optional[ListOrTuple[str]] = None # None means all cosine\n learned_variance: SingularOrIterable[bool] = True\n image_cond_drop_prob: float = 0.1"
+ },
+ {
+ "comment": "This code defines a class \"TrainConfigs\" that creates a decoder for DALL-E 2 training. It uses the Unet architecture, optionally includes CLIP for visual guidance, and allows specifying image sizes through 'image_size' or list of 'image_sizes'. The class also provides configurations for loading data from webdataset with jpg images, embedding files, and setting the number of workers for data loading.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/train_configs.py\":255-283",
+ "content": " text_cond_drop_prob: float = 0.5\n def create(self):\n decoder_kwargs = self.dict()\n unet_configs = decoder_kwargs.pop('unets')\n unets = [Unet(**config) for config in unet_configs]\n has_clip = exists(decoder_kwargs.pop('clip'))\n clip = None\n if has_clip:\n clip = self.clip.create()\n return Decoder(unets, clip=clip, **decoder_kwargs)\n @validator('image_sizes')\n def check_image_sizes(cls, image_sizes, values):\n if exists(values.get('image_size')) ^ exists(image_sizes):\n return image_sizes\n raise ValueError('either image_size or image_sizes is required, but not both')\n class Config:\n extra = \"allow\"\nclass DecoderDataConfig(BaseModel):\n webdataset_base_url: str # path to a webdataset with jpg images\n img_embeddings_url: Optional[str] = None # path to .npy files with embeddings\n text_embeddings_url: Optional[str] = None # path to .npy files with embeddings\n num_workers: int = 4"
+ },
+ {
+ "comment": "This code defines a training configuration with batch size, sharding settings, transformation preprocessing, and boolean flags for shuffling and resampling. It also includes a property method to generate the image preprocessing transforms based on provided names and optional arguments.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/train_configs.py\":284-308",
+ "content": " batch_size: int = 64\n start_shard: int = 0\n end_shard: int = 9999999\n shard_width: int = 6\n index_width: int = 4\n splits: TrainSplitConfig\n shuffle_train: bool = True\n resample_train: bool = False\n preprocessing: Dict[str, Any] = {'ToTensor': True}\n @property\n def img_preproc(self):\n def _get_transformation(transformation_name, **kwargs):\n if transformation_name == \"RandomResizedCrop\":\n return T.RandomResizedCrop(**kwargs)\n elif transformation_name == \"RandomHorizontalFlip\":\n return T.RandomHorizontalFlip()\n elif transformation_name == \"ToTensor\":\n return T.ToTensor()\n transforms = []\n for transform_name, transform_kwargs_or_bool in self.preprocessing.items():\n transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool\n transforms.append(_get_transformation(transform_name, **transform_kwargs))\n return T.Compose(transforms)"
+ },
+ {
+ "comment": "This code defines a DecoderTrainConfig class with various configuration options for training the decoder model in DALLE2. The class includes settings for epochs, learning rate, weight decay, warmup steps, finding unused parameters, static graph usage, gradient clipping, saving samples, generating example images, scaling conditions, device selection, sample limits per epoch and validation, saving immediately, using exponential moving average (EMA), EMA beta value, using mixed precision training (AMP), and unet training masks.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/train_configs.py\":310-328",
+ "content": "class DecoderTrainConfig(BaseModel):\n epochs: int = 20\n lr: SingularOrIterable[float] = 1e-4\n wd: SingularOrIterable[float] = 0.01\n warmup_steps: Optional[SingularOrIterable[int]] = None\n find_unused_parameters: bool = True\n static_graph: bool = True\n max_grad_norm: SingularOrIterable[float] = 0.5\n save_every_n_samples: int = 100000\n n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset\n cond_scale: Union[float, List[float]] = 1.0\n device: str = 'cuda:0'\n epoch_samples: Optional[int] = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.\n validation_samples: Optional[int] = None # Same as above but for validation.\n save_immediately: bool = False\n use_ema: bool = True\n ema_beta: float = 0.999\n amp: bool = False\n unet_training_mask: Optional[ListOrTuple[bool]] = None # If None, use all unets"
+ },
+ {
+ "comment": "This code defines two classes, \"DecoderEvaluateConfig\" and \"TrainDecoderConfig\", which inherit from the \"BaseModel\" class. The \"DecoderEvaluateConfig\" class specifies evaluation metrics like FID, IS, KID, and LPIPS, while the \"TrainDecoderConfig\" class combines various configuration elements including a decoder, data, training settings, evaluation settings, tracker, and seed. The \"from_json_path\" method loads configuration from a JSON file, and the \"check_has_embeddings\" validator ensures that enough information is provided to get the embeddings for training.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/train_configs.py\":330-360",
+ "content": "class DecoderEvaluateConfig(BaseModel):\n n_evaluation_samples: int = 1000\n FID: Optional[Dict[str, Any]] = None\n IS: Optional[Dict[str, Any]] = None\n KID: Optional[Dict[str, Any]] = None\n LPIPS: Optional[Dict[str, Any]] = None\nclass TrainDecoderConfig(BaseModel):\n decoder: DecoderConfig\n data: DecoderDataConfig\n train: DecoderTrainConfig\n evaluate: DecoderEvaluateConfig\n tracker: TrackerConfig\n seed: int = 0\n @classmethod\n def from_json_path(cls, json_path):\n with open(json_path) as f:\n config = json.load(f)\n print(config)\n return cls(**config)\n @model_validator(mode = 'after')\n def check_has_embeddings(self, m):\n # Makes sure that enough information is provided to get the embeddings specified for training\n values = dict(self)\n data_config, decoder_config = values.get('data'), values.get('decoder')\n if not exists(data_config) or not exists(decoder_config):\n # Then something else errored and we should just pass through"
+ },
+ {
+ "comment": "This code checks if the text embeddings and/or CLIP model are being used, ensuring that only one of these is provided to avoid redundancy. It asserts that either the CLIP or text embeddings URL must be present if text conditioning is enabled, and if only the CLIP model is loaded, it asserts that neither the text embeddings nor image embeddings URL should be provided.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/train_configs.py\":361-379",
+ "content": " return values\n using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])\n using_clip = exists(decoder_config.clip)\n img_emb_url = data_config.img_embeddings_url\n text_emb_url = data_config.text_embeddings_url\n if using_text_embeddings:\n # Then we need some way to get the embeddings\n assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'\n if using_clip:\n if using_text_embeddings:\n assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'\n else:\n assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'\n if text_emb_url:\n assert using_te"
+ },
+ {
+ "comment": "This code snippet indicates that text embeddings are being loaded but are not necessary for the task, causing unnecessary slowdown in the dataloader. It is recommended to remove this step for efficiency.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/train_configs.py\":379-381",
+ "content": "xt_embeddings, \"Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason.\"\n return m"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/567b8601-1600-4951-b12d-b038276aa095.json b/docs/doc/567b8601-1600-4951-b12d-b038276aa095.json
new file mode 100644
index 00000000..fa0c81c8
--- /dev/null
+++ b/docs/doc/567b8601-1600-4951-b12d-b038276aa095.json
@@ -0,0 +1,10 @@
+{
+ "summary": "This code defines two functions, `separate_weight_decayable_params` and `get_optimizer`. The `get_optimizer` function takes parameters, learning rate, weight decay, and other options to create an optimizer object. It filters the parameters based on `requires_grad`, separates weight-decayable parameters, and uses either Adam or AdamW optimizer depending on the weight decay value.",
+ "details": [
+ {
+ "comment": "This code defines two functions, `separate_weight_decayable_params` and `get_optimizer`. The `get_optimizer` function takes parameters, learning rate, weight decay, and other options to create an optimizer object. It filters the parameters based on `requires_grad`, separates weight-decayable parameters, and uses either Adam or AdamW optimizer depending on the weight decay value.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/optimizer.py\":0-33",
+ "content": "from torch.optim import AdamW, Adam\ndef separate_weight_decayable_params(params):\n wd_params, no_wd_params = [], []\n for param in params:\n param_list = no_wd_params if param.ndim < 2 else wd_params\n param_list.append(param)\n return wd_params, no_wd_params\ndef get_optimizer(\n params,\n lr = 1e-4,\n wd = 1e-2,\n betas = (0.9, 0.99),\n eps = 1e-8,\n filter_by_requires_grad = False,\n group_wd_params = True,\n **kwargs\n):\n if filter_by_requires_grad:\n params = list(filter(lambda t: t.requires_grad, params))\n if wd == 0:\n return Adam(params, lr = lr, betas = betas, eps = eps)\n if group_wd_params:\n wd_params, no_wd_params = separate_weight_decayable_params(params)\n params = [\n {'params': wd_params},\n {'params': no_wd_params, 'weight_decay': 0},\n ]\n return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/7472125d-6db3-456b-873d-d11465f2b72f.json b/docs/doc/7472125d-6db3-456b-873d-d11465f2b72f.json
new file mode 100644
index 00000000..bd7621ec
--- /dev/null
+++ b/docs/doc/7472125d-6db3-456b-873d-d11465f2b72f.json
@@ -0,0 +1,45 @@
+{
+ "summary": "This code defines ImageDataset and VQGanVAETrainer classes for loading image data and training a VAE model, setting parameters, optimizers, and creating loaders. It trains the model, logs losses, saves models, and tracks progress in a results folder.",
+ "details": [
+ {
+ "comment": "This code contains several utility functions and helper methods. It includes import statements for various libraries, classes for data handling and model training, as well as custom functions for logging, looping, and user input.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae_trainer.py\":0-46",
+ "content": "from math import sqrt\nimport copy\nfrom random import choice\nfrom pathlib import Path\nfrom shutil import rmtree\nfrom PIL import Image\nimport torch\nfrom torch import nn\nfrom torch.cuda.amp import autocast, GradScaler\nfrom torch.utils.data import Dataset, DataLoader, random_split\nimport torchvision.transforms as T\nfrom torchvision.datasets import ImageFolder\nfrom torchvision.utils import make_grid, save_image\nfrom einops import rearrange\nfrom dalle2_pytorch.vqgan_vae import VQGanVAE\nfrom dalle2_pytorch.optimizer import get_optimizer\nfrom ema_pytorch import EMA\n# helpers\ndef exists(val):\n return val is not None\ndef noop(*args, **kwargs):\n pass\ndef cycle(dl):\n while True:\n for data in dl:\n yield data\ndef cast_tuple(t):\n return t if isinstance(t, (tuple, list)) else (t,)\ndef yes_or_no(question):\n answer = input(f'{question} (y/n) ')\n return answer.lower() in ('yes', 'y')\ndef accum_log(log, new_logs):\n for key, new_value in new_logs.items():\n old_value = log.get(key, 0.)\n log[key] = old_value + new_value"
+ },
+ {
+ "comment": "The code defines a class \"ImageDataset\" for loading and transforming image data, and a main trainer class \"VQGanVAETrainer\" for training a VAE model. The \"ImageDataset\" class initializes with a folder path, image size, and extension types to filter the images, then applies image transformations like converting to RGB mode, resizing, horizontal flipping, cropping, and tensor conversion. The \"VQGanVAETrainer\" class initializes with parameters like the VAE model, number of training steps, learning rate, and batch size for the training process.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae_trainer.py\":47-90",
+ "content": " return log\n# classes\nclass ImageDataset(Dataset):\n def __init__(\n self,\n folder,\n image_size,\n exts = ['jpg', 'jpeg', 'png']\n ):\n super().__init__()\n self.folder = folder\n self.image_size = image_size\n self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]\n print(f'{len(self.paths)} training samples found at {folder}')\n self.transform = T.Compose([\n T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),\n T.Resize(image_size),\n T.RandomHorizontalFlip(),\n T.CenterCrop(image_size),\n T.ToTensor()\n ])\n def __len__(self):\n return len(self.paths)\n def __getitem__(self, index):\n path = self.paths[index]\n img = Image.open(path)\n return self.transform(img)\n# main trainer class\nclass VQGanVAETrainer(nn.Module):\n def __init__(\n self,\n vae,\n *,\n num_train_steps,\n lr,\n batch_size,"
+ },
+ {
+ "comment": "The code initializes an instance of a VQGanVAE and sets up various parameters for training. It checks if the provided vae is of type VQGanVAE, then assigns image size, creates an EMA model with specified update steps and intervals, registers a buffer for tracking steps, sets number of train steps, batch size, grad accumulation every, and initializes optimizer with specified learning rate and weight decay.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae_trainer.py\":91-122",
+ "content": " folder,\n grad_accum_every,\n wd = 0.,\n save_results_every = 100,\n save_model_every = 1000,\n results_folder = './results',\n valid_frac = 0.05,\n random_split_seed = 42,\n ema_beta = 0.995,\n ema_update_after_step = 500,\n ema_update_every = 10,\n apply_grad_penalty_every = 4,\n amp = False\n ):\n super().__init__()\n assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'\n image_size = vae.image_size\n self.vae = vae\n self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)\n self.register_buffer('steps', torch.Tensor([0]))\n self.num_train_steps = num_train_steps\n self.batch_size = batch_size\n self.grad_accum_every = grad_accum_every\n all_parameters = set(vae.parameters())\n discr_parameters = set(vae.discr.parameters())\n vae_parameters = all_parameters - discr_parameters\n self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)"
+ },
+ {
+ "comment": "This code initializes a Discriminator optimizer, Amplitude Signed-Precision (AMP) for mixed precision training, GradScaler for handling gradients, creates an ImageDataset from the given folder and image size, splits the dataset into training and validation if valid_frac is greater than 0, creates DataLoader for the dataset with specified batch_size and shuffle set to True.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae_trainer.py\":123-149",
+ "content": " self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)\n self.amp = amp\n self.scaler = GradScaler(enabled = amp)\n self.discr_scaler = GradScaler(enabled = amp)\n # create dataset\n self.ds = ImageDataset(folder, image_size = image_size)\n # split for validation\n if valid_frac > 0:\n train_size = int((1 - valid_frac) * len(self.ds))\n valid_size = len(self.ds) - train_size\n self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))\n print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')\n else:\n self.valid_ds = self.ds\n print(f'training with shared training and valid dataset of {len(self.ds)} samples')\n # dataloader\n self.dl = cycle(DataLoader(\n self.ds,\n batch_size = batch_size,\n shuffle = True"
+ },
+ {
+ "comment": "The code initializes the valid data loader and sets parameters for saving models, results, and applying gradient penalty. It checks if previous experiment checkpoints and results should be cleared, creates the results folder if needed, and defines the train_step function for training the VAE (generator).",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae_trainer.py\":150-187",
+ "content": " ))\n self.valid_dl = cycle(DataLoader(\n self.valid_ds,\n batch_size = batch_size,\n shuffle = True\n ))\n self.save_model_every = save_model_every\n self.save_results_every = save_results_every\n self.apply_grad_penalty_every = apply_grad_penalty_every\n self.results_folder = Path(results_folder)\n if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):\n rmtree(str(self.results_folder))\n self.results_folder.mkdir(parents = True, exist_ok = True)\n def train_step(self):\n device = next(self.vae.parameters()).device\n steps = int(self.steps.item())\n apply_grad_penalty = not (steps % self.apply_grad_penalty_every)\n self.vae.train()\n # logs\n logs = {}\n # update vae (generator)\n for _ in range(self.grad_accum_every):\n img = next(self.dl)\n img = img.to(device)\n with autocast(enabled = self.amp):"
+ },
+ {
+ "comment": "This code trains a VAE model and updates the discriminator. It uses scaling, accumulation, and gradients for efficient backpropagation. The loss is calculated and logged for both VAE and discriminator, then optimizers are updated.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae_trainer.py\":188-220",
+ "content": " loss = self.vae(\n img,\n return_loss = True,\n apply_grad_penalty = apply_grad_penalty\n )\n self.scaler.scale(loss / self.grad_accum_every).backward()\n accum_log(logs, {'loss': loss.item() / self.grad_accum_every})\n self.scaler.step(self.optim)\n self.scaler.update()\n self.optim.zero_grad()\n # update discriminator\n if exists(self.vae.discr):\n discr_loss = 0\n for _ in range(self.grad_accum_every):\n img = next(self.dl)\n img = img.to(device)\n with autocast(enabled = self.amp):\n loss = self.vae(img, return_discr_loss = True)\n self.discr_scaler.scale(loss / self.grad_accum_every).backward()\n accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})\n self.discr_scaler.step(self.discr_optim)\n self.discr_scaler.update()\n self.discr_optim.zero_grad()"
+ },
+ {
+ "comment": "This code snippet logs the VAE and discriminator losses, updates the exponential moving average (EMA) generator model, saves models every save_results_every steps, and generates and saves reconstruction images for training.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae_trainer.py\":222-250",
+ "content": " # log\n print(f\"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}\")\n # update exponential moving averaged generator\n self.ema_vae.update()\n # sample results every so often\n if not (steps % self.save_results_every):\n for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):\n model.eval()\n imgs = next(self.dl)\n imgs = imgs.to(device)\n recons = model(imgs)\n nrows = int(sqrt(self.batch_size))\n imgs_and_recons = torch.stack((imgs, recons), dim = 0)\n imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')\n imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)\n grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))\n logs['reconstructions'] = grid\n save_image(grid, str(self.results_folder / f'{filename}.png'))"
+ },
+ {
+ "comment": "Saves the VAE model and EMA-VAE model periodically during training, tracking progress in specified results folder.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae_trainer.py\":252-277",
+ "content": " print(f'{steps}: saving to {str(self.results_folder)}')\n # save model every so often\n if not (steps % self.save_model_every):\n state_dict = self.vae.state_dict()\n model_path = str(self.results_folder / f'vae.{steps}.pt')\n torch.save(state_dict, model_path)\n ema_state_dict = self.ema_vae.state_dict()\n model_path = str(self.results_folder / f'vae.{steps}.ema.pt')\n torch.save(ema_state_dict, model_path)\n print(f'{steps}: saving model to {str(self.results_folder)}')\n self.steps += 1\n return logs\n def train(self, log_fn = noop):\n device = next(self.vae.parameters()).device\n while self.steps < self.num_train_steps:\n logs = self.train_step()\n log_fn(logs)\n print('training complete')"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/7962afc4-93b8-41b6-8917-113b94d2693e.json b/docs/doc/7962afc4-93b8-41b6-8917-113b94d2693e.json
new file mode 100644
index 00000000..86ff341d
--- /dev/null
+++ b/docs/doc/7962afc4-93b8-41b6-8917-113b94d2693e.json
@@ -0,0 +1,115 @@
+{
+ "summary": "This code trains a Diffusion Prior model using PyTorch and DALLE2-pytorch library, with functions for creating the model, training, data loading, acceleration, evaluation, text-image similarity comparison, backpropagation, logging, saving best models, measuring speed, resetting validation timers, handling errors, saving models, and initializing training with data loaders and HFA setup.",
+ "details": [
+ {
+ "comment": "This code is for training a Diffusion Prior model using PyTorch and the DALLE2-pytorch library. It defines functions to create the model, configure the training process, and load data. The cosine similarity function is used for comparison, and there are helper functions to check if values exist and if they fall within specified bounds. The code also uses accelerate for efficient training and allows for device specification (CPU or GPU) and an optional accelerator instance for further optimization.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":0-44",
+ "content": "import click\nimport torch\nfrom torch import nn\nfrom typing import List\nfrom accelerate import Accelerator\nfrom accelerate.utils import set_seed\nfrom torch.utils.data import DataLoader\nfrom embedding_reader import EmbeddingReader\nfrom accelerate.utils import dataclasses as accelerate_dataclasses\nfrom dalle2_pytorch.utils import Timer\nfrom dalle2_pytorch.trackers import Tracker\nfrom dalle2_pytorch import DiffusionPriorTrainer\nfrom dalle2_pytorch.dataloaders import get_reader, make_splits\nfrom dalle2_pytorch.train_configs import (\n DiffusionPriorConfig,\n DiffusionPriorTrainConfig,\n TrainDiffusionPriorConfig,\n)\n# helpers\ncos = nn.CosineSimilarity(dim=1, eps=1e-6)\ndef exists(val):\n return val is not None\ndef all_between(values: list, lower_bound, upper_bound):\n for value in values:\n if value < lower_bound or value > upper_bound:\n return False\n return True\ndef make_model(\n prior_config: DiffusionPriorConfig,\n train_config: DiffusionPriorTrainConfig,\n device: str = None,\n accelerator: Accelerator = None,"
+ },
+ {
+ "comment": "This code defines a function `create_trainer` that takes in a `prior_config`, and creates a `DiffusionPriorTrainer` object with specified parameters. It also defines the `create_tracker` function, which creates a `Tracker` object based on the provided configuration. The functions return the created objects.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":45-82",
+ "content": "):\n # create model from config\n diffusion_prior = prior_config.create()\n # instantiate the trainer\n trainer = DiffusionPriorTrainer(\n diffusion_prior=diffusion_prior,\n lr=train_config.lr,\n wd=train_config.wd,\n max_grad_norm=train_config.max_grad_norm,\n amp=train_config.amp,\n use_ema=train_config.use_ema,\n device=device,\n accelerator=accelerator,\n warmup_steps=train_config.warmup_steps,\n )\n return trainer\ndef create_tracker(\n accelerator: Accelerator,\n config: TrainDiffusionPriorConfig,\n config_path: str,\n dummy: bool = False,\n) -> Tracker:\n tracker_config = config.tracker\n accelerator_config = {\n \"Distributed\": accelerator.distributed_type\n != accelerate_dataclasses.DistributedType.NO,\n \"DistributedType\": accelerator.distributed_type,\n \"NumProcesses\": accelerator.num_processes,\n \"MixedPrecision\": accelerator.mixed_precision,\n }\n tracker: Tracker = tracker_config.create(\n config, accelerator_config, dummy_mode=dummy"
+ },
+ {
+ "comment": "This function pads a value or tensor across all processes, gathers them and reduces them to a single average. It works with tensors of type \"mean\", \"sum\", \"max\", and \"min\". If the resulting tensor is empty, it returns None. It first waits for everyone to arrive before gathering, converts the input to a tensor if it's not already, and ensures that the tensor is on the proper device.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":83-121",
+ "content": " )\n tracker.save_config(config_path, config_name=\"prior_config.json\")\n return tracker\ndef pad_gather_reduce(trainer: DiffusionPriorTrainer, x, method=\"mean\"):\n \"\"\"\n pad a value or tensor across all processes and gather\n params:\n - trainer: a trainer that carries an accelerator object\n - x: a number or torch tensor to reduce\n - method: \"mean\", \"sum\", \"max\", \"min\"\n return:\n - the average tensor after maskin out 0's\n - None if the gather resulted in an empty tensor\n \"\"\"\n assert method in [\n \"mean\",\n \"sum\",\n \"max\",\n \"min\",\n ], \"This function has limited capabilities [sum, mean, max, min]\"\n assert type(x) is not None, \"Cannot reduce a None type object\"\n # wait for everyone to arrive here before gathering\n if type(x) is not torch.Tensor:\n x = torch.tensor([x])\n # verify that the tensor is on the proper device\n x = x.to(trainer.device)\n # pad across processes\n padded_x = trainer.accelerator.pad_across_processes(x, dim=0)"
+ },
+ {
+ "comment": "The code gathers tensor data across all processes, masks out zeros, and handles empty tensors. It then calculates the mean, sum, maximum, or minimum of the masked tensor depending on the method specified. The save_trainer function logs the model with an appropriate method based on the tracker.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":123-159",
+ "content": " # gather across all procesess\n gathered_x = trainer.accelerator.gather(padded_x)\n # mask out zeros\n masked_x = gathered_x[gathered_x != 0]\n # if the tensor is empty, warn and return None\n if len(masked_x) == 0:\n click.secho(\n f\"The call to this method resulted in an empty tensor after masking out zeros. The gathered tensor was this: {gathered_x} and the original value passed was: {x}.\",\n fg=\"red\",\n )\n return None\n if method == \"mean\":\n return torch.mean(masked_x)\n elif method == \"sum\":\n return torch.sum(masked_x)\n elif method == \"max\":\n return torch.max(masked_x)\n elif method == \"min\":\n return torch.min(masked_x)\ndef save_trainer(\n tracker: Tracker,\n trainer: DiffusionPriorTrainer,\n is_latest: bool,\n is_best: bool,\n epoch: int,\n samples_seen: int,\n best_validation_loss: float,\n):\n \"\"\"\n Logs the model with an appropriate method depending on the tracker\n \"\"\"\n trainer.accelerator.wait_for_everyone()"
+ },
+ {
+ "comment": "This code is part of a model training process. It saves the model at certain intervals and loads it later depending on the tracker type. The save function reports whether the saved model is best or latest, and the recall_trainer function loads the model with an appropriate method based on the tracker's loader type. Additionally, there are functions for evaluating validation loss.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":161-200",
+ "content": " if trainer.accelerator.is_main_process:\n click.secho(\n f\"RANK:{trainer.accelerator.process_index} | Saving Model | Best={is_best} | Latest={is_latest}\",\n fg=\"magenta\",\n )\n tracker.save(\n trainer=trainer,\n is_best=is_best,\n is_latest=is_latest,\n epoch=int(epoch),\n samples_seen=int(samples_seen),\n best_validation_loss=best_validation_loss,\n )\ndef recall_trainer(tracker: Tracker, trainer: DiffusionPriorTrainer):\n \"\"\"\n Loads the model with an appropriate method depending on the tracker\n \"\"\"\n if trainer.accelerator.is_main_process:\n click.secho(f\"Loading model from {type(tracker.loader).__name__}\", fg=\"yellow\")\n state_dict = tracker.recall()\n trainer.load(state_dict, strict=True)\n return (\n int(state_dict.get(\"epoch\", 0)),\n state_dict.get(\"best_validation_loss\", 0),\n int(state_dict.get(\"samples_seen\", 0)),\n )\n# eval functions\ndef report_validation_loss(\n trainer: DiffusionPriorTrainer,"
+ },
+ {
+ "comment": "This code measures validation loss on a given data subset, using an optional EMA model and text conditioning. It iterates through a dataloader, computes losses for each batch, accumulates them in total_loss, and finally returns the average loss. The progress is echoed if the process is the main one.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":201-238",
+ "content": " dataloader: DataLoader,\n text_conditioned: bool,\n use_ema: bool,\n tracker: Tracker,\n split: str,\n tracker_folder: str,\n loss_type: str,\n):\n \"\"\"\n Compute the validation loss on a given subset of data.\n \"\"\"\n if trainer.accelerator.is_main_process:\n click.secho(\n f\"Measuring performance on {use_ema}-{split} split\",\n fg=\"green\",\n blink=True,\n )\n total_loss = torch.zeros(1, dtype=torch.float, device=trainer.device)\n for image_embeddings, text_data in dataloader:\n image_embeddings = image_embeddings.to(trainer.device)\n text_data = text_data.to(trainer.device)\n input_args = dict(image_embed=image_embeddings)\n if text_conditioned:\n input_args = dict(**input_args, text=text_data)\n else:\n input_args = dict(**input_args, text_embed=text_data)\n if use_ema:\n loss = trainer.ema_diffusion_prior(**input_args)\n else:\n loss = trainer(**input_args)\n total_loss += loss"
+ },
+ {
+ "comment": "This code measures the cosine similarity on a given split with specified timesteps. It first sets the trainer to evaluation mode and then iterates through each batch of data from the dataloader. Within this loop, it moves both test image embeddings and text data to the device used by the trainer. If the model is text-conditioned, it generates an embedding from the tokenized text using the `embed_text` function provided by the trainer. This information can be useful for understanding how this code measures cosine similarity in a given context.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":240-274",
+ "content": " # compute the average loss across all processes\n avg_loss = pad_gather_reduce(trainer, total_loss, method=\"mean\")\n stats = {f\"{tracker_folder}/{loss_type}-loss\": avg_loss}\n # print and log results on main process\n tracker.log(stats, step=trainer.step.item() + 1)\n return avg_loss\ndef report_cosine_sims(\n trainer: DiffusionPriorTrainer,\n dataloader: DataLoader,\n text_conditioned: bool,\n tracker: Tracker,\n split: str,\n timesteps: int,\n tracker_folder: str,\n):\n trainer.eval()\n if trainer.accelerator.is_main_process:\n click.secho(\n f\"Measuring Cosine-Similarity on {split} split with {timesteps} timesteps\",\n fg=\"green\",\n blink=True,\n )\n for test_image_embeddings, text_data in dataloader:\n test_image_embeddings = test_image_embeddings.to(trainer.device)\n text_data = text_data.to(trainer.device)\n # we are text conditioned, we produce an embedding from the tokenized text\n if text_conditioned:\n text_embedding, text_encodings = trainer.embed_text(text_data)"
+ },
+ {
+ "comment": "This code shuffles text embeddings and encodings to simulate \"unrelated\" captions for training the diffusion model. If text-conditioned, it also shuffles the text condition. It prepares both text and image embeddings.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":275-302",
+ "content": " text_cond = dict(text_embed=text_embedding, text_encodings=text_encodings)\n else:\n text_embedding = text_data\n text_cond = dict(text_embed=text_embedding)\n # make a copy of the text embeddings for shuffling\n text_embed_shuffled = text_embedding.clone()\n # roll the text to simulate \"unrelated\" captions\n rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)\n text_embed_shuffled = text_embed_shuffled[rolled_idx]\n text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm(\n dim=1, keepdim=True\n )\n if text_conditioned:\n text_encodings_shuffled = text_encodings[rolled_idx]\n else:\n text_encodings_shuffled = None\n text_cond_shuffled = dict(\n text_embed=text_embed_shuffled, text_encodings=text_encodings_shuffled\n )\n # prepare the text embedding\n text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)\n # prepare image embeddings"
+ },
+ {
+ "comment": "This code calculates the similarity between text embeddings and image embeddings, then shuffles the text embeddings to create unrelated pairs. It uses diffusion models for prediction and normalizes the embeddings. The final step is calculating the similarities using cosine similarity and mean reduction method.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":303-333",
+ "content": " test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(\n dim=1, keepdim=True\n )\n # predict on the unshuffled text embeddings\n predicted_image_embeddings = trainer.p_sample_loop(\n test_image_embeddings.shape,\n text_cond,\n timesteps=timesteps,\n )\n predicted_image_embeddings = (\n predicted_image_embeddings\n / predicted_image_embeddings.norm(dim=1, keepdim=True)\n )\n # predict on the shuffled embeddings\n predicted_unrelated_embeddings = trainer.p_sample_loop(\n test_image_embeddings.shape,\n text_cond_shuffled,\n timesteps=timesteps,\n )\n predicted_unrelated_embeddings = (\n predicted_unrelated_embeddings\n / predicted_unrelated_embeddings.norm(dim=1, keepdim=True)\n )\n # calculate similarities\n orig_sim = pad_gather_reduce(\n trainer, cos(text_embed, test_image_embeddings), method=\"mean\""
+ },
+ {
+ "comment": "This code calculates similarity scores between embeddings of text, predicted images, and original images. It then logs these scores for various steps in the training process to track progress.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":334-359",
+ "content": " )\n pred_sim = pad_gather_reduce(\n trainer, cos(text_embed, predicted_image_embeddings), method=\"mean\"\n )\n unrel_sim = pad_gather_reduce(\n trainer, cos(text_embed, predicted_unrelated_embeddings), method=\"mean\"\n )\n pred_img_sim = pad_gather_reduce(\n trainer,\n cos(test_image_embeddings, predicted_image_embeddings),\n method=\"mean\",\n )\n stats = {\n f\"{tracker_folder}/baseline similarity [steps={timesteps}]\": orig_sim,\n f\"{tracker_folder}/similarity with text [steps={timesteps}]\": pred_sim,\n f\"{tracker_folder}/similarity with original image [steps={timesteps}]\": pred_img_sim,\n f\"{tracker_folder}/similarity with unrelated caption [steps={timesteps}]\": unrel_sim,\n f\"{tracker_folder}/difference from baseline similarity [steps={timesteps}]\": pred_sim\n - orig_sim,\n }\n tracker.log(stats, step=trainer.step.item() + 1)\ndef eval_model("
+ },
+ {
+ "comment": "This function runs evaluation on a model, tracks metrics, and returns the loss if requested. It uses DiffusionPriorTrainer and DataLoader. The use_ema parameter is used to differentiate between an Exponential Moving Average (EMA) model and an online (current) model. It checks whether the timesteps are valid for the model's noise scheduler. It also measures cosine metrics across various eta and timesteps if report_cosine is set to True.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":360-397",
+ "content": " trainer: DiffusionPriorTrainer,\n dataloader: DataLoader,\n text_conditioned: bool,\n split: str,\n tracker: Tracker,\n use_ema: bool,\n report_cosine: bool,\n report_loss: bool,\n timesteps: List[int],\n loss_type: str = None,\n):\n \"\"\"\n Run evaluation on a model and track metrics\n returns: loss if requested\n \"\"\"\n trainer.eval()\n use_ema = \"ema\" if use_ema else \"online\"\n tracker_folder = f\"metrics/{use_ema}-{split}\"\n # detemine if valid timesteps are passed\n min_timesteps = trainer.accelerator.unwrap_model(\n trainer.diffusion_prior\n ).sample_timesteps\n max_timesteps = trainer.accelerator.unwrap_model(\n trainer.diffusion_prior\n ).noise_scheduler.num_timesteps\n assert all_between(\n timesteps, lower_bound=min_timesteps, upper_bound=max_timesteps\n ), f\"all timesteps values must be between {min_timesteps} and {max_timesteps}: got {timesteps}\"\n # measure cosine metrics across various eta and timesteps\n if report_cosine:\n for timestep in timesteps:"
+ },
+ {
+ "comment": "This code measures cosine similarity on a separate dataset and reports the loss on another split of data in a training script. It also initializes timers for saving, measuring samples per second, and tracking validation time.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":398-439",
+ "content": " report_cosine_sims(\n trainer,\n dataloader=dataloader,\n text_conditioned=text_conditioned,\n tracker=tracker,\n split=split,\n timesteps=timestep,\n tracker_folder=tracker_folder,\n )\n # measure loss on a seperate split of data\n if report_loss:\n loss = report_validation_loss(\n trainer=trainer,\n dataloader=dataloader,\n text_conditioned=text_conditioned,\n use_ema=use_ema,\n tracker=tracker,\n split=split,\n tracker_folder=tracker_folder,\n loss_type=loss_type,\n )\n return loss\n# training script\ndef train(\n trainer: DiffusionPriorTrainer,\n tracker: Tracker,\n train_loader: DataLoader,\n eval_loader: DataLoader,\n test_loader: DataLoader,\n config: DiffusionPriorTrainConfig,\n):\n # init timers\n save_timer = Timer() # when to save\n samples_timer = Timer() # samples/sec\n validation_profiler = Timer() # how long is validation taking"
+ },
+ {
+ "comment": "The code sets up a training loop that iterates over epochs and resets the dataloader if it was paused mid-epoch. It places data on the device, tracks the best validation loss, and keeps track of samples seen.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":440-470",
+ "content": " validation_countdown = Timer() # when to perform evalutation\n # keep track of best validation loss\n best_validation_loss = config.train.best_validation_loss\n samples_seen = config.train.num_samples_seen\n # do training\n start_epoch = config.train.current_epoch\n for epoch in range(start_epoch, config.train.epochs):\n # if we finished out an old epoch, reset the distribution to be a full epoch\n tracker.log({\"tracking/epoch\": epoch}, step=trainer.step.item())\n if train_loader.dataset.get_start() > 0 and epoch == start_epoch+1:\n if trainer.accelerator.is_main_process:\n click.secho(f\"Finished resumed epoch...resetting dataloader.\")\n train_loader.dataset.set_start(0)\n for img, txt in train_loader:\n # setup things every step\n trainer.train()\n current_step = trainer.step.item()\n samples_timer.reset()\n # place data on device\n img = img.to(trainer.device)\n txt = txt.to(trainer.device)"
+ },
+ {
+ "comment": "This code is performing backpropagation, updating the exponential moving average (EMA), logging training metrics, and tracking evaluation intervals. It calculates the loss from text and image embeddings using the trainer model and updates the EMA diffusion prior. Metrics like samples per second, number of samples seen, EMA decay, and a specific loss type are logged at each step, while evaluating the validation countdown time interval for metrics tracking.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":472-503",
+ "content": " # pass to model\n loss = trainer(text=txt, image_embed=img)\n # perform backprop & apply EMA updates\n trainer.update()\n # gather info about training step\n all_loss = pad_gather_reduce(trainer, loss, method=\"mean\")\n num_samples = pad_gather_reduce(trainer, len(txt), method=\"sum\")\n samples_per_sec = num_samples / samples_timer.elapsed()\n samples_seen += num_samples\n ema_decay = trainer.ema_diffusion_prior.get_current_decay()\n # log\n tracker.log(\n {\n \"tracking/samples-sec\": samples_per_sec,\n \"tracking/samples-seen\": samples_seen,\n \"tracking/ema-decay\": ema_decay,\n f\"tracking/training-{config.prior.loss_type}\": all_loss,\n },\n step=current_step,\n )\n # Metric Tracking @ Timed Intervals\n eval_delta = pad_gather_reduce(\n trainer, validation_countdown.elapsed(), method=\"min\""
+ },
+ {
+ "comment": "This code is evaluating the model on validation data with specified options. It checks if it's time to evaluate, resets the profiler for timing, packages evaluation kwargs, and calls eval_model function with dataloader, loss type, split (validation), use_ema, report_cosine, report_loss, and eval_kwargs. It also evaluates the ema model separately.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":504-535",
+ "content": " )\n if eval_delta != None and eval_delta > config.data.eval_every_seconds:\n # begin timing how long this takes\n validation_profiler.reset()\n # package kwargs for evaluation\n eval_kwargs = {\n \"trainer\": trainer,\n \"tracker\": tracker,\n \"text_conditioned\": config.prior.condition_on_text_encodings,\n \"timesteps\": config.train.eval_timesteps,\n }\n # ONLINE MODEL : COSINE : LOSS : VALIDATION SPLIT\n eval_model(\n dataloader=eval_loader,\n loss_type=config.prior.loss_type,\n split=\"validation\",\n use_ema=False,\n report_cosine=False,\n report_loss=True,\n **eval_kwargs,\n )\n # EMA MODEL : COSINE : LOSS : VALIDATION DATA\n ema_val_loss = eval_model(\n dataloader=eval_loader,"
+ },
+ {
+ "comment": "In this code, a validation process is executed using ema (exponential moving average) to calculate the loss. The lowest ema validation loss seen so far is stored in `best_validation_loss` and if the current validation loss is lower than the previous best, the model is saved as the 'best' model. This code also logs the time taken for the validation process.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":536-565",
+ "content": " loss_type=config.prior.loss_type,\n split=\"validation\",\n use_ema=True,\n report_cosine=True,\n report_loss=True,\n **eval_kwargs,\n )\n tracker.log(\n {\n \"tracking/validation length (minutes)\": validation_profiler.elapsed()\n / 60\n }\n )\n # check if the ema validation is the lowest seen yet\n if ema_val_loss < best_validation_loss:\n best_validation_loss = ema_val_loss\n # go save the model as best\n save_trainer(\n trainer=trainer,\n tracker=tracker,\n is_best=True,\n is_latest=False,\n samples_seen=samples_seen,\n epoch=epoch,\n best_validation_loss=best_validation_loss,"
+ },
+ {
+ "comment": "This code segment resets the validation timer and handles errors in reading eval and save times. It saves the latest model if the elapsed time meets a certain condition, and resets the save timer. This helps keep track of the training progress and ensures timely saving of models for later use.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":566-597",
+ "content": " )\n # reset timer for validaiton\n validation_countdown.reset()\n elif eval_delta is None:\n click.secho(\n f\"Error occured reading the eval time on rank: {trainer.device}\",\n fg=\"yellow\",\n )\n # save as latest model on schedule\n save_delta = pad_gather_reduce(trainer, save_timer.elapsed(), method=\"min\")\n if save_delta != None and save_delta >= config.train.save_every_seconds:\n save_trainer(\n trainer=trainer,\n tracker=tracker,\n is_best=False,\n is_latest=True,\n samples_seen=samples_seen,\n epoch=epoch,\n best_validation_loss=best_validation_loss,\n )\n save_timer.reset()\n elif save_delta is None:\n click.secho(\n f\"Error occured reading the save time on rank: {trainer.device}\","
+ },
+ {
+ "comment": "Starting test phase and saving the last model as latest before validation. If test loss is lower than previous best validation loss, it will be saved as the new best model.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":598-639",
+ "content": " fg=\"yellow\",\n )\n # evaluate on test data\n if trainer.accelerator.is_main_process:\n click.secho(f\"Starting Test\", fg=\"red\")\n # save one last time as latest before beginning validation\n save_trainer(\n tracker=tracker,\n trainer=trainer,\n is_best=False,\n is_latest=True,\n samples_seen=samples_seen,\n epoch=epoch,\n best_validation_loss=best_validation_loss,\n )\n test_loss = eval_model(\n trainer=trainer,\n dataloader=test_loader,\n text_conditioned=config.prior.condition_on_text_encodings,\n split=\"test\",\n tracker=tracker,\n use_ema=True,\n report_cosine=False,\n report_loss=True,\n timesteps=config.train.eval_timesteps,\n loss_type=config.prior.loss_type,\n )\n if test_loss < best_validation_loss:\n best_validation_loss = test_loss\n # go save the model as best\n save_trainer(\n trainer=trainer,\n tracker=tracker,\n is_best=True,"
+ },
+ {
+ "comment": "The function initialize_training is responsible for loading the configuration file, setting the seed, getting a device, making the trainer, and creating a tracker. The trainer is automatically distributed if possible and configured. Additionally, the function checks whether it can recall from a checkpoint.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":640-680",
+ "content": " is_latest=False,\n samples_seen=samples_seen,\n epoch=epoch,\n best_validation_loss=test_loss,\n )\ndef initialize_training(config_file, accelerator):\n \"\"\"\n Parse the configuration file, and prepare everything necessary for training\n \"\"\"\n # load the configuration file\n if accelerator.is_main_process:\n click.secho(f\"Loading configuration from {config_file}\", fg=\"green\")\n config = TrainDiffusionPriorConfig.from_json_path(config_file)\n # seed\n set_seed(config.train.random_seed)\n # get a device\n device = accelerator.device\n # make the trainer (will automatically distribute if possible & configured)\n trainer: DiffusionPriorTrainer = make_model(\n config.prior, config.train, device, accelerator\n ).to(device)\n # create a tracker\n tracker = create_tracker(\n accelerator, config, config_file, dummy=accelerator.process_index != 0\n )\n # reload from chcekpoint\n if tracker.can_recall:\n current_epoch, best_validation_loss, samples_seen = recall_trainer("
+ },
+ {
+ "comment": "This code block displays the current epoch, best validation loss, and samples seen, updates configuration with recalled values, fetches and prepares data for training by creating a loader, and calculates the start point within the epoch.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":681-710",
+ "content": " tracker=tracker, trainer=trainer\n )\n # display best values\n if trainer.accelerator.is_main_process:\n click.secho(f\"Current Epoch: {current_epoch} | Best Val Loss: {best_validation_loss} | Samples Seen: {samples_seen}\", fg=\"yellow\")\n # update config to reflect recalled values\n config.train.num_samples_seen = samples_seen\n config.train.current_epoch = current_epoch\n config.train.best_validation_loss = best_validation_loss\n # fetch and prepare data\n if trainer.accelerator.is_main_process:\n click.secho(\"Grabbing data...\", fg=\"blue\", blink=True)\n trainer.accelerator.wait_for_everyone()\n img_reader = get_reader(\n text_conditioned=trainer.text_conditioned,\n img_url=config.data.image_url,\n meta_url=config.data.meta_url,\n )\n # calculate start point within epoch\n trainer.accelerator.wait_for_everyone()\n train_loader, eval_loader, test_loader = make_splits(\n text_conditioned=trainer.text_conditioned,"
+ },
+ {
+ "comment": "This code initializes a data loader and sets the start point for resuming training if necessary. It ensures that the training continues from where it left off in a previous run by adjusting the number of samples seen based on the total number of data points and the current epoch. The main process prints a message indicating the resumption sample count.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":711-742",
+ "content": " batch_size=config.data.batch_size,\n num_data_points=config.data.num_data_points,\n train_split=config.data.splits.train,\n eval_split=config.data.splits.val,\n image_reader=img_reader,\n rank=accelerator.state.process_index,\n world_size=accelerator.state.num_processes,\n start=0,\n )\n # update the start point to finish out the epoch on a resumed run\n if tracker.can_recall:\n samples_seen = config.train.num_samples_seen\n length = (\n config.data.num_data_points\n if samples_seen <= img_reader.count\n else img_reader.count\n )\n scaled_samples = length * config.train.current_epoch\n start_point = (\n scaled_samples - samples_seen if scaled_samples > samples_seen else samples_seen\n )\n if trainer.accelerator.is_main_process:\n click.secho(f\"Resuming at sample: {start_point}\", fg=\"yellow\")\n train_loader.dataset.set_start(start_point)\n # start training\n if trainer.accelerator.is_main_process:"
+ },
+ {
+ "comment": "Beginning Prior Training message with distributed status. Then, initiates training process using provided configurations and loaders for trainer, tracker, train_loader, eval_loader, and test_loader. Finally, executes main function with the specified config file to start Heterogeneous Fusion Acceleration (HFA) and set up the training environment.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_diffusion_prior.py\":743-769",
+ "content": " click.secho(\n f\"Beginning Prior Training : Distributed={accelerator.state.distributed_type != accelerate_dataclasses.DistributedType.NO}\",\n fg=\"yellow\",\n )\n train(\n trainer=trainer,\n tracker=tracker,\n train_loader=train_loader,\n eval_loader=eval_loader,\n test_loader=test_loader,\n config=config,\n )\n@click.command()\n@click.option(\"--config_file\", default=\"configs/train_prior_config.example.json\")\ndef main(config_file):\n # start HFA\n accelerator = Accelerator()\n # setup training\n initialize_training(config_file, accelerator)\nif __name__ == \"__main__\":\n main()"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/7af9c850-6b81-4f7b-abf7-5a2656e9ef3a.json b/docs/doc/7af9c850-6b81-4f7b-abf7-5a2656e9ef3a.json
new file mode 100644
index 00000000..e785d23b
--- /dev/null
+++ b/docs/doc/7af9c850-6b81-4f7b-abf7-5a2656e9ef3a.json
@@ -0,0 +1,15 @@
+{
+ "summary": "This code defines a Dataset class and get_images_dataloader function for loading image data. The Dataset class initializes with a folder path, image size, and extensions to consider. The get_images_dataloader function returns a DataLoader object for the specified folder with optional parameters like batch size, shuffle, cycle_dl, and pin_memory.",
+ "details": [
+ {
+ "comment": "This code defines a Dataset class and get_images_dataloader function for loading image data. The Dataset class initializes with a folder path, image size, and extensions to consider. It uses transforms to apply resizing, horizontal flipping, centercropping, and converting images to tensors. The get_images_dataloader function returns a data loader object for the specified folder.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py\":0-46",
+ "content": "from pathlib import Path\nimport torch\nfrom torch.utils import data\nfrom torchvision import transforms, utils\nfrom PIL import Image\n# helpers functions\ndef cycle(dl):\n while True:\n for data in dl:\n yield data\n# dataset and dataloader\nclass Dataset(data.Dataset):\n def __init__(\n self,\n folder,\n image_size,\n exts = ['jpg', 'jpeg', 'png']\n ):\n super().__init__()\n self.folder = folder\n self.image_size = image_size\n self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]\n self.transform = transforms.Compose([\n transforms.Resize(image_size),\n transforms.RandomHorizontalFlip(),\n transforms.CenterCrop(image_size),\n transforms.ToTensor()\n ])\n def __len__(self):\n return len(self.paths)\n def __getitem__(self, index):\n path = self.paths[index]\n img = Image.open(path)\n return self.transform(img)\ndef get_images_dataloader(\n folder,\n *,"
+ },
+ {
+ "comment": "This function takes parameters such as folder, batch size, image size, shuffle, cycle_dl, and pin_memory. It creates a dataset from the provided folder using a given image size. Then, it uses DataLoader to create a data loader with the specified batch size, shuffle, and pin memory settings. If cycle_dl is True, it applies cyclic permutations to the data loader. Finally, it returns the data loader.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py\":47-58",
+ "content": " batch_size,\n image_size,\n shuffle = True,\n cycle_dl = True,\n pin_memory = True\n):\n ds = Dataset(folder, image_size)\n dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)\n if cycle_dl:\n dl = cycle(dl)\n return dl"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/7c49687b-bcf8-461f-9ea3-5406b2617359.json b/docs/doc/7c49687b-bcf8-461f-9ea3-5406b2617359.json
new file mode 100644
index 00000000..2665af9c
--- /dev/null
+++ b/docs/doc/7c49687b-bcf8-461f-9ea3-5406b2617359.json
@@ -0,0 +1,130 @@
+{
+ "summary": "The code initializes DeepSpeed's trainer, sets model parameters, distributes the model, and handles precision. It also initializes optimizers and schedulers, prepares dataloaders, validates compatibility, performs computations, and returns total loss.",
+ "details": [
+ {
+ "comment": "The code imports various libraries and defines several utility functions for working with tensors, learning rates, optimizers, and distributed training. It also includes helper functions to handle default values and handle dictionaries. These utilities are likely used throughout the codebase to train and evaluate models efficiently.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":0-42",
+ "content": "import time\nimport copy\nfrom pathlib import Path\nfrom math import ceil\nfrom functools import partial, wraps\nfrom contextlib import nullcontext\nfrom collections.abc import Iterable\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nfrom torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR\nfrom torch.cuda.amp import autocast, GradScaler\nfrom dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior\nfrom dalle2_pytorch.optimizer import get_optimizer\nfrom dalle2_pytorch.version import __version__\nfrom packaging import version\nimport pytorch_warmup as warmup\nfrom ema_pytorch import EMA\nfrom accelerate import Accelerator, DistributedType\nimport numpy as np\n# helper functions\ndef exists(val):\n return val is not None\ndef default(val, d):\n if exists(val):\n return val\n return d() if callable(d) else d\ndef cast_tuple(val, length = 1):\n return val if isinstance(val, tuple) else ((val,) * length)\ndef pick_and_pop(keys, d):\n values = list(map(lambda key: d.pop(key), keys))\n return dict(zip(keys, values))"
+ },
+ {
+ "comment": "group_dict_by_key: Creates two dictionaries, one for keys that match the condition and another for those that do not, grouping by key.\nstring_begins_with: Returns a boolean value indicating whether a given string starts with a specified prefix.\ngroup_by_key_prefix: Groups dictionary items based on whether their keys start with a certain prefix.\ngroupby_prefix_and_trim: Similar to group_by_key_prefix, but also trims the common prefix from the keys and returns two dictionaries.\nnum_to_groups: Divides a given number into groups based on a specified divisor, appending any remainder to the last group.\ncast_torch_tensor: A decorator that wraps a function to cast its input and output tensors to specific devices.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":44-77",
+ "content": "def group_dict_by_key(cond, d):\n return_val = [dict(),dict()]\n for key in d.keys():\n match = bool(cond(key))\n ind = int(not match)\n return_val[ind][key] = d[key]\n return (*return_val,)\ndef string_begins_with(prefix, str):\n return str.startswith(prefix)\ndef group_by_key_prefix(prefix, d):\n return group_dict_by_key(partial(string_begins_with, prefix), d)\ndef groupby_prefix_and_trim(prefix, d):\n kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)\n kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))\n return kwargs_without_prefix, kwargs\ndef num_to_groups(num, divisor):\n groups = num // divisor\n remainder = num % divisor\n arr = [divisor] * groups\n if remainder > 0:\n arr.append(remainder)\n return arr\n# decorators\ndef cast_torch_tensor(fn):\n @wraps(fn)\n def inner(model, *args, **kwargs):\n device = kwargs.pop('_device', next(model.parameters()).device)\n cast_device = kwargs.pop('_cast_device', True)"
+ },
+ {
+ "comment": "This code handles argument casting and device assignment for a DeepSpeed-accelerated PyTorch model. It first checks if arguments are DeepSpeed precision types, then casts the tensors to the appropriate type if necessary. This ensures that the model's arguments are correctly prepared for training or evaluation within a DeepSpeed framework.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":78-98",
+ "content": " cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)\n kwargs_keys = kwargs.keys()\n all_args = (*args, *kwargs.values())\n split_kwargs_index = len(all_args) - len(kwargs_keys)\n all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))\n if cast_device:\n all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))\n if cast_deepspeed_precision:\n try:\n accelerator = model.accelerator\n if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:\n cast_type_map = {\n \"fp16\": torch.half,\n \"bf16\": torch.bfloat16,\n \"no\": torch.float\n }\n precision_type = cast_type_map[accelerator.mixed_precision]\n all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))"
+ },
+ {
+ "comment": "This code defines functions for splitting arguments and keywords, as well as handling gradient accumulation. It includes a function to split an iterable into chunks of specified size (`split_iterable`), a `split` function for tensors and iterables, and a `find_first` function to find the first item in an array that meets a given condition. The last function defined is `split_args_and_kwargs`, which splits arguments and keywords based on a specified size.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":99-137",
+ "content": " except AttributeError:\n # Then this model doesn't have an accelerator\n pass\n args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]\n kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))\n out = fn(model, *args, **kwargs)\n return out\n return inner\n# gradient accumulation functions\ndef split_iterable(it, split_size):\n accum = []\n for ind in range(ceil(len(it) / split_size)):\n start_index = ind * split_size\n accum.append(it[start_index: (start_index + split_size)])\n return accum\ndef split(t, split_size = None):\n if not exists(split_size):\n return t\n if isinstance(t, torch.Tensor):\n return t.split(split_size, dim = 0)\n if isinstance(t, Iterable):\n return split_iterable(t, split_size)\n return TypeError\ndef find_first(cond, arr):\n for el in arr:\n if cond(el):\n return el\n return None\ndef split_args_and_kwargs(*args, split_size = None, **kwargs):"
+ },
+ {
+ "comment": "This code splits the input arguments and keyword arguments into chunks based on batch size, split size, and dictionary keys. It then yields the chunk size fraction and the split chunked arguments and keyword arguments for further processing.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":138-158",
+ "content": " all_args = (*args, *kwargs.values())\n len_all_args = len(all_args)\n first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)\n assert exists(first_tensor)\n batch_size = len(first_tensor)\n split_size = default(split_size, batch_size)\n num_chunks = ceil(batch_size / split_size)\n dict_len = len(kwargs)\n dict_keys = kwargs.keys()\n split_kwargs_index = len_all_args - dict_len\n split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]\n chunk_sizes = tuple(map(len, split_all_args[0]))\n for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):\n chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]\n chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))\n chunk_size_frac = chunk_size / batch_size\n yield chunk_size_frac, (chunked_args, chunked_kwargs)"
+ },
+ {
+ "comment": "This code defines a `DiffusionPriorTrainer` class that takes in a `diffusion_prior`, and allows for training with different batch sizes by splitting arguments and keywords into chunks. It also supports optional accelerator, learning rate, weight decay, epsilon, max gradient norm, grouped weight decay parameters, warmup steps, and cosine decay maximum steps.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":160-191",
+ "content": "# diffusion prior trainer\ndef prior_sample_in_chunks(fn):\n @wraps(fn)\n def inner(self, *args, max_batch_size = None, **kwargs):\n if not exists(max_batch_size):\n return fn(self, *args, **kwargs)\n outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]\n return torch.cat(outputs, dim = 0)\n return inner\nclass DiffusionPriorTrainer(nn.Module):\n def __init__(\n self,\n diffusion_prior,\n accelerator = None,\n use_ema = True,\n lr = 3e-4,\n wd = 1e-2,\n eps = 1e-6,\n max_grad_norm = None,\n group_wd_params = True,\n warmup_steps = None,\n cosine_decay_max_steps = None,\n **kwargs\n ):\n super().__init__()\n assert isinstance(diffusion_prior, DiffusionPrior)\n ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)\n accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)"
+ },
+ {
+ "comment": "Checking if an accelerator is specified, assigning member variables for helpful operations, setting device and transferring model to that device, saving the diffusion prior model, and checking mixed precision settings if applicable.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":193-224",
+ "content": " if not exists(accelerator):\n accelerator = Accelerator(**accelerator_kwargs)\n # assign some helpful member vars\n self.accelerator = accelerator\n self.text_conditioned = diffusion_prior.condition_on_text_encodings\n # setting the device\n self.device = accelerator.device\n diffusion_prior.to(self.device)\n # save model\n self.diffusion_prior = diffusion_prior\n # mixed precision checks\n if (\n exists(self.accelerator) \n and self.accelerator.distributed_type == DistributedType.DEEPSPEED \n and self.diffusion_prior.clip is not None\n ):\n # Then we need to make sure clip is using the correct precision or else deepspeed will error\n cast_type_map = {\n \"fp16\": torch.half,\n \"bf16\": torch.bfloat16,\n \"no\": torch.float\n }\n precision_type = cast_type_map[accelerator.mixed_precision]\n assert precision"
+ },
+ {
+ "comment": "This code initializes the trainer for DeepSpeed, setting precision, optimizer, and scheduler. It checks if on-the-fly embedding generation from CLIP is supported and changes precision accordingly. It also distributes the model using HFA and applies exponential moving average techniques.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":224-248",
+ "content": "_type == torch.float, \"DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip\"\n self.diffusion_prior.clip.to(precision_type)\n # optimizer stuff\n self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)\n self.optimizer = get_optimizer(\n self.diffusion_prior.parameters(),\n **self.optim_kwargs,\n **kwargs\n )\n if exists(cosine_decay_max_steps):\n self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)\n else:\n self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)\n self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None\n # distribute the model if using HFA\n self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler)\n # exponential moving average stuff"
+ },
+ {
+ "comment": "The code snippet initializes a trainer object with an option for exponential moving average (EMA), gradient clipping, and tracks steps internally. It also defines a save method to save the optimizer, scheduler, model state dictionaries, and warmup scheduler on the main process. Note that LambdaLR cannot be saved due to pickling issues.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":250-279",
+ "content": " self.use_ema = use_ema\n if self.use_ema:\n self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs)\n # gradient clipping if needed\n self.max_grad_norm = max_grad_norm\n # track steps internally\n self.register_buffer('step', torch.tensor([0], device = self.device))\n # utility\n def save(self, path, overwrite = True, **kwargs):\n # only save on the main process\n if self.accelerator.is_main_process:\n print(f\"Saving checkpoint at step: {self.step.item()}\")\n path = Path(path)\n assert not (path.exists() and not overwrite)\n path.parent.mkdir(parents = True, exist_ok = True)\n # FIXME: LambdaLR can't be saved due to pickling issues\n save_obj = dict(\n optimizer = self.optimizer.state_dict(),\n scheduler = self.scheduler.state_dict(),\n warmup_scheduler = self.warmup_scheduler,\n model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),"
+ },
+ {
+ "comment": "This code saves and loads a checkpoint for a diffusion prior trainer. It also handles saving the EMA (Exponential Moving Average) model separately for easy ema-only reload, and allows overwriting the learning rate if needed. The `load` method loads an entire trainer, including its optimizer and EMA.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":280-303",
+ "content": " version = version.parse(__version__),\n step = self.step,\n **kwargs\n )\n if self.use_ema:\n save_obj = {\n **save_obj,\n 'ema': self.ema_diffusion_prior.state_dict(),\n 'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload\n }\n torch.save(save_obj, str(path))\n def load(self, path_or_state, overwrite_lr = True, strict = True):\n \"\"\"\n Load a checkpoint of a diffusion prior trainer.\n Will load the entire trainer, including the optimizer and EMA.\n Params:\n - path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file\n - overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer\n - strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match"
+ },
+ {
+ "comment": "This function loads a checkpoint from a specified path or dictionary, handling both string paths and existing dictionaries. It checks if the loaded version matches the current package version, then unwraps and loads the model's state dict, sets step values, and loads optimizer and scheduler states as well.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":305-326",
+ "content": " Returns:\n loaded_obj (dict): The loaded checkpoint dictionary\n \"\"\"\n # all processes need to load checkpoint. no restriction here\n if isinstance(path_or_state, str):\n path = Path(path_or_state)\n assert path.exists()\n loaded_obj = torch.load(str(path), map_location=self.device)\n elif isinstance(path_or_state, dict):\n loaded_obj = path_or_state\n if version.parse(__version__) != loaded_obj['version']:\n print(f'loading saved diffusion prior at version {loaded_obj[\"version\"]} but current package version is at {__version__}')\n # unwrap the model when loading from checkpoint\n self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)\n self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))\n self.optimizer.load_state_dict(loaded_obj['optimizer'])\n self.scheduler.load_state_dict(loaded_obj['scheduler'])"
+ },
+ {
+ "comment": "This function handles the warmup step, updating the learning rate if needed, loading EMA diffusion prior state from a checkpoint, and performing model update with optimization.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":328-357",
+ "content": " # set warmupstep\n if exists(self.warmup_scheduler):\n self.warmup_scheduler.last_step = self.step.item()\n # ensure new lr is used if different from old one\n if overwrite_lr:\n new_lr = self.optim_kwargs[\"lr\"]\n for group in self.optimizer.param_groups:\n group[\"lr\"] = new_lr if group[\"lr\"] > 0.0 else 0.0\n if self.use_ema:\n assert 'ema' in loaded_obj\n self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)\n # below might not be necessary, but I had a suspicion that this wasn't being loaded correctly\n self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj[\"ema_model\"])\n return loaded_obj\n # model functionality\n def update(self):\n if exists(self.max_grad_norm):\n self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)\n self.optimizer.step()\n self.optimizer.zero_grad()\n # accelerator will ocassionally skip optimizer steps in a \"dynamic loss scaling strategy\""
+ },
+ {
+ "comment": "The code defines several methods for using the diffusion prior model to generate samples. It uses exponential moving average (EMA) for model averaging, if `use_ema` is enabled. The `p_sample_loop`, `sample`, and `sample_batch_size` methods use `torch.no_grad()` for performance optimization, and `cast_torch_tensor` and `prior_sample_in_chunks` decorators are used to process data in chunks.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":358-385",
+ "content": " if not self.accelerator.optimizer_step_was_skipped:\n sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext\n with sched_context():\n self.scheduler.step()\n if self.use_ema:\n self.ema_diffusion_prior.update()\n self.step += 1\n @torch.no_grad()\n @cast_torch_tensor\n @prior_sample_in_chunks\n def p_sample_loop(self, *args, **kwargs):\n model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior\n return model.p_sample_loop(*args, **kwargs)\n @torch.no_grad()\n @cast_torch_tensor\n @prior_sample_in_chunks\n def sample(self, *args, **kwargs):\n model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior\n return model.sample(*args, **kwargs)\n @torch.no_grad()\n def sample_batch_size(self, *args, **kwargs):\n model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior\n return model.sample_batch_size(*args, **kwargs)"
+ },
+ {
+ "comment": "This code defines a trainer with a function `embed_text` that uses the unwrapped model for embedding text, and a `forward` method that performs forward pass in chunks to handle large batch sizes. The `decoder_sample_in_chunks` decorator enables chunking when sample decoding.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":387-422",
+ "content": " @torch.no_grad()\n @cast_torch_tensor\n @prior_sample_in_chunks\n def embed_text(self, *args, **kwargs):\n return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)\n @cast_torch_tensor\n def forward(\n self,\n *args,\n max_batch_size = None,\n **kwargs\n ):\n total_loss = 0.\n for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):\n with self.accelerator.autocast():\n loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)\n loss = loss * chunk_size_frac\n total_loss += loss.item()\n if self.training:\n self.accelerator.backward(loss)\n return total_loss\n# decoder trainer\ndef decoder_sample_in_chunks(fn):\n @wraps(fn)\n def inner(self, *args, max_batch_size = None, **kwargs):\n if not exists(max_batch_size):\n return fn(self, *args, **kwargs)\n if self.decoder.unconditional:"
+ },
+ {
+ "comment": "The function is a trainer that takes a decoder, accelerator, and other parameters. It can handle batching the inputs or splitting arguments and keywords to train the decoder in chunks, depending on the size of input data. The returned inner function is used for training the model using the provided configuration.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":423-453",
+ "content": " batch_size = kwargs.get('batch_size')\n batch_sizes = num_to_groups(batch_size, max_batch_size)\n outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]\n else:\n outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]\n return torch.cat(outputs, dim = 0)\n return inner\nclass DecoderTrainer(nn.Module):\n def __init__(\n self,\n decoder,\n accelerator = None,\n dataloaders = None,\n use_ema = True,\n lr = 1e-4,\n wd = 1e-2,\n eps = 1e-8,\n warmup_steps = None,\n cosine_decay_max_steps = None,\n max_grad_norm = 0.5,\n amp = False,\n group_wd_params = True,\n **kwargs\n ):\n super().__init__()\n assert isinstance(decoder, Decoder)\n ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)\n self.accelerator = default(accelerator, Accelerator)"
+ },
+ {
+ "comment": "The code initializes the trainer with specific configurations for each UNET in the decoder. It checks learning rate, weight decay, warmup steps, and cosine decay max steps for each UNET. If a UNET is an identity, it assigns no optimizer or scheduler. Otherwise, it gets an appropriate optimizer for the UNET's parameters.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":455-480",
+ "content": " self.num_unets = len(decoder.unets)\n self.use_ema = use_ema\n self.ema_unets = nn.ModuleList([])\n self.amp = amp\n # be able to finely customize learning rate, weight decay\n # per unet\n lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))\n assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'\n optimizers = []\n schedulers = []\n warmup_schedulers = []\n for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):\n if isinstance(unet, nn.Identity):\n optimizers.append(None)\n schedulers.append(None)\n warmup_schedulers.append(None)\n else:\n optimizer = get_optimizer(\n unet.parameters(),"
+ },
+ {
+ "comment": "The code initializes optimizers, optionally schedulers for learning rate adjustments, and an exponential moving average (EMA) for the UNETs. It also registers a buffer for tracking steps and handles gradient clipping if needed based on distributed type.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":481-509",
+ "content": " lr = unet_lr,\n wd = unet_wd,\n eps = unet_eps,\n group_wd_params = group_wd_params,\n **kwargs\n )\n optimizers.append(optimizer)\n if exists(unet_cosine_decay_max_steps):\n scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)\n else:\n scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)\n warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None\n warmup_schedulers.append(warmup_scheduler)\n schedulers.append(scheduler)\n if self.use_ema:\n self.ema_unets.append(EMA(unet, **ema_kwargs))\n # gradient clipping if needed\n self.max_grad_norm = max_grad_norm\n self.register_buffer('steps', torch.tensor([0] * self.num_unets))\n if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:"
+ },
+ {
+ "comment": "This code ensures that the correct precision is used by DeepSpeed and prepares the decoder, optimizers, and dataloaders for training. It converts the clip to the specified precision type, then prepares them using DeepSpeed's accelerator. The train_loader and val_loader are stored for later use.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":510-536",
+ "content": " # Then we need to make sure clip is using the correct precision or else deepspeed will error\n cast_type_map = {\n \"fp16\": torch.half,\n \"bf16\": torch.bfloat16,\n \"no\": torch.float\n }\n precision_type = cast_type_map[accelerator.mixed_precision]\n assert precision_type == torch.float, \"DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip\"\n clip = decoder.clip\n clip.to(precision_type)\n decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))\n self.decoder = decoder\n # prepare dataloaders\n train_loader = val_loader = None\n if exists(dataloaders):\n train_loader, val_loader = self.accelerator.prepare(dataloaders[\"train\"], dataloaders[\"val\"])\n self.train_loader = train_loader\n self.val_loader = val_loader\n # store optimizers\n for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):"
+ },
+ {
+ "comment": "This code defines a class with optimizers, schedulers, and warmup schedulers. It also validates the unet number and returns the number of steps taken by a specific unet. The save function saves the model's state dict to a specified path.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":537-565",
+ "content": " setattr(self, f'optim{opt_ind}', optimizer)\n # store schedulers\n for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):\n setattr(self, f'sched{sched_ind}', scheduler)\n # store warmup schedulers\n self.warmup_schedulers = warmup_schedulers\n def validate_and_return_unet_number(self, unet_number = None):\n if self.num_unets == 1:\n unet_number = default(unet_number, 1)\n assert exists(unet_number) and 1 <= unet_number <= self.num_unets\n return unet_number\n def num_steps_taken(self, unet_number = None):\n unet_number = self.validate_and_return_unet_number(unet_number)\n return self.steps[unet_number - 1].item()\n def save(self, path, overwrite = True, **kwargs):\n path = Path(path)\n assert not (path.exists() and not overwrite)\n path.parent.mkdir(parents = True, exist_ok = True)\n save_obj = dict(\n model = self.accelerator.unwrap_model(self.decoder).state_dict(),"
+ },
+ {
+ "comment": "This code snippet saves the model state, optimizer state, and scheduler state if they exist, and an optional Exponential Moving Average (EMA) state. It checks the version compatibility before loading the saved state dictionary.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":566-590",
+ "content": " version = __version__,\n steps = self.steps.cpu(),\n **kwargs\n )\n for ind in range(0, self.num_unets):\n optimizer_key = f'optim{ind}'\n scheduler_key = f'sched{ind}'\n optimizer = getattr(self, optimizer_key)\n scheduler = getattr(self, scheduler_key)\n optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None\n scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None\n save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}\n if self.use_ema:\n save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}\n self.accelerator.save(save_obj, str(path))\n def load_state_dict(self, loaded_obj, only_model = False, strict = True):\n if version.parse(__version__) != version.parse(loaded_obj['version']):\n self.accelerator.print(f'loading saved decoder at version {loaded_obj[\"version\"]}, but current package version is {__version__}')"
+ },
+ {
+ "comment": "This code loads a model and its associated optimizers, schedulers, and warmup schedulers from the given path. It also checks if early-stopping (ema) was used and loads that as well. The function returns the loaded state of each component if only_model is True, otherwise it continues with training.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":592-621",
+ "content": " self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)\n self.steps.copy_(loaded_obj['steps'])\n if only_model:\n return loaded_obj\n for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):\n optimizer_key = f'optim{ind}'\n optimizer = getattr(self, optimizer_key)\n scheduler_key = f'sched{ind}'\n scheduler = getattr(self, scheduler_key)\n warmup_scheduler = self.warmup_schedulers[ind]\n if exists(optimizer):\n optimizer.load_state_dict(loaded_obj[optimizer_key])\n if exists(scheduler):\n scheduler.load_state_dict(loaded_obj[scheduler_key])\n if exists(warmup_scheduler):\n warmup_scheduler.last_step = last_step\n if self.use_ema:\n assert 'ema' in loaded_obj\n self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)\n def load(self, path, only_model = False, strict = True):"
+ },
+ {
+ "comment": "This function loads a saved state and returns it. It also provides access to the unets (U-Nets) in the model and allows incrementing the step of a specific unet. The update method updates the optimizer and scheduler for a specified unet.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":622-651",
+ "content": " path = Path(path)\n assert path.exists()\n loaded_obj = torch.load(str(path), map_location = 'cpu')\n self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)\n return loaded_obj\n @property\n def unets(self):\n return nn.ModuleList([ema.ema_model for ema in self.ema_unets])\n def increment_step(self, unet_number):\n assert 1 <= unet_number <= self.num_unets\n unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)\n self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))\n def update(self, unet_number = None):\n unet_number = self.validate_and_return_unet_number(unet_number)\n index = unet_number - 1\n optimizer = getattr(self, f'optim{index}')\n scheduler = getattr(self, f'sched{index}')\n if exists(self.max_grad_norm):\n self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients\n optimizer.step()"
+ },
+ {
+ "comment": "This code is responsible for the sampling process in a specific model. It uses gradient descent to optimize the model and updates the exponential moving average (EMA) unets if ema is enabled. The sample function enables evaluation mode, handles non-ema usage or disabled use_ema, and returns the output based on the input arguments. The distributed argument is used for multi-process sampling.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":652-682",
+ "content": " optimizer.zero_grad()\n warmup_scheduler = self.warmup_schedulers[index]\n scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext\n with scheduler_context():\n scheduler.step()\n if self.use_ema:\n ema_unet = self.ema_unets[index]\n ema_unet.update()\n self.increment_step(unet_number)\n @torch.no_grad()\n @cast_torch_tensor\n @decoder_sample_in_chunks\n def sample(self, *args, **kwargs):\n distributed = self.accelerator.num_processes > 1\n base_decoder = self.accelerator.unwrap_model(self.decoder)\n was_training = base_decoder.training\n base_decoder.eval()\n if kwargs.pop('use_non_ema', False) or not self.use_ema:\n out = base_decoder.sample(*args, **kwargs, distributed = distributed)\n base_decoder.train(was_training)\n return out\n trainable_unets = self.accelerator.unwrap_model(self.decoder).unets\n base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling"
+ },
+ {
+ "comment": "This code defines a function for embedding text and image using the decoder's CLIP module. It also restores the original training unets, casts torch tensors, validates and returns the correct unet number, and allows for conditional lowres image return.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":684-716",
+ "content": " output = base_decoder.sample(*args, **kwargs, distributed = distributed)\n base_decoder.unets = trainable_unets # restore original training unets\n # cast the ema_model unets back to original device\n for ema in self.ema_unets:\n ema.restore_ema_model_device()\n base_decoder.train(was_training)\n return output\n @torch.no_grad()\n @cast_torch_tensor\n @prior_sample_in_chunks\n def embed_text(self, *args, **kwargs):\n return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)\n @torch.no_grad()\n @cast_torch_tensor\n @prior_sample_in_chunks\n def embed_image(self, *args, **kwargs):\n return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)\n @cast_torch_tensor\n def forward(\n self,\n *args,\n unet_number = None,\n max_batch_size = None,\n return_lowres_cond_image=False,\n **kwargs\n ):\n unet_number = self.validate_and_return_unet_number(unet_number)"
+ },
+ {
+ "comment": "This code chunk splits the input arguments and keywords into multiple smaller chunks, then iterates over them to perform computations with auto-cast enabled. The resulting losses are accumulated, and if conditional images are returned, they are stacked together. Finally, the total loss is returned.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/trainer.py\":718-741",
+ "content": " total_loss = 0.\n cond_images = []\n for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):\n with self.accelerator.autocast():\n loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs)\n # loss_obj may be a tuple with loss and cond_image\n if return_lowres_cond_image:\n loss, cond_image = loss_obj\n else:\n loss = loss_obj\n cond_image = None\n loss = loss * chunk_size_frac\n if cond_image is not None:\n cond_images.append(cond_image)\n total_loss += loss.item()\n if self.training:\n self.accelerator.backward(loss)\n if return_lowres_cond_image:\n return total_loss, torch.stack(cond_images)\n else:\n return total_loss"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/8575c17d-4679-4e52-9664-8ff238712c9e.json b/docs/doc/8575c17d-4679-4e52-9664-8ff238712c9e.json
new file mode 100644
index 00000000..dacdf829
--- /dev/null
+++ b/docs/doc/8575c17d-4679-4e52-9664-8ff238712c9e.json
@@ -0,0 +1,170 @@
+{
+ "summary": "This code divides shards, initializes training, and trains UNet models for DALL-E 2 using PyTorch. It also supports distributed training and executes as a standalone program.",
+ "details": [
+ {
+ "comment": "This code imports various modules and defines constants for training a decoder model in the DALLE2-pytorch framework. It uses DecoderTrainer, dataloaders, trackers, train configs, utilities, and models from the dalle2_pytorch package. It also includes metrics such as FrechetInceptionDistance, InceptionScore, KernelInceptionDistance, and LearnedPerceptualImagePatchSimilarity for evaluation. Accelerate is used for accelerated training, and webdataset is used for data loading.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":0-27",
+ "content": "from pathlib import Path\nfrom typing import List\nfrom datetime import timedelta\nfrom dalle2_pytorch.trainer import DecoderTrainer\nfrom dalle2_pytorch.dataloaders import create_image_embedding_dataloader\nfrom dalle2_pytorch.trackers import Tracker\nfrom dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig\nfrom dalle2_pytorch.utils import Timer, print_ribbon\nfrom dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to\nfrom clip import tokenize\nimport torchvision\nimport torch\nfrom torch import nn\nfrom torchmetrics.image.fid import FrechetInceptionDistance\nfrom torchmetrics.image.inception import InceptionScore\nfrom torchmetrics.image.kid import KernelInceptionDistance\nfrom torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity\nfrom accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs\nfrom accelerate.utils import dataclasses as accelerate_dataclasses\nimport webdataset as wds\nimport click\n# constants\nTRAIN_CALC_LOSS_EVERY_ITERS = 10\nVALID_CALC_LOSS_EVERY_ITERS = 10"
+ },
+ {
+ "comment": "This function takes available shards, URLs for embeddings, and other parameters to randomly split them into train, validation, and test sets, then returns dataloaders for each. It asserts that the proportions of splits sum up to 1, calculates the actual number of samples in each split based on the proportion, and checks if the sum of splits matches the total number of available shards.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":29-63",
+ "content": "# helpers functions\ndef exists(val):\n return val is not None\n# main functions\ndef create_dataloaders(\n available_shards,\n webdataset_base_url,\n img_embeddings_url=None,\n text_embeddings_url=None,\n shard_width=6,\n num_workers=4,\n batch_size=32,\n n_sample_images=6,\n shuffle_train=True,\n resample_train=False,\n img_preproc = None,\n index_width=4,\n train_prop = 0.75,\n val_prop = 0.15,\n test_prop = 0.10,\n seed = 0,\n **kwargs\n):\n \"\"\"\n Randomly splits the available shards into train, val, and test sets and returns a dataloader for each\n \"\"\"\n assert train_prop + test_prop + val_prop == 1\n num_train = round(train_prop*len(available_shards))\n num_test = round(test_prop*len(available_shards))\n num_val = len(available_shards) - num_train - num_test\n assert num_train + num_test + num_val == len(available_shards), f\"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}\"\n train_split, test_split, val_split ="
+ },
+ {
+ "comment": "This code randomly splits available shards into training, testing, and validation sets. It then generates corresponding URLs for each set by zero-padding the shard numbers to match the filename format. A lambda function is created to handle creating a dataloader for image embeddings using these URLs, considering various parameters like batch size and number of workers.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":63-79",
+ "content": " torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(seed))\n # The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.\n train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]\n test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]\n val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]\n create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(\n tar_url=tar_urls,\n num_workers=num_workers,\n batch_size=batch_size if not for_sampling else n_sample_images,\n img_embeddings_url=img_embeddings_url,\n text_embeddings_url=text_embeddings_url,\n index_width=index_width,\n shuffle_num = None,\n extra_keys= [\"txt\"],\n shuffle_shards = shuffle,"
+ },
+ {
+ "comment": "The code creates multiple data loaders for training, validation, and testing datasets. It returns a dictionary with each dataset's corresponding dataloader. The `get_dataset_keys` function extracts the real dataloader if the input is a WebLoader.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":80-102",
+ "content": " resample_shards = resample, \n img_preproc=img_preproc,\n handler=wds.handlers.warn_and_continue\n )\n train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)\n train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)\n val_dataloader = create_dataloader(val_urls, shuffle=False)\n test_dataloader = create_dataloader(test_urls, shuffle=False)\n test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)\n return {\n \"train\": train_dataloader,\n \"train_sampling\": train_sampling_dataloader,\n \"val\": val_dataloader,\n \"test\": test_dataloader,\n \"test_sampling\": test_sampling_dataloader\n }\ndef get_dataset_keys(dataloader):\n \"\"\"\n It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.\n \"\"\"\n # If the dataloader is actually a WebLoader, we need to extract the real dataloader"
+ },
+ {
+ "comment": "The code samples the dataloader and returns a zipped list of examples. It iterates through each image, extracts its embedding, converts it to the device's format, extends the respective lists for images and text embeddings, and finally returns them.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":103-129",
+ "content": " if isinstance(dataloader, wds.WebLoader):\n dataloader = dataloader.pipeline[0]\n return dataloader.dataset.key_map\ndef get_example_data(dataloader, device, n=5):\n \"\"\"\n Samples the dataloader and returns a zipped list of examples\n \"\"\"\n images = []\n img_embeddings = []\n text_embeddings = []\n captions = []\n for img, emb, txt in dataloader:\n img_emb, text_emb = emb.get('img'), emb.get('text')\n if img_emb is not None:\n img_emb = img_emb.to(device=device, dtype=torch.float)\n img_embeddings.extend(list(img_emb))\n else:\n # Then we add None img.shape[0] times\n img_embeddings.extend([None]*img.shape[0])\n if text_emb is not None:\n text_emb = text_emb.to(device=device, dtype=torch.float)\n text_embeddings.extend(list(text_emb))\n else:\n # Then we add None img.shape[0] times\n text_embeddings.extend([None]*img.shape[0])\n img = img.to(device=device, dtype=torch.float)"
+ },
+ {
+ "comment": "This function generates samples by taking example data and creating real images, generated images, and captions. If image embeddings are None, it generates them using the clip model. It returns three lists: real images, generated images, and captions.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":130-151",
+ "content": " images.extend(list(img))\n captions.extend(list(txt))\n if len(images) >= n:\n break\n return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))\ndef generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=\"\", match_image_size=True):\n \"\"\"\n Takes example data and generates images from the embeddings\n Returns three lists: real images, generated images, and captions\n \"\"\"\n real_images, img_embeddings, text_embeddings, txts = zip(*example_data)\n sample_params = {}\n if img_embeddings[0] is None:\n # Generate image embeddings from clip\n imgs_tensor = torch.stack(real_images)\n assert clip is not None, \"clip is None, but img_embeddings is None\"\n imgs_tensor.to(device=device)\n img_embeddings, img_encoding = clip.embed_image(imgs_tensor)\n sample_params[\"image_embed\"] = img_embeddings\n else:\n # Then we are using precomputed image embeddings"
+ },
+ {
+ "comment": "This code is responsible for preparing training samples by stacking image and text embeddings, setting parameters for start and stop U-net layers, and handling the case where real images are provided. If real images exist, it stacks them as part of the sample. The code also considers whether to generate text embeddings or use precomputed ones and ensures everything is on the specified device.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":152-171",
+ "content": " img_embeddings = torch.stack(img_embeddings)\n sample_params[\"image_embed\"] = img_embeddings\n if condition_on_text_encodings:\n if text_embeddings[0] is None:\n # Generate text embeddings from text\n assert clip is not None, \"clip is None, but text_embeddings is None\"\n tokenized_texts = tokenize(txts, truncate=True).to(device=device)\n text_embed, text_encodings = clip.embed_text(tokenized_texts)\n sample_params[\"text_encodings\"] = text_encodings\n else:\n # Then we are using precomputed text embeddings\n text_embeddings = torch.stack(text_embeddings)\n sample_params[\"text_encodings\"] = text_embeddings\n sample_params[\"start_at_unet_number\"] = start_unet\n sample_params[\"stop_at_unet_number\"] = end_unet\n if start_unet > 1:\n # If we are only training upsamplers\n sample_params[\"image\"] = torch.stack(real_images)\n if device is not None:\n sample_params[\"_device\"] = device"
+ },
+ {
+ "comment": "This function generates samples, combines them with real images in a grid format for easy viewing. It first calls `generate_samples` to get the real and generated images along with their corresponding captions. Then it uses `torchvision.utils.make_grid` to create grids of original and generated images.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":172-185",
+ "content": " samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # At sampling time we don't want to cast to FP16\n generated_images = list(samples)\n captions = [text_prepend + txt for txt in txts]\n if match_image_size:\n generated_image_size = generated_images[0].shape[-1]\n real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]\n return real_images, generated_images, captions\ndef generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=\"\"):\n \"\"\"\n Generates samples and uses torchvision to put them in a side by side grid for easy viewing\n \"\"\"\n real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)\n grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]"
+ },
+ {
+ "comment": "This function computes evaluation metrics for a decoder. It prepares data, generates samples using the trainer and start/end unets, converts images from [0, 1] to [0, 255], and types them as uint8. The generated and real images are then stored in variables for further evaluation metrics calculations.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":186-202",
+ "content": " return grid_images, captions\ndef evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):\n \"\"\"\n Computes evaluation metrics for the decoder\n \"\"\"\n metrics = {}\n # Prepare the data\n examples = get_example_data(dataloader, device, n_evaluation_samples)\n if len(examples) == 0:\n print(\"No data to evaluate. Check that your dataloader has shards.\")\n return metrics\n real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)\n real_images = torch.stack(real_images).to(device=device, dtype=torch.float)\n generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)\n # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8\n int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)"
+ },
+ {
+ "comment": "This code calculates and stores metrics for the quality of generated images, including Frechet Inception Distance (FID), Inception Score (IS), and Kernel Inception Distance (KID). It first scales the generated images, then checks if specific configuration files exist for each metric. If they do, it creates an instance of the corresponding metric class, sets it up on the device, updates with real and generated images, and computes the metric values. The computed metrics are stored in the \"metrics\" dictionary.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":203-226",
+ "content": " int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)\n def null_sync(t, *args, **kwargs):\n return [t]\n if exists(FID):\n fid = FrechetInceptionDistance(**FID, dist_sync_fn=null_sync)\n fid.to(device=device)\n fid.update(int_real_images, real=True)\n fid.update(int_generated_images, real=False)\n metrics[\"FID\"] = fid.compute().item()\n if exists(IS):\n inception = InceptionScore(**IS, dist_sync_fn=null_sync)\n inception.to(device=device)\n inception.update(int_real_images)\n is_mean, is_std = inception.compute()\n metrics[\"IS_mean\"] = is_mean.item()\n metrics[\"IS_std\"] = is_std.item()\n if exists(KID):\n kernel_inception = KernelInceptionDistance(**KID, dist_sync_fn=null_sync)\n kernel_inception.to(device=device)\n kernel_inception.update(int_real_images, real=True)\n kernel_inception.update(int_generated_images, real=False)\n kid_mean, kid_std = kernel_inception.compute()"
+ },
+ {
+ "comment": "This code calculates metrics such as KID and LPIPS for a model's performance. It stores the values in a dictionary, normalizes the images if LPIPS is present, applies the LearnedPerceptualImagePatchSimilarity function, and syncs the calculated metrics across processes using accelerator functions.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":227-246",
+ "content": " metrics[\"KID_mean\"] = kid_mean.item()\n metrics[\"KID_std\"] = kid_std.item()\n if exists(LPIPS):\n # Convert from [0, 1] to [-1, 1]\n renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)\n renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)\n lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)\n lpips.to(device=device)\n lpips.update(renorm_real_images, renorm_generated_images)\n metrics[\"LPIPS\"] = lpips.compute().item()\n if trainer.accelerator.num_processes > 1:\n # Then we should sync the metrics\n metrics_order = sorted(metrics.keys())\n metrics_tensor = torch.zeros(1, len(metrics), device=device, dtype=torch.float)\n for i, metric_name in enumerate(metrics_order):\n metrics_tensor[0, i] = metrics[metric_name]\n metrics_tensor = trainer.accelerator.gather(metrics_tensor)\n metrics_tensor = metrics_tensor.mean(dim=0)\n for i, metric_name in enumerate(metrics_order):"
+ },
+ {
+ "comment": "This code contains three functions: 1) `train_decoder`, which updates metrics based on the current metric; 2) `save_trainer`, which logs the model using an appropriate method according to the tracker; and 3) `recall_trainer`, which loads the model using the tracker. The code is part of a larger system that likely involves training a machine learning model, tracking its progress, and recalling it for further use or evaluation.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":247-263",
+ "content": " metrics[metric_name] = metrics_tensor[i].item()\n return metrics\ndef save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):\n \"\"\"\n Logs the model with an appropriate method depending on the tracker\n \"\"\"\n tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)\ndef recall_trainer(tracker: Tracker, trainer: DecoderTrainer):\n \"\"\"\n Loads the model with an appropriate method depending on the tracker\n \"\"\"\n trainer.accelerator.print(print_ribbon(f\"Loading model from {type(tracker.loader).__name__}\"))\n state_dict = tracker.recall()\n trainer.load_state_dict(state_dict, only_model=False, strict=True)\n return state_dict.get(\"epoch\", 0), state_dict.get(\"validation_losses\", []), state_dict.get(\"next_task\", \"train\"), state_dict.get(\"sample\", 0), state_dict.get(\"samples_seen\", 0)"
+ },
+ {
+ "comment": "The function trains a decoder on a dataset, using the specified dataloaders, Decoder instance, and Accelerator. It also has optional arguments for clip, evaluate_config, epoch_samples, validation_samples, save_immediately, epochs, n_sample_images, save_every_n_samples, unet_training_mask, condition_on_text_encodings, and cond_scale. The function checks if the unet_training_mask exists and asserts that its length matches the number of unets in the decoder. It also assigns trainable unet numbers to a list.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":265-293",
+ "content": "def train(\n dataloaders,\n decoder: Decoder,\n accelerator: Accelerator,\n tracker: Tracker,\n inference_device,\n clip=None,\n evaluate_config=None,\n epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch\n validation_samples = None,\n save_immediately=False,\n epochs = 20,\n n_sample_images = 5,\n save_every_n_samples = 100000,\n unet_training_mask=None,\n condition_on_text_encodings=False,\n cond_scale=1.0,\n **kwargs\n):\n \"\"\"\n Trains a decoder on a dataset.\n \"\"\"\n is_master = accelerator.process_index == 0\n if not exists(unet_training_mask):\n # Then the unet mask should be true for all unets in the decoder\n unet_training_mask = [True] * len(decoder.unets)\n assert len(unet_training_mask) == len(decoder.unets), f\"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}\"\n trainable_unet_numbers = [i+1 for i, trainable in enumerate(unet_training_mask) if trainable]"
+ },
+ {
+ "comment": "The code is removing non-trainable UNet modules and setting up a trainer for the given task. It also checks if the state can be recalled from a previous training session and updates relevant variables accordingly.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":294-321",
+ "content": " first_trainable_unet = trainable_unet_numbers[0]\n last_trainable_unet = trainable_unet_numbers[-1]\n def move_unets(unet_training_mask):\n for i in range(len(decoder.unets)):\n if not unet_training_mask[i]:\n # Replace the unet from the module list with a nn.Identity(). This training script never uses unets that aren't being trained so this is fine.\n decoder.unets[i] = nn.Identity().to(inference_device)\n # Remove non-trainable unets\n move_unets(unet_training_mask)\n trainer = DecoderTrainer(\n decoder=decoder,\n accelerator=accelerator,\n dataloaders=dataloaders,\n **kwargs\n )\n # Set up starting model and parameters based on a recalled state dict\n start_epoch = 0\n validation_losses = []\n next_task = 'train'\n sample = 0\n samples_seen = 0\n val_sample = 0\n step = lambda: int(trainer.num_steps_taken(unet_number=first_trainable_unet))\n if tracker.can_recall:\n start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)"
+ },
+ {
+ "comment": "The code loads a model and starts training from the specified task, either 'train' or 'val'. It prints the details of the loaded model, including epoch, samples seen, and minimum validation loss. The trainer is moved to the inference device. Example data for both training and testing is generated using get_example_data function with the specified number of sample images.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":322-336",
+ "content": " if next_task == 'train':\n sample = recalled_sample\n if next_task == 'val':\n val_sample = recalled_sample\n accelerator.print(f\"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}\")\n accelerator.print(f\"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}\")\n trainer.to(device=inference_device)\n accelerator.print(print_ribbon(\"Generating Example Data\", repeat=40))\n accelerator.print(\"This can take a while to load the shard lists...\")\n if is_master:\n train_example_data = get_example_data(dataloaders[\"train_sampling\"], inference_device, n_sample_images)\n accelerator.print(\"Generated training examples\")\n test_example_data = get_example_data(dataloaders[\"test_sampling\"], inference_device, n_sample_images)\n accelerator.print(\"Generated testing examples\")"
+ },
+ {
+ "comment": "Iterating over epochs in training mode, counting the total number of samples across all processes. Gathering sample length tensors using accelerator's gather function and summing them up to get the total samples seen. Updating sample and samples_seen variables accordingly.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":338-356",
+ "content": " send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]\n sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)\n unet_losses_tensor = torch.zeros(TRAIN_CALC_LOSS_EVERY_ITERS, trainer.num_unets, dtype=torch.float, device=inference_device)\n for epoch in range(start_epoch, epochs):\n accelerator.print(print_ribbon(f\"Starting epoch {epoch}\", repeat=40))\n timer = Timer()\n last_sample = sample\n last_snapshot = sample\n if next_task == 'train':\n for i, (img, emb, txt) in enumerate(dataloaders[\"train\"]):\n # We want to count the total number of samples across all processes\n sample_length_tensor[0] = len(img)\n all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.\n total_samples = all_samples.sum().item()\n sample += total_samples\n samples_seen += total_samples"
+ },
+ {
+ "comment": "This code checks if there are image or text embeddings available, sends them to the device, and then trains a model. It also performs a forward pass for image embedding generation if necessary.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":357-379",
+ "content": " img_emb = emb.get('img')\n has_img_embedding = img_emb is not None\n if has_img_embedding:\n img_emb, = send_to_device((img_emb,))\n text_emb = emb.get('text')\n has_text_embedding = text_emb is not None\n if has_text_embedding:\n text_emb, = send_to_device((text_emb,))\n img, = send_to_device((img,))\n trainer.train()\n for unet in range(1, trainer.num_unets+1):\n # Check if this is a unet we are training\n if not unet_training_mask[unet-1]: # Unet index is the unet number - 1\n continue\n forward_params = {}\n if has_img_embedding:\n forward_params['image_embed'] = img_emb\n else:\n # Forward pass automatically generates embedding\n assert clip is not None\n img_embed, img_encoding = clip.embed_image(img)"
+ },
+ {
+ "comment": "This code chunk is for training the DALL-E 2 model's decoder. It first checks if image and text embeddings are provided, and if not, it tokenizes the text and generates text embeddings using the CLIP model. Then, it passes the required parameters to the trainer and updates the model, storing the loss for each unit in the unet_losses_tensor array.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":380-393",
+ "content": " forward_params['image_embed'] = img_embed\n if condition_on_text_encodings:\n if has_text_embedding:\n forward_params['text_encodings'] = text_emb\n else:\n # Then we need to pass the text instead\n assert clip is not None\n tokenized_texts = tokenize(txt, truncate=True).to(inference_device)\n assert tokenized_texts.shape[0] == len(img), f\"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})\"\n text_embed, text_encodings = clip.embed_text(tokenized_texts)\n forward_params['text_encodings'] = text_encodings\n loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)\n trainer.update(unet_number=unet)\n unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss"
+ },
+ {
+ "comment": "This code is calculating the samples per second and resetting timers, then averaging the losses across all processes for a UNet model. It gathers the decay rate on each UNet, logs epoch, sample, and step information.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":395-412",
+ "content": " samples_per_sec = (sample - last_sample) / timer.elapsed()\n timer.reset()\n last_sample = sample\n if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:\n # We want to average losses across all processes\n unet_all_losses = accelerator.gather(unet_losses_tensor)\n mask = unet_all_losses != 0\n unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)\n loss_map = { f\"Unet {index} Training Loss\": loss.item() for index, loss in enumerate(unet_average_loss) if unet_training_mask[index] }\n # gather decay rate on each UNet\n ema_decay_list = {f\"Unet {index} EMA Decay\": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets) if unet_training_mask[index]}\n log_data = {\n \"Epoch\": epoch,\n \"Sample\": sample,\n \"Step\": i,"
+ },
+ {
+ "comment": "This code snippet is logging data and saving a snapshot of the model at specific intervals. It logs samples per second, samples seen, EMA decay parameters, and loss metrics. The snapshot is saved if the current sample meets certain conditions or every time an immediate save command is issued. The code prints \"Saving snapshot\" when a snapshot is taken.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":413-430",
+ "content": " \"Samples per second\": samples_per_sec,\n \"Samples Seen\": samples_seen,\n **ema_decay_list,\n **loss_map\n }\n if is_master:\n tracker.log(log_data, step=step())\n if is_master and (last_snapshot + save_every_n_samples < sample or (save_immediately and i == 0)): # This will miss by some amount every time, but it's not a big deal... I hope\n # It is difficult to gather this kind of info on the accelerator, so we have to do it on the master\n print(\"Saving snapshot\")\n last_snapshot = sample\n # We need to know where the model should be saved\n save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)\n if exists(n_sample_images) and n_sample_images > 0:\n trainer.eval()\n "
+ },
+ {
+ "comment": "This code is used for training a model and validating it. It generates samples from the training dataset, logs them, checks if it should stop based on sample count, switches to validation mode, and initializes variables for validation.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":430-447",
+ "content": " train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, \"Train: \")\n tracker.log_images(train_images, captions=train_captions, image_section=\"Train Samples\", step=step())\n if epoch_samples is not None and sample >= epoch_samples:\n break\n next_task = 'val'\n sample = 0\n all_average_val_losses = None\n if next_task == 'val':\n trainer.eval()\n accelerator.print(print_ribbon(f\"Starting Validation {epoch}\", repeat=40))\n last_val_sample = val_sample\n val_sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)\n average_val_loss_tensor = torch.zeros(1, trainer.num_unets, dtype=torch.float, device=inference_device)\n timer = Timer()\n accelerator.wait_for_everyone()\n i = 0"
+ },
+ {
+ "comment": "This code is part of the DALLE2-pytorch training process. It iterates over the validation dataloader, gathers sample lengths, calculates total samples, and checks for image and text embeddings. If available, it sends these embeddings along with images to the device for further processing. This code ensures that all necessary data is properly prepared and sent to the device for evaluation.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":448-466",
+ "content": " for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader\n val_sample_length_tensor[0] = len(img)\n all_samples = accelerator.gather(val_sample_length_tensor)\n total_samples = all_samples.sum().item()\n val_sample += total_samples\n img_emb = emb.get('img')\n has_img_embedding = img_emb is not None\n if has_img_embedding:\n img_emb, = send_to_device((img_emb,))\n text_emb = emb.get('text')\n has_text_embedding = text_emb is not None\n if has_text_embedding:\n text_emb, = send_to_device((text_emb,))\n img, = send_to_device((img,))\n for unet in range(1, len(decoder.unets)+1):\n if not unet_training_mask[unet-1]: # Unet index is the unet number - 1\n # No need to evaluate an unchanging unet\n continue"
+ },
+ {
+ "comment": "This code segment checks if image and text embeddings are provided. If not, it automatically generates image embedding or passes the text instead based on the condition. It also asserts the number of texts should be equal to the number of images for consistency.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":468-483",
+ "content": " forward_params = {}\n if has_img_embedding:\n forward_params['image_embed'] = img_emb.float()\n else:\n # Forward pass automatically generates embedding\n assert clip is not None\n img_embed, img_encoding = clip.embed_image(img)\n forward_params['image_embed'] = img_embed\n if condition_on_text_encodings:\n if has_text_embedding:\n forward_params['text_encodings'] = text_emb.float()\n else:\n # Then we need to pass the text instead\n assert clip is not None\n tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)\n assert tokenized_texts.shape[0] == len(img), f\"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})\""
+ },
+ {
+ "comment": "This code snippet is part of a larger model training process. It calculates the loss based on input images and text, updates the average validation loss, prints validation progress including samples per second and loss, and eventually breaks the loop when the specified number of validation samples have been processed. The code uses the PyTorch framework and the DALLE2 library for embedding text.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":484-499",
+ "content": " text_embed, text_encodings = clip.embed_text(tokenized_texts)\n forward_params['text_encodings'] = text_encodings\n loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)\n average_val_loss_tensor[0, unet-1] += loss\n if i % VALID_CALC_LOSS_EVERY_ITERS == 0:\n samples_per_sec = (val_sample - last_val_sample) / timer.elapsed()\n timer.reset()\n last_val_sample = val_sample\n accelerator.print(f\"Epoch {epoch}/{epochs} Val Step {i} - Sample {val_sample} - {samples_per_sec:.2f} samples/sec\")\n accelerator.print(f\"Loss: {(average_val_loss_tensor / (i+1))}\")\n accelerator.print(\"\")\n if validation_samples is not None and val_sample >= validation_samples:\n break\n print(f\"Rank {accelerator.state.process_index} finished validation after {i} steps\")"
+ },
+ {
+ "comment": "This code is used for averaging the validation losses and logging them during training. It also starts the evaluation process if it's time to do so, printing a message to indicate this. The average_val_loss_tensor is gathered by the accelerator, and then the mean of all the average loss tensors is calculated if the current task is 'eval'. If there are no zeros in the unet_average_val_loss, the validation losses are logged.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":500-514",
+ "content": " accelerator.wait_for_everyone()\n average_val_loss_tensor /= i+1\n # Gather all the average loss tensors\n all_average_val_losses = accelerator.gather(average_val_loss_tensor)\n if is_master:\n unet_average_val_loss = all_average_val_losses.mean(dim=0)\n val_loss_map = { f\"Unet {index} Validation Loss\": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }\n tracker.log(val_loss_map, step=step())\n next_task = 'eval'\n if next_task == 'eval':\n if exists(evaluate_config):\n accelerator.print(print_ribbon(f\"Starting Evaluation {epoch}\", repeat=40))\n evaluation = evaluate_trainer(trainer, dataloaders[\"val\"], inference_device, first_trainable_unet, last_trainable_unet, clip=clip, inference_device=inference_device, **evaluate_config.model_dump(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)\n if is_master:"
+ },
+ {
+ "comment": "The code is generating sample images and saving the model if it is the master process. It prints a ribbon and then generates grid samples from both train and test example data, conditioning on text encodings. Finally, it logs the generated images using the tracker, with labels indicating whether they are test or train samples.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":515-527",
+ "content": " tracker.log(evaluation, step=step())\n next_task = 'sample'\n val_sample = 0\n if next_task == 'sample':\n if is_master:\n # Generate examples and save the model if we are the master\n # Generate sample images\n print(print_ribbon(f\"Sampling Set {epoch}\", repeat=40))\n test_images, test_captions = generate_grid_samples(trainer, test_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, \"Test: \")\n train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, \"Train: \")\n tracker.log_images(test_images, captions=test_captions, image_section=\"Test Samples\", step=step())\n tracker.log_images(train_images, captions=train_captions, image_section=\"Train Samples\", step=step())"
+ },
+ {
+ "comment": "The code checks if the average validation loss is lower than previous min, and saves the trainer if it's a new minimum. It's part of a function called create_tracker that creates a tracker object with accelerator, config, and dummy parameters.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":529-545",
+ "content": " print(print_ribbon(f\"Starting Saving {epoch}\", repeat=40))\n is_best = False\n if all_average_val_losses is not None:\n average_loss = all_average_val_losses.mean(dim=0).sum() / sum(unet_training_mask)\n if len(validation_losses) == 0 or average_loss < min(validation_losses):\n is_best = True\n validation_losses.append(average_loss)\n save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best)\n next_task = 'train'\ndef create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker:\n tracker_config = config.tracker\n accelerator_config = {\n \"Distributed\": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,\n \"DistributedType\": accelerator.distributed_type,\n \"NumProcesses\": accelerator.num_processes,\n \"MixedPrecision\": accelerator.mixed_precision"
+ },
+ {
+ "comment": "This code initializes distributed training for DALLE2, sets manual seed, and creates an accelerator for parallel processing with optional arguments. The function returns a tracker object to save configuration.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":546-562",
+ "content": " }\n accelerator.wait_for_everyone() # If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors\n tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)\n tracker.save_config(config_path, config_name='decoder_config.json')\n tracker.add_save_metadata(state_dict_key='config', metadata=config.model_dump())\n return tracker\ndef initialize_training(config: TrainDecoderConfig, config_path):\n # Make sure if we are not loading, distributed models are initialized to the same values\n torch.manual_seed(config.seed)\n # Set up accelerator for configurable distributed training\n ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)\n init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))\n accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])\n if accelerator.num_processes > 1:"
+ },
+ {
+ "comment": "This code snippet is part of a distributed training process where it checks the accelerator settings, data sharding, and creates dataloaders for training. It ensures all processes are connected, handles DeepSpeed mixed precision mode without learned variance, splits data shards evenly across processes, and finally creates the necessary dataloaders for the training process.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":563-580",
+ "content": " # We are using distributed training and want to immediately ensure all can connect\n accelerator.print(\"Waiting for all processes to connect...\")\n accelerator.wait_for_everyone()\n accelerator.print(\"All processes online and connected\")\n # If we are in deepspeed fp16 mode, we must ensure learned variance is off\n if accelerator.mixed_precision == \"fp16\" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:\n raise ValueError(\"DeepSpeed fp16 mode does not support learned variance\")\n # Set up data\n all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))\n world_size = accelerator.num_processes\n rank = accelerator.process_index\n shards_per_process = len(all_shards) // world_size\n assert shards_per_process > 0, \"Not enough shards to split evenly\"\n my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]\n dataloaders = create_dataloaders ("
+ },
+ {
+ "comment": "The code initializes the decoder model with specified parameters, removes clip if present for compatibility, and creates a tracker if the current rank is not the master. It also calculates the number of parameters in the model and prepares it for training.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":581-602",
+ "content": " available_shards=my_shards,\n img_preproc = config.data.img_preproc,\n train_prop = config.data.splits.train,\n val_prop = config.data.splits.val,\n test_prop = config.data.splits.test,\n n_sample_images=config.train.n_sample_images,\n **config.data.model_dump(),\n rank = rank,\n seed = config.seed,\n )\n # If clip is in the model, we need to remove it for compatibility with deepspeed\n clip = None\n if config.decoder.clip is not None:\n clip = config.decoder.clip.create() # Of course we keep it to use it during training, just not in the decoder as that causes issues\n config.decoder.clip = None\n # Create the decoder model and print basic info\n decoder = config.decoder.create()\n get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))\n # Create and initialize the tracker if we are the master\n tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)"
+ },
+ {
+ "comment": "This code checks if image and/or text embeddings are available, either precomputed or generated using CLIP model. It then prints a message indicating the source of embeddings used for training.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":604-626",
+ "content": " has_img_embeddings = config.data.img_embeddings_url is not None\n has_text_embeddings = config.data.text_embeddings_url is not None\n conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])\n has_clip_model = clip is not None\n data_source_string = \"\"\n if has_img_embeddings:\n data_source_string += \"precomputed image embeddings\"\n elif has_clip_model:\n data_source_string += \"clip image embeddings generation\"\n else:\n raise ValueError(\"No image embeddings source specified\")\n if conditioning_on_text:\n if has_text_embeddings:\n data_source_string += \" and precomputed text embeddings\"\n elif has_clip_model:\n data_source_string += \" and clip text encoding generation\"\n else:\n raise ValueError(\"No text embeddings source specified\")\n accelerator.print(print_ribbon(\"Loaded Config\", repeat=40))\n accelerator.print(f\"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training\")"
+ },
+ {
+ "comment": "Training of the decoder is being executed using the specified data source, with or without conditioning on text. The number of parameters in total and for training are displayed, along with similar information for each Unet. The train function is called with dataloaders, decoder, accelerator, clip, tracker, inference_device, evaluate_config, and condition_on_text_encodings as arguments. A simple click command line interface is created to load the config and start training, using a default configuration file path and allowing for an alternative path to be specified with the --config_file option.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":627-646",
+ "content": " accelerator.print(f\"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}\")\n accelerator.print(f\"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training\")\n for i, unet in enumerate(decoder.unets):\n accelerator.print(f\"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training\")\n train(dataloaders, decoder, accelerator,\n clip=clip,\n tracker=tracker,\n inference_device=accelerator.device,\n evaluate_config=config.evaluate,\n condition_on_text_encodings=conditioning_on_text,\n **config.train.model_dump(),\n )\n# Create a simple click command line interface to load the config and start the training\n@click.command()\n@click.option(\"--config_file\", default=\"./train_decoder_config.json\", help=\"Path to config file\")\ndef main(config_file):\n config_file_path = Path(config_file)\n config = TrainDecoderConfig.from_json_path(str(config_file_path))"
+ },
+ {
+ "comment": "This code snippet initializes training and then calls the main function if the script is run directly. It ensures proper execution when running the script as a standalone program.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/train_decoder.py\":647-650",
+ "content": " initialize_training(config, config_path=config_file_path)\nif __name__ == \"__main__\":\n main()"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/925c6747-a25b-403d-9bae-5573ba4ac7b0.json b/docs/doc/925c6747-a25b-403d-9bae-5573ba4ac7b0.json
new file mode 100644
index 00000000..a7703b5d
--- /dev/null
+++ b/docs/doc/925c6747-a25b-403d-9bae-5573ba4ac7b0.json
@@ -0,0 +1,50 @@
+{
+ "summary": "This code offers efficient data retrieval classes for DALL-E 2, supports text conditioning and MPI distribution. It divides embedding reader objects into training, evaluation, and test sets using PyTorch Dataloaders, without specifying batch sizes.",
+ "details": [
+ {
+ "comment": "The code defines a class called PriorEmbeddingDataset that wraps the EmbeddingReader class. It allows for simplified sample retrieval from various configurations of EmbeddingReader by enabling batch-based access to prior data, where text_conditioned and batch_size are parameters, along with start and stop indices for the range of data to be loaded.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/prior_loader.py\":0-39",
+ "content": "from math import ceil\nfrom clip import tokenize\nfrom embedding_reader import EmbeddingReader\nfrom torch import from_numpy\nfrom torch.utils.data import IterableDataset, DataLoader\nclass PriorEmbeddingDataset(IterableDataset):\n \"\"\"\n PriorEmbeddingDataset is a wrapper of EmbeddingReader.\n It enables one to simplify the logic necessary to yield samples from\n the different EmbeddingReader configurations available.\n \"\"\"\n def __init__(\n self,\n text_conditioned: bool,\n batch_size: int,\n start: int,\n stop: int,\n image_reader,\n text_reader: EmbeddingReader = None,\n ) -> None:\n super(PriorEmbeddingDataset).__init__()\n self.text_conditioned = text_conditioned\n if not self.text_conditioned:\n self.text_reader = text_reader\n self.image_reader = image_reader\n self.start = start\n self.stop = stop\n self.batch_size = batch_size\n def __len__(self):\n return self.stop - self.start\n def __iter__(self):"
+ },
+ {
+ "comment": "The code defines a PriorEmbeddingDataset class for data loading in DALLE2-pytorch. It uses an image_reader and text_reader to load data in a batch, with optional text conditioning. It includes a __next__ method for iterating through the dataset and a set_start method for adjusting the starting point within the reader.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/prior_loader.py\":40-71",
+ "content": " # D.R.Y loader args\n loader_args = dict(\n batch_size=self.batch_size,\n start=self.start,\n end=self.stop,\n show_progress=False,\n )\n # if the data requested is text conditioned, only load images\n if self.text_conditioned:\n self.loader = self.image_reader(**loader_args)\n # otherwise, include text embeddings and bypass metadata\n else:\n self.loader = zip(\n self.image_reader(**loader_args), self.text_reader(**loader_args)\n )\n # return the data loader in its formatted state\n return self\n def __next__(self):\n try:\n return self.get_sample()\n except StopIteration:\n raise StopIteration\n def __str__(self):\n return f\"\"\n def set_start(self, start):\n \"\"\"\n Adjust the starting point within the reader, useful for resuming an epoch"
+ },
+ {
+ "comment": "This code defines a class with methods to manage data loading and distribution for the DALL-E 2 model. It supports text-conditioned or unconditioned data, preprocesses input into a common format, and distributes data across multiple ranks using MPI.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/prior_loader.py\":72-111",
+ "content": " \"\"\"\n self.start = start\n def get_start(self):\n return self.start\n def get_sample(self):\n \"\"\"\n pre-proocess data from either reader into a common format\n \"\"\"\n if self.text_conditioned:\n image_embedding, caption = next(self.loader)\n image_embedding = from_numpy(image_embedding)\n tokenized_caption = tokenize(caption[\"caption\"].to_list(), truncate=True)\n return image_embedding, tokenized_caption\n else:\n (image_embedding, _), (text_embedding, _) = next(self.loader)\n image_embedding = from_numpy(image_embedding)\n text_embedding = from_numpy(text_embedding)\n return image_embedding, text_embedding\n# helper functions\ndef distribute_to_rank(start, stop, rank, world_size):\n \"\"\"\n Distribute data to each rank given the world size.\n Return:\n - New start and stop points for this rank.\n \"\"\"\n num_samples = int(stop - start)\n per_rank = int(ceil((num_samples) / float(world_size)))"
+ },
+ {
+ "comment": "The code is defining functions that calculate the start and stop points for a given rank, and another function to create an EmbeddingReader object based on URLs. It asserts that certain inputs are not None before proceeding, ensuring necessary information is provided.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/prior_loader.py\":113-148",
+ "content": " assert (\n per_rank > 0\n ), f\"Number of samples per rank must be larger than 0, (found: {per_rank})\"\n rank_start = start + rank * per_rank\n rank_stop = min(rank_start + per_rank, stop)\n new_length = rank_stop - rank_start\n assert (\n new_length > 0\n ), \"Calculated start and stop points result in a length of zero for this rank.\"\n return rank_start, rank_stop\ndef get_reader(\n text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None\n):\n \"\"\"\n Create an EmbeddingReader object from the specified URLs\n get_reader() will always expect a url to image embeddings.\n If text-conditioned, it will also expect a meta_url for the captions.\n Otherwise, it will need txt_url for the matching text embeddings.\n Returns an image_reader object if text-conditioned.\n Otherwise it returns both an image_reader and a text_reader\n \"\"\"\n assert img_url is not None, \"Must supply a image url\"\n if text_conditioned:\n assert meta_url is not None, \"Must supply meta url if text-conditioned\""
+ },
+ {
+ "comment": "This code defines a function to split an embedding reader object into training, evaluation, and optional test sets. It takes in the text conditioned flag, batch size, number of data points, and train/eval splits as input parameters. If text-conditioning is not enabled, it requires text embedding URLs as well and returns two readers.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/prior_loader.py\":150-186",
+ "content": " image_reader = EmbeddingReader(\n embeddings_folder=img_url,\n file_format=\"parquet_npy\",\n # will assume the caption column exists and is the only one requested\n meta_columns=[\"caption\"],\n metadata_folder=meta_url,\n )\n return image_reader\n # otherwise we will require text embeddings as well and return two readers\n assert (\n txt_url is not None\n ), \"Must supply text embedding url if not text-conditioning\"\n image_reader = EmbeddingReader(img_url, file_format=\"npy\")\n text_reader = EmbeddingReader(txt_url, file_format=\"npy\")\n return image_reader, text_reader\ndef make_splits(\n text_conditioned: bool,\n batch_size: int,\n num_data_points: int,\n train_split: float,\n eval_split: float,\n image_reader: EmbeddingReader,\n text_reader: EmbeddingReader = None,\n start=0,\n rank=0,\n world_size=1,\n):\n \"\"\"\n Split an embedding reader object as needed.\n NOTE: make_splits() will infer the test set size from your train and eval."
+ },
+ {
+ "comment": "This function takes various inputs like batch size, train and eval splits, readers, and starting point to create PyTorch Dataloaders for image-text pairs. It ensures the start position is within the reader's count, and if the specified data points count exceeds the available ones, it defaults to the remaining count.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/prior_loader.py\":188-209",
+ "content": " Input:\n - text_conditioned: whether to prepare text-conditioned training data\n - batch_size: the batch size for a single gpu\n - num_data_points: the total number of data points you wish to train on\n - train_split: the percentage of data you wish to train on\n - eval_split: the percentage of data you wish to validate on\n - image_reader: the image_reader you wish to split\n - text_reader: the text_reader you want to split (if !text_conditioned)\n - start: the starting point within your dataset\n - rank: the rank of your worker\n - world_size: the total world size of your distributed training run\n Returns:\n - PyTorch Dataloaders that yield tuples of (img, txt) data.\n \"\"\"\n assert start < image_reader.count, \"start position cannot exceed reader count.\"\n # verify that the num_data_points does not exceed the max points\n if num_data_points > (image_reader.count - start):\n print(\n \"Specified count is larger than what's available...defaulting to reader's count.\""
+ },
+ {
+ "comment": "Computing split points for training and evaluation data sets based on the specified splits. Distributing the data to ranks according to the world size. Wrapping up the splits into a dictionary with start, stop, and batch_size parameters.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/prior_loader.py\":210-241",
+ "content": " )\n num_data_points = image_reader.count\n # compute split points\n train_set_size = int(train_split * num_data_points)\n eval_set_size = int(eval_split * num_data_points)\n eval_start = train_set_size\n eval_stop = int(eval_start + eval_set_size)\n assert (\n train_split + eval_split\n ) < 1.0, \"Specified train and eval split is too large to infer a test split.\"\n # distribute to rank\n rank_train_start, rank_train_stop = distribute_to_rank(\n start, train_set_size, rank, world_size\n )\n rank_eval_start, rank_eval_stop = distribute_to_rank(\n train_set_size, eval_stop, rank, world_size\n )\n rank_test_start, rank_test_stop = distribute_to_rank(\n eval_stop, num_data_points, rank, world_size\n )\n # wrap up splits into a dict\n train_split_args = dict(\n start=rank_train_start, stop=rank_train_stop, batch_size=batch_size\n )\n eval_split_args = dict(\n start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size\n )\n test_split_args = dict("
+ },
+ {
+ "comment": "Code is creating a PriorEmbeddingDataset for train, validation, and test datasets based on given arguments. If text_conditioned, it creates separate dictionaries for each dataset and passes them to the PriorEmbeddingDataset class; otherwise, it adds additional non-conditioned arguments for the same process.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/prior_loader.py\":242-270",
+ "content": " start=rank_test_start, stop=rank_test_stop, batch_size=batch_size\n )\n if text_conditioned:\n # add the text-conditioned args to a unified dict\n reader_args = dict(\n text_conditioned=text_conditioned,\n image_reader=image_reader,\n )\n train_split_args = dict(**reader_args, **train_split_args)\n eval_split_args = dict(**reader_args, **eval_split_args)\n test_split_args = dict(**reader_args, **test_split_args)\n train = PriorEmbeddingDataset(**train_split_args)\n val = PriorEmbeddingDataset(**eval_split_args)\n test = PriorEmbeddingDataset(**test_split_args)\n else:\n # add the non-conditioned args to a unified dict\n reader_args = dict(\n text_conditioned=text_conditioned,\n image_reader=image_reader,\n text_reader=text_reader,\n )\n train_split_args = dict(**reader_args, **train_split_args)\n eval_split_args = dict(**reader_args, **eval_split_args)\n test_split_args = dict(**reader_args, **test_split_args)"
+ },
+ {
+ "comment": "This code creates train, val, and test datasets using PriorEmbeddingDataset with specific args. DataLoaders are created without specifying batch sizes, so the true batch size is determined in PriorEmbeddingDataset. The loaders and datasets are returned for further processing.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/prior_loader.py\":272-281",
+ "content": " train = PriorEmbeddingDataset(**train_split_args)\n val = PriorEmbeddingDataset(**eval_split_args)\n test = PriorEmbeddingDataset(**test_split_args)\n # true batch size is specifed in the PriorEmbeddingDataset\n train_loader = DataLoader(train, batch_size=None)\n eval_loader = DataLoader(val, batch_size=None)\n test_loader = DataLoader(test, batch_size=None)\n return train_loader, eval_loader, test_loader"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/9b4f39a6-6758-4baa-abe2-20e5b8f4dbe6.json b/docs/doc/9b4f39a6-6758-4baa-abe2-20e5b8f4dbe6.json
new file mode 100644
index 00000000..1d21fd6e
--- /dev/null
+++ b/docs/doc/9b4f39a6-6758-4baa-abe2-20e5b8f4dbe6.json
@@ -0,0 +1,220 @@
+{
+ "summary": "The code enhances DALL-E 2 with additional layers, unconditional generation, and prior models. It provides installation guidance for saving generated images and inpainting through Latent Diffusion. The accompanying code snippet offers BibTeX entries for four research articles published between 2021 and 2022.",
+ "details": [
+ {
+ "comment": "This code is for an implementation of DALL-E 2, OpenAI's text-to-image synthesis neural network, in Pytorch. It includes a link to a Yannic Kilcher summary and AssemblyAI explainer. The main novelty is an extra layer using the prior network, predicting an image embedding from text embedding. This repository builds out the diffusion prior network, using a causal transformer as the denoising network. It currently holds the SOTA for text-to-image synthesis.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":0-12",
+ "content": "\n## DALL-E 2 - Pytorch\nImplementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch.\nYannic Kilcher summary | AssemblyAI explainer\nThe main novelty seems to be an extra layer of indirection with the prior network (whether it is an autoregressive transformer or a diffusion network), which predicts an image embedding based on the text embedding from CLIP. Specifically, this repository will only build out the diffusion prior network, as it is the best performing variant (but which incidentally involves a causal transformer as the denoising network \ud83d\ude02)\nThis model is SOTA for text-to-image for now.\nPlease join if you are interested in helping out with the replication with the LAION community | Yannic Interview\nAs of 5/23/22, it is no longer SOTA. SOTA will be here. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.\n## Status\n- A research group has used the code in this repository to train a functional diffusion prior for their CLIP generations. Will share their work once they release their preprint. This, and Katherine's own experiments, validate OpenAI's finding that the extra prior increases variety of generations.\n- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.\n\n*ongoing at 21k steps*"
+ },
+ {
+ "comment": "Justin Pinkney successfully trained the diffusion prior for his CLIP to Stylegan2 text-to-image application. Romain scaled up training to 800 GPUs with existing scripts without any issues. LAION is training prior models, available on HuggingFace and WANDB. Decoder testing runs are ongoing. DALL-E 2 repository by LAION is under development.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":26-35",
+ "content": "- Justin Pinkney successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application\n- Romain has scaled up training to 800 GPUs with the available scripts without any issues\n## Pre-Trained Models\n- LAION is training prior models. Checkpoints are available on \ud83e\udd17huggingface and the training statistics are available on \ud83d\udc1dWANDB.\n- Decoder - In-progress test run \ud83d\udea7\n- Decoder - Another test run with sparse attention\n- DALL-E 2 \ud83d\udea7 - DALL-E 2 Laion repository"
+ },
+ {
+ "comment": "This code block expresses gratitude to the contributors who assisted in developing and improving this library, acknowledging their efforts for distributed training code, bug fixes, Q&A support, and project management.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":37-47",
+ "content": "## Appreciation\nThis library would not have gotten to this working state without the help of\n- Zion for the distributed training code for the diffusion prior\n- Aidan for the distributed training code for the decoder as well as the dataloaders\n- Kumar for working on the initial diffusion training script\n- Romain for the pull request reviews and project management\n- He Cao and xiankgx for the Q&A and for identifying of critical bugs\n- Marunine for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes\n- MalumaDev for proposing the use of pixel shuffle upsampler for fixing checkboard artifacts"
+ },
+ {
+ "comment": "Acknowledgments to Katherine, Stability AI, HuggingFace (Sylvain), and Alex for their contributions; installation instructions with pip command; usage notes mentioning CLIP training, x-clip package, and LAION discord; repository integration with `x-clip` mentioned.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":48-69",
+ "content": "- Katherine for her advice\n- Stability AI for the generous sponsorship\n- \ud83e\udd17 Huggingface and in particular Sylvain for the Accelerate library\n- Alex for einops, indispensable tool for tensor manipulation\n... and many others. Thank you! \ud83d\ude4f\n## Install\n```bash\n$ pip install dalle2-pytorch\n```\n## Usage\nTo train DALLE-2 is a 3 step process, with the training of CLIP being the most important\nTo train CLIP, you can either use x-clip package, or join the LAION discord, where a lot of replication efforts are already underway.\nThis repository will demonstrate integration with `x-clip` for starters\n```python"
+ },
+ {
+ "comment": "The code imports necessary libraries and initializes a CLIP model with specific dimensions for text, image, and latent embeddings. It includes various settings such as token counts, encoding depths, image sizes, heads, and learning techniques (FILIP, DCL, CLOOB, DeCLIP, SLIP). It also indicates whether to use masked language learning on text (MLM) or not.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":70-90",
+ "content": "import torch\nfrom dalle2_pytorch import CLIP\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 1,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 1,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8,\n use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)\n decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)\n extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)\n use_visual_ssl = True, # whether to do self supervised learning on images\n visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP\n use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)"
+ },
+ {
+ "comment": "The code snippet initializes a CLIP model, sets text and image self-supervised loss weights, generates mock data for training, computes the contrastive loss, backpropagates gradients, and trains the decoder using a Unet architecture. The trained CLIP from step 1 is used in this step to train the decoder.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":91-135",
+ "content": " text_ssl_loss_weight = 0.05, # weight for text MLM loss\n image_ssl_loss_weight = 0.05 # weight for image self-supervised learning loss\n).cuda()\n# mock data\ntext = torch.randint(0, 49408, (4, 256)).cuda()\nimages = torch.randn(4, 3, 256, 256).cuda()\n# train\nloss = clip(\n text,\n images,\n return_loss = True # needs to be set to True to return contrastive loss\n)\nloss.backward()\n# do the above with as many texts and images as possible in a loop\n```\nThen, you will need to train the decoder, which learns to generate images based on the image embedding coming from the trained CLIP above\n```python\nimport torch\nfrom dalle2_pytorch import Unet, Decoder, CLIP\n# trained clip from step 1\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 1,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 1,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8\n).cuda()\n# unet for the decoder"
+ },
+ {
+ "comment": "In this code, a U-Net model is created using the provided configuration and then placed on the GPU. A decoder is also created, containing the U-Net and CLIP models, with specific parameters for timesteps, image, and text drop probabilities. The decoder generates images based on CLIP image embeddings after going through many steps of training. Finally, a trained CLIP model from step one is imported for use in the diffusion prior network.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":137-180",
+ "content": "unet = Unet(\n dim = 128,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults=(1, 2, 4, 8)\n).cuda()\n# decoder, which contains the unet and clip\ndecoder = Decoder(\n unet = unet,\n clip = clip,\n timesteps = 100,\n image_cond_drop_prob = 0.1,\n text_cond_drop_prob = 0.5\n).cuda()\n# mock images (get a lot of this)\nimages = torch.randn(4, 3, 256, 256).cuda()\n# feed images into decoder\nloss = decoder(images)\nloss.backward()\n# do the above for many many many many steps\n# then it will learn to generate images based on the CLIP image embeddings\n```\nFinally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP from the first step\n```python\nimport torch\nfrom dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP\n# get trained CLIP from step one\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,"
+ },
+ {
+ "comment": "The code sets up a diffusion prior network for generating image embeddings from text embeddings using PyTorch. The network is composed of an autoregressive transformer, CLIP model, and other layers. It also includes a prior_network, random data, and losses are calculated by feeding the text and images into the diffusion prior network before backpropagation. This process is repeated many times to train the network.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":181-222",
+ "content": " text_enc_depth = 6,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 6,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8,\n).cuda()\n# setup prior network, which contains an autoregressive transformer\nprior_network = DiffusionPriorNetwork(\n dim = 512,\n depth = 6,\n dim_head = 64,\n heads = 8\n).cuda()\n# diffusion prior network, which contains the CLIP and network (with transformer) above\ndiffusion_prior = DiffusionPrior(\n net = prior_network,\n clip = clip,\n timesteps = 100,\n cond_drop_prob = 0.2\n).cuda()\n# mock data\ntext = torch.randint(0, 49408, (4, 256)).cuda()\nimages = torch.randn(4, 3, 256, 256).cuda()\n# feed text and images into diffusion prior network\nloss = diffusion_prior(text, images)\nloss.backward()\n# do the above for many many many steps\n# now the diffusion prior can generate image embeddings from the text embeddings\n```\nIn the paper, they actually used a recently discovered technique,"
+ },
+ {
+ "comment": "This code imports necessary modules and initializes a CLIP model, two UNETs for the decoder, and a decoder itself. The CLIP model is trained from a previous step, while the two UNETs are initialized with different dimensions for cascading DDPMs. The decoder contains the CLIP model and both UNETs.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":222-268",
+ "content": " from Jonathan Ho himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.\nThis can easily be used within this framework as so\n```python\nimport torch\nfrom dalle2_pytorch import Unet, Decoder, CLIP\n# trained clip from step 1\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 6,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 6,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8\n).cuda()\n# 2 unets for the decoder (a la cascading DDPM)\nunet1 = Unet(\n dim = 32,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults = (1, 2, 4, 8)\n).cuda()\nunet2 = Unet(\n dim = 32,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults = (1, 2, 4, 8, 16)\n).cuda()\n# decoder, which contains the unet(s) and clip\ndecoder = Decoder(\n clip = clip,\n unet = (unet1, unet2), "
+ },
+ {
+ "comment": "The code inserts two U-Nets into a decoder model in ascending order of resolution. The images are generated using a specified number of U-Nets, and the loss is calculated for each U-Net separately. Finally, a trained `DiffusionPrior` and `Decoder` (wrapping `CLIP`, a causal transformer, and unet(s)) are inserted to generate DALL-E2 images from text.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":268-297",
+ "content": " # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)\n image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)\n timesteps = 1000,\n image_cond_drop_prob = 0.1,\n text_cond_drop_prob = 0.5\n).cuda()\n# mock images (get a lot of this)\nimages = torch.randn(4, 3, 512, 512).cuda()\n# feed images into decoder, specifying which unet you want to train\n# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme\nloss = decoder(images, unet_number = 1)\nloss.backward()\nloss = decoder(images, unet_number = 2)\nloss.backward()\n# do the above for many steps for both unets\n```\nFinally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))\n```python\nfrom dalle2_pytorch import DALLE2\ndalle2 = DALLE2(\n prior = diffusion_prior,"
+ },
+ {
+ "comment": "This code is importing necessary modules and creating an instance of DALLE2 model. It then generates images from input text using the model, and performs training on the model by calculating loss and performing backpropagation for multiple steps. The code also creates a prior network with specified dimensions and depth.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":298-352",
+ "content": " decoder = decoder\n)\n# send the text as a string if you want to use the simple tokenizer from DALLE v1\n# or you can do it as token ids, if you have your own tokenizer\ntexts = ['glistening morning dew on a flower petal']\nimages = dalle2(texts) # (1, 3, 256, 256)\n```\nThat's it!\nLet's see the whole script below\n```python\nimport torch\nfrom dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 6,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 6,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8\n).cuda()\n# mock data\ntext = torch.randint(0, 49408, (4, 256)).cuda()\nimages = torch.randn(4, 3, 256, 256).cuda()\n# train\nloss = clip(\n text,\n images,\n return_loss = True\n)\nloss.backward()\n# do above for many steps ...\n# prior networks (with transformer)\nprior_network = DiffusionPriorNetwork(\n dim = 512,\n depth = 6,\n dim_head = 64,"
+ },
+ {
+ "comment": "The code sets up a DALLE-like model using PyTorch, with two Unets for the decoder and trains it by iteratively calculating the loss. The model uses diffusion prior and is conditioned on both text and image encodings. It has different dimensions and timesteps for each Unet. The code is used to train an AI image generation model.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":353-399",
+ "content": " heads = 8\n).cuda()\ndiffusion_prior = DiffusionPrior(\n net = prior_network,\n clip = clip,\n timesteps = 1000,\n sample_timesteps = 64,\n cond_drop_prob = 0.2\n).cuda()\nloss = diffusion_prior(text, images)\nloss.backward()\n# do above for many steps ...\n# decoder (with unet)\nunet1 = Unet(\n dim = 128,\n image_embed_dim = 512,\n text_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults=(1, 2, 4, 8),\n cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings\n).cuda()\nunet2 = Unet(\n dim = 16,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults = (1, 2, 4, 8, 16)\n).cuda()\ndecoder = Decoder(\n unet = (unet1, unet2),\n image_sizes = (128, 256),\n clip = clip,\n timesteps = 100,\n image_cond_drop_prob = 0.1,\n text_cond_drop_prob = 0.5\n).cuda()\nfor unet_number in (1, 2):\n loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much"
+ },
+ {
+ "comment": "The code is initializing a DALLE2 model with specified prior and decoder, generating images from input text using classifier-free guidance (with conditional scale 2), and then saving the generated image of size 256x256. The code also mentions that training will be automated into a CLI tool for small-scale training and that preprocessing images and text into embeddings might be required for scaling up.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":400-430",
+ "content": " loss.backward()\n# do above for many steps\ndalle2 = DALLE2(\n prior = diffusion_prior,\n decoder = decoder\n)\nimages = dalle2(\n ['cute puppy chasing after a squirrel'],\n cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)\n)\n# save your image (in this example, of size 256x256)\n```\nEverything in this readme should run without error\nYou can also train the decoder on images of greater than the size (say 512x512) at which CLIP was trained (256x256). The images will be resized to CLIP image resolution for the image embeddings\nFor the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.\n## Training on Preprocessed CLIP Embeddings\nIt is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings`\nWorking example below\n```python\nimport torch"
+ },
+ {
+ "comment": "This code is importing modules, initializing a trained CLIP model and setting up a diffusion prior network containing an autoregressive transformer. The diffusion prior contains both the CLIP and network. Mock data is then created for testing purposes.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":431-473",
+ "content": "from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP\n# get trained CLIP from step one\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 6,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 6,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8,\n).cuda()\n# setup prior network, which contains an autoregressive transformer\nprior_network = DiffusionPriorNetwork(\n dim = 512,\n depth = 6,\n dim_head = 64,\n heads = 8\n).cuda()\n# diffusion prior network, which contains the CLIP and network (with transformer) above\ndiffusion_prior = DiffusionPrior(\n net = prior_network,\n clip = clip,\n timesteps = 100,\n cond_drop_prob = 0.2,\n condition_on_text_encodings = False # this probably should be true, but just to get Laion started\n).cuda()\n# mock data\ntext = torch.randint(0, 49408, (4, 256)).cuda()\nimages = torch.randn(4, 3, 256, 256).cuda()\n# precompute the text and image embeddings"
+ },
+ {
+ "comment": "The code initializes a diffusion prior network with an autoregressive transformer and uses CLIP for image and text embeddings. It then calculates the loss by feeding the embeddings into the diffusion prior network, backpropagates the gradients, and repeats this process multiple times. Alternatively, CLIP can be excluded from the model initialization by passing the `image_embed_dim` directly to the `DiffusionPrior` class.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":474-509",
+ "content": "# here using the diffusion prior class, but could be done with CLIP alone\nclip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed\nclip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed\n# feed text and images into diffusion prior network\nloss = diffusion_prior(\n text_embed = clip_text_embeds,\n image_embed = clip_image_embeds\n)\nloss.backward()\n# do the above for many many many steps\n# now the diffusion prior can generate image embeddings from the text embeddings\n```\nYou can also completely go `CLIP`-less, in which case you will need to pass in the `image_embed_dim` into the `DiffusionPrior` on initialization\n```python\nimport torch\nfrom dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior\n# setup prior network, which contains an autoregressive transformer\nprior_network = DiffusionPriorNetwork(\n dim = 512,\n depth = 6,\n dim_head = 64,\n heads = 8\n).cuda()\n# diffusion prior network, which contains the CLIP and network (with transformer) above\ndiffusion_prior = DiffusionPrior("
+ },
+ {
+ "comment": "The code snippet is creating a diffusion model using the provided parameters and utilizing the OpenAI CLIP for image and text embeddings. The text and image embeddings are precomputed, then fed into the diffusion prior network to calculate loss and perform backpropagation. This process is repeated many times to train the model for generating image embeddings from text embeddings.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":510-543",
+ "content": " net = prior_network,\n image_embed_dim = 512, # this needs to be set\n timesteps = 100,\n cond_drop_prob = 0.2,\n condition_on_text_encodings = False # this probably should be true, but just to get Laion started\n).cuda()\n# mock data\ntext = torch.randint(0, 49408, (4, 256)).cuda()\nimages = torch.randn(4, 3, 256, 256).cuda()\n# precompute the text and image embeddings\n# here using the diffusion prior class, but could be done with CLIP alone\nclip_image_embeds = torch.randn(4, 512).cuda()\nclip_text_embeds = torch.randn(4, 512).cuda()\n# feed text and images into diffusion prior network\nloss = diffusion_prior(\n text_embed = clip_text_embeds,\n image_embed = clip_image_embeds\n)\nloss.backward()\n# do the above for many many many steps\n# now the diffusion prior can generate image embeddings from the text embeddings\n```\n## OpenAI CLIP\nAlthough there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your o"
+ },
+ {
+ "comment": "This code snippet demonstrates how to use OpenAI's CLIP model, pre-trained, within the DALLE2 PyTorch framework. It defines a function `OpenAIClipAdapter` that allows easy integration of pre-trained CLIP with DALLE2's prior and decoder networks. The code provides an example of how to use these networks for training purposes by defining a diffusion prior and unet decoder, and applying them to some mock data.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":543-588",
+ "content": "wn CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.\nTo use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so\n```python\nimport torch\nfrom dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter\n# openai pretrained clip - defaults to ViT-B/32\nclip = OpenAIClipAdapter()\n# mock data\ntext = torch.randint(0, 49408, (4, 256)).cuda()\nimages = torch.randn(4, 3, 256, 256).cuda()\n# prior networks (with transformer)\nprior_network = DiffusionPriorNetwork(\n dim = 512,\n depth = 6,\n dim_head = 64,\n heads = 8\n).cuda()\ndiffusion_prior = DiffusionPrior(\n net = prior_network,\n clip = clip,\n timesteps = 100,\n cond_drop_prob = 0.2\n).cuda()\nloss = diffusion_prior(text, images)\nloss.backward()\n# do above for many steps ...\n# decoder (with unet)\nunet1 = Unet(\n dim = 128,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults=(1, 2, 4, 8),"
+ },
+ {
+ "comment": "The code initializes a DALLE2 model and trains it by feeding images and text. It creates Unet layers, a Decoder, and a DALLE2 instance using given dimensions and parameters. The training loop iterates over unet_number, calculates loss, and applies gradient descent to optimize the model. Finally, the DALLE2 model generates images based on input text with conditional scaling.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":589-624",
+ "content": " text_embed_dim = 512,\n cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)\n).cuda()\nunet2 = Unet(\n dim = 16,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults = (1, 2, 4, 8, 16)\n).cuda()\ndecoder = Decoder(\n unet = (unet1, unet2),\n image_sizes = (128, 256),\n clip = clip,\n timesteps = 1000,\n sample_timesteps = (250, 27),\n image_cond_drop_prob = 0.1,\n text_cond_drop_prob = 0.5\n).cuda()\nfor unet_number in (1, 2):\n loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much\n loss.backward()\n# do above for many steps\ndalle2 = DALLE2(\n prior = diffusion_prior,\n decoder = decoder\n)\nimages = dalle2(\n ['a butterfly trying to escape a tornado'],\n cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)"
+ },
+ {
+ "comment": "The code provides instructions on how to save an image and use Open Clip for image processing. It mentions installing the open-clip-torch package, using a state-of-the-art (SOTA) Open Clip model, initializing the OpenClipAdapter with the desired model, and utilizing the Decoder's built-in inpainting feature, following the formulation presented in Repaint. The code also showcases how to import necessary modules and initialize a CLIP object with specified dimensions.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":625-662",
+ "content": ")\n# save your image (in this example, of size 256x256)\n```\nAlternatively, you can also use Open Clip\n```bash\n$ pip install open-clip-torch\n```\nEx. using the SOTA Open Clip model trained by Romain\n```python\nfrom dalle2_pytorch import OpenClipAdapter\nclip = OpenClipAdapter('ViT-H/14')\n```\nNow you'll just have to worry about training the Prior and the Decoder!\n## Inpainting\nInpainting is also built into the `Decoder`. You simply have to pass in the `inpaint_image` and `inpaint_mask` (boolean tensor where `True` indicates which regions of the inpaint image to keep)\nThis repository uses the formulation put forth by Lugmayr et al. in Repaint\n```python\nimport torch\nfrom dalle2_pytorch import Unet, Decoder, CLIP\n# trained clip from step 1\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,"
+ },
+ {
+ "comment": "This code initializes a DALL-E 2 model with specified dimensions for text and visual encoders, along with two UNet models for the decoder. The decoder is then instantiated using these components and a set of image sizes, timesteps, and conditional drop probabilities. Finally, mock images are created for training purposes.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":663-699",
+ "content": " text_enc_depth = 6,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 6,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8\n).cuda()\n# 2 unets for the decoder (a la cascading DDPM)\nunet = Unet(\n dim = 16,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults = (1, 1, 1, 1)\n).cuda()\n# decoder, which contains the unet(s) and clip\ndecoder = Decoder(\n clip = clip,\n unet = (unet,), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)\n image_sizes = (256,), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)\n timesteps = 1000,\n image_cond_drop_prob = 0.1,\n text_cond_drop_prob = 0.5\n).cuda()\n# mock images (get a lot of this)\nimages = torch.randn(4, 3, 256, 256).cuda()\n# feed images into decoder, specifying which unet you want to train\n# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme"
+ },
+ {
+ "comment": "This code initializes a decoder and performs inpainting using DALL-E2 with Latent Diffusion. It generates a mock image embedding, sets the input image and mask for inpainting, then samples the inpainted images from the decoder.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":701-730",
+ "content": "loss = decoder(images, unet_number = 1)\nloss.backward()\n# do the above for many steps for both unets\nmock_image_embed = torch.randn(1, 512).cuda()\n# then to do inpainting\ninpaint_image = torch.randn(1, 3, 256, 256).cuda() # (batch, channels, height, width)\ninpaint_mask = torch.ones(1, 256, 256).bool().cuda() # (batch, height, width)\ninpainted_images = decoder.sample(\n image_embed = mock_image_embed,\n inpaint_image = inpaint_image, # just pass in the inpaint image\n inpaint_mask = inpaint_mask # and the mask\n)\ninpainted_images.shape # (1, 3, 256, 256)\n```\n## Experimental\n### DALL-E2 with Latent Diffusion\nThis repository decides to take the next step and offer DALL-E v2 combined with latent diffusion, from Rombach et al.\nYou can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.\nThe repository also comes equipped with all the necessary settin"
+ },
+ {
+ "comment": "The code is importing necessary modules for training a VQGAN-VAE model. It initializes a CLIP model and three Unet models for the decoder, as well as a VQGanVAE model. The CLIP model is pre-trained, while the VQGanVAE needs to be trained beforehand. This code seems to aim at improving the performance of an autoencoder using residual or multi-headed quantization techniques.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":730-761",
+ "content": "gs to recreate `ViT-VQGan` from the Improved VQGans paper. Furthermore, the vector quantization library also comes equipped to do residual or multi-headed quantization, which I believe will give an even further boost in performance to the autoencoder.\n```python\nimport torch\nfrom dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE\n# trained clip from step 1\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 1,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 1,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8\n)\n# 3 unets for the decoder (a la cascading DDPM)\n# first two unets are doing latent diffusion\n# vqgan-vae must be trained beforehand\nvae1 = VQGanVAE(\n dim = 32,\n image_size = 256,\n layers = 3,\n layer_mults = (1, 2, 4)"
+ },
+ {
+ "comment": "This code sets up a DALLE2 model by creating and configuring various components: VQGanVAE (vae1), Unet models (unet1, unet2, unet3) and a Decoder. The decoder combines the clip and VAEs with corresponding Unets at different resolutions. The image sizes specify the resolutions for each Unet stage, starting from 256 for the first one up to 1024 for the third one.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":762-806",
+ "content": ")\nvae2 = VQGanVAE(\n dim = 32,\n image_size = 512,\n layers = 3,\n layer_mults = (1, 2, 4)\n)\nunet1 = Unet(\n dim = 32,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n sparse_attn = True,\n sparse_attn_window = 2,\n dim_mults = (1, 2, 4, 8)\n)\nunet2 = Unet(\n dim = 32,\n image_embed_dim = 512,\n channels = 3,\n dim_mults = (1, 2, 4, 8, 16),\n cond_on_image_embeds = True,\n cond_on_text_encodings = False\n)\nunet3 = Unet(\n dim = 32,\n image_embed_dim = 512,\n channels = 3,\n dim_mults = (1, 2, 4, 8, 16),\n cond_on_image_embeds = True,\n cond_on_text_encodings = False,\n attend_at_middle = False\n)\n# decoder, which contains the unet(s) and clip\ndecoder = Decoder(\n clip = clip,\n vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3\n unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)\n image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third"
+ },
+ {
+ "comment": "The code demonstrates how to train a `Decoder` with multiple unets using a cascading DDPM scheme. First, it initializes the model's parameters and assigns the required GPU. Then, it creates random images and specifies which unet to train in each iteration by calling the `one_unet_in_gpu()` method. The code trains multiple steps for each unet before moving on to the next one. Finally, a mock image is generated from an embedding using the trained decoder.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":807-845",
+ "content": " timesteps = 100,\n image_cond_drop_prob = 0.1,\n text_cond_drop_prob = 0.5\n).cuda()\n# mock images (get a lot of this)\nimages = torch.randn(1, 3, 1024, 1024).cuda()\n# feed images into decoder, specifying which unet you want to train\n# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme\nwith decoder.one_unet_in_gpu(1):\n loss = decoder(images, unet_number = 1)\n loss.backward()\nwith decoder.one_unet_in_gpu(2):\n loss = decoder(images, unet_number = 2)\n loss.backward()\nwith decoder.one_unet_in_gpu(3):\n loss = decoder(images, unet_number = 3)\n loss.backward()\n# do the above for many steps for both unets\n# then it will learn to generate images based on the CLIP image embeddings\n# chaining the unets from lowest resolution to highest resolution (thus cascading)\nmock_image_embed = torch.randn(1, 512).cuda()\nimages = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)\n```\n## Training wrapper\n### Decoder Training\nTraining the `Decoder` may be confusing,"
+ },
+ {
+ "comment": "The code defines a `CLIP` object and creates two `Unet` instances with different architectures. The `CLIP` model is used for text-to-image generation, while the `Unet` models are variational autoencoders that will be trained to generate images based on input text. The `unet1` has a smaller architecture compared to `unet2`, and both use the same embeddings. The code also provides mock data for testing the functionality of the decoder and trainers.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":845-887",
+ "content": " as one needs to keep track of an optimizer for each of the `Unet`(s) separately. Each `Unet` will also need its own corresponding exponential moving average. The `DecoderTrainer` hopes to make this simple, as shown below\n```python\nimport torch\nfrom dalle2_pytorch import DALLE2, Unet, Decoder, CLIP, DecoderTrainer\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 6,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 6,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8\n).cuda()\n# mock data\ntext = torch.randint(0, 49408, (32, 256)).cuda()\nimages = torch.randn(32, 3, 256, 256).cuda()\n# decoder (with unet)\nunet1 = Unet(\n dim = 128,\n image_embed_dim = 512,\n text_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults=(1, 2, 4, 8),\n cond_on_text_encodings = True,\n).cuda()\nunet2 = Unet(\n dim = 16,\n image_embed_dim = 512,\n cond_dim = 128,\n channels = 3,\n dim_mults = (1, 2, 4, 8, 16),"
+ },
+ {
+ "comment": "This code sets up a decoder, trainer, and trains the unets to generate images based on text input. The trainer updates the unets and their exponential moving averages after each iteration. Finally, it samples from the moving-averaged unets to create new images.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":888-925",
+ "content": ").cuda()\ndecoder = Decoder(\n unet = (unet1, unet2),\n image_sizes = (128, 256),\n clip = clip,\n timesteps = 1000\n).cuda()\ndecoder_trainer = DecoderTrainer(\n decoder,\n lr = 3e-4,\n wd = 1e-2,\n ema_beta = 0.99,\n ema_update_after_step = 1000,\n ema_update_every = 10,\n)\nfor unet_number in (1, 2):\n loss = decoder_trainer(\n images,\n text = text,\n unet_number = unet_number, # which unet to train on\n max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times\n )\n decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average\n# after much training\n# you can sample from the exponentially moving averaged unets as so\nmock_image_embed = torch.randn(32, 512).cuda()\nimages = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)\n```\n### Diffusion Prior Training\nSimilarly, one can use the `Di"
+ },
+ {
+ "comment": "This code creates a CLIP model, initializes a diffusion prior network, and sets up a trainer for the diffusion prior. The CLIP model is used to encode text and images into latent representations, while the diffusion prior network is responsible for predicting the future of latent samples. The trainer will automatically update the moving average of the prior network over time.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":925-970",
+ "content": "ffusionPriorTrainer` to automatically instantiate and keep track of an exponential moving averaged prior.\n```python\nimport torch\nfrom dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, CLIP\nclip = CLIP(\n dim_text = 512,\n dim_image = 512,\n dim_latent = 512,\n num_text_tokens = 49408,\n text_enc_depth = 6,\n text_seq_len = 256,\n text_heads = 8,\n visual_enc_depth = 6,\n visual_image_size = 256,\n visual_patch_size = 32,\n visual_heads = 8\n).cuda()\n# mock data\ntext = torch.randint(0, 49408, (512, 256)).cuda()\nimages = torch.randn(512, 3, 256, 256).cuda()\n# prior networks (with transformer)\nprior_network = DiffusionPriorNetwork(\n dim = 512,\n depth = 6,\n dim_head = 64,\n heads = 8\n).cuda()\ndiffusion_prior = DiffusionPrior(\n net = prior_network,\n clip = clip,\n timesteps = 100,\n cond_drop_prob = 0.2\n).cuda()\ndiffusion_prior_trainer = DiffusionPriorTrainer(\n diffusion_prior,\n lr = 3e-4,\n wd = 1e-2,\n ema_beta = 0.99,"
+ },
+ {
+ "comment": "This code initializes a diffusion prior trainer with exponential moving average (EMA) update parameters, trains the model using diffusion_prior_trainer, updates optimizer and EMA diffusion prior, and finally samples from the EMA of the diffusion prior. The code also mentions that unconditional training or cascading DDPMs can be done by setting `unconditional = True` in the Decoder.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":971-1008",
+ "content": " ema_update_after_step = 1000,\n ema_update_every = 10,\n)\nloss = diffusion_prior_trainer(text, images, max_batch_size = 4)\ndiffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior\n# after much of the above three lines in a loop\n# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior\nimage_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings\n```\n## Bonus\n### Unconditional Training\nThe repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`\nex.\n```python\nimport torch\nfrom dalle2_pytorch import Unet, Decoder, DecoderTrainer\n# unet for the cascading ddpm\nunet1 = Unet(\n dim = 128,\n dim_mults=(1, 2, 4, 8)\n).cuda()\nunet2 = Unet(\n dim = 32,\n dim_mults = (1, 2, 4, 8, 16)\n).cuda()\n# decoder, which contains the unets"
+ },
+ {
+ "comment": "The code initializes a decoder, trainer for the decoder, and generates images. It then trains the decoder by feeding images into it, updating the trainer, and repeats this process many times to enable learning. Finally, it uses the trained decoder to generate new images.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":1010-1045",
+ "content": "decoder = Decoder(\n unet = (unet1, unet2),\n image_sizes = (256, 512), # first unet up to 256px, then second to 512px\n timesteps = 1000,\n unconditional = True\n).cuda()\n# decoder trainer\ndecoder_trainer = DecoderTrainer(decoder)\n# images (get a lot of this)\nimages = torch.randn(1, 3, 512, 512).cuda()\n# feed images into decoder\nfor i in (1, 2):\n loss = decoder_trainer(images, unet_number = i)\n decoder_trainer.update(unet_number = i)\n# do the above for many many many many images\n# then it will learn to generate images\nimages = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)\n```\n## Dataloaders\n### Decoder Dataloaders\nIn order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.\n#### Decoder: Image Embedding Dataset\nWhen training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two simi"
+ },
+ {
+ "comment": "This code describes a dataset format using webdataset, containing .jpg and .npy files in .tar archives. It allows specifying an external source for embeddings with the same shard numbers and filename-to-index correspondence. The code provides steps to generate this type of dataset using img2dataset and clip-retrieval.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":1045-1049",
+ "content": "lar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.\nGenerating a dataset of this type: \n1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.\n2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings."
+ },
+ {
+ "comment": "This code snippet demonstrates the usage of the `create_image_embedding_dataloader` function from the DALLE2-pytorch library. It creates an image embedding dataloader by specifying a URL path for the webdataset tar files and optional embeddings folder, setting the number of workers and batch size, defining the shard width, and deciding whether to shuffle the shards or not. The purpose is to reorder the embeddings into the expected format for image generation tasks.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":1050-1065",
+ "content": "3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.\nUsage:\n```python\nfrom dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader\n# Create a dataloader directly.\ndataloader = create_image_embedding_dataloader(\n tar_url=\"/path/or/url/to/webdataset/{0000..9999}.tar\", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar\n embeddings_url=\"path/or/url/to/embeddings/folder\", # Included if .npy files are not in webdataset. Left out or set to None otherwise\n num_workers=4,\n batch_size=32,\n shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index\n shuffle_num=200, # Does a shuffle of the data with a buffer size of 200\n shuffle_shards=True, # Shuffle the order the shards are read in"
+ },
+ {
+ "comment": "The code snippet shows how to load an ImageEmbeddingDataset and print its shape. The dataset is loaded from a webdataset at the specified URL, with embedding files located in the given folder. It uses shard_width=4 for sharding the data and sets resample to False. The loader creates images and embeddings which are printed for verification. Additionally, it mentions creating a dataset without a loader if manual configuration is preferred.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":1066-1092",
+ "content": " resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually\n)\nfor img, emb in dataloader:\n print(img.shape) # torch.Size([32, 3, 256, 256])\n print(emb[\"img\"].shape) # torch.Size([32, 512])\n # Train decoder only as shown above\n# Or create a dataset without a loader so you can configure it manually\ndataset = ImageEmbeddingDataset(\n urls=\"/path/or/url/to/webdataset/{0000..9999}.tar\",\n embedding_folder_url=\"path/or/url/to/embeddings/folder\",\n shard_width=4,\n shuffle_shards=True,\n resample=False\n)\n```\n### Scripts\n#### `train_diffusion_prior.py`\nFor detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md)\n## Todo\n- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon\n- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)"
+ },
+ {
+ "comment": "The code outlines the steps to create a DDPM model, including conditioning it with text encodings and incorporating a cascade of unets for different resolutions. It also mentions adding efficient attention in unet, allowing customization of conditioning for specific unets, offloading unets to CPU, building latent diffusion architecture, and providing the option for vq-reg variant (vqgan-vae). The decoder objective can be customized between predicting epsilon or x0.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":1093-1101",
+ "content": "- [x] make sure it works end to end to produce an output tensor, taking a single gradient step\n- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)\n- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)\n- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions\n- [x] add efficient attention in unet\n- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)\n- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)\n- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms\n- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0"
+ },
+ {
+ "comment": "This code is a list of tasks to be completed for the DALLE2-pytorch project. It includes implementing attention-based upsampling, using inheritance, integrating Vit-VQGAN, creating an abstract interface for CLIP adapters, handling mixed precision and gradient accumulation in the decoder trainer, adding a training wrapper class for each unet in the cascade, incorporating convnext backbone for VQGAN-VAE, making sure DDPMs can be run with traditional resnet blocks, enabling super resolution training on crops for latter unets, and allowing conv-like attention with rel pos bias.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":1102-1112",
+ "content": "- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435\n- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms\n- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion\n- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in\n- [x] take care of mixed precision as well as gradient accumulation within decoder trainer\n- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer\n- [x] bring in tools to train vqgan-vae\n- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)\n- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)\n- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)\n- [x] "
+ },
+ {
+ "comment": "The code is about configuring and training a diffusion prior model for image generation. It includes improvements such as making hyperparameters configurable, incorporating cross-scale embedding, introducing cross embed layers for downsampling, and using an experimental tracker agnostic setup. The code also utilizes pydantic for configuration drive training, saves and restores all exponential moving averaged models for both diffusion prior and decoder.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":1112-1120",
+ "content": "offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention\n- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)\n- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training\n- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images\n- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14\n- [x] cross embed layers for downsampling, as an option\n- [x] use an experimental tracker agnostic setup, as done here\n- [x] use pydantic for config drive training\n- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)"
+ },
+ {
+ "comment": "This code lists various tasks and features that have been implemented or are planned for the DALLE2-pytorch model. These include save/load methods, creation of diffusion prior models, skip-layer excitations, grid attention in Cascading DDPM, unet conditioning, speed up inference, resampler from REPAINT paper, final combination of upsample feature maps, and consideration for Elucidated DALLE2.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":1121-1129",
+ "content": "- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes\n- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs\n- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)\n- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)\n- [x] allow for unet to be able to condition non-cross attention style as well\n- [x] speed up inference, read up on papers (ddim)\n- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865\n- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments\n- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364"
+ },
+ {
+ "comment": "This code chunk appears to be a task list for the DALLE2-pytorch project, followed by citations in BibTeX format. The tasks include implementing simple outpainting and text-guided 2x image size expansion. The project also plans on integrating the VQGAN-VAE, which can be pulled from a pretrained model to test latent diffusion and DALL-E2 integration. The cited works include \"Hierarchical Text-Conditional Image Generation with CLIP Latents\", \"High-Resolution Image Synthesis with Latent Diffusion Models\", \"Efficient Attention: Attention with Linear Complexities\" and a Twitter post by Katherine Crowson.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":1130-1164",
+ "content": "- [ ] add simple outpainting, text-guided 2x size the image for starters\n- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2\n## Citations\n```bibtex\n@misc{ramesh2022,\n title = {Hierarchical Text-Conditional Image Generation with CLIP Latents}, \n author = {Aditya Ramesh et al},\n year = {2022}\n}\n```\n```bibtex\n@misc{crowson2022,\n author = {Katherine Crowson},\n url = {https://twitter.com/rivershavewings}\n}\n```\n```bibtex\n@misc{rombach2021highresolution,\n title = {High-Resolution Image Synthesis with Latent Diffusion Models}, \n author = {Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Bj\u00f6rn Ommer},\n year = {2021},\n eprint = {2112.10752},\n archivePrefix = {arXiv},\n primaryClass = {cs.CV}\n}\n```\n```bibtex\n@article{shen2019efficient,\n author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},\n title = {Efficient Attention: Attention with Linear Complexities},"
+ },
+ {
+ "comment": "The code defines four BibTeX entries for academic papers, providing the paper title, author(s), journal or arXiv, and publication year. These entries can be used to cite the papers in a BibTeX database or bibliography file.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":1165-1197",
+ "content": " journal = {CoRR},\n year = {2018},\n url = {http://arxiv.org/abs/1812.01243},\n}\n```\n```bibtex\n@article{Yu2021VectorquantizedIM,\n title = {Vector-quantized Image Modeling with Improved VQGAN},\n author = {Jiahui Yu and Xin Li and Jing Yu Koh and Han Zhang and Ruoming Pang and James Qin and Alexander Ku and Yuanzhong Xu and Jason Baldridge and Yonghui Wu},\n journal = {ArXiv},\n year = {2021},\n volume = {abs/2110.04627}\n}\n```\n```bibtex\n@article{Shleifer2021NormFormerIT,\n title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},\n author = {Sam Shleifer and Jason Weston and Myle Ott},\n journal = {ArXiv},\n year = {2021},\n volume = {abs/2110.09456}\n}\n```\n```bibtex\n@article{Yu2022CoCaCC,\n title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},\n author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},\n journal = {ArXiv},\n year = {2022},\n volume = {abs/2205.01917}"
+ },
+ {
+ "comment": "This code snippet represents the citation for a research paper called \"CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention\" by Wenxiao Wang et al. The paper is available at arXiv with ID 2108.00154 and focuses on a versatile vision transformer model using cross-scale attention.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":1198-1224",
+ "content": "}\n```\n```bibtex\n@misc{wang2021crossformer,\n title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},\n author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},\n year = {2021},\n eprint = {2108.00154},\n archivePrefix = {arXiv},\n primaryClass = {cs.CV}\n}\n```\n```bibtex\n@article{ho2021cascaded,\n title = {Cascaded Diffusion Models for High Fidelity Image Generation},\n author = {Ho, Jonathan and Saharia, Chitwan and Chan, William and Fleet, David J and Norouzi, Mohammad and Salimans, Tim},\n journal = {arXiv preprint arXiv:2106.15282},\n year = {2021}\n}\n```\n```bibtex\n@misc{Saharia2022,\n title = {Imagen: unprecedented photorealism \u00d7 deep level of language understanding},\n author = {Chitwan Saharia*, William Chan*, Saurabh Saxena\u2020, Lala Li\u2020, Jay Whang\u2020, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho\u2020, David Fleet\u2020, Mohammad Norouzi*},"
+ },
+ {
+ "comment": "The code includes four BibTeX entries, each representing a different research article. The first entry is for the article titled \"Perception Prioritized Training of Diffusion Models\" by Choi et al., published in 2022. The second entry is for the article titled \"Palette: Image-to-Image Diffusion Models\" by Saharia et al., published in 2021. The third entry is for the article titled \"RePaint: Inpainting using Denoising Diffusion Probabilistic Models\" by Lugmayr et al., published in 2022. The last entry, marked as incomplete, is for an unnamed article by Chen et al. published in 2022.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":1225-1260",
+ "content": " year = {2022}\n}\n```\n```bibtex\n@article{Choi2022PerceptionPT,\n title = {Perception Prioritized Training of Diffusion Models},\n author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},\n journal = {ArXiv},\n year = {2022},\n volume = {abs/2204.00227}\n}\n```\n```bibtex\n@article{Saharia2021PaletteID,\n title = {Palette: Image-to-Image Diffusion Models},\n author = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi},\n journal = {ArXiv},\n year = {2021},\n volume = {abs/2111.05826}\n}\n```\n```bibtex\n@article{Lugmayr2022RePaintIU,\n title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},\n author = {Andreas Lugmayr and Martin Danelljan and Andr{\\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},\n journal = {ArXiv},\n year = {2022},\n volume = {abs/2201.09865}\n}\n```\n```bibtex\n@misc{chen2022analog,"
+ },
+ {
+ "comment": "The code represents BibTeX entries for research articles and conferences. These entries provide information about the title, authors, journals or conferences, publication years, and related identifiers of the cited works.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":1261-1292",
+ "content": " title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},\n author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},\n year = {2022},\n eprint = {2208.04202},\n archivePrefix = {arXiv},\n primaryClass = {cs.CV}\n}\n```\n```bibtex\n@article{Qiao2019WeightS,\n title = {Weight Standardization},\n author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},\n journal = {ArXiv},\n year = {2019},\n volume = {abs/1903.10520}\n}\n```\n```bibtex\n@inproceedings{rogozhnikov2022einops,\n title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},\n author = {Alex Rogozhnikov},\n booktitle = {International Conference on Learning Representations},\n year = {2022},\n url = {https://openreview.net/forum?id=oapKSVM2bcj}\n}\n```\n```bibtex\n@article{Sunkara2022NoMS,\n title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},"
+ },
+ {
+ "comment": "This code snippet provides the citation information for two papers in the BibTeX format. The first paper is titled \"Progressive Distillation for Fast Sampling of Diffusion Models\" by Tim Salimans and Jonathan Ho, published in ArXiv in 2022 with volume abs/2202.00512. The second paper is called \"ArXiv:2208.03641\" which doesn't seem to have a title or authors listed, but it was also published on ArXiv in 2022 with the volume abs/2208.03641. Both papers discuss generative modeling techniques using diffusion models.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/README.md\":1293-1310",
+ "content": " author = {Raja Sunkara and Tie Luo},\n journal = {ArXiv},\n year = {2022},\n volume = {abs/2208.03641}\n}\n```\n```bibtex\n@article{Salimans2022ProgressiveDF,\n title = {Progressive Distillation for Fast Sampling of Diffusion Models},\n author = {Tim Salimans and Jonathan Ho},\n journal = {ArXiv},\n year = {2022},\n volume = {abs/2202.00512}\n}\n```\n*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/b1ddbfd7-3af5-4ca1-a1bf-bd27cb28b640.json b/docs/doc/b1ddbfd7-3af5-4ca1-a1bf-bd27cb28b640.json
new file mode 100644
index 00000000..779aaf52
--- /dev/null
+++ b/docs/doc/b1ddbfd7-3af5-4ca1-a1bf-bd27cb28b640.json
@@ -0,0 +1,10 @@
+{
+ "summary": "This code is importing modules from the DALLE2-pytorch library, which includes the main DALLE2 class, diffusion prior network, unet, decoder, clip adapters, trainer for the decoder and diffusion prior, and VQGanVAE. The x_clip module also appears to be imported, but its purpose is not explicitly described in this chunk of code.",
+ "details": [
+ {
+ "comment": "This code is importing modules from the DALLE2-pytorch library, which includes the main DALLE2 class, diffusion prior network, unet, decoder, clip adapters, trainer for the decoder and diffusion prior, and VQGanVAE. The x_clip module also appears to be imported, but its purpose is not explicitly described in this chunk of code.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/__init__.py\":0-6",
+ "content": "from dalle2_pytorch.version import __version__\nfrom dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder\nfrom dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter\nfrom dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer\nfrom dalle2_pytorch.vqgan_vae import VQGanVAE\nfrom x_clip import CLIP"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/b6e46f34-6f1c-4275-b78b-c1825e4185e0.json b/docs/doc/b6e46f34-6f1c-4275-b78b-c1825e4185e0.json
new file mode 100644
index 00000000..a8d0ea7b
--- /dev/null
+++ b/docs/doc/b6e46f34-6f1c-4275-b78b-c1825e4185e0.json
@@ -0,0 +1,15 @@
+{
+ "summary": "This code imports libraries, defines functions, and parses command-line arguments for model path, conditioning scale, and input text. It loads a DALL-E2 model, generates an image based on the input text, saves it in PIL format, and returns the saved image.",
+ "details": [
+ {
+ "comment": "This code imports necessary libraries, defines some utility functions and a main function. It also includes a command-line argument parser with options for model path, conditioning scale, and the text input. The assert statement ensures that the specified model file exists before proceeding.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/cli.py\":0-32",
+ "content": "import click\nimport torch\nimport torchvision.transforms as T\nfrom functools import reduce\nfrom pathlib import Path\nfrom dalle2_pytorch import DALLE2, Decoder, DiffusionPrior\ndef safeget(dictionary, keys, default = None):\n return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)\ndef simple_slugify(text, max_length = 255):\n return text.replace(\"-\", \"_\").replace(\",\", \"\").replace(\" \", \"_\").replace(\"|\", \"--\").strip('-_')[:max_length]\ndef get_pkg_version():\n from pkg_resources import get_distribution\n return get_distribution('dalle2_pytorch').version\ndef main():\n pass\n@click.command()\n@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')\n@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')\n@click.argument('text')\ndef dream(\n model,\n cond_scale,\n text\n):\n model_path = Path(model)\n full_model_path = str(model_path.resolve())\n assert model_path.exists(), f'model not found at {full_model_path}'"
+ },
+ {
+ "comment": "This code loads a saved DALL-E2 model from a specified path, checks the version, initializes the prior and decoder components, recreates the model using these components, loads its parameters, generates an image based on input text, converts it to PIL format, saves it with a file name derived from the input text, and returns the saved image.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/cli.py\":33-51",
+ "content": " loaded = torch.load(str(model_path))\n version = safeget(loaded, 'version')\n print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')\n prior_init_params = safeget(loaded, 'init_params.prior')\n decoder_init_params = safeget(loaded, 'init_params.decoder')\n model_params = safeget(loaded, 'model_params')\n prior = DiffusionPrior(**prior_init_params)\n decoder = Decoder(**decoder_init_params)\n dalle2 = DALLE2(prior, decoder)\n dalle2.load_state_dict(model_params)\n image = dalle2(text, cond_scale = cond_scale)\n pil_image = T.ToPILImage()(image)\n return pil_image.save(f'./{simple_slugify(text)}.png')"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/b700c538-7940-4eaa-8ee5-44234a5f680b.json b/docs/doc/b700c538-7940-4eaa-8ee5-44234a5f680b.json
new file mode 100644
index 00000000..c46110fb
--- /dev/null
+++ b/docs/doc/b700c538-7940-4eaa-8ee5-44234a5f680b.json
@@ -0,0 +1,10 @@
+{
+ "summary": "This code snippet includes helper functions for time, print, and import operations. It defines a Timer class for measuring elapsed time, a print_ribbon function to format print statements with a banner, and an import_or_print_error function to handle module imports, displaying an error message if necessary and exiting the program.",
+ "details": [
+ {
+ "comment": "This code snippet includes helper functions for time, print, and import operations. It defines a Timer class for measuring elapsed time, a print_ribbon function to format print statements with a banner, and an import_or_print_error function to handle module imports, displaying an error message if necessary and exiting the program.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/utils.py\":0-34",
+ "content": "import time\nimport importlib\n# helper functions\ndef exists(val):\n return val is not None\n# time helpers\nclass Timer:\n def __init__(self):\n self.reset()\n def reset(self):\n self.last_time = time.time()\n def elapsed(self):\n return time.time() - self.last_time\n# print helpers\ndef print_ribbon(s, symbol = '=', repeat = 40):\n flank = symbol * repeat\n return f'{flank} {s} {flank}'\n# import helpers\ndef import_or_print_error(pkg_name, err_str = None):\n try:\n return importlib.import_module(pkg_name)\n except ModuleNotFoundError as e:\n if exists(err_str):\n print(err_str)\n exit()"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/ba42b845-3454-42bb-8ef1-001dfb0ae97c.json b/docs/doc/ba42b845-3454-42bb-8ef1-001dfb0ae97c.json
new file mode 100644
index 00000000..cef09731
--- /dev/null
+++ b/docs/doc/ba42b845-3454-42bb-8ef1-001dfb0ae97c.json
@@ -0,0 +1,600 @@
+{
+ "summary": "The code uses VQGAN-VAE, CLIP, and CoCa libraries for image generation, and includes helper functions, PyTorch CLIP model, neural networks, DALL-E 2 architecture, self-attention layers with normalization and dropout regularization. It initializes efficient DALL-E 2 and Imagen models, utilizes diffusion models for denoising and inpainting images, and incorporates conditional sampling from DALLE2-pytorch model for low-resolution image generation.",
+ "details": [
+ {
+ "comment": "This code imports various libraries and defines functions for data processing, including image resizing, Gaussian blurring, and rotary embeddings. It also utilizes the VQGAN-VAE, CLIP model, and CoCa. The code contains namedtuples, helper functions, and constants relevant to the tasks of image generation and language modeling.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":0-48",
+ "content": "import math\nimport random\nfrom tqdm.auto import tqdm\nfrom functools import partial, wraps\nfrom contextlib import contextmanager\nfrom collections import namedtuple\nfrom pathlib import Path\nimport torch\nimport torch.nn.functional as F\nfrom torch.utils.checkpoint import checkpoint\nfrom torch import nn, einsum\nimport torchvision.transforms as T\nfrom einops import rearrange, repeat, reduce, pack, unpack\nfrom einops.layers.torch import Rearrange\nfrom kornia.filters import gaussian_blur2d\nimport kornia.augmentation as K\nfrom dalle2_pytorch.tokenizer import tokenizer\nfrom dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE\nfrom resize_right import resize\n# rotary embeddings\nfrom rotary_embedding_torch import RotaryEmbedding\n# use x-clip\nfrom x_clip import CLIP\nfrom coca_pytorch import CoCa\n# constants\nNAT = 1. / math.log(2.)\nUnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized'])\n# helper functions\ndef exists(val):\n return val is not None\ndef identity(t, *args, **kwargs):\n return t\ndef first(arr, d = None):"
+ },
+ {
+ "comment": "Function 'if len(arr) == 0: return d' checks if the array is empty and returns the value 'd' if it is.\n'maybe(fn)' function creates a decorator that checks if the input exists, returning it if it does not.\n'default(val, d)' function returns the provided value 'val' if it exists; otherwise, it returns the default value 'd'.\n'cast_tuple(val, length=None, validate=True)' casts its argument to a tuple and optionally checks its length.\n'module_device(module)' retrieves the device of the module, defaulting to CPU for certain types like nn.Identity.\n'zero_init_(m)' initializes the weights and biases of the given module 'm' with zeros.\n'null_context(*args, **kwargs)' is a context manager that does nothing.\n'eval_decorator(fn)' wraps a function to evaluate the model before executing it.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":49-95",
+ "content": " if len(arr) == 0:\n return d\n return arr[0]\ndef maybe(fn):\n @wraps(fn)\n def inner(x, *args, **kwargs):\n if not exists(x):\n return x\n return fn(x, *args, **kwargs)\n return inner\ndef default(val, d):\n if exists(val):\n return val\n return d() if callable(d) else d\ndef cast_tuple(val, length = None, validate = True):\n if isinstance(val, list):\n val = tuple(val)\n out = val if isinstance(val, tuple) else ((val,) * default(length, 1))\n if exists(length) and validate:\n assert len(out) == length\n return out\ndef module_device(module):\n if isinstance(module, nn.Identity):\n return 'cpu' # It doesn't matter\n return next(module.parameters()).device\ndef zero_init_(m):\n nn.init.zeros_(m.weight)\n if exists(m.bias):\n nn.init.zeros_(m.bias)\n@contextmanager\ndef null_context(*args, **kwargs):\n yield\ndef eval_decorator(fn):\n def inner(model, *args, **kwargs):\n was_training = model.training\n model.eval()\n out = fn(model, *args, **kwargs)"
+ },
+ {
+ "comment": "This code defines several helper functions for processing lists of strings, padding tuples to a specific length, and creating checkpointable versions of Python functions. It also includes a function to determine if a given dtype is a floating point type, and a conditional wrapper for creating a checkpointable version of a function or module list.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":96-132",
+ "content": " model.train(was_training)\n return out\n return inner\ndef is_float_dtype(dtype):\n return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])\ndef is_list_str(x):\n if not isinstance(x, (list, tuple)):\n return False\n return all([type(el) == str for el in x])\ndef pad_tuple_to_length(t, length, fillvalue = None):\n remain_length = length - len(t)\n if remain_length <= 0:\n return t\n return (*t, *((fillvalue,) * remain_length))\n# checkpointing helper function\ndef make_checkpointable(fn, **kwargs):\n if isinstance(fn, nn.ModuleList):\n return [maybe(make_checkpointable)(el, **kwargs) for el in fn]\n condition = kwargs.pop('condition', None)\n if exists(condition) and not condition(fn):\n return fn\n @wraps(fn)\n def inner(*args):\n input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])\n if not input_needs_grad:\n return fn(*args)\n return checkpoint(fn, *args)"
+ },
+ {
+ "comment": "The code defines functions for controlling the gradient flow in a module, freezing all layers in a model, and making it evaluate only. It also includes helper functions to log a tensor, normalize a tensor using L2 norm, and resize an image to the specified size with optional interpolation method.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":134-178",
+ "content": " return inner\n# for controlling freezing of CLIP\ndef set_module_requires_grad_(module, requires_grad):\n for param in module.parameters():\n param.requires_grad = requires_grad\ndef freeze_all_layers_(module):\n set_module_requires_grad_(module, False)\ndef unfreeze_all_layers_(module):\n set_module_requires_grad_(module, True)\ndef freeze_model_and_make_eval_(model):\n model.eval()\n freeze_all_layers_(model)\n# tensor helpers\ndef log(t, eps = 1e-12):\n return torch.log(t.clamp(min = eps))\ndef l2norm(t):\n return F.normalize(t, dim = -1)\ndef resize_image_to(\n image,\n target_image_size,\n clamp_range = None,\n nearest = False,\n **kwargs\n):\n orig_image_size = image.shape[-1]\n if orig_image_size == target_image_size:\n return image\n if not nearest:\n scale_factors = target_image_size / orig_image_size\n out = resize(image, scale_factors = scale_factors, **kwargs)\n else:\n out = F.interpolate(image, target_image_size, mode = 'nearest')\n if exists(clamp_range):"
+ },
+ {
+ "comment": "This code defines a function for normalizing an image to the range of -1 to 1, and another for unnormalizing it back to the 0 to 1 range. It also includes a namedtuple for returning embedded text and image data along with their encodings. The code further defines a base class for clip adapters that takes a CLIP model as an argument and provides methods for validating and resizing images to match CLIP's requirements.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":179-213",
+ "content": " out = out.clamp(*clamp_range)\n return out\n# image normalization functions\n# ddpms expect images to be in the range of -1 to 1\n# but CLIP may otherwise\ndef normalize_neg_one_to_one(img):\n return img * 2 - 1\ndef unnormalize_zero_to_one(normed_img):\n return (normed_img + 1) * 0.5\n# clip related adapters\nEmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings'])\nEmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])\nclass BaseClipAdapter(nn.Module):\n def __init__(self, clip, **kwargs):\n super().__init__()\n self.clip = clip\n self.overrides = kwargs\n def validate_and_resize_image(self, image):\n image_size = image.shape[-1]\n assert image_size >= self.image_size, f'you are passing in an image of size {image_size} but CLIP requires the image size to be at least {self.image_size}'\n return resize_image_to(image, self.image_size)\n @property\n def dim_latent(self):\n raise NotImplementedError\n @property"
+ },
+ {
+ "comment": "This code defines a base class `BaseClipAdapter` with four methods that must be implemented by derived classes. The `XClipAdapter` class inherits from `BaseClipAdapter` and provides implementations for the properties of the underlying `clip` object, which is an instance of some clip model. The `embed_text` method takes a text input, truncates it to fit the maximum text length defined by `max_text_len`, applies a text transformer from the `clip` object, and returns the embeddings.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":214-256",
+ "content": " def image_size(self):\n raise NotImplementedError\n @property\n def image_channels(self):\n raise NotImplementedError\n @property\n def max_text_len(self):\n raise NotImplementedError\n def embed_text(self, text):\n raise NotImplementedError\n def embed_image(self, image):\n raise NotImplementedError\nclass XClipAdapter(BaseClipAdapter):\n @property\n def dim_latent(self):\n return self.clip.dim_latent\n @property\n def image_size(self):\n return self.clip.image_size\n @property\n def image_channels(self):\n return self.clip.image_channels\n @property\n def max_text_len(self):\n return self.clip.text_seq_len\n @torch.no_grad()\n def embed_text(self, text):\n text = text[..., :self.max_text_len]\n text_mask = text != 0\n encoder_output = self.clip.text_transformer(text)\n encoder_output_is_cls = encoder_output.ndim == 3\n text_cls, text_encodings = (encoder_output[:, 0], encoder_output[:, 1:]) if encoder_output_is_cls else (encoder_output, None)"
+ },
+ {
+ "comment": "This code snippet defines a class called CoCaAdapter, which is a base adapter for the DALL-E 2 PyTorch model. It contains methods to embed text and images, with optional overrides for image size and channels. The dim_latent property returns the dimension of the latent space, while max_text_len is used to set the maximum length for text inputs.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":257-288",
+ "content": " text_embed = self.clip.to_text_latent(text_cls)\n if exists(text_encodings):\n text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)\n return EmbeddedText(l2norm(text_embed), text_encodings)\n @torch.no_grad()\n def embed_image(self, image):\n image = self.validate_and_resize_image(image)\n encoder_output = self.clip.visual_transformer(image)\n image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]\n image_embed = self.clip.to_visual_latent(image_cls)\n return EmbeddedImage(l2norm(image_embed), image_encodings)\nclass CoCaAdapter(BaseClipAdapter):\n @property\n def dim_latent(self):\n return self.clip.dim\n @property\n def image_size(self):\n assert 'image_size' in self.overrides\n return self.overrides['image_size']\n @property\n def image_channels(self):\n assert 'image_channels' in self.overrides\n return self.overrides['image_channels']\n @property\n def max_text_len(self):"
+ },
+ {
+ "comment": "This code is for a text-to-image model that uses CLIP as its base. It has functions to embed texts and images, with the ability to handle maximum text length. It initializes an OpenAIClipAdapter class using CLIP's 'ViT-B/32' model and finds the layer for text attention final output.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":289-318",
+ "content": " assert 'max_text_len' in self.overrides\n return self.overrides['max_text_len']\n @torch.no_grad()\n def embed_text(self, text):\n text = text[..., :self.max_text_len]\n text_mask = text != 0\n text_embed, text_encodings = self.clip.embed_text(text)\n text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)\n return EmbeddedText(text_embed, text_encodings)\n @torch.no_grad()\n def embed_image(self, image):\n image = self.validate_and_resize_image(image)\n image_embed, image_encodings = self.clip.embed_image(image)\n return EmbeddedImage(image_embed, image_encodings)\nclass OpenAIClipAdapter(BaseClipAdapter):\n def __init__(\n self,\n name = 'ViT-B/32'\n ):\n import clip\n openai_clip, preprocess = clip.load(name)\n super().__init__(openai_clip)\n self.eos_id = 49407 # for handling 0 being also '!'\n text_attention_final = self.find_layer('ln_final')\n self.dim_latent_ = text_attention_final.weight.shape[0]"
+ },
+ {
+ "comment": "This code is part of a neural network model for text-to-image generation using PyTorch. It includes functions to handle text attention, clear the internal state, and embed input text. The class has properties such as `dim_latent`, `image_size`, `image_channels`, `max_text_len` which are used to define the network's structure and behavior.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":319-359",
+ "content": " self.handle = text_attention_final.register_forward_hook(self._hook)\n self.clip_normalize = preprocess.transforms[-1]\n self.cleared = False\n def find_layer(self, layer):\n modules = dict([*self.clip.named_modules()])\n return modules.get(layer, None)\n def clear(self):\n if self.cleared:\n return\n self.handle()\n def _hook(self, _, inputs, outputs):\n self.text_encodings = outputs\n @property\n def dim_latent(self):\n return self.dim_latent_\n @property\n def image_size(self):\n return self.clip.visual.input_resolution\n @property\n def image_channels(self):\n return 3\n @property\n def max_text_len(self):\n return self.clip.context_length\n @torch.no_grad()\n def embed_text(self, text):\n text = text[..., :self.max_text_len]\n is_eos_id = (text == self.eos_id)\n text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0\n text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)"
+ },
+ {
+ "comment": "Method to embed text using CLIP model by encoding the input text, applying a mask on text encodings, and returning EmbeddedText object with L2 normalized text embedding and float text encodings.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":360-389",
+ "content": " text_mask = text_mask & (text != 0)\n assert not self.cleared\n text_embed = self.clip.encode_text(text)\n text_encodings = self.text_encodings\n text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)\n del self.text_encodings\n return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())\n @torch.no_grad()\n def embed_image(self, image):\n assert not self.cleared\n image = self.validate_and_resize_image(image)\n image = self.clip_normalize(image)\n image_embed = self.clip.encode_image(image)\n return EmbeddedImage(l2norm(image_embed.float()), None)\nclass OpenClipAdapter(BaseClipAdapter):\n def __init__(\n self,\n name = 'ViT-B/32',\n pretrained = 'laion400m_e32'\n ):\n import open_clip\n clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)\n super().__init__(clip)\n self.eos_id = 49407\n text_attention_final = self.find_layer('ln_final')"
+ },
+ {
+ "comment": "The code represents a class that appears to be a part of a larger model. It has methods for embedding text, clearing internal state, finding layers in the network, and retrieving properties like latent dimension and maximum text length. The class relies on other components such as `preprocess`, `clip`, and `image_size`.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":390-432",
+ "content": " self._dim_latent = text_attention_final.weight.shape[0]\n self.handle = text_attention_final.register_forward_hook(self._hook)\n self.clip_normalize = preprocess.transforms[-1]\n self.cleared = False\n def find_layer(self, layer):\n modules = dict([*self.clip.named_modules()])\n return modules.get(layer, None)\n def clear(self):\n if self.cleared:\n return\n self.handle()\n def _hook(self, _, inputs, outputs):\n self.text_encodings = outputs\n @property\n def dim_latent(self):\n return self._dim_latent\n @property\n def image_size(self):\n image_size = self.clip.visual.image_size\n if isinstance(image_size, tuple):\n return max(image_size)\n return image_size\n @property\n def image_channels(self):\n return 3\n @property\n def max_text_len(self):\n return self.clip.context_length\n @torch.no_grad()\n def embed_text(self, text):\n text = text[..., :self.max_text_len]\n is_eos_id = (text == self.eos_id)"
+ },
+ {
+ "comment": "This function takes in a text input and returns an EmbeddedText object containing the embedded text representation and a corresponding mask. It first creates a mask excluding the end of sentence (EOS) token, pads it, and applies the mask to the original mask. Then, it encodes the text using CLIP's encode_text function, and finally normalizes the resulting embeddings. The classifier free guidance functions return a probability mask based on the given probability value for a specific shape and device.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":433-458",
+ "content": " text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0\n text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)\n text_mask = text_mask & (text != 0)\n assert not self.cleared\n text_embed = self.clip.encode_text(text)\n text_encodings = self.text_encodings\n text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)\n del self.text_encodings\n return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())\n @torch.no_grad()\n def embed_image(self, image):\n assert not self.cleared\n image = self.validate_and_resize_image(image)\n image = self.clip_normalize(image)\n image_embed = self.clip.encode_image(image)\n return EmbeddedImage(l2norm(image_embed.float()), None)\n# classifier free guidance functions\ndef prob_mask_like(shape, prob, device):\n if prob == 1:\n return torch.ones(shape, device = device, dtype = torch.bool)\n elif prob == 0:\n return torch.zeros(shape, device = device, dtype = torch.bool)"
+ },
+ {
+ "comment": "This code defines several helper functions used in the DALLE2-pytorch model. These functions are involved in tasks such as extracting values, calculating normal KL divergence, approximating the standard normal cumulative distribution function, and computing the discretized Gaussian log likelihood. The code also includes error handling for potential nan gradients when using deepspeed fp16.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":459-487",
+ "content": " else:\n return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob\n# gaussian diffusion helper functions\ndef extract(a, t, x_shape):\n b, *_ = t.shape\n out = a.gather(-1, t)\n return out.reshape(b, *((1,) * (len(x_shape) - 1)))\ndef meanflat(x):\n return x.mean(dim = tuple(range(1, len(x.shape))))\ndef normal_kl(mean1, logvar1, mean2, logvar2):\n return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))\ndef approx_standard_normal_cdf(x):\n return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3))))\ndef discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):\n assert x.shape == means.shape == log_scales.shape\n # attempting to correct nan gradients when learned variance is turned on\n # in the setting of deepspeed fp16\n eps = 1e-12 if x.dtype == torch.float32 else 1e-3\n centered_x = x - means\n inv_stdv = torch.exp(-log_scales)\n plus_in = inv_stdv * (centered_x + 1. / 255.)"
+ },
+ {
+ "comment": "Function at line 488-518 calculates log probabilities for a given input x, using an adaptive quantile regression approach with a cosine or linear schedule. The cosine_beta_schedule function generates a sequence of beta values using a cosine schedule, and the linear_beta_schedule function generates a sequence of beta values linearly.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":488-518",
+ "content": " cdf_plus = approx_standard_normal_cdf(plus_in)\n min_in = inv_stdv * (centered_x - 1. / 255.)\n cdf_min = approx_standard_normal_cdf(min_in)\n log_cdf_plus = log(cdf_plus, eps = eps)\n log_one_minus_cdf_min = log(1. - cdf_min, eps = eps)\n cdf_delta = cdf_plus - cdf_min\n log_probs = torch.where(x < -thres,\n log_cdf_plus,\n torch.where(x > thres,\n log_one_minus_cdf_min,\n log(cdf_delta, eps = eps)))\n return log_probs\ndef cosine_beta_schedule(timesteps, s = 0.008):\n \"\"\"\n cosine schedule\n as proposed in https://openreview.net/forum?id=-NEXDKk8gZ\n \"\"\"\n steps = timesteps + 1\n x = torch.linspace(0, timesteps, steps, dtype = torch.float64)\n alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2\n alphas_cumprod = alphas_cumprod / first(alphas_cumprod)\n betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])\n return torch.clip(betas, 0, 0.999)\ndef linear_beta_schedule(timesteps):\n scale = 1000 / timesteps\n beta_start = scale * 0.0001"
+ },
+ {
+ "comment": "This code defines three beta scheduling functions (linear, quadratic, cosine) and a class for the NoiseScheduler. The scheduler initializes with a selected beta schedule and timesteps. The beta_schedule parameter determines which function to use for generating the betas, which represent noise scaling factors in the model's training process.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":519-547",
+ "content": " beta_end = scale * 0.02\n return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)\ndef quadratic_beta_schedule(timesteps):\n scale = 1000 / timesteps\n beta_start = scale * 0.0001\n beta_end = scale * 0.02\n return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2\ndef sigmoid_beta_schedule(timesteps):\n scale = 1000 / timesteps\n beta_start = scale * 0.0001\n beta_end = scale * 0.02\n betas = torch.linspace(-6, 6, timesteps, dtype = torch.float64)\n return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start\nclass NoiseScheduler(nn.Module):\n def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):\n super().__init__()\n if beta_schedule == \"cosine\":\n betas = cosine_beta_schedule(timesteps)\n elif beta_schedule == \"linear\":\n betas = linear_beta_schedule(timesteps)\n elif beta_schedule == \"quadratic\":\n betas = quadratic_beta_schedule(timesteps)"
+ },
+ {
+ "comment": "This code sets the beta schedule and alpha values based on user input, then selects a loss function according to the specified type. The code also registers buffer helper functions for 'betas' and 'alphas_cumprod'.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":548-579",
+ "content": " elif beta_schedule == \"jsd\":\n betas = 1.0 / torch.linspace(timesteps, 1, timesteps)\n elif beta_schedule == \"sigmoid\":\n betas = sigmoid_beta_schedule(timesteps)\n else:\n raise NotImplementedError()\n alphas = 1. - betas\n alphas_cumprod = torch.cumprod(alphas, axis = 0)\n alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)\n timesteps, = betas.shape\n self.num_timesteps = int(timesteps)\n if loss_type == 'l1':\n loss_fn = F.l1_loss\n elif loss_type == 'l2':\n loss_fn = F.mse_loss\n elif loss_type == 'huber':\n loss_fn = F.smooth_l1_loss\n else:\n raise NotImplementedError()\n self.loss_type = loss_type\n self.loss_fn = loss_fn\n # register buffer helper function to cast double back to float\n register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))\n register_buffer('betas', betas)\n register_buffer('alphas_cumprod', alphas_cumprod)"
+ },
+ {
+ "comment": "The code is registering various buffers for computations related to diffusion. It calculates the posterior variance and clips the log of the posterior variance to avoid numerical instability at the beginning of the diffusion chain.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":580-600",
+ "content": " register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)\n # calculations for diffusion q(x_t | x_{t-1}) and others\n register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))\n register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))\n register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))\n register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))\n register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))\n # calculations for posterior q(x_{t-1} | x_t, x_0)\n posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)\n # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n register_buffer('posterior_variance', posterior_variance)\n # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))"
+ },
+ {
+ "comment": "In this code segment, the author is computing posterior means for a model, performing loss reweighting, generating random times, and calculating posterior values. The posterior means are calculated based on betas and alphas, while the loss reweighting considers p2_loss_weight_gamma. Random times are sampled for a batch of inputs using torch.randint. The q_posterior function calculates posterior mean, variance, and log-variance clipped from these computed values.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":601-619",
+ "content": " register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))\n register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))\n # p2 loss reweighting\n self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.\n register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)\n def sample_random_times(self, batch):\n return torch.randint(0, self.num_timesteps, (batch,), device = self.betas.device, dtype = torch.long)\n def q_posterior(self, x_start, x_t, t):\n posterior_mean = (\n extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +\n extract(self.posterior_mean_coef2, t, x_t.shape) * x_t\n )\n posterior_variance = extract(self.posterior_variance, t, x_t.shape)\n posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)\n return posterior_mean, posterior_variance, posterior_log_variance_clipped"
+ },
+ {
+ "comment": "The code defines three functions: `q_sample`, `calculate_v`, and `q_sample_from_to`. These functions are part of a neural network for generating images. `q_sample` combines alpha and noise values to generate a sample, while `calculate_v` calculates the difference between an alpha-blended noise and a one minus alpha-blended image start. The `q_sample_from_to` function samples from one timestep to another by interpolating alphas and sigmas.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":621-644",
+ "content": " def q_sample(self, x_start, t, noise = None):\n noise = default(noise, lambda: torch.randn_like(x_start))\n return (\n extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +\n extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise\n )\n def calculate_v(self, x_start, t, noise = None):\n return (\n extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -\n extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start\n )\n def q_sample_from_to(self, x_from, from_t, to_t, noise = None):\n shape = x_from.shape\n noise = default(noise, lambda: torch.randn_like(x_from))\n alpha = extract(self.sqrt_alphas_cumprod, from_t, shape)\n sigma = extract(self.sqrt_one_minus_alphas_cumprod, from_t, shape)\n alpha_next = extract(self.sqrt_alphas_cumprod, to_t, shape)\n sigma_next = extract(self.sqrt_one_minus_alphas_cumprod, to_t, shape)\n return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha"
+ },
+ {
+ "comment": "The code defines three methods for predicting values from different inputs, including v and noise. It also includes a method to reweight loss using p2_loss_weight and a class to rearrange images into sequences.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":646-676",
+ "content": " def predict_start_from_v(self, x_t, t, v):\n return (\n extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -\n extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v\n )\n def predict_start_from_noise(self, x_t, t, noise):\n return (\n extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -\n extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise\n )\n def predict_noise_from_start(self, x_t, t, x0):\n return (\n (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \\\n extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)\n )\n def p2_reweigh_loss(self, loss, times):\n if not self.has_p2_loss_reweighting:\n return loss\n return loss * extract(self.p2_loss_weight, times, loss.shape)\n# rearrange image to sequence\nclass RearrangeToSequence(nn.Module):\n def __init__(self, fn):\n super().__init__()\n self.fn = fn\n def forward(self, x):"
+ },
+ {
+ "comment": "This function is applying layer normalization to input tensor 'x' and returning the normalized output. The 'LayerNorm' class is a type of layer normalization, while 'ChanLayerNorm' is a channel-wise version. The code includes settings for epsilon, float precision, and stability options.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":677-710",
+ "content": " x = rearrange(x, 'b c ... -> b ... c')\n x, ps = pack([x], 'b * c')\n x = self.fn(x)\n x, = unpack(x, ps, 'b * c')\n x = rearrange(x, 'b ... c -> b c ...')\n return x\n# diffusion prior\nclass LayerNorm(nn.Module):\n def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):\n super().__init__()\n self.eps = eps\n self.fp16_eps = fp16_eps\n self.stable = stable\n self.g = nn.Parameter(torch.ones(dim))\n def forward(self, x):\n eps = self.eps if x.dtype == torch.float32 else self.fp16_eps\n if self.stable:\n x = x / x.amax(dim = -1, keepdim = True).detach()\n var = torch.var(x, dim = -1, unbiased = False, keepdim = True)\n mean = torch.mean(x, dim = -1, keepdim = True)\n return (x - mean) * (var + eps).rsqrt() * self.g\nclass ChanLayerNorm(nn.Module):\n def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):\n super().__init__()\n self.eps = eps\n self.fp16_eps = fp16_eps"
+ },
+ {
+ "comment": "This code defines a Residual class that wraps a function and adds it to the input. It also contains an MLP (Multi-Layer Perceptron) class with optional normalization and activation functions, followed by a series of fully connected layers. The forward method in DALLE2_PyTorch performs normalization, calculates mean and variance, then applies element-wise transformations before returning the output.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":711-749",
+ "content": " self.stable = stable\n self.g = nn.Parameter(torch.ones(1, dim, 1, 1))\n def forward(self, x):\n eps = self.eps if x.dtype == torch.float32 else self.fp16_eps\n if self.stable:\n x = x / x.amax(dim = 1, keepdim = True).detach()\n var = torch.var(x, dim = 1, unbiased = False, keepdim = True)\n mean = torch.mean(x, dim = 1, keepdim = True)\n return (x - mean) * (var + eps).rsqrt() * self.g\nclass Residual(nn.Module):\n def __init__(self, fn):\n super().__init__()\n self.fn = fn\n def forward(self, x, **kwargs):\n return self.fn(x, **kwargs) + x\n# mlp\nclass MLP(nn.Module):\n def __init__(\n self,\n dim_in,\n dim_out,\n *,\n expansion_factor = 2.,\n depth = 2,\n norm = False,\n ):\n super().__init__()\n hidden_dim = int(expansion_factor * dim_out)\n norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()\n layers = [nn.Sequential(\n nn.Linear(dim_in, hidden_dim),"
+ },
+ {
+ "comment": "This code defines a neural network architecture for the DALL-E 2 model. It includes a sequential layer with multiple linear layers, SiLU activation function, and normalization. The forward method performs inference on input data. Another class is defined for relative positional bias in causal transformer. The RelPosBias class initializes an embedding layer to calculate the relative position between elements for attention mechanism. It uses the concept of buckets, where each bucket represents a range of distances between two elements, and computes the relative position bucket based on input data.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":750-791",
+ "content": " nn.SiLU(),\n norm_fn()\n )]\n for _ in range(depth - 1):\n layers.append(nn.Sequential(\n nn.Linear(hidden_dim, hidden_dim),\n nn.SiLU(),\n norm_fn()\n ))\n layers.append(nn.Linear(hidden_dim, dim_out))\n self.net = nn.Sequential(*layers)\n def forward(self, x):\n return self.net(x.float())\n# relative positional bias for causal transformer\nclass RelPosBias(nn.Module):\n def __init__(\n self,\n heads = 8,\n num_buckets = 32,\n max_distance = 128,\n ):\n super().__init__()\n self.num_buckets = num_buckets\n self.max_distance = max_distance\n self.relative_attention_bias = nn.Embedding(num_buckets, heads)\n @staticmethod\n def _relative_position_bucket(\n relative_position,\n num_buckets = 32,\n max_distance = 128\n ):\n n = -relative_position\n n = torch.max(n, torch.zeros_like(n))\n max_exact = num_buckets // 2\n is_small = n < max_exact"
+ },
+ {
+ "comment": "This code snippet defines a class for DALLE2-pytorch, containing a method to calculate relative position buckets and an attention layer. The attention layer uses the SwiGLU activation function in its FeedForward module. The purpose of this code is to facilitate the calculation and application of positional embeddings in a transformer model.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":793-815",
+ "content": " val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()\n val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))\n return torch.where(is_small, n, val_if_large)\n def forward(self, i, j, *, device):\n q_pos = torch.arange(i, dtype = torch.long, device = device)\n k_pos = torch.arange(j, dtype = torch.long, device = device)\n rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')\n rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)\n values = self.relative_attention_bias(rp_bucket)\n return rearrange(values, 'i j h -> h i j')\n# feedforward\nclass SwiGLU(nn.Module):\n \"\"\" used successfully in https://arxiv.org/abs/2204.0231 \"\"\"\n def forward(self, x):\n x, gate = x.chunk(2, dim = -1)\n return x * F.silu(gate)\ndef FeedForward(\n dim,\n mult = 4,"
+ },
+ {
+ "comment": "The code defines a module that applies post-activation normalization. It also includes a nested Attention class that performs multi-head attention with optional causal masking and rotary embedding. The main components include layer normalization, dropout regularization, and linear transformations for dimensionality adjustments. The cosine similarity calculation is utilized if specified.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":816-857",
+ "content": " dropout = 0.,\n post_activation_norm = False\n):\n \"\"\" post-activation norm https://arxiv.org/abs/2110.09456 \"\"\"\n inner_dim = int(mult * dim)\n return nn.Sequential(\n LayerNorm(dim),\n nn.Linear(dim, inner_dim * 2, bias = False),\n SwiGLU(),\n LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),\n nn.Dropout(dropout),\n nn.Linear(inner_dim, dim, bias = False)\n )\n# attention\nclass Attention(nn.Module):\n def __init__(\n self,\n dim,\n *,\n dim_head = 64,\n heads = 8,\n dropout = 0.,\n causal = False,\n rotary_emb = None,\n cosine_sim = True,\n cosine_sim_scale = 16\n ):\n super().__init__()\n self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)\n self.cosine_sim = cosine_sim\n self.heads = heads\n inner_dim = dim_head * heads\n self.causal = causal\n self.norm = LayerNorm(dim)\n self.dropout = nn.Dropout(dropout)\n self.null_kv = nn.Parameter(torch.randn(2, dim_head))"
+ },
+ {
+ "comment": "This code defines a self-attention layer for DALL\u00b7E 2, initializing linear layers and including the option to use rotary embeddings. It also allows for classifier free guidance by adding null key/value pairs and using cosine similarity if enabled.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":858-890",
+ "content": " self.to_q = nn.Linear(dim, inner_dim, bias = False)\n self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)\n self.rotary_emb = rotary_emb\n self.to_out = nn.Sequential(\n nn.Linear(inner_dim, dim, bias = False),\n LayerNorm(dim)\n )\n def forward(self, x, mask = None, attn_bias = None):\n b, n, device = *x.shape[:2], x.device\n x = self.norm(x)\n q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))\n q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)\n q = q * self.scale\n # rotary embeddings\n if exists(self.rotary_emb):\n q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))\n # add null key / value for classifier free guidance in prior net\n nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))\n k = torch.cat((nk, k), dim = -2)\n v = torch.cat((nv, v), dim = -2)\n # whether to use cosine sim\n if self.cosine_sim:"
+ },
+ {
+ "comment": "This code snippet performs multi-head attention by first normalizing the query and key tensors, calculating their similarities, adding relative positional encoding if available, masking irrelevant values based on a given mask, applying causal masking if specified, and finally computing the attention weights and aggregating the corresponding values.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":891-927",
+ "content": " q, k = map(l2norm, (q, k))\n q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))\n # calculate query / key similarities\n sim = einsum('b h i d, b j d -> b h i j', q, k)\n # relative positional encoding (T5 style)\n if exists(attn_bias):\n sim = sim + attn_bias\n # masking\n max_neg_value = -torch.finfo(sim.dtype).max\n if exists(mask):\n mask = F.pad(mask, (1, 0), value = True)\n mask = rearrange(mask, 'b j -> b 1 1 j')\n sim = sim.masked_fill(~mask, max_neg_value)\n if self.causal:\n i, j = sim.shape[-2:]\n causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)\n sim = sim.masked_fill(causal_mask, max_neg_value)\n # attention\n attn = sim.softmax(dim = -1, dtype = torch.float32)\n attn = attn.type(sim.dtype)\n attn = self.dropout(attn)\n # aggregate values\n out = einsum('b h i j, b j d -> b h i d', attn, v)"
+ },
+ {
+ "comment": "This code defines a `CausalTransformer` class for natural language processing tasks. The class initializes several modules such as LayerNorm, RelPosBias, RotaryEmbedding, and Attention. It also includes a FeedForward layer with configurable parameters like `dim`, `depth`, `dim_head`, `heads`, `ff_mult`, `attn_dropout`, `ff_dropout`, `norm_in`, `norm_out`, `final_proj`, and `rotary_emb`. The code snippet you provided is responsible for rearranging the tensor dimensions and returning it after processing by the `CausalTransformer` model.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":929-960",
+ "content": " out = rearrange(out, 'b h n d -> b n (h d)')\n return self.to_out(out)\nclass CausalTransformer(nn.Module):\n def __init__(\n self,\n *,\n dim,\n depth,\n dim_head = 64,\n heads = 8,\n ff_mult = 4,\n norm_in = False,\n norm_out = True,\n attn_dropout = 0.,\n ff_dropout = 0.,\n final_proj = True,\n normformer = False,\n rotary_emb = True\n ):\n super().__init__()\n self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM\n self.rel_pos_bias = RelPosBias(heads = heads)\n rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None\n self.layers = nn.ModuleList([])\n for _ in range(depth):\n self.layers.append(nn.ModuleList([\n Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),\n FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)"
+ },
+ {
+ "comment": "The code initializes a DiffusionPriorNetwork model with multiple layers, including attention and feed-forward modules. It also includes layer normalization and the option to project the output. The network takes in input of varying dimensions and can condition on time, image, and/or text embeddings. The self_cond parameter determines whether or not to use self-conditioning.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":961-992",
+ "content": " ]))\n self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options\n self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()\n def forward(self, x):\n n, device = x.shape[1], x.device\n x = self.init_norm(x)\n attn_bias = self.rel_pos_bias(n, n + 1, device = device)\n for attn, ff in self.layers:\n x = attn(x, attn_bias = attn_bias) + x\n x = ff(x) + x\n out = self.norm(x)\n return self.project_out(out)\nclass DiffusionPriorNetwork(nn.Module):\n def __init__(\n self,\n dim,\n num_timesteps = None,\n num_time_embeds = 1,\n num_image_embeds = 1,\n num_text_embeds = 1,\n max_text_len = 256,\n self_cond = False,\n **kwargs\n ):\n super().__init__()"
+ },
+ {
+ "comment": "This code defines a class with parameters for dimensionality, number of time, image, and text embeddings. It initializes layers to transform input into text, time, and image embeddings. The \"learned_query\" is a learned parameter for the model.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":993-1016",
+ "content": " self.dim = dim\n self.num_time_embeds = num_time_embeds\n self.num_image_embeds = num_image_embeds\n self.num_text_embeds = num_text_embeds\n self.to_text_embeds = nn.Sequential(\n nn.Linear(dim, dim * num_text_embeds) if num_text_embeds > 1 else nn.Identity(),\n Rearrange('b (n d) -> b n d', n = num_text_embeds)\n )\n self.continuous_embedded_time = not exists(num_timesteps)\n self.to_time_embeds = nn.Sequential(\n nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP\n Rearrange('b (n d) -> b n d', n = num_time_embeds)\n )\n self.to_image_embeds = nn.Sequential(\n nn.Linear(dim, dim * num_image_embeds) if num_image_embeds > 1 else nn.Identity(),\n Rearrange('b (n d) -> b n d', n = num_image_embeds)\n )\n self.learned_query = nn.Parameter(torch.randn(dim))"
+ },
+ {
+ "comment": "The code defines a model with a causal transformer and includes parameters for padding strategy, self-conditioning, and a function to perform forward calculations. The `forward_with_cond_scale` method takes conditional scaling as input and returns the scaled logits by combining original logits with null logits at 100% condition drop probabilities.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1017-1051",
+ "content": " self.causal_transformer = CausalTransformer(dim = dim, **kwargs)\n # dalle1 learned padding strategy\n self.max_text_len = max_text_len\n self.null_text_encodings = nn.Parameter(torch.randn(1, max_text_len, dim))\n self.null_text_embeds = nn.Parameter(torch.randn(1, num_text_embeds, dim))\n self.null_image_embed = nn.Parameter(torch.randn(1, dim))\n # whether to use self conditioning, Hinton's group's new ddpm technique\n self.self_cond = self_cond\n def forward_with_cond_scale(\n self,\n *args,\n cond_scale = 1.,\n **kwargs\n ):\n logits = self.forward(*args, **kwargs)\n if cond_scale == 1:\n return logits\n null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs)\n return null_logits + (logits - null_logits) * cond_scale\n def forward(\n self,\n image_embed,\n diffusion_timesteps,\n *,\n text_embed,\n text_encodings = None,"
+ },
+ {
+ "comment": "This code initializes a model's parameters based on the given image_embed. It sets up self-conditioning if necessary, converts text and image embeddings to the appropriate format, and creates classifier free guidance masks for both text and image inputs. The model will use these embeddings and masks for prediction.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1052-1075",
+ "content": " self_cond = None,\n text_cond_drop_prob = 0.,\n image_cond_drop_prob = 0.\n ):\n batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype\n num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds\n # setup self conditioning\n if self.self_cond:\n self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))\n self_cond = rearrange(self_cond, 'b d -> b 1 d')\n # in section 2.2, last paragraph\n # \"... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction\"\n text_embed = self.to_text_embeds(text_embed)\n image_embed = self.to_image_embeds(image_embed)\n # classifier free guidance masks\n text_keep_mask = prob_mask_like((batch,), 1 - text_cond_drop_prob, device = device)\n text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')"
+ },
+ {
+ "comment": "This code snippet is preparing the input data for a DALL-E 2 model by handling text encodings. It creates an image_keep_mask, makes text encodings optional based on their existence, applies masking to remove padding or null encodings, and ensures that the length of text_encodings matches the expected maximum length.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1077-1102",
+ "content": " image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device)\n image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1')\n # make text encodings optional\n # although the paper seems to suggest it is present <--\n if not exists(text_encodings):\n text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)\n mask = torch.any(text_encodings != 0., dim = -1)\n # replace any padding in the text encodings with learned padding tokens unique across position\n text_encodings = text_encodings[:, :self.max_text_len]\n mask = mask[:, :self.max_text_len]\n text_len = text_encodings.shape[-2]\n remainder = self.max_text_len - text_len\n if remainder > 0:\n text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)\n mask = F.pad(mask, (0, remainder), value = False)\n # mask out text encodings with null encodings\n null_text_encodings = self.null_text_encodings.to(text_encodings.dtype)"
+ },
+ {
+ "comment": "This code section is applying masking to text, image, and null embeddings based on the `text_keep_mask` and `image_keep_mask`. It uses these masks to decide which embeddings to keep or replace with null embeddings. The embeddings are also being converted to appropriate data types. Additionally, there's a conditional check for continuous embedded time.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1104-1133",
+ "content": " text_encodings = torch.where(\n rearrange(mask, 'b n -> b n 1').clone() & text_keep_mask,\n text_encodings,\n null_text_encodings\n )\n # mask out text embeddings with null text embeddings\n null_text_embeds = self.null_text_embeds.to(text_embed.dtype)\n text_embed = torch.where(\n text_keep_mask,\n text_embed,\n null_text_embeds\n )\n # mask out image embeddings with null image embeddings\n null_image_embed = self.null_image_embed.to(image_embed.dtype)\n image_embed = torch.where(\n image_keep_mask,\n image_embed,\n null_image_embed\n )\n # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)\n # but let's just do it right\n if self.continuous_embedded_time:"
+ },
+ {
+ "comment": "The code defines a DiffusionPrior class that takes in various inputs such as text encodings, timesteps, and image embeddings. It applies causal transformer to learn the learned_query, which predicts the image embedding per DDPM timestep. The text_cond_drop_prob parameter is optional and if provided, will dropout the text conditioning with a specified probability.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1134-1173",
+ "content": " diffusion_timesteps = diffusion_timesteps.type(dtype)\n time_embed = self.to_time_embeds(diffusion_timesteps)\n learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)\n if self.self_cond:\n learned_queries = torch.cat((self_cond, learned_queries), dim = -2)\n tokens = torch.cat((\n text_encodings,\n text_embed,\n time_embed,\n image_embed,\n learned_queries\n ), dim = -2)\n # attend\n tokens = self.causal_transformer(tokens)\n # get learned query, which should predict the image embedding (per DDPM timestep)\n pred_image_embed = tokens[..., -1, :]\n return pred_image_embed\nclass DiffusionPrior(nn.Module):\n def __init__(\n self,\n net,\n *,\n clip = None,\n image_embed_dim = None,\n image_size = None,\n image_channels = 3,\n timesteps = 1000,\n sample_timesteps = None,\n cond_drop_prob = 0.,\n text_cond_drop_prob = None,"
+ },
+ {
+ "comment": "This code snippet initializes a DALLE2 model with various optional parameters for training and sampling. These include loss type, conditioning on text encodings, clamping of image embeddings, scaling the L2-normed image embedding, and adapter overrides for CLIP adapter integration.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1174-1187",
+ "content": " image_cond_drop_prob = None,\n loss_type = \"l2\",\n predict_x_start = True,\n predict_v = False,\n beta_schedule = \"cosine\",\n condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training\n sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)\n sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)\n training_clamp_l2norm = False,\n init_image_embed_l2norm = False,\n image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132\n clip_adapter_overrides = dict()\n ):\n super().__init__()"
+ },
+ {
+ "comment": "The code is initializing an instance of a model. It sets the sample_timesteps, creates a NoiseScheduler object with specified parameters, checks if CLIP is provided and adapts it if necessary, sets the image_embed_dim if not given, and assigns the network architecture (net) to be used.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1189-1213",
+ "content": " self.sample_timesteps = sample_timesteps\n self.noise_scheduler = NoiseScheduler(\n beta_schedule = beta_schedule,\n timesteps = timesteps,\n loss_type = loss_type\n )\n if exists(clip):\n assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'\n if isinstance(clip, CLIP):\n clip = XClipAdapter(clip, **clip_adapter_overrides)\n elif isinstance(clip, CoCa):\n clip = CoCaAdapter(clip, **clip_adapter_overrides)\n assert isinstance(clip, BaseClipAdapter)\n freeze_model_and_make_eval_(clip)\n self.clip = clip\n else:\n assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'\n self.clip = None\n self.net = net\n self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)"
+ },
+ {
+ "comment": "The code asserts that the diffusion prior network dimension and the image embedding dimension are consistent, and checks if a CLIP is passed in with correct latent dimensions. It also sets channels, text conditional drop probability, image conditional drop probability, enables classifier guidance if probabilities are greater than 0, and conditions on text encodings. It offers both options to predict noise or x0 directly for image embedding as per the paper's claim of better results.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1215-1226",
+ "content": " assert net.dim == self.image_embed_dim, f'your diffusion prior network has a dimension of {net.dim}, but you set your image embedding dimension (keyword image_embed_dim) on DiffusionPrior to {self.image_embed_dim}'\n assert not exists(clip) or clip.dim_latent == self.image_embed_dim, f'you passed in a CLIP to the diffusion prior with latent dimensions of {clip.dim_latent}, but your image embedding dimension (keyword image_embed_dim) for the DiffusionPrior was set to {self.image_embed_dim}'\n self.channels = default(image_channels, lambda: clip.image_channels)\n self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob)\n self.image_cond_drop_prob = default(image_cond_drop_prob, cond_drop_prob)\n self.can_classifier_guidance = self.text_cond_drop_prob > 0. and self.image_cond_drop_prob > 0.\n self.condition_on_text_encodings = condition_on_text_encodings\n # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both."
+ },
+ {
+ "comment": "The code sets various parameters and properties for an object, including predict_x_start, image_embed_scale, sampling_clamp_l2norm, etc. It also defines the l2norm_clamp_embed function and p_mean_variance function. The device property retrieves the device used by the object, and there's a register_buffer for tracking device usage.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1228-1254",
+ "content": " self.predict_x_start = predict_x_start\n self.predict_v = predict_v # takes precedence over predict_x_start\n # @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132\n self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)\n # whether to force an l2norm, similar to clipping denoised, when sampling\n self.sampling_clamp_l2norm = sampling_clamp_l2norm\n self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm\n self.training_clamp_l2norm = training_clamp_l2norm\n self.init_image_embed_l2norm = init_image_embed_l2norm\n # device tracker\n self.register_buffer('_dummy', torch.tensor([True]), persistent = False)\n @property\n def device(self):\n return self._dummy.device\n def l2norm_clamp_embed(self, image_embed):\n return l2norm(image_embed) * self.image_embed_scale\n def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):"
+ },
+ {
+ "comment": "This code asserts that the model was not trained with conditional dropout, preventing classifier free guidance if cond_scale is anything other than 1. It then calculates and returns the model mean, posterior variance, posterior log variance, and x_start depending on different conditions.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1255-1273",
+ "content": " assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'\n pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)\n if self.predict_v:\n x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)\n elif self.predict_x_start:\n x_start = pred\n else:\n x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)\n if clip_denoised and not self.predict_x_start:\n x_start.clamp_(-1., 1.)\n if self.predict_x_start and self.sampling_clamp_l2norm:\n x_start = l2norm(x_start) * self.image_embed_scale\n model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)\n return model_mean, posterior_variance, posterior_log_variance, x_start"
+ },
+ {
+ "comment": "This code defines the `p_sample` and `p_sample_loop_ddpm` functions. `p_sample` takes input, generates a model mean and log variance, applies noise based on whether t is zero or not, and returns the prediction and x_start. `p_sample_loop_ddpm` initializes an image embedding, optionally normalizes it, and iterates through a reversed range to perform some unspecified operation for each iteration. The code uses PyTorch's `@torch.no_grad()` decorator to disable gradient computation during these functions' execution.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1275-1295",
+ "content": " @torch.no_grad()\n def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):\n b, *_, device = *x.shape, x.device\n model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)\n noise = torch.randn_like(x)\n # no noise when t == 0\n nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))\n pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n return pred, x_start\n @torch.no_grad()\n def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):\n batch, device = shape[0], self.device\n image_embed = torch.randn(shape, device = device)\n x_start = None # for self-conditioning\n if self.init_image_embed_l2norm:\n image_embed = l2norm(image_embed) * self.image_embed_scale\n for i in tqdm(reversed(range("
+ },
+ {
+ "comment": "The code defines the `p_sample` function which samples images and their corresponding embeddings using a loop over time steps. It also includes an optional L2-norm clamping for final image embedding. The `p_sample_loop_ddim` function is a helper method to define shape, times, and time pairs for the sampling loop in DDIM style.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1295-1315",
+ "content": "0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):\n times = torch.full((batch,), i, device = device, dtype = torch.long)\n self_cond = x_start if self.net.self_cond else None\n image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)\n if self.sampling_final_clamp_l2norm and self.predict_x_start:\n image_embed = self.l2norm_clamp_embed(image_embed)\n return image_embed\n @torch.no_grad()\n def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):\n batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps\n times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1]\n times = list(reversed(times.int().tolist()))\n time_pairs = list(zip(times[:-1], times[1:]))\n image_embed = torch.randn(shape, device = device)"
+ },
+ {
+ "comment": "The code is iterating through time pairs, calculating alpha values and performing a forward pass in the neural network. It also adjusts x_start based on prediction methods and performs noise scheduling. The purpose seems to be generating an image using conditional sampling with self-conditioning and considering different prediction methods for x_start.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1317-1341",
+ "content": " x_start = None # for self-conditioning\n if self.init_image_embed_l2norm:\n image_embed = l2norm(image_embed) * self.image_embed_scale\n for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):\n alpha = alphas[time]\n alpha_next = alphas[time_next]\n time_cond = torch.full((batch,), time, device = device, dtype = torch.long)\n self_cond = x_start if self.net.self_cond else None\n pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)\n # derive x0\n if self.predict_v:\n x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)\n elif self.predict_x_start:\n x_start = pred\n else:\n x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)\n # clip x0 before maybe predicting noise"
+ },
+ {
+ "comment": "In this code segment, it checks if predicting x_start is enabled and performs L2-norm clamping if necessary. It then predicts noise using the noise scheduler based on image embeddings, time condition, and x_start. If time_next is less than 0, it sets image_embed to x_start. Calculates coefficients c1 and c2 for RNN sampling and generates noise accordingly. Combines these elements to generate the final image_embed which is then optionally L2-norm clamped if enabled.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1343-1371",
+ "content": " if not self.predict_x_start:\n x_start.clamp_(-1., 1.)\n if self.predict_x_start and self.sampling_clamp_l2norm:\n x_start = self.l2norm_clamp_embed(x_start)\n # predict noise\n pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)\n if time_next < 0:\n image_embed = x_start\n continue\n c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()\n c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()\n noise = torch.randn_like(image_embed) if time_next > 0 else 0.\n image_embed = x_start * alpha_next.sqrt() + \\\n c1 * noise + \\\n c2 * pred_noise\n if self.predict_x_start and self.sampling_final_clamp_l2norm:\n image_embed = self.l2norm_clamp_embed(image_embed)\n return image_embed\n @torch.no_grad()\n def p_sample_loop(self, *args, timesteps = None, **kwargs):"
+ },
+ {
+ "comment": "This code is from the DALLE2-pytorch model. It first determines if the timesteps are less than the number of timesteps in the noise scheduler. If so, it uses the p_sample_loop_ddim function to get the normalized image embeddings, otherwise it uses the p_sample_loop_ddpm function. The code then scales the normalized image embeddings by the image_embed_scale and returns the scaled embeddings. The p_losses function generates a noisy version of the input image embedding using the noise scheduler, and optionally conditions the model with self-conditioning if the condition is met. Finally, it passes the noisy embedding to the network for prediction.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1372-1395",
+ "content": " timesteps = default(timesteps, self.noise_scheduler.num_timesteps)\n assert timesteps <= self.noise_scheduler.num_timesteps\n is_ddim = timesteps < self.noise_scheduler.num_timesteps\n if not is_ddim:\n normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)\n else:\n normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)\n image_embed = normalized_image_embed / self.image_embed_scale\n return image_embed\n def p_losses(self, image_embed, times, text_cond, noise = None):\n noise = default(noise, lambda: torch.randn_like(image_embed))\n image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)\n self_cond = None\n if self.net.self_cond and random.random() < 0.5:\n with torch.no_grad():\n self_cond = self.net(image_embed_noisy, times, **text_cond).detach()\n pred = self.net(\n image_embed_noisy,"
+ },
+ {
+ "comment": "The code defines a method for predicting and calculating loss. It takes in parameters such as times, self_cond, text_cond_drop_prob, image_cond_drop_prob, and text_cond. If certain conditions are met, it performs l2norm clamping on the prediction, sets the target based on whether to predict x or v, then calculates the loss using the noise scheduler's loss function. The code also includes a sample_batch_size method that samples an image batch and iterates over time steps in reverse order for some processing.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1396-1424",
+ "content": " times,\n self_cond = self_cond,\n text_cond_drop_prob = self.text_cond_drop_prob,\n image_cond_drop_prob = self.image_cond_drop_prob,\n **text_cond\n )\n if self.predict_x_start and self.training_clamp_l2norm:\n pred = self.l2norm_clamp_embed(pred)\n if self.predict_v:\n target = self.noise_scheduler.calculate_v(image_embed, times, noise)\n elif self.predict_x_start:\n target = image_embed\n else:\n target = noise\n loss = self.noise_scheduler.loss_fn(pred, target)\n return loss\n @torch.no_grad()\n @eval_decorator\n def sample_batch_size(self, batch_size, text_cond, cond_scale = 1.):\n device = self.betas.device\n shape = (batch_size, self.image_embed_dim)\n img = torch.randn(shape, device = device)\n for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = self.noise_scheduler.num_timesteps):"
+ },
+ {
+ "comment": "This code is part of a DALL-E 2 model implementation in PyTorch. The sample function generates multiple image embeddings based on provided text, then chooses the most similar one according to CLIP's similarity judgment. The function uses a p_sample_loop method which takes timesteps as input and returns a batch of images with the specified size.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1425-1453",
+ "content": " img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)\n return img\n @torch.no_grad()\n @eval_decorator\n def sample(\n self,\n text,\n num_samples_per_batch = 2,\n cond_scale = 1.,\n timesteps = None\n ):\n timesteps = default(timesteps, self.sample_timesteps)\n # in the paper, what they did was\n # sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP\n text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)\n batch_size = text.shape[0]\n image_embed_dim = self.image_embed_dim\n text_embed, text_encodings = self.clip.embed_text(text)\n text_cond = dict(text_embed = text_embed)\n if self.condition_on_text_encodings:\n text_cond = {**text_cond, 'text_encodings': text_encodings}\n image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)"
+ },
+ {
+ "comment": "This function retrieves the original unscaled image embeddings from the input, rearranges them based on the number of samples per batch, calculates text-image similarities using Euclidean distance, gets the top indices and gathers corresponding embeddings. It allows for training on preprocessed CLIP text and image embeddings or CLIP text encodings. If neither text nor text embedding is supplied, an assertion error will be raised.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1455-1480",
+ "content": " # retrieve original unscaled image embed\n text_embeds = text_cond['text_embed']\n text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)\n image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)\n text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))\n top_sim_indices = text_image_sims.topk(k = 1).indices\n top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)\n top_image_embeds = image_embeds.gather(1, top_sim_indices)\n return rearrange(top_image_embeds, 'b 1 d -> b d')\n def forward(\n self,\n text = None,\n image = None,\n text_embed = None, # allow for training on preprocessed CLIP text and image embeddings\n image_embed = None,\n text_encodings = None, # as well as CLIP text encodings\n *args,\n **kwargs\n ):\n assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'"
+ },
+ {
+ "comment": "The code snippet checks if an image or image embedding is supplied and throws an error if neither exists. It also verifies the presence of text encodings or text based on the specified conditioning during initialization. The code then calculates the text embeddings from the given text using the clip model. If conditioned on text encodings, it includes them in the text_cond dictionary. It samples random times for timestep conditioning from the noise scheduler and scales the image embed (by Katherine).",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1481-1503",
+ "content": " assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'\n assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'\n if exists(image):\n image_embed, _ = self.clip.embed_image(image)\n # calculate text conditionings, based on what is passed in\n if exists(text):\n text_embed, text_encodings = self.clip.embed_text(text)\n text_cond = dict(text_embed = text_embed)\n if self.condition_on_text_encodings:\n assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'\n text_cond = {**text_cond, 'text_encodings': text_encodings}\n # timestep conditioning from ddpm\n batch, device = image_embed.shape[0], image_embed.device\n times = self.noise_scheduler.sample_random_times(batch)\n # scale image embed (Katherine)"
+ },
+ {
+ "comment": "This code contains two classes, `NearestUpsample` and `PixelShuffleUpsample`. `NearestUpsample` performs nearest neighbor upsampling followed by a convolution operation. `PixelShuffleUpsample` applies pixel shuffling after a 1x1 convolution to reduce checkerboard artifacts. Both classes can be used for image upsampling tasks.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1505-1542",
+ "content": " image_embed *= self.image_embed_scale\n # calculate forward loss\n return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)\n# decoder\ndef NearestUpsample(dim, dim_out = None):\n dim_out = default(dim_out, dim)\n return nn.Sequential(\n nn.Upsample(scale_factor = 2, mode = 'nearest'),\n nn.Conv2d(dim, dim_out, 3, padding = 1)\n )\nclass PixelShuffleUpsample(nn.Module):\n \"\"\"\n code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts\n https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf\n \"\"\"\n def __init__(self, dim, dim_out = None):\n super().__init__()\n dim_out = default(dim_out, dim)\n conv = nn.Conv2d(dim, dim_out * 4, 1)\n self.net = nn.Sequential(\n conv,\n nn.SiLU(),\n nn.PixelShuffle(2)\n )\n self.init_conv_(conv)\n def init_conv_(self, conv):\n o, i, h, w = conv.weight.shape\n conv_weight = torch.empty(o // 4, i, h, w)\n nn.init.kaiming_uniform_(conv_weight)"
+ },
+ {
+ "comment": "The code defines a class called WeightStandardizedConv2d that extends nn.Conv2d and implements weight standardization for improving synergy with group normalization. It also includes a function named Downsample to downsample the input using pixel unshuffle technique, which is optimal according to a reference paper. The forward method in WeightStandardizedConv2d calculates mean and variance of flattened weights and performs weight standardization before applying convolution operations.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1543-1573",
+ "content": " conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')\n conv.weight.data.copy_(conv_weight)\n nn.init.zeros_(conv.bias.data)\n def forward(self, x):\n return self.net(x)\ndef Downsample(dim, dim_out = None):\n # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample\n # named SP-conv in the paper, but basically a pixel unshuffle\n dim_out = default(dim_out, dim)\n return nn.Sequential(\n Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),\n nn.Conv2d(dim * 4, dim_out, 1)\n )\nclass WeightStandardizedConv2d(nn.Conv2d):\n \"\"\"\n https://arxiv.org/abs/1903.10520\n weight standardization purportedly works synergistically with group normalization\n \"\"\"\n def forward(self, x):\n eps = 1e-5 if x.dtype == torch.float32 else 1e-3\n weight = self.weight\n flattened_weights = rearrange(weight, 'o ... -> o (...)')\n mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')\n var = torch.var(flattened_weights, dim = -1, unbiased = False)"
+ },
+ {
+ "comment": "This code snippet contains the definition of three classes: `rearrange`, `SinusoidalPosEmb`, and `Block`. The `rearrange` function is used to reshape tensors, `SinusoidalPosEmb` class computes sinusoidal positional embeddings, and `Block` class defines a convolutional block with an option for weight standardization.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1574-1604",
+ "content": " var = rearrange(var, 'o -> o 1 1 1')\n weight = (weight - mean) * (var + eps).rsqrt()\n return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)\nclass SinusoidalPosEmb(nn.Module):\n def __init__(self, dim):\n super().__init__()\n self.dim = dim\n def forward(self, x):\n dtype, device = x.dtype, x.device\n assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'\n half_dim = self.dim // 2\n emb = math.log(10000) / (half_dim - 1)\n emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)\n emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')\n return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)\nclass Block(nn.Module):\n def __init__(\n self,\n dim,\n dim_out,\n groups = 8,\n weight_standardization = False\n ):\n super().__init__()\n conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d"
+ },
+ {
+ "comment": "This code snippet defines a ResnetBlock class that takes in dimensions and other parameters for its initialization. It includes a project layer, normalization layer, activation function, and optional scale-shift operation. The forward method performs the computation steps involving these layers. Additionally, it checks if time_cond_dim is given to initialize a time MLP and if cond_dim exists to initialize a cross-attention layer.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1606-1648",
+ "content": " self.project = conv_klass(dim, dim_out, 3, padding = 1)\n self.norm = nn.GroupNorm(groups, dim_out)\n self.act = nn.SiLU()\n def forward(self, x, scale_shift = None):\n x = self.project(x)\n x = self.norm(x)\n if exists(scale_shift):\n scale, shift = scale_shift\n x = x * (scale + 1) + shift\n x = self.act(x)\n return x\nclass ResnetBlock(nn.Module):\n def __init__(\n self,\n dim,\n dim_out,\n *,\n cond_dim = None,\n time_cond_dim = None,\n groups = 8,\n weight_standardization = False,\n cosine_sim_cross_attn = False\n ):\n super().__init__()\n self.time_mlp = None\n if exists(time_cond_dim):\n self.time_mlp = nn.Sequential(\n nn.SiLU(),\n nn.Linear(time_cond_dim, dim_out * 2)\n )\n self.cross_attn = None\n if exists(cond_dim):\n self.cross_attn = CrossAttention(\n dim = dim_out,\n context_dim = cond_dim,"
+ },
+ {
+ "comment": "This code defines a class for an encoder-decoder architecture with residual connections and cross-attention. It includes blocks, convolutions, time MLP, and optional cross-attention with conditional input. The forward method processes the input, applies blocks, optionally performs time embedding and cross-attention, and returns the output.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1649-1675",
+ "content": " cosine_sim = cosine_sim_cross_attn\n )\n self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)\n self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)\n self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()\n def forward(self, x, time_emb = None, cond = None):\n scale_shift = None\n if exists(self.time_mlp) and exists(time_emb):\n time_emb = self.time_mlp(time_emb)\n time_emb = rearrange(time_emb, 'b c -> b c 1 1')\n scale_shift = time_emb.chunk(2, dim = 1)\n h = self.block1(x, scale_shift = scale_shift)\n if exists(self.cross_attn):\n assert exists(cond)\n h = rearrange(h, 'b c ... -> b ... c')\n h, ps = pack([h], 'b * c')\n h = self.cross_attn(h, context = cond) + h\n h, = unpack(h, ps, 'b * c')\n h = rearrange(h, 'b ... c -> b c ...')"
+ },
+ {
+ "comment": "The code defines a CrossAttention class with parameters for dimensionality, context dimension, number of heads, dropout rate, and normalization options. It initializes the necessary layers including linear transformations and layer norms. The cosine similarity scale and null keys are also defined.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1677-1710",
+ "content": " h = self.block2(h)\n return h + self.res_conv(x)\nclass CrossAttention(nn.Module):\n def __init__(\n self,\n dim,\n *,\n context_dim = None,\n dim_head = 64,\n heads = 8,\n dropout = 0.,\n norm_context = False,\n cosine_sim = False,\n cosine_sim_scale = 16\n ):\n super().__init__()\n self.cosine_sim = cosine_sim\n self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)\n self.heads = heads\n inner_dim = dim_head * heads\n context_dim = default(context_dim, dim)\n self.norm = LayerNorm(dim)\n self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()\n self.dropout = nn.Dropout(dropout)\n self.null_kv = nn.Parameter(torch.randn(2, dim_head))\n self.to_q = nn.Linear(dim, inner_dim, bias = False)\n self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)\n self.to_out = nn.Sequential(\n nn.Linear(inner_dim, dim, bias = False),"
+ },
+ {
+ "comment": "This function defines a multi-head attention layer. It normalizes input x and context, splits them into queries (q), keys (k), and values (v). It also includes null key/value pairs for classifier free guidance in the prior net. If cosine_sim is set, it normalizes q and k again. It then computes the attention scores (sim) between q and k, and applies a mask if available, replacing negative values with max_neg_value.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1711-1742",
+ "content": " LayerNorm(dim)\n )\n def forward(self, x, context, mask = None):\n b, n, device = *x.shape[:2], x.device\n x = self.norm(x)\n context = self.norm_context(context)\n q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))\n q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))\n # add null key / value for classifier free guidance in prior net\n nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))\n k = torch.cat((nk, k), dim = -2)\n v = torch.cat((nv, v), dim = -2)\n if self.cosine_sim:\n q, k = map(l2norm, (q, k))\n q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))\n sim = einsum('b h i d, b h j d -> b h i j', q, k)\n max_neg_value = -torch.finfo(sim.dtype).max\n if exists(mask):\n mask = F.pad(mask, (1, 0), value = True)\n mask = rearrange(mask, 'b j -> b 1 1 j')\n sim = sim.masked_fill(~mask, max_neg_value)"
+ },
+ {
+ "comment": "This code defines a LinearAttention module that performs multi-head attention. It normalizes the input, applies convolutions to split input into queries (Q), keys (K), and values (V), then computes attention weights, rearranges output dimensions for efficiency, and finally passes the result through another set of convolutions before returning it.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1744-1779",
+ "content": " attn = sim.softmax(dim = -1, dtype = torch.float32)\n attn = attn.type(sim.dtype)\n out = einsum('b h i j, b h j d -> b h i d', attn, v)\n out = rearrange(out, 'b h n d -> b n (h d)')\n return self.to_out(out)\nclass LinearAttention(nn.Module):\n def __init__(\n self,\n dim,\n dim_head = 32,\n heads = 8,\n **kwargs\n ):\n super().__init__()\n self.scale = dim_head ** -0.5\n self.heads = heads\n inner_dim = dim_head * heads\n self.norm = ChanLayerNorm(dim)\n self.nonlin = nn.GELU()\n self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)\n self.to_out = nn.Sequential(\n nn.Conv2d(inner_dim, dim, 1, bias = False),\n ChanLayerNorm(dim)\n )\n def forward(self, fmap):\n h, x, y = self.heads, *fmap.shape[-2:]\n seq_len = x * y\n fmap = self.norm(fmap)\n q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)\n q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))"
+ },
+ {
+ "comment": "The code calculates and applies attention weights to query (q) and key (k) tensors, normalizes them, scales the vectors, and performs element-wise multiplication. It then applies a linear transformation (nonlin) on the result and rearranges the dimensions of the output tensor using the 'rearrange' function. The code also defines a CrossEmbedLayer class that initializes convolutional layers for feature extraction at multiple scales.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1781-1815",
+ "content": " q = q.softmax(dim = -1)\n k = k.softmax(dim = -2)\n q = q * self.scale\n v = l2norm(v)\n k, v = map(lambda t: t / math.sqrt(seq_len), (k, v))\n context = einsum('b n d, b n e -> b d e', k, v)\n out = einsum('b n d, b d e -> b n e', q, context)\n out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)\n out = self.nonlin(out)\n return self.to_out(out)\nclass CrossEmbedLayer(nn.Module):\n def __init__(\n self,\n dim_in,\n kernel_sizes,\n dim_out = None,\n stride = 2\n ):\n super().__init__()\n assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])\n dim_out = default(dim_out, dim_in)\n kernel_sizes = sorted(kernel_sizes)\n num_scales = len(kernel_sizes)\n # calculate the dimension at each scale\n dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]\n dim_scales = [*dim_scales, dim_out - sum(dim_scales)]\n self.convs = nn.ModuleList([])"
+ },
+ {
+ "comment": "The code defines a convolutional network with adjustable kernel sizes and applies an upsampling combiner to combine feature maps. The enabled flag controls whether the upsampling combiner is active, and it can be customized with different input/output dimensions.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1816-1848",
+ "content": " for kernel, dim_scale in zip(kernel_sizes, dim_scales):\n self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))\n def forward(self, x):\n fmaps = tuple(map(lambda conv: conv(x), self.convs))\n return torch.cat(fmaps, dim = 1)\nclass UpsampleCombiner(nn.Module):\n def __init__(\n self,\n dim,\n *,\n enabled = False,\n dim_ins = tuple(),\n dim_outs = tuple()\n ):\n super().__init__()\n assert len(dim_ins) == len(dim_outs)\n self.enabled = enabled\n if not self.enabled:\n self.dim_out = dim\n return\n self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])\n self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)\n def forward(self, x, fmaps = None):\n target_size = x.shape[-1]\n fmaps = default(fmaps, tuple())\n if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:"
+ },
+ {
+ "comment": "This code defines a Unet model with multiple components including fmaps, convolutions, image and text embeddings, dimensions, conditional parameters, and attention mechanisms. It also includes options for lowres_cond, self_attn, lowres_noise_cond, sparse_attn, and cosine_sim_cross_attn.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1849-1876",
+ "content": " return x\n fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]\n outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]\n return torch.cat((x, *outs), dim = 1)\nclass Unet(nn.Module):\n def __init__(\n self,\n dim,\n *,\n image_embed_dim = None,\n text_embed_dim = None,\n cond_dim = None,\n num_image_tokens = 4,\n num_time_tokens = 2,\n out_dim = None,\n dim_mults=(1, 2, 4, 8),\n channels = 3,\n channels_out = None,\n self_attn = False,\n attn_dim_head = 32,\n attn_heads = 16,\n lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/\n lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen\n self_cond = False, # set this to True to use the self-conditioning technique from - https://arxiv.org/abs/2208.04202\n sparse_attn = False,\n cosine_sim_cross_attn = False,"
+ },
+ {
+ "comment": "The code defines various settings for the DALLE2 model, including whether to use cosine similarity self-attention, if a layer of attention should be at the bottleneck, and whether to condition on text or image embeddings. It also includes options for initializing embeddings, resnet blocks, cross embeddings, and more. These settings allow for customization and optimization in the DALLE2 model's architecture.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1877-1894",
+ "content": " cosine_sim_self_attn = False,\n attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)\n cond_on_text_encodings = False,\n max_text_len = 256,\n cond_on_image_embeds = False,\n add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - \"Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding\"\n init_dim = None,\n init_conv_kernel_size = 7,\n resnet_groups = 8,\n resnet_weight_standardization = False,\n num_resnet_blocks = 2,\n init_cross_embed = True,\n init_cross_embed_kernel_sizes = (3, 7, 15),\n cross_embed_downsample = False,\n cross_embed_downsample_kernel_sizes = (2, 4),\n memory_efficient = False,\n scale_skip_connection = False,\n pixel_shuffle_upsample = True,"
+ },
+ {
+ "comment": "The code initializes a DDPM model with specified parameters such as number of channels, output channels, low resolution conditioning and self-conditioning. It determines the dimensions and initial number of channels based on these inputs and saves the hyperparameters for possible cascading DDPM in the future.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1895-1926",
+ "content": " final_conv_kernel_size = 1,\n combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper\n checkpoint_during_training = False,\n **kwargs\n ):\n super().__init__()\n # save locals to take care of some hyperparameters for cascading DDPM\n self._locals = locals()\n del self._locals['self']\n del self._locals['__class__']\n # for eventual cascading diffusion\n self.lowres_cond = lowres_cond\n # whether to do self conditioning\n self.self_cond = self_cond\n # determine dimensions\n self.channels = channels\n self.channels_out = default(channels_out, channels)\n # initial number of channels depends on\n # (1) low resolution conditioning from cascading ddpm paper, conditioned on previous unet output in the cascade\n # (2) self conditioning (bit diffusion paper)\n init_channels = channels * (1 + int(lowres_cond) + int(self_cond))\n init_dim = default(init_dim, dim)"
+ },
+ {
+ "comment": "This code initializes layers for processing time and image inputs. It creates a CrossEmbedLayer or Conv2d layer for the initial input, sets the dimensions for subsequent stages, defines layers to transform time-based data into conditioning tokens, and initializes an image-to-tokens sequence of layers. These layers will be used in a DALL-E 2 model for processing text, image, and time-based inputs for generating images.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1928-1955",
+ "content": " self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)\n dims = [init_dim, *map(lambda m: dim * m, dim_mults)]\n in_out = list(zip(dims[:-1], dims[1:]))\n num_stages = len(in_out)\n # time, image embeddings, and optional text encoding\n cond_dim = default(cond_dim, dim)\n time_cond_dim = dim * 4\n self.to_time_hiddens = nn.Sequential(\n SinusoidalPosEmb(dim),\n nn.Linear(dim, time_cond_dim),\n nn.GELU()\n )\n self.to_time_tokens = nn.Sequential(\n nn.Linear(time_cond_dim, cond_dim * num_time_tokens),\n Rearrange('b (r d) -> b r d', r = num_time_tokens)\n )\n self.to_time_cond = nn.Sequential(\n nn.Linear(time_cond_dim, time_cond_dim)\n )\n self.image_to_tokens = nn.Sequential("
+ },
+ {
+ "comment": "The code defines the architecture of a model. It includes linear layers, layer normalization, GELU activation function, and conditioning options for image embeddings, text encodings, and low resolution noise. These components are used to transform inputs and generate conditions based on optional parameters.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1956-1980",
+ "content": " nn.Linear(image_embed_dim, cond_dim * num_image_tokens),\n Rearrange('b (n d) -> b n d', n = num_image_tokens)\n ) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()\n self.to_image_hiddens = nn.Sequential(\n nn.Linear(image_embed_dim, time_cond_dim),\n nn.GELU()\n ) if cond_on_image_embeds and add_image_embeds_to_time else None\n self.norm_cond = nn.LayerNorm(cond_dim)\n self.norm_mid_cond = nn.LayerNorm(cond_dim)\n # text encoding conditioning (optional)\n self.text_to_cond = None\n self.text_embed_dim = None\n if cond_on_text_encodings:\n assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'\n self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)\n self.text_embed_dim = text_embed_dim\n # low resolution noise conditiong, based on Imagen's upsampler training technique\n self.lowres_noise_cond = lowres_noise_cond"
+ },
+ {
+ "comment": "This code initializes various components of a model. It creates an optional sequential layer for low-res noise conditioning based on a flag, allows fine control over whether to condition on image embeddings and text encodings, and sets up parameters for classifier-free guidance. The skip connection scale is set either to 1 or scaled as per Imagen's approach.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":1982-2005",
+ "content": " self.to_lowres_noise_cond = nn.Sequential(\n SinusoidalPosEmb(dim),\n nn.Linear(dim, time_cond_dim),\n nn.GELU(),\n nn.Linear(time_cond_dim, time_cond_dim)\n ) if lowres_noise_cond else None\n # finer control over whether to condition on image embeddings and text encodings\n # so one can have the latter unets in the cascading DDPMs only focus on super-resoluting\n self.cond_on_text_encodings = cond_on_text_encodings\n self.cond_on_image_embeds = cond_on_image_embeds\n # for classifier free guidance\n self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))\n self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))\n self.max_text_len = max_text_len\n self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))\n # whether to scale skip connection, adopted in Imagen\n self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)"
+ },
+ {
+ "comment": "This code initializes various parameters and classes for the DALL-E 2 model. It sets up attention, resnet block, downsampling, and upsampling functions based on user inputs. The code uses partial function applications to customize the resnet blocks and other components according to specific settings.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2007-2034",
+ "content": " # attention related params\n attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim = cosine_sim_self_attn)\n self_attn = cast_tuple(self_attn, num_stages)\n create_self_attn = lambda dim: RearrangeToSequence(Residual(Attention(dim, **attn_kwargs)))\n # resnet block klass\n resnet_groups = cast_tuple(resnet_groups, num_stages)\n top_level_resnet_group = first(resnet_groups)\n num_resnet_blocks = cast_tuple(num_resnet_blocks, num_stages)\n # downsample klass\n downsample_klass = Downsample\n if cross_embed_downsample:\n downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)\n # upsample klass\n upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample\n # prepare resnet klass\n resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization)"
+ },
+ {
+ "comment": "The code initializes the memory efficient UNet with an initial resnet block, and creates two lists for downsampling and upsampling layers. It also keeps track of skip connection dimensions and dimensions for final upsample feature map combiner. The code iterates over different layer configurations, including whether to use self-attention or not.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2036-2058",
+ "content": " # give memory efficient unet an initial resnet block\n self.init_resnet_block = resnet_block(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None\n # layers\n self.downs = nn.ModuleList([])\n self.ups = nn.ModuleList([])\n num_resolutions = len(in_out)\n skip_connect_dims = [] # keeping track of skip connection dimensions\n upsample_combiner_dims = [] # keeping track of dimensions for final upsample feature map combiner\n for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):\n is_first = ind == 0\n is_last = ind >= (num_resolutions - 1)\n layer_cond_dim = cond_dim if not is_first else None\n dim_layer = dim_out if memory_efficient else dim_in\n skip_connect_dims.append(dim_layer)\n attention = nn.Identity()\n if layer_self_attn:"
+ },
+ {
+ "comment": "This code initializes a module for a neural network. It adds downsampling modules, resnet blocks, attention layers, and convolutional layers based on the given parameters. The last block of the code initializes two additional blocks and an attention layer for further processing.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2059-2075",
+ "content": " attention = create_self_attn(dim_layer)\n elif sparse_attn:\n attention = Residual(LinearAttention(dim_layer, **attn_kwargs))\n self.downs.append(nn.ModuleList([\n downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,\n resnet_block(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),\n nn.ModuleList([resnet_block(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),\n attention,\n downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)\n ]))\n mid_dim = dims[-1]\n self.mid_block1 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])\n self.mid_attn = create_self_attn(mid_dim)\n self.mid_block2 = re"
+ },
+ {
+ "comment": "The code is defining a ResNet-based architecture with optional self-attention layers. It iterates through the input and output dimensions, groups, number of resnet blocks, and self-attention usage to create a series of resnet blocks, optionally including an identity or linear attention layer after each block.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2075-2093",
+ "content": "snet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])\n for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):\n is_last = ind >= (len(in_out) - 1)\n layer_cond_dim = cond_dim if not is_last else None\n skip_connect_dim = skip_connect_dims.pop()\n attention = nn.Identity()\n if layer_self_attn:\n attention = create_self_attn(dim_out)\n elif sparse_attn:\n attention = Residual(LinearAttention(dim_out, **attn_kwargs))\n upsample_combiner_dims.append(dim_out)\n self.ups.append(nn.ModuleList([\n resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),\n nn.ModuleList([resnet_block(dim_out + skip_connect_dim,"
+ },
+ {
+ "comment": "This code defines a DALL\u00b7E 2 model architecture. It includes multiple resnet blocks, an upsampling sequence, and a final convolution layer. The number of resnet blocks is determined by the `layer_num_resnet_blocks` parameter. The upsample sequence combines outputs from all upsample blocks if `combine_upsample_fmaps` is set to True. The final resnet block takes in the combined output and the model's channels, with time conditioning (`time_cond_dim`) and top-level resnet grouping (`top_level_resnet_group`). Finally, a convolution layer converts the output to the desired channel size (`channels_out`). The `zero_init_` function initializes the final convolution layer with zero values.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2093-2115",
+ "content": " dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),\n attention,\n upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()\n ]))\n # whether to combine outputs from all upsample blocks for final resnet block\n self.upsample_combiner = UpsampleCombiner(\n dim = dim,\n enabled = combine_upsample_fmaps,\n dim_ins = upsample_combiner_dims,\n dim_outs = (dim,) * len(upsample_combiner_dims)\n )\n # a final resnet block\n self.final_resnet_block = resnet_block(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)\n out_dim_in = dim + (channels if lowres_cond else 0)\n self.to_out = nn.Conv2d(out_dim_in, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)\n zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it"
+ },
+ {
+ "comment": "This code function checks if the current unet model parameters are correct for cascading DDPM. If not, it reinitializes the unet with the new settings. The parameters being checked include lowres_cond, channels, cond_on_image_embeds, and cond_on_text_encodings.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2117-2145",
+ "content": " # whether to checkpoint during training\n self.checkpoint_during_training = checkpoint_during_training\n # if the current settings for the unet are not correct\n # for cascading DDPM, then reinit the unet with the right settings\n def cast_model_parameters(\n self,\n *,\n lowres_cond,\n lowres_noise_cond,\n channels,\n channels_out,\n cond_on_image_embeds,\n cond_on_text_encodings,\n ):\n if lowres_cond == self.lowres_cond and \\\n channels == self.channels and \\\n cond_on_image_embeds == self.cond_on_image_embeds and \\\n cond_on_text_encodings == self.cond_on_text_encodings and \\\n lowres_noise_cond == self.lowres_noise_cond and \\\n channels_out == self.channels_out:\n return self\n updated_kwargs = dict(\n lowres_cond = lowres_cond,\n channels = channels,\n channels_out = channels_out,\n cond_on_image_embeds = cond_on_image_embeds,"
+ },
+ {
+ "comment": "This code defines a class with forward, forward_with_cond_scale methods that take various parameters and perform image processing operations. The forward method calculates logits based on input images, time, image embeddings, and other optional parameters. The forward_with_cond_scale method applies conditional scaling to the logits calculated by the forward method.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2146-2184",
+ "content": " cond_on_text_encodings = cond_on_text_encodings,\n lowres_noise_cond = lowres_noise_cond\n )\n return self.__class__(**{**self._locals, **updated_kwargs})\n def forward_with_cond_scale(\n self,\n *args,\n cond_scale = 1.,\n **kwargs\n ):\n logits = self.forward(*args, **kwargs)\n if cond_scale == 1:\n return logits\n null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)\n return null_logits + (logits - null_logits) * cond_scale\n def forward(\n self,\n x,\n time,\n *,\n image_embed,\n lowres_cond_img = None,\n lowres_noise_level = None,\n text_encodings = None,\n image_cond_drop_prob = 0.,\n text_cond_drop_prob = 0.,\n blur_sigma = None,\n blur_kernel_size = None,\n disable_checkpoint = False,\n self_cond = None\n ):\n batch_size, device = x.shape[0], x.device\n # add low resolution conditioning, if present"
+ },
+ {
+ "comment": "The code checks if low resolution conditioning image exists and appends it to the input. It then concatenates self-conditioning, initializes a convolution, clones the input for residual calculations, performs time conditioning, and applies low resolution noise conditioning (if enabled).",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2186-2215",
+ "content": " assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'\n # concat self conditioning, if needed\n if self.self_cond:\n self_cond = default(self_cond, lambda: torch.zeros_like(x))\n x = torch.cat((x, self_cond), dim = 1)\n # concat low resolution conditioning\n if exists(lowres_cond_img):\n x = torch.cat((x, lowres_cond_img), dim = 1)\n # initial convolution\n x = self.init_conv(x)\n r = x.clone() # final residual\n # time conditioning\n time = time.type_as(x)\n time_hiddens = self.to_time_hiddens(time)\n time_tokens = self.to_time_tokens(time_hiddens)\n t = self.to_time_cond(time_hiddens)\n # low res noise conditioning (similar to time above)\n if exists(lowres_noise_level):\n assert exists(self.to_lowres_noise_cond), 'lowres_noise_cond must be set to True on instantiation of the unet in order to conditiong on lowres noise'"
+ },
+ {
+ "comment": "This code performs conditional dropout by maintaining image and text masks, checks if an image embedding exists, applies a conditional dropout to the image embedding based on the masks, and adds it to the time embedding.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2216-2242",
+ "content": " lowres_noise_level = lowres_noise_level.type_as(x)\n t = t + self.to_lowres_noise_cond(lowres_noise_level)\n # conditional dropout\n image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)\n text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)\n text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')\n # image embedding to be summed to time embedding\n # discovered by @mhh0318 in the paper\n if exists(image_embed) and exists(self.to_image_hiddens):\n image_hiddens = self.to_image_hiddens(image_embed)\n image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')\n null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)\n image_hiddens = torch.where(\n image_keep_mask_hidden,\n image_hiddens,\n null_image_hiddens\n )\n t = t + image_hiddens\n # mask out image embedding depending on condition dropout"
+ },
+ {
+ "comment": "This code chunk is setting up the input for a classifier-free guidance model. It checks if the image and text encodings are provided, and if so, prepares them for the model's input. If both the image embeddings and text encodings are present, it applies conditional guidance by masking the image tokens with the image_keep_mask and nullifying where needed. It asserts that the text encodings match the batch size and the expected embedding dimension of the model.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2243-2264",
+ "content": " # for classifier free guidance\n image_tokens = None\n if self.cond_on_image_embeds:\n image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')\n image_tokens = self.image_to_tokens(image_embed)\n null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working\n image_tokens = torch.where(\n image_keep_mask_embed,\n image_tokens,\n null_image_embed\n )\n # take care of text encodings (optional)\n text_tokens = None\n if exists(text_encodings) and self.cond_on_text_encodings:\n assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {text_encodings.shape} - required batch size is {batch_size}'\n assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'"
+ },
+ {
+ "comment": "This code snippet is preparing text_tokens for the model by applying padding and ensuring correct shape. It creates a binary mask (text_mask) from the non-zero elements in text_encodings to indicate which tokens are present, then applies this mask to both text_tokens and text_keep_mask. The code also checks if there's remaining space in the max_text_len and pads text_tokens accordingly. Lastly, it asserts that the shapes of text_mask and text_keep_mask match before combining them using a logical AND operation.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2266-2287",
+ "content": " text_mask = torch.any(text_encodings != 0., dim = -1)\n text_tokens = self.text_to_cond(text_encodings)\n text_tokens = text_tokens[:, :self.max_text_len]\n text_mask = text_mask[:, :self.max_text_len]\n text_tokens_len = text_tokens.shape[1]\n remainder = self.max_text_len - text_tokens_len\n if remainder > 0:\n text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))\n text_mask = F.pad(text_mask, (0, remainder), value = False)\n text_mask = rearrange(text_mask, 'b n -> b n 1')\n assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}. text encoding is of shape {text_encodings.shape}'\n text_keep_mask = text_mask & text_keep_mask\n null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working\n text_tokens = torch.where("
+ },
+ {
+ "comment": "This code snippet is part of the DALLE2-pytorch model, responsible for handling conditioning tokens (main and auxiliary) for image and text inputs. The code normalizes these tokens using `self.norm_cond` and `self.norm_mid_cond`, applies gradient checkpointing, and makes certain modules (e.g., `self.init_resnet_block`) checkpointable based on training parameters. This helps to optimize the model's computation during inference and improve its performance.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2288-2317",
+ "content": " text_keep_mask,\n text_tokens,\n null_text_embed\n )\n # main conditioning tokens (c)\n c = time_tokens\n if exists(image_tokens):\n c = torch.cat((c, image_tokens), dim = -2)\n # text and image conditioning tokens (mid_c)\n # to save on compute, only do cross attention based conditioning on the inner most layers of the Unet\n mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)\n # normalize conditioning tokens\n c = self.norm_cond(c)\n mid_c = self.norm_mid_cond(mid_c)\n # gradient checkpointing\n can_checkpoint = self.training and self.checkpoint_during_training and not disable_checkpoint\n apply_checkpoint_fn = make_checkpointable if can_checkpoint else identity\n # make checkpointable modules\n init_resnet_block, mid_block1, mid_attn, mid_block2, final_resnet_block = [maybe(apply_checkpoint_fn)(module) for module in (self.init_resnet_block, self.mid_block1, self.mid_attn, self.mid_block2, self.final_resnet_block)]"
+ },
+ {
+ "comment": "This code initializes a U-Net model by iterating over its components. It applies pre-downsample, initial block, and resnet blocks to the input x. Then, it adds hidden representations of down and up stages into separate lists. After that, it passes x through an attention module and potentially post-downsample. Finally, it processes x with two more blocks, possibly applies mid-attention, and returns the final result.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2319-2355",
+ "content": " can_checkpoint_cond = lambda m: isinstance(m, ResnetBlock)\n downs, ups = [maybe(apply_checkpoint_fn)(m, condition = can_checkpoint_cond) for m in (self.downs, self.ups)]\n # initial resnet block\n if exists(init_resnet_block):\n x = init_resnet_block(x, t)\n # go through the layers of the unet, down and up\n down_hiddens = []\n up_hiddens = []\n for pre_downsample, init_block, resnet_blocks, attn, post_downsample in downs:\n if exists(pre_downsample):\n x = pre_downsample(x)\n x = init_block(x, t, c)\n for resnet_block in resnet_blocks:\n x = resnet_block(x, t, c)\n down_hiddens.append(x.contiguous())\n x = attn(x)\n down_hiddens.append(x.contiguous())\n if exists(post_downsample):\n x = post_downsample(x)\n x = mid_block1(x, t, mid_c)\n if exists(mid_attn):\n x = mid_attn(x)\n x = mid_block2(x, t, mid_c)\n "
+ },
+ {
+ "comment": "This code defines a class for processing input images, which consists of an upscaling network and a low-resolution conditioner. The upscaling network takes in a low-resolution image and upscales it using skip connections and residual blocks. The low-resolution conditioner can optionally take a low-resolution version of the input image as additional input. The final output is passed through an activation function before being returned.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2355-2392",
+ "content": " connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)\n for init_block, resnet_blocks, attn, upsample in ups:\n x = connect_skip(x)\n x = init_block(x, t, c)\n for resnet_block in resnet_blocks:\n x = connect_skip(x)\n x = resnet_block(x, t, c)\n x = attn(x)\n up_hiddens.append(x.contiguous())\n x = upsample(x)\n x = self.upsample_combiner(x, up_hiddens)\n x = torch.cat((x, r), dim = 1)\n x = final_resnet_block(x, t)\n if exists(lowres_cond_img):\n x = torch.cat((x, lowres_cond_img), dim = 1)\n return self.to_out(x)\nclass LowresConditioner(nn.Module):\n def __init__(\n self,\n downsample_first = True,\n use_blur = True,\n blur_prob = 0.5,\n blur_sigma = 0.6,\n blur_kernel_size = 3,\n use_noise = False,\n input_image_range = None,\n normalize_img_fn = identity,\n unnormalize_img_fn = identity"
+ },
+ {
+ "comment": "This code initializes an object with various parameters, including downsampling, image range, and noise-related options. It also includes methods for generating noise images based on the given parameters. The class utilizes normalization and denormalization functions as well as a NoiseScheduler instance to apply noise to the input condition maps.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2393-2417",
+ "content": " ):\n super().__init__()\n self.downsample_first = downsample_first\n self.input_image_range = input_image_range\n self.use_blur = use_blur\n self.blur_prob = blur_prob\n self.blur_sigma = blur_sigma\n self.blur_kernel_size = blur_kernel_size\n self.use_noise = use_noise\n self.normalize_img = normalize_img_fn\n self.unnormalize_img = unnormalize_img_fn\n self.noise_scheduler = NoiseScheduler(beta_schedule = 'linear', timesteps = 1000, loss_type = 'l2') if use_noise else None\n def noise_image(self, cond_fmap, noise_levels = None):\n assert exists(self.noise_scheduler)\n batch = cond_fmap.shape[0]\n cond_fmap = self.normalize_img(cond_fmap)\n random_noise_levels = default(noise_levels, lambda: self.noise_scheduler.sample_random_times(batch))\n cond_fmap = self.noise_scheduler.q_sample(cond_fmap, t = random_noise_levels, noise = torch.randn_like(cond_fmap))\n cond_fmap = self.unnormalize_img(cond_fmap)"
+ },
+ {
+ "comment": "This function takes a conditional feature map and optional parameters to resize, blur, and downsample the image. The code checks if downsampling is needed first, then decides whether to apply blurring based on a probability setting. Blur sigma and kernel size are also set based on default values or user input.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2418-2446",
+ "content": " return cond_fmap, random_noise_levels\n def forward(\n self,\n cond_fmap,\n *,\n target_image_size,\n downsample_image_size = None,\n should_blur = True,\n blur_sigma = None,\n blur_kernel_size = None\n ):\n if self.downsample_first and exists(downsample_image_size):\n cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = True)\n # blur is only applied 50% of the time\n # section 3.1 in https://arxiv.org/abs/2106.15282\n if self.use_blur and should_blur and random.random() < self.blur_prob:\n # when training, blur the low resolution conditional image\n blur_sigma = default(blur_sigma, self.blur_sigma)\n blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)\n # allow for drawing a random sigma between lo and hi float values\n if isinstance(blur_sigma, tuple):\n blur_sigma = tuple(map(float, blur_sigma))"
+ },
+ {
+ "comment": "This code performs image conditioning by applying Gaussian blur and noise addition, then resizes the image to a target size. The blurring and noise addition are optional depending on the use_noise flag, and the final result is returned along with any applied random noise levels.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2447-2470",
+ "content": " blur_sigma = random.uniform(*blur_sigma)\n # allow for drawing a random kernel size between lo and hi int values\n if isinstance(blur_kernel_size, tuple):\n blur_kernel_size = tuple(map(int, blur_kernel_size))\n kernel_size_lo, kernel_size_hi = blur_kernel_size\n blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)\n cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))\n # resize to target image size\n cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range, nearest = True)\n # noise conditioning, as done in Imagen\n # as a replacement for the BSR noising, and potentially replace blurring for first stage too\n random_noise_levels = None\n if self.use_noise:\n cond_fmap, random_noise_levels = self.noise_image(cond_fmap)\n # return conditioning feature map, as well as the augmentation noise levels"
+ },
+ {
+ "comment": "The code defines a Decoder class that takes various parameters like unet, clip, image_size, channels, vae, timesteps, sample_timesteps, image_cond_drop_prob, text_cond_drop_prob, loss_type, beta_schedule, predict_x_start, predict_v, predict_x_start_for_latent_diffusion, image_sizes, random_crop_sizes, use_noise_for_lowres_cond, and use_blur_for_lowres_cond. It returns cond_fmap and random_noise_levels.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2472-2495",
+ "content": " return cond_fmap, random_noise_levels\nclass Decoder(nn.Module):\n def __init__(\n self,\n unet,\n *,\n clip = None,\n image_size = None,\n channels = 3,\n vae = tuple(),\n timesteps = 1000,\n sample_timesteps = None,\n image_cond_drop_prob = 0.1,\n text_cond_drop_prob = 0.5,\n loss_type = 'l2',\n beta_schedule = None,\n predict_x_start = False,\n predict_v = False,\n predict_x_start_for_latent_diffusion = False,\n image_sizes = None, # for cascading ddpm, image size at each stage\n random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)\n use_noise_for_lowres_cond = False, # whether to use Imagen-like noising for low resolution conditioning \n use_blur_for_lowres_cond = True, # whether to use the blur conditioning used in the original cascading ddpm paper, as well as DALL-E2"
+ },
+ {
+ "comment": "This code snippet is responsible for configuring the settings for a denoising diffusion probabilistic model (DDPM) in the DALLE2-pytorch project. The settings include cascading DDPM parameters, noise level at sample time, clip options, learned variance configuration, and unconditional image generation toggles.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2496-2508",
+ "content": " lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur\n blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time\n blur_sigma = 0.6, # cascading ddpm - blur sigma\n blur_kernel_size = 3, # cascading ddpm - blur kernel size\n lowres_noise_sample_level = 0.2, # in imagen paper, they use a 0.2 noise level at sample time for low resolution conditioning\n clip_denoised = True,\n clip_x_start = True,\n clip_adapter_overrides = dict(),\n learned_variance = True,\n learned_variance_constrain_frac = False,\n vb_loss_weight = 0.001,\n unconditional = False, # set to True for generating images without conditioning\n auto_normalize_img = True, # whether to take care of normalizing the i"
+ },
+ {
+ "comment": "The code initializes an object with various parameters such as use_dynamic_thres, dynamic_thres_percentile, p2_loss_weight_gamma, p2_loss_weight_k, ddim_sampling_eta, and clip. It also checks if the 'clip' parameter is given and performs necessary assertions. If 'clip' exists and unconditional image training is not being done, it ensures the channels match with CLIP's accepted channels. It also uses XClipAdapter for compatibility with additional overrides.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2508-2525",
+ "content": "mage from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader\n use_dynamic_thres = False, # from the Imagen paper\n dynamic_thres_percentile = 0.95,\n p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended\n p2_loss_weight_k = 1,\n ddim_sampling_eta = 0. # can be set to 0. for deterministic sampling afaict\n ):\n super().__init__()\n # clip\n self.clip = None\n if exists(clip):\n assert not unconditional, 'clip must not be given if doing unconditional image training'\n assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'\n if isinstance(clip, CLIP):\n clip = XClipAdapter(clip, **clip_adapter_overrides)"
+ },
+ {
+ "comment": "The code checks the input 'clip' type and applies the CoCaAdapter if it's an instance of CoCa. It then freezes the model for evaluation, ensures 'clip' is a BaseClipAdapter instance, and assigns it to self.clip. The image_size is determined from either 'image_size', 'image_sizes', or 'clip'. It sets the 'channels', 'normalize_img', and 'unnormalize_img' based on given parameters.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2526-2554",
+ "content": " elif isinstance(clip, CoCa):\n clip = CoCaAdapter(clip, **clip_adapter_overrides)\n freeze_model_and_make_eval_(clip)\n assert isinstance(clip, BaseClipAdapter)\n self.clip = clip\n # determine image size, with image_size and image_sizes taking precedence\n if exists(image_size) or exists(image_sizes):\n assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'\n image_size = default(image_size, lambda: image_sizes[-1])\n elif exists(clip):\n image_size = clip.image_size\n else:\n raise Error('either image_size, image_sizes, or clip must be given to decoder')\n # channels\n self.channels = channels\n # normalize and unnormalize image functions\n self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity\n self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity\n # verify conditioning method"
+ },
+ {
+ "comment": "This code initializes the U-Nets and VAEs for a DALL-E 2 model. It sets the number of unets, whether they are unconditional or conditioned on previous unets, and their learned variance. It also sets default parameters for conditioning with noise and constrains the output of the network from 0 to 1.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2556-2576",
+ "content": " unets = cast_tuple(unet)\n num_unets = len(unets)\n self.num_unets = num_unets\n self.unconditional = unconditional\n # automatically take care of ensuring that first unet is unconditional\n # while the rest of the unets are conditioned on the low resolution image produced by previous unet\n vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))\n # whether to use learned variance, defaults to True for the first unet in the cascade, as in paper\n learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)\n self.learned_variance = learned_variance\n self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1\n self.vb_loss_weight = vb_loss_weight\n # default and validate conditioning parameters\n use_noise_for_lowres_cond = cast_tuple(use_noise_for_lowres_cond, num_unets - 1, validate = False)"
+ },
+ {
+ "comment": "This code is setting up Unets and Vaes for a model. It ensures that the lists of noise conditions and blur conditions are long enough to correspond to each Unet, adds the Unets and Vaes to module lists, and asserts that at least one Unet will not need low res noise or blur conditioning.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2577-2596",
+ "content": " use_blur_for_lowres_cond = cast_tuple(use_blur_for_lowres_cond, num_unets - 1, validate = False)\n if len(use_noise_for_lowres_cond) < num_unets:\n use_noise_for_lowres_cond = (False, *use_noise_for_lowres_cond)\n if len(use_blur_for_lowres_cond) < num_unets:\n use_blur_for_lowres_cond = (False, *use_blur_for_lowres_cond)\n assert not use_noise_for_lowres_cond[0], 'first unet will never need low res noise conditioning'\n assert not use_blur_for_lowres_cond[0], 'first unet will never need low res blur conditioning'\n assert num_unets == 1 or all((use_noise or use_blur) for use_noise, use_blur in zip(use_noise_for_lowres_cond[1:], use_blur_for_lowres_cond[1:]))\n # construct unets and vaes\n self.unets = nn.ModuleList([])\n self.vaes = nn.ModuleList([])\n for ind, (one_unet, one_vae, one_unet_learned_var, lowres_noise_cond) in enumerate(zip(unets, vaes, learned_variance, use_noise_for_lowres_cond)):\n assert isinstance(one_unet, Unet)"
+ },
+ {
+ "comment": "This code block appends a new VAE instance to the list of VAEs and a copied evaluation version of that VAE to the VAEs list. The code also sets the sampling timesteps and ddim_sampling_eta based on the input parameters.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2597-2620",
+ "content": " assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))\n is_first = ind == 0\n latent_dim = one_vae.encoded_dim if exists(one_vae) else None\n unet_channels = default(latent_dim, self.channels)\n unet_channels_out = unet_channels * (1 if not one_unet_learned_var else 2)\n one_unet = one_unet.cast_model_parameters(\n lowres_cond = not is_first,\n lowres_noise_cond = lowres_noise_cond,\n cond_on_image_embeds = not unconditional and is_first,\n cond_on_text_encodings = not unconditional and one_unet.cond_on_text_encodings,\n channels = unet_channels,\n channels_out = unet_channels_out\n )\n self.unets.append(one_unet)\n self.vaes.append(one_vae.copy_for_eval())\n # sampling timesteps, defaults to non-ddim with full timesteps sampling\n self.sample_timesteps = cast_tuple(sample_timesteps, num_unets)\n self.ddim_sampling_eta = ddim_sampling_eta"
+ },
+ {
+ "comment": "This code creates noise schedulers for each unet, based on the provided beta schedule and loss weight gamma. It asserts that sampling timesteps must be less than or equal to the number of training timesteps, and initializes a NoiseScheduler object with the specified parameters for each unet.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2622-2640",
+ "content": " # create noise schedulers per unet\n if not exists(beta_schedule):\n beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1)))\n beta_schedule = cast_tuple(beta_schedule, num_unets)\n p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)\n self.noise_schedulers = nn.ModuleList([])\n for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):\n assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'\n noise_scheduler = NoiseScheduler(\n beta_schedule = unet_beta_schedule,\n timesteps = timesteps,\n loss_type = loss_type,\n p2_loss_weight_gamma = unet_p2_loss_weight_gamma,\n p2_loss_weight_k = p2_loss_weight_k"
+ },
+ {
+ "comment": "This code is setting up the parameters for a model. It creates noise schedulers, defines image sizes and crop sizes for different resolutions, and configures predicting x0 and v values. These settings will be used to train or use the model. The code also performs assertions to ensure that the correct number of unets and vaes are provided for each resolution.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2641-2665",
+ "content": " )\n self.noise_schedulers.append(noise_scheduler)\n # unet image sizes\n image_sizes = default(image_sizes, (image_size,))\n image_sizes = tuple(sorted(set(image_sizes)))\n assert self.num_unets == len(image_sizes), f'you did not supply the correct number of u-nets ({self.num_unets}) for resolutions {image_sizes}'\n self.image_sizes = image_sizes\n self.sample_channels = cast_tuple(self.channels, len(image_sizes))\n # random crop sizes (for super-resoluting unets at the end of cascade?)\n self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes))\n assert not exists(self.random_crop_sizes[0]), 'you would not need to randomly crop the image for the base unet'\n # predict x0 config\n self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))\n # predict v\n self.predict_v = cast_tuple(predict_v, len(unets))"
+ },
+ {
+ "comment": "The code initializes the input image range and handles lowres_cond for each unet in the model. It ensures that the first unet is unconditioned, while the rest have `lowres_cond` set to True. The `LowresConditioner` class is used with specified parameters for downsampling, blurring, and input image range.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2667-2690",
+ "content": " # input image range\n self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)\n # cascading ddpm related stuff\n lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))\n assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'\n self.lowres_conds = nn.ModuleList([])\n for unet_index, use_noise, use_blur in zip(range(num_unets), use_noise_for_lowres_cond, use_blur_for_lowres_cond):\n if unet_index == 0:\n self.lowres_conds.append(None)\n continue\n lowres_cond = LowresConditioner(\n downsample_first = lowres_downsample_first,\n use_blur = use_blur,\n use_noise = use_noise,\n blur_prob = blur_prob,\n blur_sigma = blur_sigma,\n blur_kernel_size = blur_kernel_size,\n input_image_range = self.input_image_range,"
+ },
+ {
+ "comment": "This code is setting up parameters and functions for an image generation model. It includes normalization and unnormalization functions, lowres noise sample level, classifier free guidance settings, clipping options during sampling, dynamic thresholding settings, and device management. The model can condition on text encodings and uses a device tracker to keep track of device information.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2691-2724",
+ "content": " normalize_img_fn = self.normalize_img,\n unnormalize_img_fn = self.unnormalize_img\n )\n self.lowres_conds.append(lowres_cond)\n self.lowres_noise_sample_level = lowres_noise_sample_level\n # classifier free guidance\n self.image_cond_drop_prob = image_cond_drop_prob\n self.text_cond_drop_prob = text_cond_drop_prob\n self.can_classifier_guidance = image_cond_drop_prob > 0. or text_cond_drop_prob > 0.\n # whether to clip when sampling\n self.clip_denoised = clip_denoised\n self.clip_x_start = clip_x_start\n # dynamic thresholding settings, if clipping denoised during sampling\n self.use_dynamic_thres = use_dynamic_thres\n self.dynamic_thres_percentile = dynamic_thres_percentile\n # device tracker\n self.register_buffer('_dummy', torch.Tensor([True]), persistent = False)\n @property\n def device(self):\n return self._dummy.device\n @property\n def condition_on_text_encodings(self):"
+ },
+ {
+ "comment": "This code defines methods for working with a collection of UNET models. The `get_unet` method retrieves a specific UNET based on its number, ensuring it is within the valid range. `parse_unet_output` parses the output of a UNET, interpreting learned variance if present. The `one_unet_in_gpu` context manager allows running inference for one UNET on the GPU while keeping other UNETs on the CPU.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2725-2761",
+ "content": " return any([unet.cond_on_text_encodings for unet in self.unets if isinstance(unet, Unet)])\n def get_unet(self, unet_number):\n assert 0 < unet_number <= self.num_unets\n index = unet_number - 1\n return self.unets[index]\n def parse_unet_output(self, learned_variance, output):\n var_interp_frac_unnormalized = None\n if learned_variance:\n output, var_interp_frac_unnormalized = output.chunk(2, dim = 1)\n return UnetOutput(output, var_interp_frac_unnormalized)\n @contextmanager\n def one_unet_in_gpu(self, unet_number = None, unet = None):\n assert exists(unet_number) ^ exists(unet)\n if exists(unet_number):\n unet = self.get_unet(unet_number)\n # devices\n cuda, cpu = torch.device('cuda'), torch.device('cpu')\n self.cuda()\n devices = [module_device(unet) for unet in self.unets]\n self.unets.to(cpu)\n unet.to(cuda)\n yield\n for unet, device in zip(self.unets, devices):\n unet.to(device)"
+ },
+ {
+ "comment": "This code snippet defines a function `dynamic_threshold` and `p_mean_variance`. The `dynamic_threshold` function adjusts the threshold for clamping based on the input's quantile values. It uses static thresholding (s=1) by default, but can be set to dynamic thresholding if `self.use_dynamic_thres` is true. The `p_mean_variance` function performs classifier-free guidance for image generation and includes options for mean/variance prediction, conditioning, noise scheduling, and more.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2763-2784",
+ "content": " def dynamic_threshold(self, x):\n \"\"\" proposed in https://arxiv.org/abs/2205.11487 as an improved clamping in the setting of classifier free guidance \"\"\"\n # s is the threshold amount\n # static thresholding would just be s = 1\n s = 1.\n if self.use_dynamic_thres:\n s = torch.quantile(\n rearrange(x, 'b ... -> b (...)').abs(),\n self.dynamic_thres_percentile,\n dim = -1\n )\n s.clamp_(min = 1.)\n s = s.view(-1, *((1,) * (x.ndim - 1)))\n # clip by threshold, depending on whether static or dynamic\n x = x.clamp(-s, s) / s\n return x\n def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):\n assert not (cond_scale != 1. and not self."
+ },
+ {
+ "comment": "This code block is responsible for decoding an input image using a pre-trained unet model. It applies classifier free guidance if enabled, and then calculates the mean and variance of the posterior distribution to perform denoising diffusion probability.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2784-2802",
+ "content": "can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'\n model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))\n pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)\n if predict_v:\n x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)\n elif predict_x_start:\n x_start = pred\n else:\n x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)\n if clip_denoised:\n x_start = self.dynamic_threshold(x_start)\n model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)\n if learned_variance:"
+ },
+ {
+ "comment": "This code calculates the posterior variance and log variance for a model based on the maximum and minimum log beta values, as described in Equation 15 from arXiv paper. It also applies a learned constraint factor and uses sigmoid activation if required. The function returns the model mean, posterior variance, posterior log variance, and x_start.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2803-2819",
+ "content": " # if learned variance, posterio variance and posterior log variance are predicted by the network\n # by an interpolation of the max and min log beta values\n # eq 15 - https://arxiv.org/abs/2102.09672\n min_log = extract(noise_scheduler.posterior_log_variance_clipped, t, x.shape)\n max_log = extract(torch.log(noise_scheduler.betas), t, x.shape)\n var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)\n if self.learned_variance_constrain_frac:\n var_interp_frac = var_interp_frac.sigmoid()\n posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log\n posterior_variance = posterior_log_variance.exp()\n return model_mean, posterior_variance, posterior_log_variance, x_start\n @torch.no_grad()\n def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_"
+ },
+ {
+ "comment": "This function takes input x and returns the predicted values pred and x_start. It uses a p_mean_variance method from self to calculate model_mean, model_log_variance, and x_start. Noise is added to the input x, except when t == 0. The result is the sum of model_mean and nonzero_mask * (0.5 * model_log_variance).exp() * noise. This is a part of the DDPM (Denoising Diffusion Probabilistic Models) framework for generating images.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2819-2835",
+ "content": "start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):\n b, *_, device = *x.shape, x.device\n model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)\n noise = torch.randn_like(x)\n # no noise when t == 0\n nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))\n pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n return pred, x_start\n @torch.no_grad()\n def p_sample_loop_ddpm(\n self,\n unet,\n shape,\n image_embed,\n noise_scheduler,\n predict_x_start = False,"
+ },
+ {
+ "comment": "This function initializes image and related variables. If inpainting is present, it normalizes and resizes the image, mask, and sets their dimensions accordingly. The function also determines if the model is performing latent diffusion by checking for provided parameters. It then proceeds to an if-not condition where it assumes that the model is not performing latent diffusion.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2836-2865",
+ "content": " predict_v = False,\n learned_variance = False,\n clip_denoised = True,\n lowres_cond_img = None,\n text_encodings = None,\n cond_scale = 1,\n is_latent_diffusion = False,\n lowres_noise_level = None,\n inpaint_image = None,\n inpaint_mask = None,\n inpaint_resample_times = 5\n ):\n device = self.device\n b = shape[0]\n img = torch.randn(shape, device = device)\n x_start = None # for self-conditioning\n is_inpaint = exists(inpaint_image)\n resample_times = inpaint_resample_times if is_inpaint else 1\n if is_inpaint:\n inpaint_image = self.normalize_img(inpaint_image)\n inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)\n inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()\n inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)\n inpaint_mask = inpaint_mask.bool()\n if not is_latent_diffusion:"
+ },
+ {
+ "comment": "This code performs progressive growing of an image using a diffusion model, such as DALLE 2. It iterates over timesteps in reverse order and resamples each timestep to produce a final output image. It also includes the option for inpainting by following the Repaint paper's approach. The self-conditioning and U-Net are utilized within the p_sample function, which takes care of the actual sampling process.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2866-2889",
+ "content": " lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)\n for time in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):\n is_last_timestep = time == 0\n for r in reversed(range(0, resample_times)):\n is_last_resample_step = r == 0\n times = torch.full((b,), time, device = device, dtype = torch.long)\n if is_inpaint:\n # following the repaint paper\n # https://arxiv.org/abs/2201.09865\n noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)\n img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)\n self_cond = x_start if unet.self_cond else None\n img, x_start = self.p_sample(\n unet,\n img,\n times,\n image_embed = image_embed,\n text_encodings = text_encodings,"
+ },
+ {
+ "comment": "This code is part of a model that performs image denoising using diffusion models. It samples images at different timesteps, applies noise scheduling for resampling, and handles inpainting by combining input mask and image embeddings. The output is then unnormalized for the final result.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2890-2916",
+ "content": " cond_scale = cond_scale,\n self_cond = self_cond,\n lowres_cond_img = lowres_cond_img,\n lowres_noise_level = lowres_noise_level,\n predict_x_start = predict_x_start,\n predict_v = predict_v,\n noise_scheduler = noise_scheduler,\n learned_variance = learned_variance,\n clip_denoised = clip_denoised\n )\n if is_inpaint and not (is_last_timestep or is_last_resample_step):\n # in repaint, you renoise and resample up to 10 times every step\n img = noise_scheduler.q_sample_from_to(img, times - 1, times)\n if is_inpaint:\n img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)\n unnormalize_img = self.unnormalize_img(img)\n return unnormalize_img\n @torch.no_grad()\n def p_sample_loop_ddim(\n self,\n unet,\n shape,\n image_embed,"
+ },
+ {
+ "comment": "This function takes multiple parameters including noise_scheduler, timesteps, eta, and more. It extracts necessary information like batch size, device, total timesteps, alphas, and other parameters to perform DDIM sampling. It also checks if inpainting is required and resamples times accordingly.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2917-2945",
+ "content": " noise_scheduler,\n timesteps,\n eta = 1.,\n predict_x_start = False,\n predict_v = False,\n learned_variance = False,\n clip_denoised = True,\n lowres_cond_img = None,\n text_encodings = None,\n cond_scale = 1,\n is_latent_diffusion = False,\n lowres_noise_level = None,\n inpaint_image = None,\n inpaint_mask = None,\n inpaint_resample_times = 5\n ):\n batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod, self.ddim_sampling_eta\n times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]\n times = list(reversed(times.int().tolist()))\n time_pairs = list(zip(times[:-1], times[1:]))\n time_pairs = list(filter(lambda t: t[0] > t[1], time_pairs))\n is_inpaint = exists(inpaint_image)\n resample_times = inpaint_resample_times if is_inpaint else 1\n if is_inpaint:\n inpaint_image = self.normalize_img(inpaint_image)"
+ },
+ {
+ "comment": "The code is sampling from a diffusion model and applying inpainting. It resizes images, prepares masks for inpainting, sets up variables for time steps, and conditions the model based on inpainting or not. The code follows the process described in the Repaint paper (https://arxiv.org/abs/2201.09865).",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2946-2971",
+ "content": " inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)\n inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()\n inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)\n inpaint_mask = inpaint_mask.bool()\n img = torch.randn(shape, device = device)\n x_start = None # for self-conditioning\n if not is_latent_diffusion:\n lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)\n for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):\n is_last_timestep = time_next == 0\n for r in reversed(range(0, resample_times)):\n is_last_resample_step = r == 0\n alpha = alphas[time]\n alpha_next = alphas[time_next]\n time_cond = torch.full((batch,), time, device = device, dtype = torch.long)\n if is_inpaint:\n # following the repaint paper\n # https://arxiv.org/abs/2201.09865"
+ },
+ {
+ "comment": "This code is using a conditional image generation model to generate an output image based on the input image, conditioning factors (time_cond, image_embed, text_encodings), and possibly predicting x0 values for further processing or clipping.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2972-2993",
+ "content": " noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)\n img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)\n self_cond = x_start if unet.self_cond else None\n unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)\n pred, _ = self.parse_unet_output(learned_variance, unet_output)\n # predict x0\n if predict_v:\n x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)\n elif predict_x_start:\n x_start = pred\n else:\n x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)\n # maybe clip x0\n if clip_denoised:\n x_start = self.dynamic_threshold(x_start)"
+ },
+ {
+ "comment": "Predicts noise based on the current state and time, applies coefficients to noise and image, performs inpainting if necessary, and unnormalizes the image.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":2995-3016",
+ "content": " # predict noise\n pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)\n c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()\n c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()\n noise = torch.randn_like(img) if not is_last_timestep else 0.\n img = x_start * alpha_next.sqrt() + \\\n c1 * noise + \\\n c2 * pred_noise\n if is_inpaint and not (is_last_timestep or is_last_resample_step):\n # in repaint, you renoise and resample up to 10 times every step\n time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long)\n img = noise_scheduler.q_sample_from_to(img, time_next_cond, time_cond)\n if exists(inpaint_image):\n img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)\n img = self.unnormalize_img(img)\n return img"
+ },
+ {
+ "comment": "Function `p_sample_loop` takes in arguments, determines if DDPM or DDIM should be used for sampling, and calls respective function.\nIn `p_losses`, noise is defaulted if not provided, and images are normalized before processing if not latent diffusion.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":3018-3038",
+ "content": " @torch.no_grad()\n def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):\n num_timesteps = noise_scheduler.num_timesteps\n timesteps = default(timesteps, num_timesteps)\n assert timesteps <= num_timesteps\n is_ddim = timesteps < num_timesteps\n if not is_ddim:\n return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **kwargs)\n return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)\n def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):\n noise = default(noise, lambda: torch.randn_like(x_start))\n # normalize to [-1, 1]\n if not is_latent_diffusion:\n x_start = self.normalize_img(x_start)\n lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)"
+ },
+ {
+ "comment": "Code snippet is from the DALLE2-pytorch model. It samples noisy images and uses them to conditionally generate unet outputs for self-conditioning and prediction, with optional dropout probabilities for image and text conditions.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":3040-3074",
+ "content": " # get x_t\n x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)\n # unet kwargs\n unet_kwargs = dict(\n image_embed = image_embed,\n text_encodings = text_encodings,\n lowres_cond_img = lowres_cond_img,\n lowres_noise_level = lowres_noise_level,\n )\n # self conditioning\n self_cond = None\n if unet.self_cond and random.random() < 0.5:\n with torch.no_grad():\n unet_output = unet(x_noisy, times, **unet_kwargs)\n self_cond, _ = self.parse_unet_output(learned_variance, unet_output)\n self_cond = self_cond.detach()\n # forward to get model prediction\n unet_output = unet(\n x_noisy,\n times,\n **unet_kwargs,\n self_cond = self_cond,\n image_cond_drop_prob = self.image_cond_drop_prob,\n text_cond_drop_prob = self.text_cond_drop_prob,\n )\n pred, _ = self.parse_unet_output(learned_variance, unet_output)"
+ },
+ {
+ "comment": "The code calculates the loss in a specific manner depending on the input parameters. If predict_v is true, it calculates the target value for v. If predict_x_start is true, it uses x_start as the target. Otherwise, it uses noise as the target. Then, it applies the loss function, reduces the loss, reweighs the loss based on times, and finally calculates the mean of the loss. If learned_variance is not used, it returns the simple loss.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":3076-3097",
+ "content": " if predict_v:\n target = noise_scheduler.calculate_v(x_start, times, noise)\n elif predict_x_start:\n target = x_start\n else:\n target = noise\n loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')\n loss = reduce(loss, 'b ... -> b (...)', 'mean')\n loss = noise_scheduler.p2_reweigh_loss(loss, times)\n loss = loss.mean()\n if not learned_variance:\n # return simple loss if not using learned variance\n return loss\n # most of the code below is transcribed from\n # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py\n # the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 \"simple\" loss\n # it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation"
+ },
+ {
+ "comment": "This code calculates the KL divergence between true and model predicted posterior distributions, and decoder negative log likelihood. It uses detached model predictions for stability reasons as per the paper. The loss at the first timestep is the decoder NLL, otherwise it's the KL divergence.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":3099-3114",
+ "content": " # if learning the variance, also include the extra weight kl loss\n true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)\n model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output)\n # kl loss with detached model predicted mean, for stability reasons as in paper\n detached_model_mean = model_mean.detach()\n kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)\n kl = meanflat(kl) * NAT\n decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)\n decoder_nll = meanflat(decoder_nll) * NAT\n # at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))"
+ },
+ {
+ "comment": "This function calculates the variational Bayes loss and adds it to the main loss. It then samples from the model given input parameters such as image, text, batch size, etc., with option for conditional or unconditional sampling. The function also performs some assertions on the inputs to ensure proper usage.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":3116-3148",
+ "content": " vb_losses = torch.where(times == 0, decoder_nll, kl)\n # weight the vb loss smaller, for stability, as in the paper (recommended 0.001)\n vb_loss = vb_losses.mean() * self.vb_loss_weight\n return loss + vb_loss\n @torch.no_grad()\n @eval_decorator\n def sample(\n self,\n image = None,\n image_embed = None,\n text = None,\n text_encodings = None,\n batch_size = 1,\n cond_scale = 1.,\n start_at_unet_number = 1,\n stop_at_unet_number = None,\n distributed = False,\n inpaint_image = None,\n inpaint_mask = None,\n inpaint_resample_times = 5,\n one_unet_in_gpu_at_time = True\n ):\n assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'\n if not self.unconditional:\n batch_size = image_embed.shape[0]\n if exists(text) and not exists(text_encodings) and not self.unconditional:\n assert exists(self.clip)"
+ },
+ {
+ "comment": "This code checks for valid inputs and asserts whether text, text encodings, or inpaint_image and mask are present based on the condition specified. It also ensures that the image input has the correct batch size when starting at a specific unet number. If necessary, it resizes the image using nearest-neighbor interpolation.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":3149-3162",
+ "content": " _, text_encodings = self.clip.embed_text(text)\n assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'\n assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'\n assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'\n img = None\n if start_at_unet_number > 1:\n # Then we are not generating the first image and one must have been passed in\n assert exists(image), 'image must be passed in if starting at unet number > 1'\n assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)\n prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]\n img = resize_image_to(image, prev_unet_output_size, nearest = True)"
+ },
+ {
+ "comment": "This code is iterating through each unet in the model, skipping the first X unets based on a given parameter. It checks if the current unet should be processed based on its position, and then prepares low resolution conditioning for upsamplers if required. The code also handles CUDA processing and uses context managers to ensure efficient resource usage.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":3164-3181",
+ "content": " is_cuda = next(self.parameters()).is_cuda\n num_unets = self.num_unets\n cond_scale = cast_tuple(cond_scale, num_unets)\n for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):\n if unet_number < start_at_unet_number:\n continue # It's the easiest way to do it\n context = self.one_unet_in_gpu(unet = unet) if is_cuda and one_unet_in_gpu_at_time else null_context()\n with context:\n # prepare low resolution conditioning for upsamplers\n lowres_cond_img = lowres_noise_level = None\n shape = (batch_size, channel, image_size, image_size)\n if unet.lowres_cond:"
+ },
+ {
+ "comment": "This code is part of a denoising diffusion model. It first resizes the input image to a target size and applies noise if needed. Then, it checks if the VAE (Variational Autoencoder) is used for latent diffusion and adjusts the image size accordingly. Finally, it encodes the low-resolution image using the VAE and enters a denoising loop with a UNet model to generate the final output image.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":3182-3203",
+ "content": " lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = True)\n if lowres_cond.use_noise:\n lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device)\n lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level)\n # latent diffusion\n is_latent_diffusion = isinstance(vae, VQGanVAE)\n image_size = vae.get_encoded_fmap_size(image_size)\n shape = (batch_size, vae.encoded_dim, image_size, image_size)\n lowres_cond_img = maybe(vae.encode)(lowres_cond_img)\n # denoising loop for image\n img = self.p_sample_loop(\n unet,\n shape,\n image_embed = image_embed,\n text_encodings = text_encodings,\n cond_scale = unet_cond_scale,"
+ },
+ {
+ "comment": "The function takes an image and optionally text, generates images at different UNet resolutions based on input parameters, and returns the generated image. It includes options for low-resolution output, inpainting, and stopping at a specific UNet resolution.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":3204-3232",
+ "content": " predict_x_start = predict_x_start,\n predict_v = predict_v,\n learned_variance = learned_variance,\n clip_denoised = not is_latent_diffusion,\n lowres_cond_img = lowres_cond_img,\n lowres_noise_level = lowres_noise_level,\n is_latent_diffusion = is_latent_diffusion,\n noise_scheduler = noise_scheduler,\n timesteps = sample_timesteps,\n inpaint_image = inpaint_image,\n inpaint_mask = inpaint_mask,\n inpaint_resample_times = inpaint_resample_times\n )\n img = vae.decode(img)\n if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:\n break\n return img\n def forward(\n self,\n image,\n text = None,\n image_embed = None,\n text_encodings = None,\n unet_number = None,\n return_lowres_"
+ },
+ {
+ "comment": "This function is initializing variables for a specific U-Net in the model, based on the provided unet_number. It assigns the corresponding U-Net, VAE, noise scheduler, lowres conditioner, target image size, predict x start, predict v, random crop size, and learned variance from predefined lists for that U-Net index. It also ensures the image shape aligns with the expected number of channels.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":3232-3250",
+ "content": "cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes\n ):\n assert not (self.num_unets > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {self.num_unets}, if you are training cascading DDPM (multiple unets)'\n unet_number = default(unet_number, 1)\n unet_index = unet_number - 1\n unet = self.get_unet(unet_number)\n vae = self.vaes[unet_index]\n noise_scheduler = self.noise_schedulers[unet_index]\n lowres_conditioner = self.lowres_conds[unet_index]\n target_image_size = self.image_sizes[unet_index]\n predict_x_start = self.predict_x_start[unet_index]\n predict_v = self.predict_v[unet_index]\n random_crop_size = self.random_crop_sizes[unet_index]\n learned_variance = self.learned_variance[unet_index]\n b, c, h, w, device, = *image.shape, image.device\n assert image.shape[1] == self.channels"
+ },
+ {
+ "comment": "The code checks if the image and/or text inputs exist, ensuring that either the CLIP model or the necessary inputs are present. It asserts that if the decoder is supposed to be conditioned on text encodings, then the text encodings must be provided, and vice versa. This helps prevent errors in the input data for generating image embeddings.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":3251-3266",
+ "content": " assert h >= target_image_size and w >= target_image_size\n times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)\n if not exists(image_embed) and not self.unconditional:\n assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'\n image_embed, _ = self.clip.embed_image(image)\n if exists(text) and not exists(text_encodings) and not self.unconditional:\n assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'\n _, text_encodings = self.clip.embed_text(text)\n assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'\n assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'\n "
+ },
+ {
+ "comment": "This code snippet is conditioning a low-resolution image using the lowres_conditioner and performing data augmentation via Kornia's RandomCrop. It also encodes both the image and the conditioned image using a VAE (Variational Autoencoder) and calculates loss from p_losses for further processing in the U-net model.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":3266-3284",
+ "content": "lowres_cond_img, lowres_noise_level = lowres_conditioner(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if exists(lowres_conditioner) else (None, None)\n image = resize_image_to(image, target_image_size, nearest = True)\n if exists(random_crop_size):\n aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)\n # make sure low res conditioner and image both get augmented the same way\n # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop\n image = aug(image)\n lowres_cond_img = aug(lowres_cond_img, params = aug._params)\n is_latent_diffusion = not isinstance(vae, NullVQGanVAE)\n vae.eval()\n with torch.no_grad():\n image = vae.encode(image)\n lowres_cond_img = maybe(vae.encode)(lowres_cond_img)\n losses = self.p_losses(unet, image, times, image_embed = image"
+ },
+ {
+ "comment": "This code defines a DALLE2 class with prior and decoder modules. It takes text input, performs diffusion, and returns losses or lowres_cond_img based on the return flag. If not returning the lowres conditional image, it returns only losses.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":3284-3318",
+ "content": "_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)\n if not return_lowres_cond_image:\n return losses\n return losses, lowres_cond_img\n# main class\nclass DALLE2(nn.Module):\n def __init__(\n self,\n *,\n prior,\n decoder,\n prior_num_samples = 2\n ):\n super().__init__()\n assert isinstance(prior, DiffusionPrior)\n assert isinstance(decoder, Decoder)\n self.prior = prior\n self.decoder = decoder\n self.prior_num_samples = prior_num_samples\n self.decoder_need_text_cond = self.decoder.condition_on_text_encodings\n self.to_pil = T.ToPILImage()\n @torch.no_grad()\n @eval_decorator\n def forward(\n self,\n text,\n cond_scale = 1.,\n prior_cond_scale = 1.,"
+ },
+ {
+ "comment": "This function takes text as input, tokenizes it if necessary, and uses a prior model to generate image embeddings. It then passes these embeddings along with the text (if required) to a decoder model to generate images. Optionally, it converts the images to PIL format and returns them. If only one text is given, it returns the first generated image.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dalle2_pytorch.py\":3319-3339",
+ "content": " return_pil_images = False\n ):\n device = module_device(self)\n one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)\n if isinstance(text, str) or is_list_str(text):\n text = [text] if not isinstance(text, (list, tuple)) else text\n text = tokenizer.tokenize(text).to(device)\n image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)\n text_cond = text if self.decoder_need_text_cond else None\n images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)\n if return_pil_images:\n images = list(map(self.to_pil, images.unbind(dim = 0)))\n if one_text:\n return first(images)\n return images"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/c3edea65-23d8-4615-9c12-d7fe96faa179.json b/docs/doc/c3edea65-23d8-4615-9c12-d7fe96faa179.json
new file mode 100644
index 00000000..ad9e7bdc
--- /dev/null
+++ b/docs/doc/c3edea65-23d8-4615-9c12-d7fe96faa179.json
@@ -0,0 +1,10 @@
+{
+ "summary": "This code imports necessary classes for ImageEmbeddingDataset and PriorEmbeddingDataset from their respective modules in the DALLE2-pytorch library. These datasets are used to load data for the model's training and inference.",
+ "details": [
+ {
+ "comment": "This code imports necessary classes for ImageEmbeddingDataset and PriorEmbeddingDataset from their respective modules in the DALLE2-pytorch library. These datasets are used to load data for the model's training and inference.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/dataloaders/__init__.py\":0-1",
+ "content": "from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader\nfrom dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/d1e3a524-d0ee-487e-a5f2-c3f231bbdadd.json b/docs/doc/d1e3a524-d0ee-487e-a5f2-c3f231bbdadd.json
new file mode 100644
index 00000000..51388604
--- /dev/null
+++ b/docs/doc/d1e3a524-d0ee-487e-a5f2-c3f231bbdadd.json
@@ -0,0 +1,40 @@
+{
+ "summary": "The code simplifies DALL-E2 text tokenization by offering a PyTorch BPE tokenizer implementation with features for whitespace cleanup, formatting fixes, human-readable conversion, and handling context length limitations.",
+ "details": [
+ {
+ "comment": "This code imports necessary libraries and defines functions for tokenization, specifically for the DALL-E2 model. It uses OpenAI's simple tokenizer, a byte-to-unicode conversion, and a function to generate character pairs from a given word. The code is meant to provide users with an easy way to start training DALL-E without implementing BPE (Byte Pair Encoding).",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/tokenizer.py\":0-41",
+ "content": "# take from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py\n# to give users a quick easy start to training DALL-E without doing BPE\nimport torch\nimport html\nimport os\nimport ftfy\nimport regex as re\nfrom functools import lru_cache\nfrom pathlib import Path\nfrom dalle2_pytorch.utils import import_or_print_error\n# OpenAI simple tokenizer\n@lru_cache()\ndef default_bpe():\n return os.path.join(os.path.dirname(os.path.abspath(__file__)), \"data/bpe_simple_vocab_16e6.txt\")\n@lru_cache()\ndef bytes_to_unicode():\n bs = list(range(ord(\"!\"), ord(\"~\") + 1)) + list(range(ord(\"\u00a1\"), ord(\"\u00ac\") + 1)) + list(range(ord(\"\u00ae\"), ord(\"\u00ff\") + 1))\n cs = bs[:]\n n = 0\n for b in range(2 ** 8):\n if b not in bs:\n bs.append(b)\n cs.append(2 ** 8 + n)\n n += 1\n cs = [chr(n) for n in cs]\n return dict(zip(bs, cs))\ndef get_pairs(word):\n pairs = set()\n prev_char = word[0]\n for char in word[1:]:\n pairs.add((prev_char, char))\n prev_char = char\n return pairs\ndef basic_clean(text):"
+ },
+ {
+ "comment": "This code is a Python class for a tokenizer that utilizes byte encoding and decoding, along with byte-pair encoding (BPE) to convert text into tokens. The class also includes methods for cleaning whitespace and fixing text formatting issues. The BPE merges are loaded from a specified file path, and the vocabulary is expanded by adding special tokens like \"<|startoftext|>\" and \"<|endoftext|>\".",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/tokenizer.py\":42-68",
+ "content": " text = ftfy.fix_text(text)\n text = html.unescape(html.unescape(text))\n return text.strip()\ndef whitespace_clean(text):\n text = re.sub(r'\\s+', ' ', text)\n text = text.strip()\n return text\nclass SimpleTokenizer(object):\n def __init__(self, bpe_path = default_bpe()):\n self.byte_encoder = bytes_to_unicode()\n self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}\n merges = Path(bpe_path).read_text(encoding='utf8').split('\\n')\n merges = merges[1:49152 - 256 - 2 + 1]\n merges = [tuple(merge.split()) for merge in merges]\n vocab = list(bytes_to_unicode().values())\n vocab = vocab + [v + '' for v in vocab]\n for merge in merges:\n vocab.append(''.join(merge))\n vocab.extend(['<|startoftext|>', '<|endoftext|>'])\n self.vocab_size = 49408\n self.encoder = dict(zip(vocab, range(len(vocab))))\n self.decoder = {v: k for k, v in self.encoder.items()}\n self.bpe_ranks = dict(zip(merges, range(len(merges))))"
+ },
+ {
+ "comment": "The code defines a tokenizer that uses byte-pair encoding (BPE) for text. It compiles a regular expression pattern to match words and special tokens like \"<|startoftext|>\" and \"<|endoftext|>\". The `bpe` method takes a token, checks if it's in the cache, and if not, processes it using BPE by splitting it into smaller parts until no more splits are possible.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/tokenizer.py\":69-97",
+ "content": " self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}\n self.pat = re.compile(\n r\"\"\"<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+\"\"\",\n re.IGNORECASE)\n def bpe(self, token):\n if token in self.cache:\n return self.cache[token]\n word = tuple(token[:-1]) + (token[-1] + '',)\n pairs = get_pairs(word)\n if not pairs:\n return token + ''\n while True:\n bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))\n if bigram not in self.bpe_ranks:\n break\n first, second = bigram\n new_word = []\n i = 0\n while i < len(word):\n try:\n j = word.index(first, i)\n new_word.extend(word[i:j])\n i = j\n except:\n new_word.extend(word[i:])\n break"
+ },
+ {
+ "comment": "Code snippet is from a byte-pair encoding (BPE) tokenizer implementation in PyTorch. The code encodes input text into BPE tokens, performs wordpiece tokenization, and caches the mapping between tokens and words for decoding. The encode() function processes the input text by applying preprocessing steps, performing BPE, and extending tokens list with BPE tokens. The decode() function allows decoding of encoded tokens back to words using cached mappings.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/tokenizer.py\":99-125",
+ "content": " if word[i] == first and i < len(word) - 1 and word[i + 1] == second:\n new_word.append(first + second)\n i += 2\n else:\n new_word.append(word[i])\n i += 1\n new_word = tuple(new_word)\n word = new_word\n if len(word) == 1:\n break\n else:\n pairs = get_pairs(word)\n word = ' '.join(word)\n self.cache[token] = word\n return word\n def encode(self, text):\n bpe_tokens = []\n text = whitespace_clean(basic_clean(text)).lower()\n for token in re.findall(self.pat, text):\n token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))\n bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))\n return bpe_tokens\n def decode(self, tokens, remove_start_end = True, pad_tokens = set()):\n if torch.is_tensor(tokens):\n tokens = tokens.tolist()"
+ },
+ {
+ "comment": "The code defines a SimpleTokenizer class that tokenizes input texts using an encoding scheme and provides a method to convert encoded tokens into human-readable text. It also includes a tokenize function to process multiple input texts, considering context length limitations and handling truncation. The provided code snippet focuses on the process of converting encoded tokens into text.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/tokenizer.py\":127-150",
+ "content": " if remove_start_end:\n tokens = [token for token in tokens if token not in (49406, 40407, 0)]\n text = ''.join([self.decoder[token] for token in tokens if token not in pad_tokens])\n text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=\"replace\").replace('', ' ')\n return text\n def tokenize(self, texts, context_length = 256, truncate_text = False):\n if isinstance(texts, str):\n texts = [texts]\n all_tokens = [self.encode(text) for text in texts]\n result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)\n for i, tokens in enumerate(all_tokens):\n if len(tokens) > context_length:\n if truncate_text:\n tokens = tokens[:context_length]\n else:\n raise RuntimeError(f\"Input {texts[i]} is too long for context length {context_length}\")\n result[i, :len(tokens)] = torch.tensor(tokens)\n return result\ntokenizer = SimpleTokenizer()"
+ },
+ {
+ "comment": "This code defines a YTTM tokenizer class in PyTorch. The constructor loads the BPE model from the specified path and initializes the tokenizer instance, which can decode and encode text sequences. The decode function converts tokenized lists to human-readable strings, while the encode function transforms input texts into tokenized lists. The tokenize method takes a list of texts, encodes them, and returns a tensor of shape (number_of_texts, context_length) for further processing.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/tokenizer.py\":152-181",
+ "content": "# YTTM tokenizer\nclass YttmTokenizer:\n def __init__(self, bpe_path = None):\n bpe_path = Path(bpe_path)\n assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'\n self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')\n tokenizer = self.yttm.BPE(model = str(bpe_path))\n self.tokenizer = tokenizer\n self.vocab_size = tokenizer.vocab_size()\n def decode(self, tokens, pad_tokens = set()):\n if torch.is_tensor(tokens):\n tokens = tokens.tolist()\n return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))\n def encode(self, texts):\n encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)\n return list(map(torch.tensor, encoded))\n def tokenize(self, texts, context_length = 256, truncate_text = False):\n if isinstance(texts, str):\n texts = [texts]\n all_tokens = self.encode(texts)\n result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)"
+ },
+ {
+ "comment": "This code segment iterates through all tokens in a list, truncating any token sequence longer than the specified context length. If truncation is not allowed and an input text is too long, it raises a RuntimeError. The truncated or original tokens are then converted to torch tensors and stored in a result array.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/tokenizer.py\":182-190",
+ "content": " for i, tokens in enumerate(all_tokens):\n if len(tokens) > context_length:\n if truncate_text:\n tokens = tokens[:context_length]\n else:\n raise RuntimeError(f\"Input {texts[i]} is too long for context length {context_length}\")\n result[i, :len(tokens)] = torch.tensor(tokens)\n return result"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/dc144f5b-33a9-4b70-a7f5-b6eb9409dc11.json b/docs/doc/dc144f5b-33a9-4b70-a7f5-b6eb9409dc11.json
new file mode 100644
index 00000000..a86f152f
--- /dev/null
+++ b/docs/doc/dc144f5b-33a9-4b70-a7f5-b6eb9409dc11.json
@@ -0,0 +1,110 @@
+{
+ "summary": "Code describes VQGAN-VAE and Vision Transformer architectures for image generation models, including convolutional layers, self-attention mechanisms, layer normalization, initializes model, calculates losses, determines adaptive weight, applies clamp function, calculates combined loss, returns reconstructed feature maps if required.",
+ "details": [
+ {
+ "comment": "This code imports various libraries and defines several constants, helper functions, and decorators for use in a deep learning model. It also sets up a class for a Vector Quantize module using PyTorch, with functionality to evaluate the model and remove the VGG feature if present.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":0-50",
+ "content": "import copy\nimport math\nfrom math import sqrt\nfrom functools import partial, wraps\nfrom vector_quantize_pytorch import VectorQuantize as VQ\nimport torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom torch.autograd import grad as torch_grad\nimport torchvision\nfrom einops import rearrange, reduce, repeat, pack, unpack\nfrom einops.layers.torch import Rearrange\n# constants\nMList = nn.ModuleList\n# helper functions\ndef exists(val):\n return val is not None\ndef default(val, d):\n return val if exists(val) else d\n# decorators\ndef eval_decorator(fn):\n def inner(model, *args, **kwargs):\n was_training = model.training\n model.eval()\n out = fn(model, *args, **kwargs)\n model.train(was_training)\n return out\n return inner\ndef remove_vgg(fn):\n @wraps(fn)\n def inner(self, *args, **kwargs):\n has_vgg = hasattr(self, 'vgg')\n if has_vgg:\n vgg = self.vgg\n delattr(self, 'vgg')\n out = fn(self, *args, **kwargs)\n if has_vgg:\n self.vgg = vgg"
+ },
+ {
+ "comment": "This code contains various utility functions. \"pick_and_pop\" removes and returns keys from a dictionary, \"group_dict_by_key\" groups dictionary items by key condition, \"string_begins_with\" checks if a string begins with a given prefix, \"group_by_key_prefix\" groups dictionary items based on a key prefix, and \"groupby_prefix_and_trim\" trims key prefixes before grouping. Lastly, the \"log\" function calculates the natural logarithm of an input tensor, and the \"gradient_penalty\" function is used to calculate a gradient penalty for image generation tasks.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":52-86",
+ "content": " return out\n return inner\n# keyword argument helpers\ndef pick_and_pop(keys, d):\n values = list(map(lambda key: d.pop(key), keys))\n return dict(zip(keys, values))\ndef group_dict_by_key(cond, d):\n return_val = [dict(),dict()]\n for key in d.keys():\n match = bool(cond(key))\n ind = int(not match)\n return_val[ind][key] = d[key]\n return (*return_val,)\ndef string_begins_with(prefix, string_input):\n return string_input.startswith(prefix)\ndef group_by_key_prefix(prefix, d):\n return group_dict_by_key(partial(string_begins_with, prefix), d)\ndef groupby_prefix_and_trim(prefix, d):\n kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)\n kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))\n return kwargs_without_prefix, kwargs\n# tensor helper functions\ndef log(t, eps = 1e-10):\n return torch.log(t + eps)\ndef gradient_penalty(images, output, weight = 10):\n batch_size = images.shape[0]"
+ },
+ {
+ "comment": "This code contains several utility functions and loss functions used in the VQ-VAE-GAN model. It includes functions for gradient calculations, normalization, activation functions, and various GAN losses. The functions are defined to be reusable throughout the codebase.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":87-120",
+ "content": " gradients = torch_grad(outputs = output, inputs = images,\n grad_outputs = torch.ones(output.size(), device = images.device),\n create_graph = True, retain_graph = True, only_inputs = True)[0]\n gradients = rearrange(gradients, 'b ... -> b (...)')\n return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()\ndef l2norm(t):\n return F.normalize(t, dim = -1)\ndef leaky_relu(p = 0.1):\n return nn.LeakyReLU(0.1)\ndef stable_softmax(t, dim = -1, alpha = 32 ** 2):\n t = t / alpha\n t = t - torch.amax(t, dim = dim, keepdim = True).detach()\n return (t * alpha).softmax(dim = dim)\ndef safe_div(numer, denom, eps = 1e-8):\n return numer / (denom + eps)\n# gan losses\ndef hinge_discr_loss(fake, real):\n return (F.relu(1 + fake) + F.relu(1 - real)).mean()\ndef hinge_gen_loss(fake):\n return -fake.mean()\ndef bce_discr_loss(fake, real):\n return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()\ndef bce_gen_loss(fake):\n return -log(torch.sigmoid(fake)).mean()"
+ },
+ {
+ "comment": "The code defines a function to compute gradients of a layer wrt the loss, and introduces two custom modules: LayerNormChan for layer normalization and Discriminator for a convolutional network. The discriminator consists of multiple layers with decreasing kernel sizes, each followed by a leaky ReLU activation function. These components are part of the VQGAN-VAE architecture in DALLE2-pytorch.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":122-162",
+ "content": "def grad_layer_wrt_loss(loss, layer):\n return torch_grad(\n outputs = loss,\n inputs = layer,\n grad_outputs = torch.ones_like(loss),\n retain_graph = True\n )[0].detach()\n# vqgan vae\nclass LayerNormChan(nn.Module):\n def __init__(\n self,\n dim,\n eps = 1e-5\n ):\n super().__init__()\n self.eps = eps\n self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))\n def forward(self, x):\n var = torch.var(x, dim = 1, unbiased = False, keepdim = True)\n mean = torch.mean(x, dim = 1, keepdim = True)\n return (x - mean) / (var + self.eps).sqrt() * self.gamma\n# discriminator\nclass Discriminator(nn.Module):\n def __init__(\n self,\n dims,\n channels = 3,\n groups = 16,\n init_kernel_size = 5\n ):\n super().__init__()\n dim_pairs = zip(dims[:-1], dims[1:])\n self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])\n for dim_in, dim_out in dim_pairs:"
+ },
+ {
+ "comment": "The code defines a VQGAN-VAE model. It uses convolutional layers and group normalization for downsampling the input image, followed by linear layers and leaky ReLU activation functions in a sequential manner to generate logits. The `ContinuousPositionBias` class is used for positional encoding in the model.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":163-196",
+ "content": " self.layers.append(nn.Sequential(\n nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),\n nn.GroupNorm(groups, dim_out),\n leaky_relu()\n ))\n dim = dims[-1]\n self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training\n nn.Conv2d(dim, dim, 1),\n leaky_relu(),\n nn.Conv2d(dim, 1, 4)\n )\n def forward(self, x):\n for net in self.layers:\n x = net(x)\n return self.to_logits(x)\n# positional encoding\nclass ContinuousPositionBias(nn.Module):\n \"\"\" from https://arxiv.org/abs/2111.09883 \"\"\"\n def __init__(self, *, dim, heads, layers = 2):\n super().__init__()\n self.net = MList([])\n self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))\n for _ in range(layers - 1):\n self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))\n self.net.append(nn.Linear(dim, heads))\n self.register_buffer('rel_pos', None, persistent = False)"
+ },
+ {
+ "comment": "The code defines a VQ-VAE implementation with a resnet encoder/decoder for image generation. The function calculates relative positional embeddings and applies them to the input, then passes the result through a resnet encoder/decoder network before returning the transformed input. The ResnetEncDec class creates an instance of the resnet encoder/decoder with optional parameters such as dimensions, channels, layers, layer_mults, num_resnet_blocks, resnet_groups, first_conv_kernel_size, and use_attn.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":198-231",
+ "content": " def forward(self, x):\n n, device = x.shape[-1], x.device\n fmap_size = int(sqrt(n))\n if not exists(self.rel_pos):\n pos = torch.arange(fmap_size, device = device)\n grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))\n grid = rearrange(grid, 'c i j -> (i j) c')\n rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')\n rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)\n self.register_buffer('rel_pos', rel_pos, persistent = False)\n rel_pos = self.rel_pos.float()\n for layer in self.net:\n rel_pos = layer(rel_pos)\n bias = rearrange(rel_pos, 'i j h -> h i j')\n return x + bias\n# resnet encoder / decoder\nclass ResnetEncDec(nn.Module):\n def __init__(\n self,\n dim,\n *,\n channels = 3,\n layers = 4,\n layer_mults = None,\n num_resnet_blocks = 1,\n resnet_groups = 16,\n first_conv_kernel_size = 5,\n use_attn = True,"
+ },
+ {
+ "comment": "This code defines a class with specified parameters for layers, encoders, and decoders. It ensures the dimension is divisible by resnet_groups. The layer multipliers are stored in a list and used to determine the dimensions of each layer. num_resnet_blocks and use_attn are checked to make sure they match the designated number of layers.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":232-261",
+ "content": " attn_dim_head = 64,\n attn_heads = 8,\n attn_dropout = 0.,\n ):\n super().__init__()\n assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'\n self.layers = layers\n self.encoders = MList([])\n self.decoders = MList([])\n layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))\n assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'\n layer_dims = [dim * mult for mult in layer_mults]\n dims = (dim, *layer_dims)\n self.encoded_dim = dims[-1]\n dim_pairs = zip(dims[:-1], dims[1:])\n append = lambda arr, t: arr.append(t)\n prepend = lambda arr, t: arr.insert(0, t)\n if not isinstance(num_resnet_blocks, tuple):\n num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)\n if not isinstance(use_attn, tuple):\n use_attn = (*((False,) * (layers - 1)), use_attn)"
+ },
+ {
+ "comment": "This code creates encoder and decoder blocks for a VQ-VAE model. It asserts that the number of resnet blocks and use_attn match the layers, then iterates over each layer creating convolutional layers, LeakyReLU activation functions, optionally adding attention modules, and repeating a specific number of residual blocks in both encoders and decoders.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":263-278",
+ "content": " assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'\n assert len(use_attn) == layers\n for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):\n append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))\n prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))\n if layer_use_attn:\n prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))\n for _ in range(layer_num_resnet_blocks):\n append(self.encoders, ResBlock(dim_out, groups = resnet_groups))\n prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))\n if layer_use_attn:\n append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))"
+ },
+ {
+ "comment": "The code defines a class for a VQGAN-VAE model. It consists of encoder and decoder blocks, along with a GLUResBlock for the residual connections in the decoder. The encoder and decoder are composed of convolutional layers that reduce and increase image size respectively. The encoded image size is defined as the original image size divided by 2 to the power of the number of layers. The model can encode and decode images using the encoder and decoder blocks, and the last decoder layer's weights can be accessed separately.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":280-314",
+ "content": " prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))\n append(self.decoders, nn.Conv2d(dim, channels, 1))\n def get_encoded_fmap_size(self, image_size):\n return image_size // (2 ** self.layers)\n @property\n def last_dec_layer(self):\n return self.decoders[-1].weight\n def encode(self, x):\n for enc in self.encoders:\n x = enc(x)\n return x\n def decode(self, x):\n for dec in self.decoders:\n x = dec(x)\n return x\nclass GLUResBlock(nn.Module):\n def __init__(self, chan, groups = 16):\n super().__init__()\n self.net = nn.Sequential(\n nn.Conv2d(chan, chan * 2, 3, padding = 1),\n nn.GLU(dim = 1),\n nn.GroupNorm(groups, chan),\n nn.Conv2d(chan, chan * 2, 3, padding = 1),\n nn.GLU(dim = 1),\n nn.GroupNorm(groups, chan),\n nn.Conv2d(chan, chan, 1)\n )\n def forward(self, x):\n return self.net(x) + x"
+ },
+ {
+ "comment": "This code defines a residual block and a VQGAN attention layer for image processing. The ResBlock consists of multiple 2D convolutions and GroupNorm layers, followed by leaky ReLU activation functions. The VQGANAttention class is responsible for self-attention in the VQGAN model, using continuous position bias and multi-head attention with dropout regularization.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":316-353",
+ "content": "class ResBlock(nn.Module):\n def __init__(self, chan, groups = 16):\n super().__init__()\n self.net = nn.Sequential(\n nn.Conv2d(chan, chan, 3, padding = 1),\n nn.GroupNorm(groups, chan),\n leaky_relu(),\n nn.Conv2d(chan, chan, 3, padding = 1),\n nn.GroupNorm(groups, chan),\n leaky_relu(),\n nn.Conv2d(chan, chan, 1)\n )\n def forward(self, x):\n return self.net(x) + x\n# vqgan attention layer\nclass VQGanAttention(nn.Module):\n def __init__(\n self,\n *,\n dim,\n dim_head = 64,\n heads = 8,\n dropout = 0.\n ):\n super().__init__()\n self.heads = heads\n self.scale = dim_head ** -0.5\n inner_dim = heads * dim_head\n self.dropout = nn.Dropout(dropout)\n self.pre_norm = LayerNormChan(dim)\n self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)\n self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)\n self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)"
+ },
+ {
+ "comment": "This code defines a class for the Attention module in a ViT (Vision Transformer) model. It performs multi-head attention using key, query, and value tensors, followed by a softmax function to compute attention weights. The output is then passed through a linear layer and layer normalization before being added back to the input with residual connection.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":355-395",
+ "content": " def forward(self, x):\n h = self.heads\n height, width, residual = *x.shape[-2:], x.clone()\n x = self.pre_norm(x)\n q, k, v = self.to_qkv(x).chunk(3, dim = 1)\n q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))\n sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale\n sim = self.cpb(sim)\n attn = stable_softmax(sim, dim = -1)\n attn = self.dropout(attn)\n out = einsum('b h i j, b h c j -> b h c i', attn, v)\n out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)\n out = self.to_out(out)\n return out + residual\n# ViT encoder / decoder\nclass RearrangeImage(nn.Module):\n def forward(self, x):\n n = x.shape[1]\n w = h = int(sqrt(n))\n return rearrange(x, 'b (h w) ... -> b h w ...', h = h, w = w)\nclass Attention(nn.Module):\n def __init__(\n self,\n dim,\n *,\n heads = 8,\n dim_head = 32\n ):\n super().__init__()\n self.norm = nn.LayerNorm(dim)"
+ },
+ {
+ "comment": "This code defines a MultiHeadAttention module for a transformer model. It initializes the attention head count and scale, calculates inner dimension based on head count and input dimension. The forward function performs multi-head attention by splitting input into query, key, value tensors, scaling query tensor, computing similarity between query and key, subtracting maximum similarity to avoid zero gradients, performing softmax on attention scores, and finally producing output tensor through weighted sum of value tensors. The FeedForward function defines a feedforward network for the transformer model, consisting of layer normalization, linear layers with GELU activation function.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":396-432",
+ "content": " self.heads = heads\n self.scale = dim_head ** -0.5\n inner_dim = dim_head * heads\n self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)\n self.to_out = nn.Linear(inner_dim, dim)\n def forward(self, x):\n h = self.heads\n x = self.norm(x)\n q, k, v = self.to_qkv(x).chunk(3, dim = -1)\n q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))\n q = q * self.scale\n sim = einsum('b h i d, b h j d -> b h i j', q, k)\n sim = sim - sim.amax(dim = -1, keepdim = True).detach()\n attn = sim.softmax(dim = -1)\n out = einsum('b h i j, b h j d -> b h i d', attn, v)\n out = rearrange(out, 'b h n d -> b n (h d)')\n return self.to_out(out)\ndef FeedForward(dim, mult = 4):\n return nn.Sequential(\n nn.LayerNorm(dim),\n nn.Linear(dim, dim * mult, bias = False),\n nn.GELU(),\n nn.Linear(dim * mult, dim, bias = False)\n )\nclass Transformer(nn.Module):\n def __init__(\n self,"
+ },
+ {
+ "comment": "The code defines a class for an encoder-decoder architecture, which is part of the Vision Transformer (ViT) model. It utilizes attention and feedforward layers, and includes layer normalization in its forward pass. The encoder section takes input images, reshapes them into patches, and passes them through multiple attention and feedforward layers.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":433-475",
+ "content": " dim,\n *,\n layers,\n dim_head = 32,\n heads = 8,\n ff_mult = 4\n ):\n super().__init__()\n self.layers = nn.ModuleList([])\n for _ in range(layers):\n self.layers.append(nn.ModuleList([\n Attention(dim = dim, dim_head = dim_head, heads = heads),\n FeedForward(dim = dim, mult = ff_mult)\n ]))\n self.norm = nn.LayerNorm(dim)\n def forward(self, x):\n for attn, ff in self.layers:\n x = attn(x) + x\n x = ff(x) + x\n return self.norm(x)\nclass ViTEncDec(nn.Module):\n def __init__(\n self,\n dim,\n channels = 3,\n layers = 4,\n patch_size = 8,\n dim_head = 32,\n heads = 8,\n ff_mult = 4\n ):\n super().__init__()\n self.encoded_dim = dim\n self.patch_size = patch_size\n input_dim = channels * (patch_size ** 2)\n self.encoder = nn.Sequential(\n Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),"
+ },
+ {
+ "comment": "The code defines a VQ-VAE model for image generation, consisting of an encoder and decoder. The encoder processes the input image and outputs a compressed codebook index followed by a positional embedding. The decoder then reconstructs the original image from these inputs using a series of transformers and linear layers. The get_encoded_fmap_size function calculates the encoded feature map size based on the input image size.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":476-509",
+ "content": " nn.Linear(input_dim, dim),\n Transformer(\n dim = dim,\n dim_head = dim_head,\n heads = heads,\n ff_mult = ff_mult,\n layers = layers\n ),\n RearrangeImage(),\n Rearrange('b h w c -> b c h w')\n )\n self.decoder = nn.Sequential(\n Rearrange('b c h w -> b (h w) c'),\n Transformer(\n dim = dim,\n dim_head = dim_head,\n heads = heads,\n ff_mult = ff_mult,\n layers = layers\n ),\n nn.Sequential(\n nn.Linear(dim, dim * 4, bias = False),\n nn.Tanh(),\n nn.Linear(dim * 4, input_dim, bias = False),\n ),\n RearrangeImage(),\n Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)\n )\n def get_encoded_fmap_size(self, image_size):\n return image_size // self.patch_size\n @property"
+ },
+ {
+ "comment": "This code defines two classes: NullVQGanVAE and VQGanVAE. The NullVQGanVAE is a placeholder class without any specific layers or functionality, while the VQGanVAE class represents a variant of the VAE model with optional features like VGG loss, GAN integration, and customizable parameters for codebook dimensions and layers.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":510-561",
+ "content": " def last_dec_layer(self):\n return self.decoder[-3][-1].weight\n def encode(self, x):\n return self.encoder(x)\n def decode(self, x):\n return self.decoder(x)\n# main vqgan-vae classes\nclass NullVQGanVAE(nn.Module):\n def __init__(\n self,\n *,\n channels\n ):\n super().__init__()\n self.encoded_dim = channels\n self.layers = 0\n def get_encoded_fmap_size(self, size):\n return size\n def copy_for_eval(self):\n return self\n def encode(self, x):\n return x\n def decode(self, x):\n return x\nclass VQGanVAE(nn.Module):\n def __init__(\n self,\n *,\n dim,\n image_size,\n channels = 3,\n layers = 4,\n l2_recon_loss = False,\n use_hinge_loss = True,\n vgg = None,\n vq_codebook_dim = 256,\n vq_codebook_size = 512,\n vq_decay = 0.8,\n vq_commitment_weight = 1.,\n vq_kmeans_init = True,\n vq_use_cosine_sim = True,\n use_vgg_and_gan = True,\n vae_type = 'resnet',"
+ },
+ {
+ "comment": "This code initializes a VQ-VAE model with given parameters. It uses a specified encoder-decoder network (ResNet or ViT), codebook size, and other VQ-specific options. The VQ module is initialized based on the dimensionality of the encoder-decoder's encoded output, and the codebook size and related options. If an invalid VAE type is given, a ValueError is raised.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":562-595",
+ "content": " discr_layers = 4,\n **kwargs\n ):\n super().__init__()\n vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)\n encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)\n self.image_size = image_size\n self.channels = channels\n self.codebook_size = vq_codebook_size\n if vae_type == 'resnet':\n enc_dec_klass = ResnetEncDec\n elif vae_type == 'vit':\n enc_dec_klass = ViTEncDec\n else:\n raise ValueError(f'{vae_type} not valid')\n self.enc_dec = enc_dec_klass(\n dim = dim,\n channels = channels,\n layers = layers,\n **encdec_kwargs\n )\n self.vq = VQ(\n dim = self.enc_dec.encoded_dim,\n codebook_dim = vq_codebook_dim,\n codebook_size = vq_codebook_size,\n decay = vq_decay,\n commitment_weight = vq_commitment_weight,\n accept_image_fmap = True,\n kmeans_init = vq_kmeans_init,\n use_cosine_sim = vq_use_cosine_sim,"
+ },
+ {
+ "comment": "This code defines a VQGAN-VAE model with optional GAN and perceptual loss components. It initializes the VGG model, Discriminator, and sets the reconstruction and generator losses based on provided arguments. The encoded_dim property returns the dimension of the encoded images.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":596-632",
+ "content": " **vq_kwargs\n )\n # reconstruction loss\n self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss\n # turn off GAN and perceptual loss if grayscale\n self.vgg = None\n self.discr = None\n self.use_vgg_and_gan = use_vgg_and_gan\n if not use_vgg_and_gan:\n return\n # preceptual loss\n if exists(vgg):\n self.vgg = vgg\n else:\n self.vgg = torchvision.models.vgg16(pretrained = True)\n self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])\n # gan related losses\n layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))\n layer_dims = [dim * mult for mult in layer_mults]\n dims = (dim, *layer_dims)\n self.discr = Discriminator(dims = dims, channels = channels)\n self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss\n self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss\n @property\n def encoded_dim(self):"
+ },
+ {
+ "comment": "This code defines a class with methods to get encoded dimensions, calculate encoded frame map size, copy the model for evaluation, save and load state dictionary while removing VGG, encode input frames, and decode encoded frames.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":633-671",
+ "content": " return self.enc_dec.encoded_dim\n def get_encoded_fmap_size(self, image_size):\n return self.enc_dec.get_encoded_fmap_size(image_size)\n def copy_for_eval(self):\n device = next(self.parameters()).device\n vae_copy = copy.deepcopy(self.cpu())\n if vae_copy.use_vgg_and_gan:\n del vae_copy.discr\n del vae_copy.vgg\n vae_copy.eval()\n return vae_copy.to(device)\n @remove_vgg\n def state_dict(self, *args, **kwargs):\n return super().state_dict(*args, **kwargs)\n @remove_vgg\n def load_state_dict(self, *args, **kwargs):\n return super().load_state_dict(*args, **kwargs)\n @property\n def codebook(self):\n return self.vq.codebook\n def encode(self, fmap):\n fmap = self.enc_dec.encode(fmap)\n return fmap\n def decode(self, fmap, return_indices_and_loss = False):\n fmap, indices, commit_loss = self.vq(fmap)\n fmap = self.enc_dec.decode(fmap)\n if not return_indices_and_loss:\n return fmap"
+ },
+ {
+ "comment": "This function encodes an input image, decodes it, and can optionally return autoencoder or discriminator losses. It expects the image to have the specified dimensions and number of channels. The code asserts that the image's height, width, and number of channels match the expected values, and that only one type of loss is returned at a time.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":673-699",
+ "content": " return fmap, indices, commit_loss\n def forward(\n self,\n img,\n return_loss = False,\n return_discr_loss = False,\n return_recons = False,\n add_gradient_penalty = True\n ):\n batch, channels, height, width, device = *img.shape, img.device\n assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'\n assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'\n fmap = self.encode(img)\n fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)\n if not return_loss and not return_discr_loss:\n return fmap\n assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'\n # whether to return discriminator loss\n if return_discr_loss:\n assert exists(self.discr), 'discriminator must exist to train it'"
+ },
+ {
+ "comment": "The code is calculating the reconstruction and perceptual loss for an image generation model. It also includes gradient penalty for the discriminator loss, and optionally returns the reconstructed feature map.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":701-738",
+ "content": " fmap.detach_()\n img.requires_grad_()\n fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))\n discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)\n if add_gradient_penalty:\n gp = gradient_penalty(img, img_discr_logits)\n loss = discr_loss + gp\n if return_recons:\n return loss, fmap\n return loss\n # reconstruction loss\n recon_loss = self.recon_loss_fn(fmap, img)\n # early return if training on grayscale\n if not self.use_vgg_and_gan:\n if return_recons:\n return recon_loss, fmap\n return recon_loss\n # perceptual loss\n img_vgg_input = img\n fmap_vgg_input = fmap\n if img.shape[1] == 1:\n # handle grayscale for vgg\n img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))\n img_vgg_feats = self.vgg(img_vgg_input)"
+ },
+ {
+ "comment": "This code calculates a combination of losses, including reconstruction, perceptual, and commitment. The adaptive weight is determined based on the gradients of these losses. A clamp function limits the adaptive weight to prevent extreme values. Finally, the combined loss is calculated and returned. If return_recons is True, fmap is also returned.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/dalle2_pytorch/vqgan_vae.py\":739-763",
+ "content": " recon_vgg_feats = self.vgg(fmap_vgg_input)\n perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)\n # generator loss\n gen_loss = self.gen_loss(self.discr(fmap))\n # calculate adaptive weight\n last_dec_layer = self.enc_dec.last_dec_layer\n norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)\n norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)\n adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)\n adaptive_weight.clamp_(max = 1e4)\n # combine losses\n loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss\n if return_recons:\n return loss, fmap\n return loss"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/e6fa9e64-fbf2-4547-8c7a-7a39a322fc17.json b/docs/doc/e6fa9e64-fbf2-4547-8c7a-7a39a322fc17.json
new file mode 100644
index 00000000..beee2e9e
--- /dev/null
+++ b/docs/doc/e6fa9e64-fbf2-4547-8c7a-7a39a322fc17.json
@@ -0,0 +1,75 @@
+{
+ "summary": "This code uses diffusion prior and CLIP to generate images from text prompts, implements pre-trained decoders, compares EMA models, checks image embeddings in DALLE2-pytorch, and discusses overfitting and running diffusion model training scripts.",
+ "details": [
+ {
+ "comment": "This code introduces the concept of a diffusion prior, which is a trained model that allows translation between two embedding spaces. It motivates the use case of generating images from text using CLIP and a Decoder when embeddings are not guaranteed to be in the same space. The code loads CLIP and a pre-trained decoder, then retrieves a prompt from the user and encodes it with CLIP for further processing.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/prior.md\":0-20",
+ "content": "# Diffusion Prior\nThis readme serves as an introduction to the diffusion prior.\n## Intro\nA properly trained prior will allow you to translate between two embedding spaces. If you know *a priori* that two embeddings are connected some way\u2014then ability the translate between them could extremely helpful.\n### Motivation\nBefore we dive into the model, let\u2019s look at a quick example of where the model may be helpful.\nFor demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder.\n> [CLIP](https://openai.com/blog/clip/) is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are ***close*** the image and text embeddings occupy two disjoint sets.\n```python\n# Load Models\nclip_model = clip.load(\"ViT-L/14\")\ndecoder = Decoder(checkpoint=\"best.pth\") # A decoder trained on CLIP Image embeddings\n# Retrieve prompt from user and encode with CLIP"
+ },
+ {
+ "comment": "This code snippet demonstrates the process of generating an image from a text prompt using deep learning models. The decoder model is trained to convert text into embeddings that are in the same space as CLIP image embeddings. First, we load two models: Prior and Decoder. Then, we retrieve a user-inputted prompt, tokenize it, and use the Prior model to sample a text embedding in the same space as images. Finally, we pass this text embedding into the Decoder model to generate an image.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/prior.md\":21-46",
+ "content": "prompt = \"A corgi wearing sunglasses\"\ntokenized_text = tokenize(prompt)\ntext_embedding = clip_model.encode_text(tokenized_text)\n# Now, pass the text embedding to the decoder\npredicted_image = decoder.sample(text_embedding)\n```\n> **Question**: *Can you spot the issue here?*\n>\n> **Answer**: *We\u2019re trying to generate an image from a text embedding!*\nUnfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution\n```python\n# Load Models\nprior= Prior(checkpoint=\"prior.pth\") # A decoder trained to go from: text-> clip text emb -> clip img emb\ndecoder = Decoder(checkpoint=\"decoder.pth\") # A decoder trained on CLIP Image embeddings\n# Retrieve prompt from user and encode with a prior\nprompt = \"A corgi wearing sunglasses\"\ntokenized_text = tokenize(prompt)\ntext_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images!\n# Now, pass the predicted image embedding to the decoder\npredicted_image = decoder.sample(text_embedding)"
+ },
+ {
+ "comment": "The code demonstrates how to load a pre-trained prior model for use in generating embeddings within CLIP's image space, enhancing the performance of the decoder. The usage section outlines the necessary steps to load a checkpoint from a specific path using `load_diffusion_model()`.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/prior.md\":47-75",
+ "content": "```\nWith the prior we are able to successfully generate embeddings *within* CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data.\n> **You may be asking yourself the following question:**\n>\n> *\"Why don't you just train the decoder on clip text embeddings instead of image embeddings?\"*\n>\n> OpenAI covers this topic in their [DALLE-2 paper](https://arxiv.org/abs/2204.06125). The TL;DR is *\"it doesn't work as well as decoders trained on image embeddings\"*...also...its just an example :smile:\n## Usage\nTo utilize a pre-trained prior, it\u2019s quite simple.\n### Loading Checkpoints\n```python\nimport torch\nfrom dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter\nfrom dalle2_pytorch.trainer import DiffusionPriorTrainer\ndef load_diffusion_model(dprior_path):\n prior_network = DiffusionPriorNetwork(\n dim=768,\n depth=24,\n dim_head=64,\n heads=32,\n normformer=True,\n attn_dropout=5e-2,"
+ },
+ {
+ "comment": "Here, a pre-trained model is instantiated and its weights are loaded. This can be done just like any other PyTorch model. To generate embeddings from text, first tokenize the input text using `clip.tokenize()`.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/prior.md\":76-118",
+ "content": " ff_dropout=5e-2,\n num_time_embeds=1,\n num_image_embeds=1,\n num_text_embeds=1,\n num_timesteps=1000,\n ff_mult=4\n )\n diffusion_prior = DiffusionPrior(\n net=prior_network,\n clip=OpenAIClipAdapter(\"ViT-L/14\"),\n image_embed_dim=768,\n timesteps=1000,\n cond_drop_prob=0.1,\n loss_type=\"l2\",\n condition_on_text_encodings=True,\n )\n trainer = DiffusionPriorTrainer(\n diffusion_prior=diffusion_prior,\n lr=1.1e-4,\n wd=6.02e-2,\n max_grad_norm=0.5,\n amp=False,\n group_wd_params=True,\n use_ema=True,\n device=device,\n accelerator=None,\n )\n trainer.load(dprior_path)\n return trainer\n```\n Here we instantiate a model matches the configuration it was trained with, and then load the weights (*just like any other PyTorch model!*)\n### Sampling\nOnce we have a pre-trained model, generating embeddings is quite simple!\n```python\n# tokenize the text\ntokenized_text = clip.tokenize(\"\")"
+ },
+ {
+ "comment": "The code snippet is predicting an embedding using the prior's sample function, which returns a tensor of the same shape as the training data. The number of embeddings to sample can be specified and conditioning scale can be adjusted for better results. It serves as a replacement for clip.encode_text() in CLIP priors.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/prior.md\":119-129",
+ "content": "# predict an embedding\npredicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)\n```\nThe resulting tensor returned from `.sample()` is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on `ViT-L/14` embeddings will predict an embedding of shape (1, 768).\n> For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text().\n**Some things to note:**\n* It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is `n=2`). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt.\n* You may specify a higher conditioning scale than the default (`1.0`). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than `1.0` but *ymmv*."
+ },
+ {
+ "comment": "Training the prior involves preparing a dataset in the format expected by EmbeddingReader. Precomputed embeddings for images significantly increase training efficiency and are beneficial for other tasks as well. To obtain precomputed embeddings, you can use img2dataset and clip_retrieval. The configuration file enables tracking and reproducing experiments.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/prior.md\":131-145",
+ "content": "---\n## Training\n### Overview\nTraining the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration\n## Dataset\nTo train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage [img2datset](https://github.com/rom1504/img2dataset) to pull images from a list of URLs and [clip_retrieval](https://github.com/rom1504/clip-retrieval#clip-inference) for generating the actual embeddings that can be used in the prior's dataloader.\n## Configuration\nThe configuration file allows for you to easily track and reproduce experiments. It is a simple JSON file that wil"
+ },
+ {
+ "comment": "This code describes the architecture, dataset, and training parameters for a specific task. It also mentions distributed training using HuggingFace's Accelerate library and various evaluation metrics available during training.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/prior.md\":145-154",
+ "content": "l specify the architecture, dataset, and training parameters. For more information and specifics please see the configuration README.\n## Distributed Training\nIf you would like to train in a distributed manner we have opted to leverage huggingface\u2019 new Accelerate library. HFA makes it extremely simple to distribute work across multiple GPU\u2019s and nodes. All that is required of you is to follow the simple CLI configuration tool [more information here](https://huggingface.co/docs/accelerate/accelerator).\n## Evaluation\nThere are a variety of metrics available to you when training the prior. You can read a brief description of each in the table below:\n| Metric | Description | Comments "
+ },
+ {
+ "comment": "This code is for calculating the validation loss associated with the online model validation process. The calculated validation loss will be used to evaluate the performance of the trained model during inference.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/prior.md\":154-156",
+ "content": " |\n| ----------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| Online Model Validation | The validation loss associated "
+ },
+ {
+ "comment": "This code is discussing the usage of an Exponential Moving Average (EMA) model in a machine learning context. The EMA model's performance is compared to the online model, specifically focusing on validation loss as a metric. The lower the validation loss, the better the model's performance, with values around 0.1 achievable after billions of samples. However, the EMA validation loss might lag behind but should outperform in the long term.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/prior.md\":156-157",
+ "content": "with your online model. | Ideally validation loss will be as low as possible. Using L2 loss, values as low as `0.1` and lower are possible after around 1 Billion samples seen. |\n| EMA Validation | This metric measures the validation loss associated with your EMA model. | This will likely lag behind your \"online\" model's validation loss, but should outperform in the long-term. "
+ },
+ {
+ "comment": "This code snippet is explaining the concept of baseline similarity in the context of DALLE2-pytorch, where it refers to the similarity between dataset prompts and image embeddings. It also mentions that generally, a cosine similarity value of 0.3 is considered good for caption similarity. Additionally, there's information about another metric - similarity with original image, which measures cosine similarity between the prior's predicted image embedding and the actual image.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/prior.md\":157-159",
+ "content": " |\n| Baseline Similarity | Baseline similarity refers to the similarity between your dataset's prompts and associated image embeddings. This will serve as a guide for your prior's performance in cosine similarity. | Generally `0.3` is considered a good cosine similarity for caption similarity. |\n| Similarity With Original Image | This metric will measure the cosine similarity between your prior's predicted image embedding and the actual image"
+ },
+ {
+ "comment": "The code provides information about the similarity metric between generated images and captions, as well as the difference from baseline similarity. The values should improve rapidly in early stages of training and plateau over time, while staying around 0 for the difference metric. Values above 0.5/0.6 or climbing to high values may indicate issues with training efficiency or overfitting, respectively.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/prior.md\":159-160",
+ "content": " that the caption was associated with. This is useful for determining wether your prior is generating images with the right contents. | Values around `0.75`+ are obtainable. This metric should improve rapidly in the early stages of training and plateau with diminishing increases over time. If it takes hundreds of millions of samples to reach above `0.5`/`0.6` similarity--then you likely are suffering from some kind of training error or inefficiency (i.e. not using EMA) |\n| Difference From Baseline Similarity | Sometimes its useful to visualize a metric in another light. This metric will show you how your prior's predicted image embeddings match up with the baseline similarity measured in your dataset. | This value should float around `0.0` with some room for variation. After a billion samples seen, values are within `0.01`+/- of `0.0`. If this climbs to high, (~>`0.02`) then this may be a sign that your model is overfitting "
+ },
+ {
+ "comment": "The code measures the cosine similarity between predicted image embeddings and original captions, as well as with unrelated captions to detect overfitting. Monitoring these metrics is crucial for model performance, as they indicate how well the model is learning from captions and generating valid image embeddings.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/prior.md\":160-162",
+ "content": "somehow. |\n| Similarity With Text | This metric is your bread and butter cosine similarity between the predicted image embedding and the original caption given to the prior. Monitoring this metric will be on of your main focuses and is probably the second most important behind your loss. | As mentioned, this value should be close to baseline similarity. We have observed early rapid increase with diminishing returns as the prior learns to generate valid image embeddings. If this value increases too far beyond the baseline similarity--it could be an indication that your model is overfitting. |\n| Similarity With Unrelated Caption | This metric will attempt to exposed an overfit prior by feeding it arbitrary prompts (from your dataset) and then measure the similarity of this predicted embedding with some other image. "
+ },
+ {
+ "comment": "The code provides instructions on how to launch the training script for a diffusion model using either distributed training with HuggingFace Accelerate or without it. It also mentions that checkpoints will be saved in the directory specified in the configuration file, and an additional final checkpoint will be saved before running the test split. The prior value should ideally be kept low to avoid fooling CLIP into believing unrelated captions and images have high cosine similarity.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/prior.md\":162-174",
+ "content": " | Early on we found that a poorly trained/modeled prior could effectively fool CLIP into believing that the cosine similarity between two images were high (when in fact the caption and image were completely unrelated). With this in mind--a low value is ideal, anything below `0.1` is probably safe. |\n## Launching the script\nNow that you\u2019ve done all the prep it\u2019s time for the easy part! \ud83d\ude80\nTo actually launch the script, you will either use `accelerate launch train_diffusion_prior.py --config_path ` to launch with distributed training & huggingface accelerate or `python train_diffusion_prior.py` if you would like to train on your gpu/cpu without huggingface accelerate.\n## Checkpointing\nCheckpoints will be saved to the directory specified in your configuration file.\nAdditionally, a final checkpoint is saved before running the test split. This file will be saved to the same directory and"
+ },
+ {
+ "comment": "This code snippet is providing information about the \"latest.pth\" file and its purpose to avoid potential problems with `save_every` configuration not overlapping with data requirements. It also mentions that the prior network has not been trained for tasks other than traditional CLIP embedding translation, hinting at future experiments applying the prior network to other tasks.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/prior.md\":174-182",
+ "content": " titled \u201clatest.pth\u201d. This is to avoid problems where your `save_every` configuration does not overlap with the number of steps required to do a complete pass through the data.\n## Things To Keep In Mind\nThe prior has not been trained for tasks other than the traditional CLIP embedding translation\u2026at least yet.\nAs we finalize the replication of unCLIP, there will almost assuredly be experiments attempting to apply the prior network to other tasks.\nWith that in mind, you are more or less a pioneer in embedding-translation if you are reading this and attempting something you don\u2019t see documentation for!"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/doc/ef6bd6dd-2706-4200-811d-fd277410e29c.json b/docs/doc/ef6bd6dd-2706-4200-811d-fd277410e29c.json
new file mode 100644
index 00000000..fcd4eef7
--- /dev/null
+++ b/docs/doc/ef6bd6dd-2706-4200-811d-fd277410e29c.json
@@ -0,0 +1,10 @@
+{
+ "summary": "Install the updated pip and then install the project in editable mode. Then, run tests for a specific decoder configuration using CUDA visible devices and a provided JSON file.",
+ "details": [
+ {
+ "comment": "Install the updated pip and then install the project in editable mode. Then, run tests for a specific decoder configuration using CUDA visible devices and a provided JSON file.",
+ "location": "\"/media/root/Toshiba XG3/works/DALLE2-pytorch/docs/src/Makefile\":0-5",
+ "content": "install:\n\tpip install -U pip\n\tpip install -e .\ntest:\n\tCUDA_VISIBLE_DEVICES= python train_decoder.py --config_file configs/train_decoder_config.test.json"
+ }
+ ]
+}
\ No newline at end of file
diff --git a/docs/github-markdown.css b/docs/github-markdown.css
new file mode 100644
index 00000000..96a4f29e
--- /dev/null
+++ b/docs/github-markdown.css
@@ -0,0 +1,1197 @@
+@media (prefers-color-scheme: dark) {
+
+ .markdown-body,
+ [data-theme="dark"] {
+ /*dark*/
+ color-scheme: dark;
+ --color-prettylights-syntax-comment: #8b949e;
+ --color-prettylights-syntax-constant: #79c0ff;
+ --color-prettylights-syntax-entity: #d2a8ff;
+ --color-prettylights-syntax-storage-modifier-import: #c9d1d9;
+ --color-prettylights-syntax-entity-tag: #7ee787;
+ --color-prettylights-syntax-keyword: #ff7b72;
+ --color-prettylights-syntax-string: #a5d6ff;
+ --color-prettylights-syntax-variable: #ffa657;
+ --color-prettylights-syntax-brackethighlighter-unmatched: #f85149;
+ --color-prettylights-syntax-invalid-illegal-text: #f0f6fc;
+ --color-prettylights-syntax-invalid-illegal-bg: #8e1519;
+ --color-prettylights-syntax-carriage-return-text: #f0f6fc;
+ --color-prettylights-syntax-carriage-return-bg: #b62324;
+ --color-prettylights-syntax-string-regexp: #7ee787;
+ --color-prettylights-syntax-markup-list: #f2cc60;
+ --color-prettylights-syntax-markup-heading: #1f6feb;
+ --color-prettylights-syntax-markup-italic: #c9d1d9;
+ --color-prettylights-syntax-markup-bold: #c9d1d9;
+ --color-prettylights-syntax-markup-deleted-text: #ffdcd7;
+ --color-prettylights-syntax-markup-deleted-bg: #67060c;
+ --color-prettylights-syntax-markup-inserted-text: #aff5b4;
+ --color-prettylights-syntax-markup-inserted-bg: #033a16;
+ --color-prettylights-syntax-markup-changed-text: #ffdfb6;
+ --color-prettylights-syntax-markup-changed-bg: #5a1e02;
+ --color-prettylights-syntax-markup-ignored-text: #c9d1d9;
+ --color-prettylights-syntax-markup-ignored-bg: #1158c7;
+ --color-prettylights-syntax-meta-diff-range: #d2a8ff;
+ --color-prettylights-syntax-brackethighlighter-angle: #8b949e;
+ --color-prettylights-syntax-sublimelinter-gutter-mark: #484f58;
+ --color-prettylights-syntax-constant-other-reference-link: #a5d6ff;
+ --color-fg-default: #e6edf3;
+ --color-fg-muted: #848d97;
+ --color-fg-subtle: #6e7681;
+ --color-canvas-default: #0d1117;
+ --color-canvas-subtle: #161b22;
+ --color-border-default: #30363d;
+ --color-border-muted: #21262d;
+ --color-neutral-muted: rgba(110, 118, 129, 0.4);
+ --color-accent-fg: #2f81f7;
+ --color-accent-emphasis: #1f6feb;
+ --color-success-fg: #3fb950;
+ --color-success-emphasis: #238636;
+ --color-attention-fg: #d29922;
+ --color-attention-emphasis: #9e6a03;
+ --color-attention-subtle: rgba(187, 128, 9, 0.15);
+ --color-danger-fg: #f85149;
+ --color-danger-emphasis: #da3633;
+ --color-done-fg: #a371f7;
+ --color-done-emphasis: #8957e5;
+ }
+}
+
+@media (prefers-color-scheme: light) {
+
+ .markdown-body,
+ [data-theme="light"] {
+ /*light*/
+ color-scheme: light;
+ --color-prettylights-syntax-comment: #57606a;
+ --color-prettylights-syntax-constant: #0550ae;
+ --color-prettylights-syntax-entity: #6639ba;
+ --color-prettylights-syntax-storage-modifier-import: #24292f;
+ --color-prettylights-syntax-entity-tag: #116329;
+ --color-prettylights-syntax-keyword: #cf222e;
+ --color-prettylights-syntax-string: #0a3069;
+ --color-prettylights-syntax-variable: #953800;
+ --color-prettylights-syntax-brackethighlighter-unmatched: #82071e;
+ --color-prettylights-syntax-invalid-illegal-text: #f6f8fa;
+ --color-prettylights-syntax-invalid-illegal-bg: #82071e;
+ --color-prettylights-syntax-carriage-return-text: #f6f8fa;
+ --color-prettylights-syntax-carriage-return-bg: #cf222e;
+ --color-prettylights-syntax-string-regexp: #116329;
+ --color-prettylights-syntax-markup-list: #3b2300;
+ --color-prettylights-syntax-markup-heading: #0550ae;
+ --color-prettylights-syntax-markup-italic: #24292f;
+ --color-prettylights-syntax-markup-bold: #24292f;
+ --color-prettylights-syntax-markup-deleted-text: #82071e;
+ --color-prettylights-syntax-markup-deleted-bg: #ffebe9;
+ --color-prettylights-syntax-markup-inserted-text: #116329;
+ --color-prettylights-syntax-markup-inserted-bg: #dafbe1;
+ --color-prettylights-syntax-markup-changed-text: #953800;
+ --color-prettylights-syntax-markup-changed-bg: #ffd8b5;
+ --color-prettylights-syntax-markup-ignored-text: #eaeef2;
+ --color-prettylights-syntax-markup-ignored-bg: #0550ae;
+ --color-prettylights-syntax-meta-diff-range: #8250df;
+ --color-prettylights-syntax-brackethighlighter-angle: #57606a;
+ --color-prettylights-syntax-sublimelinter-gutter-mark: #8c959f;
+ --color-prettylights-syntax-constant-other-reference-link: #0a3069;
+ --color-fg-default: #1F2328;
+ --color-fg-muted: #656d76;
+ --color-fg-subtle: #6e7781;
+ --color-canvas-default: #ffffff;
+ --color-canvas-subtle: #f6f8fa;
+ --color-border-default: #d0d7de;
+ --color-border-muted: hsla(210, 18%, 87%, 1);
+ --color-neutral-muted: rgba(175, 184, 193, 0.2);
+ --color-accent-fg: #0969da;
+ --color-accent-emphasis: #0969da;
+ --color-success-fg: #1a7f37;
+ --color-success-emphasis: #1f883d;
+ --color-attention-fg: #9a6700;
+ --color-attention-emphasis: #9a6700;
+ --color-attention-subtle: #fff8c5;
+ --color-danger-fg: #d1242f;
+ --color-danger-emphasis: #cf222e;
+ --color-done-fg: #8250df;
+ --color-done-emphasis: #8250df;
+ }
+}
+
+.markdown-body {
+ -ms-text-size-adjust: 100%;
+ -webkit-text-size-adjust: 100%;
+ margin: 0;
+ color: var(--color-fg-default);
+ background-color: var(--color-canvas-default);
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "Noto Sans", Helvetica, Arial, sans-serif, "Apple Color Emoji", "Segoe UI Emoji";
+ font-size: 16px;
+ line-height: 1.5;
+ word-wrap: break-word;
+}
+
+.markdown-body .octicon {
+ display: inline-block;
+ fill: currentColor;
+ vertical-align: text-bottom;
+}
+
+.markdown-body h1:hover .anchor .octicon-link:before,
+.markdown-body h2:hover .anchor .octicon-link:before,
+.markdown-body h3:hover .anchor .octicon-link:before,
+.markdown-body h4:hover .anchor .octicon-link:before,
+.markdown-body h5:hover .anchor .octicon-link:before,
+.markdown-body h6:hover .anchor .octicon-link:before {
+ width: 16px;
+ height: 16px;
+ content: ' ';
+ display: inline-block;
+ background-color: currentColor;
+ -webkit-mask-image: url("data:image/svg+xml,");
+ mask-image: url("data:image/svg+xml,");
+}
+
+.markdown-body details,
+.markdown-body figcaption,
+.markdown-body figure {
+ display: block;
+}
+
+.markdown-body summary {
+ display: list-item;
+}
+
+.markdown-body [hidden] {
+ display: none !important;
+}
+
+.markdown-body a {
+ background-color: transparent;
+ color: var(--color-accent-fg);
+ text-decoration: none;
+}
+
+.markdown-body abbr[title] {
+ border-bottom: none;
+ -webkit-text-decoration: underline dotted;
+ text-decoration: underline dotted;
+}
+
+.markdown-body b,
+.markdown-body strong {
+ font-weight: var(--base-text-weight-semibold, 600);
+}
+
+.markdown-body dfn {
+ font-style: italic;
+}
+
+.markdown-body h1 {
+ margin: .67em 0;
+ font-weight: var(--base-text-weight-semibold, 600);
+ padding-bottom: .3em;
+ font-size: 2em;
+ border-bottom: 1px solid var(--color-border-muted);
+}
+
+.markdown-body mark {
+ background-color: var(--color-attention-subtle);
+ color: var(--color-fg-default);
+}
+
+.markdown-body small {
+ font-size: 90%;
+}
+
+.markdown-body sub,
+.markdown-body sup {
+ font-size: 75%;
+ line-height: 0;
+ position: relative;
+ vertical-align: baseline;
+}
+
+.markdown-body sub {
+ bottom: -0.25em;
+}
+
+.markdown-body sup {
+ top: -0.5em;
+}
+
+.markdown-body img {
+ border-style: none;
+ max-width: 100%;
+ box-sizing: content-box;
+ background-color: var(--color-canvas-default);
+}
+
+.markdown-body code,
+.markdown-body kbd,
+.markdown-body pre,
+.markdown-body samp {
+ font-family: monospace;
+ font-size: 1em;
+}
+
+.markdown-body figure {
+ margin: 1em 40px;
+}
+
+.markdown-body hr {
+ box-sizing: content-box;
+ overflow: hidden;
+ background: transparent;
+ border-bottom: 1px solid var(--color-border-muted);
+ height: .25em;
+ padding: 0;
+ margin: 24px 0;
+ background-color: var(--color-border-default);
+ border: 0;
+}
+
+.markdown-body input {
+ font: inherit;
+ margin: 0;
+ overflow: visible;
+ font-family: inherit;
+ font-size: inherit;
+ line-height: inherit;
+}
+
+.markdown-body [type=button],
+.markdown-body [type=reset],
+.markdown-body [type=submit] {
+ -webkit-appearance: button;
+ appearance: button;
+}
+
+.markdown-body [type=checkbox],
+.markdown-body [type=radio] {
+ box-sizing: border-box;
+ padding: 0;
+}
+
+.markdown-body [type=number]::-webkit-inner-spin-button,
+.markdown-body [type=number]::-webkit-outer-spin-button {
+ height: auto;
+}
+
+.markdown-body [type=search]::-webkit-search-cancel-button,
+.markdown-body [type=search]::-webkit-search-decoration {
+ -webkit-appearance: none;
+ appearance: none;
+}
+
+.markdown-body ::-webkit-input-placeholder {
+ color: inherit;
+ opacity: .54;
+}
+
+.markdown-body ::-webkit-file-upload-button {
+ -webkit-appearance: button;
+ appearance: button;
+ font: inherit;
+}
+
+.markdown-body a:hover {
+ text-decoration: underline;
+}
+
+.markdown-body ::placeholder {
+ color: var(--color-fg-subtle);
+ opacity: 1;
+}
+
+.markdown-body hr::before {
+ display: table;
+ content: "";
+}
+
+.markdown-body hr::after {
+ display: table;
+ clear: both;
+ content: "";
+}
+
+.markdown-body table {
+ border-spacing: 0;
+ border-collapse: collapse;
+ display: block;
+ width: max-content;
+ max-width: 100%;
+ overflow: auto;
+}
+
+.markdown-body td,
+.markdown-body th {
+ padding: 0;
+}
+
+.markdown-body details summary {
+ cursor: pointer;
+}
+
+.markdown-body details:not([open])>*:not(summary) {
+ display: none !important;
+}
+
+.markdown-body a:focus,
+.markdown-body [role=button]:focus,
+.markdown-body input[type=radio]:focus,
+.markdown-body input[type=checkbox]:focus {
+ outline: 2px solid var(--color-accent-fg);
+ outline-offset: -2px;
+ box-shadow: none;
+}
+
+.markdown-body a:focus:not(:focus-visible),
+.markdown-body [role=button]:focus:not(:focus-visible),
+.markdown-body input[type=radio]:focus:not(:focus-visible),
+.markdown-body input[type=checkbox]:focus:not(:focus-visible) {
+ outline: solid 1px transparent;
+}
+
+.markdown-body a:focus-visible,
+.markdown-body [role=button]:focus-visible,
+.markdown-body input[type=radio]:focus-visible,
+.markdown-body input[type=checkbox]:focus-visible {
+ outline: 2px solid var(--color-accent-fg);
+ outline-offset: -2px;
+ box-shadow: none;
+}
+
+.markdown-body a:not([class]):focus,
+.markdown-body a:not([class]):focus-visible,
+.markdown-body input[type=radio]:focus,
+.markdown-body input[type=radio]:focus-visible,
+.markdown-body input[type=checkbox]:focus,
+.markdown-body input[type=checkbox]:focus-visible {
+ outline-offset: 0;
+}
+
+.markdown-body kbd {
+ display: inline-block;
+ padding: 3px 5px;
+ font: 11px ui-monospace, SFMono-Regular, SF Mono, Menlo, Consolas, Liberation Mono, monospace;
+ line-height: 10px;
+ color: var(--color-fg-default);
+ vertical-align: middle;
+ background-color: var(--color-canvas-subtle);
+ border: solid 1px var(--color-neutral-muted);
+ border-bottom-color: var(--color-neutral-muted);
+ border-radius: 6px;
+ box-shadow: inset 0 -1px 0 var(--color-neutral-muted);
+}
+
+.markdown-body h1,
+.markdown-body h2,
+.markdown-body h3,
+.markdown-body h4,
+.markdown-body h5,
+.markdown-body h6 {
+ margin-top: 24px;
+ margin-bottom: 16px;
+ font-weight: var(--base-text-weight-semibold, 600);
+ line-height: 1.25;
+}
+
+.markdown-body h2 {
+ font-weight: var(--base-text-weight-semibold, 600);
+ padding-bottom: .3em;
+ font-size: 1.5em;
+ border-bottom: 1px solid var(--color-border-muted);
+}
+
+.markdown-body h3 {
+ font-weight: var(--base-text-weight-semibold, 600);
+ font-size: 1.25em;
+}
+
+.markdown-body h4 {
+ font-weight: var(--base-text-weight-semibold, 600);
+ font-size: 1em;
+}
+
+.markdown-body h5 {
+ font-weight: var(--base-text-weight-semibold, 600);
+ font-size: .875em;
+}
+
+.markdown-body h6 {
+ font-weight: var(--base-text-weight-semibold, 600);
+ font-size: .85em;
+ color: var(--color-fg-muted);
+}
+
+.markdown-body p {
+ margin-top: 0;
+ margin-bottom: 10px;
+}
+
+.markdown-body blockquote {
+ margin: 0;
+ padding: 0 1em;
+ color: var(--color-fg-muted);
+ border-left: .25em solid var(--color-border-default);
+}
+
+.markdown-body ul,
+.markdown-body ol {
+ margin-top: 0;
+ margin-bottom: 0;
+ padding-left: 2em;
+}
+
+.markdown-body ol ol,
+.markdown-body ul ol {
+ list-style-type: lower-roman;
+}
+
+.markdown-body ul ul ol,
+.markdown-body ul ol ol,
+.markdown-body ol ul ol,
+.markdown-body ol ol ol {
+ list-style-type: lower-alpha;
+}
+
+.markdown-body dd {
+ margin-left: 0;
+}
+
+.markdown-body tt,
+.markdown-body code,
+.markdown-body samp {
+ font-family: ui-monospace, SFMono-Regular, SF Mono, Menlo, Consolas, Liberation Mono, monospace;
+ font-size: 12px;
+}
+
+.markdown-body pre {
+ margin-top: 0;
+ margin-bottom: 0;
+ font-family: ui-monospace, SFMono-Regular, SF Mono, Menlo, Consolas, Liberation Mono, monospace;
+ font-size: 12px;
+ word-wrap: normal;
+}
+
+.markdown-body .octicon {
+ display: inline-block;
+ overflow: visible !important;
+ vertical-align: text-bottom;
+ fill: currentColor;
+}
+
+.markdown-body input::-webkit-outer-spin-button,
+.markdown-body input::-webkit-inner-spin-button {
+ margin: 0;
+ -webkit-appearance: none;
+ appearance: none;
+}
+
+.markdown-body .mr-2 {
+ margin-right: var(--base-size-8, 8px) !important;
+}
+
+.markdown-body::before {
+ display: table;
+ content: "";
+}
+
+.markdown-body::after {
+ display: table;
+ clear: both;
+ content: "";
+}
+
+.markdown-body>*:first-child {
+ margin-top: 0 !important;
+}
+
+.markdown-body>*:last-child {
+ margin-bottom: 0 !important;
+}
+
+.markdown-body a:not([href]) {
+ color: inherit;
+ text-decoration: none;
+}
+
+.markdown-body .absent {
+ color: var(--color-danger-fg);
+}
+
+.markdown-body .anchor {
+ float: left;
+ padding-right: 4px;
+ margin-left: -20px;
+ line-height: 1;
+}
+
+.markdown-body .anchor:focus {
+ outline: none;
+}
+
+.markdown-body p,
+.markdown-body blockquote,
+.markdown-body ul,
+.markdown-body ol,
+.markdown-body dl,
+.markdown-body table,
+.markdown-body pre,
+.markdown-body details {
+ margin-top: 0;
+ margin-bottom: 16px;
+}
+
+.markdown-body blockquote>:first-child {
+ margin-top: 0;
+}
+
+.markdown-body blockquote>:last-child {
+ margin-bottom: 0;
+}
+
+.markdown-body h1 .octicon-link,
+.markdown-body h2 .octicon-link,
+.markdown-body h3 .octicon-link,
+.markdown-body h4 .octicon-link,
+.markdown-body h5 .octicon-link,
+.markdown-body h6 .octicon-link {
+ color: var(--color-fg-default);
+ vertical-align: middle;
+ visibility: hidden;
+}
+
+.markdown-body h1:hover .anchor,
+.markdown-body h2:hover .anchor,
+.markdown-body h3:hover .anchor,
+.markdown-body h4:hover .anchor,
+.markdown-body h5:hover .anchor,
+.markdown-body h6:hover .anchor {
+ text-decoration: none;
+}
+
+.markdown-body h1:hover .anchor .octicon-link,
+.markdown-body h2:hover .anchor .octicon-link,
+.markdown-body h3:hover .anchor .octicon-link,
+.markdown-body h4:hover .anchor .octicon-link,
+.markdown-body h5:hover .anchor .octicon-link,
+.markdown-body h6:hover .anchor .octicon-link {
+ visibility: visible;
+}
+
+.markdown-body h1 tt,
+.markdown-body h1 code,
+.markdown-body h2 tt,
+.markdown-body h2 code,
+.markdown-body h3 tt,
+.markdown-body h3 code,
+.markdown-body h4 tt,
+.markdown-body h4 code,
+.markdown-body h5 tt,
+.markdown-body h5 code,
+.markdown-body h6 tt,
+.markdown-body h6 code {
+ padding: 0 .2em;
+ font-size: inherit;
+}
+
+.markdown-body summary h1,
+.markdown-body summary h2,
+.markdown-body summary h3,
+.markdown-body summary h4,
+.markdown-body summary h5,
+.markdown-body summary h6 {
+ display: inline-block;
+}
+
+.markdown-body summary h1 .anchor,
+.markdown-body summary h2 .anchor,
+.markdown-body summary h3 .anchor,
+.markdown-body summary h4 .anchor,
+.markdown-body summary h5 .anchor,
+.markdown-body summary h6 .anchor {
+ margin-left: -40px;
+}
+
+.markdown-body summary h1,
+.markdown-body summary h2 {
+ padding-bottom: 0;
+ border-bottom: 0;
+}
+
+.markdown-body ul.no-list,
+.markdown-body ol.no-list {
+ padding: 0;
+ list-style-type: none;
+}
+
+.markdown-body ol[type="a s"] {
+ list-style-type: lower-alpha;
+}
+
+.markdown-body ol[type="A s"] {
+ list-style-type: upper-alpha;
+}
+
+.markdown-body ol[type="i s"] {
+ list-style-type: lower-roman;
+}
+
+.markdown-body ol[type="I s"] {
+ list-style-type: upper-roman;
+}
+
+.markdown-body ol[type="1"] {
+ list-style-type: decimal;
+}
+
+.markdown-body div>ol:not([type]) {
+ list-style-type: decimal;
+}
+
+.markdown-body ul ul,
+.markdown-body ul ol,
+.markdown-body ol ol,
+.markdown-body ol ul {
+ margin-top: 0;
+ margin-bottom: 0;
+}
+
+.markdown-body li>p {
+ margin-top: 16px;
+}
+
+.markdown-body li+li {
+ margin-top: .25em;
+}
+
+.markdown-body dl {
+ padding: 0;
+}
+
+.markdown-body dl dt {
+ padding: 0;
+ margin-top: 16px;
+ font-size: 1em;
+ font-style: italic;
+ font-weight: var(--base-text-weight-semibold, 600);
+}
+
+.markdown-body dl dd {
+ padding: 0 16px;
+ margin-bottom: 16px;
+}
+
+.markdown-body table th {
+ font-weight: var(--base-text-weight-semibold, 600);
+}
+
+.markdown-body table th,
+.markdown-body table td {
+ padding: 6px 13px;
+ border: 1px solid var(--color-border-default);
+}
+
+.markdown-body table td>:last-child {
+ margin-bottom: 0;
+}
+
+.markdown-body table tr {
+ background-color: var(--color-canvas-default);
+ border-top: 1px solid var(--color-border-muted);
+}
+
+.markdown-body table tr:nth-child(2n) {
+ background-color: var(--color-canvas-subtle);
+}
+
+.markdown-body table img {
+ background-color: transparent;
+}
+
+.markdown-body img[align=right] {
+ padding-left: 20px;
+}
+
+.markdown-body img[align=left] {
+ padding-right: 20px;
+}
+
+.markdown-body .emoji {
+ max-width: none;
+ vertical-align: text-top;
+ background-color: transparent;
+}
+
+.markdown-body span.frame {
+ display: block;
+ overflow: hidden;
+}
+
+.markdown-body span.frame>span {
+ display: block;
+ float: left;
+ width: auto;
+ padding: 7px;
+ margin: 13px 0 0;
+ overflow: hidden;
+ border: 1px solid var(--color-border-default);
+}
+
+.markdown-body span.frame span img {
+ display: block;
+ float: left;
+}
+
+.markdown-body span.frame span span {
+ display: block;
+ padding: 5px 0 0;
+ clear: both;
+ color: var(--color-fg-default);
+}
+
+.markdown-body span.align-center {
+ display: block;
+ overflow: hidden;
+ clear: both;
+}
+
+.markdown-body span.align-center>span {
+ display: block;
+ margin: 13px auto 0;
+ overflow: hidden;
+ text-align: center;
+}
+
+.markdown-body span.align-center span img {
+ margin: 0 auto;
+ text-align: center;
+}
+
+.markdown-body span.align-right {
+ display: block;
+ overflow: hidden;
+ clear: both;
+}
+
+.markdown-body span.align-right>span {
+ display: block;
+ margin: 13px 0 0;
+ overflow: hidden;
+ text-align: right;
+}
+
+.markdown-body span.align-right span img {
+ margin: 0;
+ text-align: right;
+}
+
+.markdown-body span.float-left {
+ display: block;
+ float: left;
+ margin-right: 13px;
+ overflow: hidden;
+}
+
+.markdown-body span.float-left span {
+ margin: 13px 0 0;
+}
+
+.markdown-body span.float-right {
+ display: block;
+ float: right;
+ margin-left: 13px;
+ overflow: hidden;
+}
+
+.markdown-body span.float-right>span {
+ display: block;
+ margin: 13px auto 0;
+ overflow: hidden;
+ text-align: right;
+}
+
+.markdown-body code,
+.markdown-body tt {
+ padding: .2em .4em;
+ margin: 0;
+ font-size: 85%;
+ white-space: break-spaces;
+ background-color: var(--color-neutral-muted);
+ border-radius: 6px;
+}
+
+.markdown-body code br,
+.markdown-body tt br {
+ display: none;
+}
+
+.markdown-body del code {
+ text-decoration: inherit;
+}
+
+.markdown-body samp {
+ font-size: 85%;
+}
+
+.markdown-body pre code {
+ font-size: 100%;
+}
+
+.markdown-body pre>code {
+ padding: 0;
+ margin: 0;
+ word-break: normal;
+ white-space: pre;
+ background: transparent;
+ border: 0;
+}
+
+.markdown-body .highlight {
+ margin-bottom: 16px;
+}
+
+.markdown-body .highlight pre {
+ margin-bottom: 0;
+ word-break: normal;
+}
+
+.markdown-body .highlight pre,
+.markdown-body pre {
+ padding: 16px;
+ overflow: auto;
+ font-size: 85%;
+ line-height: 1.45;
+ color: var(--color-fg-default);
+ background-color: var(--color-canvas-subtle);
+ border-radius: 6px;
+}
+
+.markdown-body pre code,
+.markdown-body pre tt {
+ display: inline;
+ max-width: auto;
+ padding: 0;
+ margin: 0;
+ overflow: visible;
+ line-height: inherit;
+ word-wrap: normal;
+ background-color: transparent;
+ border: 0;
+}
+
+.markdown-body .csv-data td,
+.markdown-body .csv-data th {
+ padding: 5px;
+ overflow: hidden;
+ font-size: 12px;
+ line-height: 1;
+ text-align: left;
+ white-space: nowrap;
+}
+
+.markdown-body .csv-data .blob-num {
+ padding: 10px 8px 9px;
+ text-align: right;
+ background: var(--color-canvas-default);
+ border: 0;
+}
+
+.markdown-body .csv-data tr {
+ border-top: 0;
+}
+
+.markdown-body .csv-data th {
+ font-weight: var(--base-text-weight-semibold, 600);
+ background: var(--color-canvas-subtle);
+ border-top: 0;
+}
+
+.markdown-body [data-footnote-ref]::before {
+ content: "[";
+}
+
+.markdown-body [data-footnote-ref]::after {
+ content: "]";
+}
+
+.markdown-body .footnotes {
+ font-size: 12px;
+ color: var(--color-fg-muted);
+ border-top: 1px solid var(--color-border-default);
+}
+
+.markdown-body .footnotes ol {
+ padding-left: 16px;
+}
+
+.markdown-body .footnotes ol ul {
+ display: inline-block;
+ padding-left: 16px;
+ margin-top: 16px;
+}
+
+.markdown-body .footnotes li {
+ position: relative;
+}
+
+.markdown-body .footnotes li:target::before {
+ position: absolute;
+ top: -8px;
+ right: -8px;
+ bottom: -8px;
+ left: -24px;
+ pointer-events: none;
+ content: "";
+ border: 2px solid var(--color-accent-emphasis);
+ border-radius: 6px;
+}
+
+.markdown-body .footnotes li:target {
+ color: var(--color-fg-default);
+}
+
+.markdown-body .footnotes .data-footnote-backref g-emoji {
+ font-family: monospace;
+}
+
+.markdown-body .pl-c {
+ color: var(--color-prettylights-syntax-comment);
+}
+
+.markdown-body .pl-c1,
+.markdown-body .pl-s .pl-v {
+ color: var(--color-prettylights-syntax-constant);
+}
+
+.markdown-body .pl-e,
+.markdown-body .pl-en {
+ color: var(--color-prettylights-syntax-entity);
+}
+
+.markdown-body .pl-smi,
+.markdown-body .pl-s .pl-s1 {
+ color: var(--color-prettylights-syntax-storage-modifier-import);
+}
+
+.markdown-body .pl-ent {
+ color: var(--color-prettylights-syntax-entity-tag);
+}
+
+.markdown-body .pl-k {
+ color: var(--color-prettylights-syntax-keyword);
+}
+
+.markdown-body .pl-s,
+.markdown-body .pl-pds,
+.markdown-body .pl-s .pl-pse .pl-s1,
+.markdown-body .pl-sr,
+.markdown-body .pl-sr .pl-cce,
+.markdown-body .pl-sr .pl-sre,
+.markdown-body .pl-sr .pl-sra {
+ color: var(--color-prettylights-syntax-string);
+}
+
+.markdown-body .pl-v,
+.markdown-body .pl-smw {
+ color: var(--color-prettylights-syntax-variable);
+}
+
+.markdown-body .pl-bu {
+ color: var(--color-prettylights-syntax-brackethighlighter-unmatched);
+}
+
+.markdown-body .pl-ii {
+ color: var(--color-prettylights-syntax-invalid-illegal-text);
+ background-color: var(--color-prettylights-syntax-invalid-illegal-bg);
+}
+
+.markdown-body .pl-c2 {
+ color: var(--color-prettylights-syntax-carriage-return-text);
+ background-color: var(--color-prettylights-syntax-carriage-return-bg);
+}
+
+.markdown-body .pl-sr .pl-cce {
+ font-weight: bold;
+ color: var(--color-prettylights-syntax-string-regexp);
+}
+
+.markdown-body .pl-ml {
+ color: var(--color-prettylights-syntax-markup-list);
+}
+
+.markdown-body .pl-mh,
+.markdown-body .pl-mh .pl-en,
+.markdown-body .pl-ms {
+ font-weight: bold;
+ color: var(--color-prettylights-syntax-markup-heading);
+}
+
+.markdown-body .pl-mi {
+ font-style: italic;
+ color: var(--color-prettylights-syntax-markup-italic);
+}
+
+.markdown-body .pl-mb {
+ font-weight: bold;
+ color: var(--color-prettylights-syntax-markup-bold);
+}
+
+.markdown-body .pl-md {
+ color: var(--color-prettylights-syntax-markup-deleted-text);
+ background-color: var(--color-prettylights-syntax-markup-deleted-bg);
+}
+
+.markdown-body .pl-mi1 {
+ color: var(--color-prettylights-syntax-markup-inserted-text);
+ background-color: var(--color-prettylights-syntax-markup-inserted-bg);
+}
+
+.markdown-body .pl-mc {
+ color: var(--color-prettylights-syntax-markup-changed-text);
+ background-color: var(--color-prettylights-syntax-markup-changed-bg);
+}
+
+.markdown-body .pl-mi2 {
+ color: var(--color-prettylights-syntax-markup-ignored-text);
+ background-color: var(--color-prettylights-syntax-markup-ignored-bg);
+}
+
+.markdown-body .pl-mdr {
+ font-weight: bold;
+ color: var(--color-prettylights-syntax-meta-diff-range);
+}
+
+.markdown-body .pl-ba {
+ color: var(--color-prettylights-syntax-brackethighlighter-angle);
+}
+
+.markdown-body .pl-sg {
+ color: var(--color-prettylights-syntax-sublimelinter-gutter-mark);
+}
+
+.markdown-body .pl-corl {
+ text-decoration: underline;
+ color: var(--color-prettylights-syntax-constant-other-reference-link);
+}
+
+.markdown-body g-emoji {
+ display: inline-block;
+ min-width: 1ch;
+ font-family: "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol";
+ font-size: 1em;
+ font-style: normal !important;
+ font-weight: var(--base-text-weight-normal, 400);
+ line-height: 1;
+ vertical-align: -0.075em;
+}
+
+.markdown-body g-emoji img {
+ width: 1em;
+ height: 1em;
+}
+
+.markdown-body .task-list-item {
+ list-style-type: none;
+}
+
+.markdown-body .task-list-item label {
+ font-weight: var(--base-text-weight-normal, 400);
+}
+
+.markdown-body .task-list-item.enabled label {
+ cursor: pointer;
+}
+
+.markdown-body .task-list-item+.task-list-item {
+ margin-top: 4px;
+}
+
+.markdown-body .task-list-item .handle {
+ display: none;
+}
+
+.markdown-body .task-list-item-checkbox {
+ margin: 0 .2em .25em -1.4em;
+ vertical-align: middle;
+}
+
+.markdown-body .contains-task-list:dir(rtl) .task-list-item-checkbox {
+ margin: 0 -1.6em .25em .2em;
+}
+
+.markdown-body .contains-task-list {
+ position: relative;
+}
+
+.markdown-body .contains-task-list:hover .task-list-item-convert-container,
+.markdown-body .contains-task-list:focus-within .task-list-item-convert-container {
+ display: block;
+ width: auto;
+ height: 24px;
+ overflow: visible;
+ clip: auto;
+}
+
+.markdown-body ::-webkit-calendar-picker-indicator {
+ filter: invert(50%);
+}
+
+.markdown-body .markdown-alert {
+ padding: var(--base-size-8) var(--base-size-16);
+ margin-bottom: 16px;
+ color: inherit;
+ border-left: .25em solid var(--color-border-default);
+}
+
+.markdown-body .markdown-alert>:first-child {
+ margin-top: 0;
+}
+
+.markdown-body .markdown-alert>:last-child {
+ margin-bottom: 0;
+}
+
+.markdown-body .markdown-alert .markdown-alert-title {
+ display: flex;
+ font-weight: var(--base-text-weight-medium, 500);
+ align-items: center;
+ line-height: 1;
+}
+
+.markdown-body .markdown-alert.markdown-alert-note {
+ border-left-color: var(--color-accent-emphasis);
+}
+
+.markdown-body .markdown-alert.markdown-alert-note .markdown-alert-title {
+ color: var(--color-accent-fg);
+}
+
+.markdown-body .markdown-alert.markdown-alert-important {
+ border-left-color: var(--color-done-emphasis);
+}
+
+.markdown-body .markdown-alert.markdown-alert-important .markdown-alert-title {
+ color: var(--color-done-fg);
+}
+
+.markdown-body .markdown-alert.markdown-alert-warning {
+ border-left-color: var(--color-attention-emphasis);
+}
+
+.markdown-body .markdown-alert.markdown-alert-warning .markdown-alert-title {
+ color: var(--color-attention-fg);
+}
+
+.markdown-body .markdown-alert.markdown-alert-tip {
+ border-left-color: var(--color-success-emphasis);
+}
+
+.markdown-body .markdown-alert.markdown-alert-tip .markdown-alert-title {
+ color: var(--color-success-fg);
+}
+
+.markdown-body .markdown-alert.markdown-alert-caution {
+ border-left-color: var(--color-danger-emphasis);
+}
+
+.markdown-body .markdown-alert.markdown-alert-caution .markdown-alert-title {
+ color: var(--color-danger-fg);
+}
\ No newline at end of file
diff --git a/docs/index.html b/docs/index.html
new file mode 100644
index 00000000..d1154b4d
--- /dev/null
+++ b/docs/index.html
@@ -0,0 +1,1250 @@
+
+
+
+
+
+
+
+
+
+ Search Code By Comment
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Document index of:
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/docs/metadata.json b/docs/metadata.json
new file mode 100644
index 00000000..87cc9d07
--- /dev/null
+++ b/docs/metadata.json
@@ -0,0 +1,135 @@
+{
+ "url": {
+ "full": "https://github.com/lucidrains/DALLE2-pytorch",
+ "partial": "lucidrains/DALLE2-pytorch"
+ },
+ "file_mapping": {
+ "0": {
+ "filepath": "/MANIFEST.in",
+ "entry_id": 0,
+ "language_id": "text"
+ },
+ "1": {
+ "filepath": "/Makefile",
+ "entry_id": 4,
+ "language_id": "makefile"
+ },
+ "2": {
+ "filepath": "/README.md",
+ "entry_id": 8,
+ "language_id": "markdown"
+ },
+ "3": {
+ "filepath": "/configs/README.md",
+ "entry_id": 96,
+ "language_id": "markdown"
+ },
+ "4": {
+ "filepath": "/dalle2_pytorch/__init__.py",
+ "entry_id": 122,
+ "language_id": "python"
+ },
+ "5": {
+ "filepath": "/dalle2_pytorch/cli.py",
+ "entry_id": 126,
+ "language_id": "python"
+ },
+ "6": {
+ "filepath": "/dalle2_pytorch/dalle2_pytorch.py",
+ "entry_id": 132,
+ "language_id": "python"
+ },
+ "7": {
+ "filepath": "/dalle2_pytorch/dataloaders/README.md",
+ "entry_id": 372,
+ "language_id": "plain-text"
+ },
+ "8": {
+ "filepath": "/dalle2_pytorch/dataloaders/__init__.py",
+ "entry_id": 384,
+ "language_id": "python"
+ },
+ "9": {
+ "filepath": "/dalle2_pytorch/dataloaders/decoder_loader.py",
+ "entry_id": 388,
+ "language_id": "python"
+ },
+ "10": {
+ "filepath": "/dalle2_pytorch/dataloaders/prior_loader.py",
+ "entry_id": 418,
+ "language_id": "python"
+ },
+ "11": {
+ "filepath": "/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py",
+ "entry_id": 438,
+ "language_id": "python"
+ },
+ "12": {
+ "filepath": "/dalle2_pytorch/optimizer.py",
+ "entry_id": 444,
+ "language_id": "python"
+ },
+ "13": {
+ "filepath": "/dalle2_pytorch/tokenizer.py",
+ "entry_id": 448,
+ "language_id": "python"
+ },
+ "14": {
+ "filepath": "/dalle2_pytorch/trackers.py",
+ "entry_id": 464,
+ "language_id": "python"
+ },
+ "15": {
+ "filepath": "/dalle2_pytorch/train_configs.py",
+ "entry_id": 518,
+ "language_id": "python"
+ },
+ "16": {
+ "filepath": "/dalle2_pytorch/trainer.py",
+ "entry_id": 548,
+ "language_id": "python"
+ },
+ "17": {
+ "filepath": "/dalle2_pytorch/utils.py",
+ "entry_id": 600,
+ "language_id": "python"
+ },
+ "18": {
+ "filepath": "/dalle2_pytorch/version.py",
+ "entry_id": 604,
+ "language_id": "python"
+ },
+ "19": {
+ "filepath": "/dalle2_pytorch/vqgan_vae.py",
+ "entry_id": 608,
+ "language_id": "python"
+ },
+ "20": {
+ "filepath": "/dalle2_pytorch/vqgan_vae_trainer.py",
+ "entry_id": 652,
+ "language_id": "python"
+ },
+ "21": {
+ "filepath": "/prior.md",
+ "entry_id": 670,
+ "language_id": "markdown"
+ },
+ "22": {
+ "filepath": "/setup.py",
+ "entry_id": 700,
+ "language_id": "python"
+ },
+ "23": {
+ "filepath": "/train_decoder.py",
+ "entry_id": 706,
+ "language_id": "python"
+ },
+ "24": {
+ "filepath": "/train_diffusion_prior.py",
+ "entry_id": 774,
+ "language_id": "python"
+ }
+ },
+ "project_name": "DALLE2-pytorch",
+ "split_count": 9
+}
\ No newline at end of file
diff --git a/docs/metadata_title.json b/docs/metadata_title.json
new file mode 100644
index 00000000..f885a7a9
--- /dev/null
+++ b/docs/metadata_title.json
@@ -0,0 +1 @@
+{"split_count": 2}
\ No newline at end of file
diff --git a/docs/sitemap.xml b/docs/sitemap.xml
new file mode 100644
index 00000000..84250a9f
--- /dev/null
+++ b/docs/sitemap.xml
@@ -0,0 +1,163 @@
+
+
+
+
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/MANIFEST.in
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/Makefile
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/README.md
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/configs/README.md
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/__init__.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/cli.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/dalle2_pytorch.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/dataloaders/README.md
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/dataloaders/__init__.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/dataloaders/decoder_loader.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/dataloaders/prior_loader.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/optimizer.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/tokenizer.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/trackers.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/train_configs.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/trainer.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/utils.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/version.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/vqgan_vae.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/dalle2_pytorch/vqgan_vae_trainer.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/prior.md
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/setup.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/train_decoder.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch?q=/train_diffusion_prior.py
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
+ https://james4ever0.github.io/DALLE2-pytorch/tree.html?full=true
+ 2023-12-28T09:21:02+00:00
+ 1.00
+
+
+
\ No newline at end of file
diff --git a/docs/src/MANIFEST.in b/docs/src/MANIFEST.in
new file mode 100644
index 00000000..b9463875
--- /dev/null
+++ b/docs/src/MANIFEST.in
@@ -0,0 +1 @@
+recursive-include dalle2_pytorch *.txt
diff --git a/docs/src/Makefile b/docs/src/Makefile
new file mode 100644
index 00000000..5ce5220f
--- /dev/null
+++ b/docs/src/Makefile
@@ -0,0 +1,6 @@
+install:
+ pip install -U pip
+ pip install -e .
+
+test:
+ CUDA_VISIBLE_DEVICES= python train_decoder.py --config_file configs/train_decoder_config.test.json
diff --git a/docs/src/README.md b/docs/src/README.md
new file mode 100644
index 00000000..6a859734
--- /dev/null
+++ b/docs/src/README.md
@@ -0,0 +1,1311 @@
+
+
+## DALL-E 2 - Pytorch
+
+Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch.
+
+Yannic Kilcher summary | AssemblyAI explainer
+
+The main novelty seems to be an extra layer of indirection with the prior network (whether it is an autoregressive transformer or a diffusion network), which predicts an image embedding based on the text embedding from CLIP. Specifically, this repository will only build out the diffusion prior network, as it is the best performing variant (but which incidentally involves a causal transformer as the denoising network 😂)
+
+This model is SOTA for text-to-image for now.
+
+Please join if you are interested in helping out with the replication with the LAION community | Yannic Interview
+
+As of 5/23/22, it is no longer SOTA. SOTA will be here. Jax versions as well as text-to-video project will be shifted towards the Imagen architecture, as it is way simpler.
+
+## Status
+
+- A research group has used the code in this repository to train a functional diffusion prior for their CLIP generations. Will share their work once they release their preprint. This, and Katherine's own experiments, validate OpenAI's finding that the extra prior increases variety of generations.
+
+- Decoder is now verified working for unconditional generation on my experimental setup for Oxford flowers. 2 researchers have also confirmed Decoder is working for them.
+
+
+
+*ongoing at 21k steps*
+
+- Justin Pinkney successfully trained the diffusion prior in the repository for his CLIP to Stylegan2 text-to-image application
+
+- Romain has scaled up training to 800 GPUs with the available scripts without any issues
+
+## Pre-Trained Models
+
+- LAION is training prior models. Checkpoints are available on 🤗huggingface and the training statistics are available on 🐝WANDB.
+- Decoder - In-progress test run 🚧
+- Decoder - Another test run with sparse attention
+- DALL-E 2 🚧 - DALL-E 2 Laion repository
+
+## Appreciation
+
+This library would not have gotten to this working state without the help of
+
+- Zion for the distributed training code for the diffusion prior
+- Aidan for the distributed training code for the decoder as well as the dataloaders
+- Kumar for working on the initial diffusion training script
+- Romain for the pull request reviews and project management
+- He Cao and xiankgx for the Q&A and for identifying of critical bugs
+- Marunine for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes
+- MalumaDev for proposing the use of pixel shuffle upsampler for fixing checkboard artifacts
+- Katherine for her advice
+- Stability AI for the generous sponsorship
+- 🤗 Huggingface and in particular Sylvain for the Accelerate library
+- Alex for einops, indispensable tool for tensor manipulation
+
+... and many others. Thank you! 🙏
+
+## Install
+
+```bash
+$ pip install dalle2-pytorch
+```
+
+## Usage
+
+To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
+
+To train CLIP, you can either use x-clip package, or join the LAION discord, where a lot of replication efforts are already underway.
+
+This repository will demonstrate integration with `x-clip` for starters
+
+```python
+import torch
+from dalle2_pytorch import CLIP
+
+clip = CLIP(
+ dim_text = 512,
+ dim_image = 512,
+ dim_latent = 512,
+ num_text_tokens = 49408,
+ text_enc_depth = 1,
+ text_seq_len = 256,
+ text_heads = 8,
+ visual_enc_depth = 1,
+ visual_image_size = 256,
+ visual_patch_size = 32,
+ visual_heads = 8,
+ use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
+ decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
+ extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
+ use_visual_ssl = True, # whether to do self supervised learning on images
+ visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
+ use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
+ text_ssl_loss_weight = 0.05, # weight for text MLM loss
+ image_ssl_loss_weight = 0.05 # weight for image self-supervised learning loss
+).cuda()
+
+# mock data
+
+text = torch.randint(0, 49408, (4, 256)).cuda()
+images = torch.randn(4, 3, 256, 256).cuda()
+
+# train
+
+loss = clip(
+ text,
+ images,
+ return_loss = True # needs to be set to True to return contrastive loss
+)
+
+loss.backward()
+
+# do the above with as many texts and images as possible in a loop
+```
+
+Then, you will need to train the decoder, which learns to generate images based on the image embedding coming from the trained CLIP above
+
+```python
+import torch
+from dalle2_pytorch import Unet, Decoder, CLIP
+
+# trained clip from step 1
+
+clip = CLIP(
+ dim_text = 512,
+ dim_image = 512,
+ dim_latent = 512,
+ num_text_tokens = 49408,
+ text_enc_depth = 1,
+ text_seq_len = 256,
+ text_heads = 8,
+ visual_enc_depth = 1,
+ visual_image_size = 256,
+ visual_patch_size = 32,
+ visual_heads = 8
+).cuda()
+
+# unet for the decoder
+
+unet = Unet(
+ dim = 128,
+ image_embed_dim = 512,
+ cond_dim = 128,
+ channels = 3,
+ dim_mults=(1, 2, 4, 8)
+).cuda()
+
+# decoder, which contains the unet and clip
+
+decoder = Decoder(
+ unet = unet,
+ clip = clip,
+ timesteps = 100,
+ image_cond_drop_prob = 0.1,
+ text_cond_drop_prob = 0.5
+).cuda()
+
+# mock images (get a lot of this)
+
+images = torch.randn(4, 3, 256, 256).cuda()
+
+# feed images into decoder
+
+loss = decoder(images)
+loss.backward()
+
+# do the above for many many many many steps
+# then it will learn to generate images based on the CLIP image embeddings
+```
+
+Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP from the first step
+
+```python
+import torch
+from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
+
+# get trained CLIP from step one
+
+clip = CLIP(
+ dim_text = 512,
+ dim_image = 512,
+ dim_latent = 512,
+ num_text_tokens = 49408,
+ text_enc_depth = 6,
+ text_seq_len = 256,
+ text_heads = 8,
+ visual_enc_depth = 6,
+ visual_image_size = 256,
+ visual_patch_size = 32,
+ visual_heads = 8,
+).cuda()
+
+# setup prior network, which contains an autoregressive transformer
+
+prior_network = DiffusionPriorNetwork(
+ dim = 512,
+ depth = 6,
+ dim_head = 64,
+ heads = 8
+).cuda()
+
+# diffusion prior network, which contains the CLIP and network (with transformer) above
+
+diffusion_prior = DiffusionPrior(
+ net = prior_network,
+ clip = clip,
+ timesteps = 100,
+ cond_drop_prob = 0.2
+).cuda()
+
+# mock data
+
+text = torch.randint(0, 49408, (4, 256)).cuda()
+images = torch.randn(4, 3, 256, 256).cuda()
+
+# feed text and images into diffusion prior network
+
+loss = diffusion_prior(text, images)
+loss.backward()
+
+# do the above for many many many steps
+# now the diffusion prior can generate image embeddings from the text embeddings
+```
+
+In the paper, they actually used a recently discovered technique, from Jonathan Ho himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.
+
+This can easily be used within this framework as so
+
+```python
+import torch
+from dalle2_pytorch import Unet, Decoder, CLIP
+
+# trained clip from step 1
+
+clip = CLIP(
+ dim_text = 512,
+ dim_image = 512,
+ dim_latent = 512,
+ num_text_tokens = 49408,
+ text_enc_depth = 6,
+ text_seq_len = 256,
+ text_heads = 8,
+ visual_enc_depth = 6,
+ visual_image_size = 256,
+ visual_patch_size = 32,
+ visual_heads = 8
+).cuda()
+
+# 2 unets for the decoder (a la cascading DDPM)
+
+unet1 = Unet(
+ dim = 32,
+ image_embed_dim = 512,
+ cond_dim = 128,
+ channels = 3,
+ dim_mults = (1, 2, 4, 8)
+).cuda()
+
+unet2 = Unet(
+ dim = 32,
+ image_embed_dim = 512,
+ cond_dim = 128,
+ channels = 3,
+ dim_mults = (1, 2, 4, 8, 16)
+).cuda()
+
+# decoder, which contains the unet(s) and clip
+
+decoder = Decoder(
+ clip = clip,
+ unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
+ image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
+ timesteps = 1000,
+ image_cond_drop_prob = 0.1,
+ text_cond_drop_prob = 0.5
+).cuda()
+
+# mock images (get a lot of this)
+
+images = torch.randn(4, 3, 512, 512).cuda()
+
+# feed images into decoder, specifying which unet you want to train
+# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
+
+loss = decoder(images, unet_number = 1)
+loss.backward()
+
+loss = decoder(images, unet_number = 2)
+loss.backward()
+
+# do the above for many steps for both unets
+```
+
+Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
+
+```python
+from dalle2_pytorch import DALLE2
+
+dalle2 = DALLE2(
+ prior = diffusion_prior,
+ decoder = decoder
+)
+
+# send the text as a string if you want to use the simple tokenizer from DALLE v1
+# or you can do it as token ids, if you have your own tokenizer
+
+texts = ['glistening morning dew on a flower petal']
+images = dalle2(texts) # (1, 3, 256, 256)
+```
+
+That's it!
+
+Let's see the whole script below
+
+```python
+import torch
+from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
+
+clip = CLIP(
+ dim_text = 512,
+ dim_image = 512,
+ dim_latent = 512,
+ num_text_tokens = 49408,
+ text_enc_depth = 6,
+ text_seq_len = 256,
+ text_heads = 8,
+ visual_enc_depth = 6,
+ visual_image_size = 256,
+ visual_patch_size = 32,
+ visual_heads = 8
+).cuda()
+
+# mock data
+
+text = torch.randint(0, 49408, (4, 256)).cuda()
+images = torch.randn(4, 3, 256, 256).cuda()
+
+# train
+
+loss = clip(
+ text,
+ images,
+ return_loss = True
+)
+
+loss.backward()
+
+# do above for many steps ...
+
+# prior networks (with transformer)
+
+prior_network = DiffusionPriorNetwork(
+ dim = 512,
+ depth = 6,
+ dim_head = 64,
+ heads = 8
+).cuda()
+
+diffusion_prior = DiffusionPrior(
+ net = prior_network,
+ clip = clip,
+ timesteps = 1000,
+ sample_timesteps = 64,
+ cond_drop_prob = 0.2
+).cuda()
+
+loss = diffusion_prior(text, images)
+loss.backward()
+
+# do above for many steps ...
+
+# decoder (with unet)
+
+unet1 = Unet(
+ dim = 128,
+ image_embed_dim = 512,
+ text_embed_dim = 512,
+ cond_dim = 128,
+ channels = 3,
+ dim_mults=(1, 2, 4, 8),
+ cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings
+).cuda()
+
+unet2 = Unet(
+ dim = 16,
+ image_embed_dim = 512,
+ cond_dim = 128,
+ channels = 3,
+ dim_mults = (1, 2, 4, 8, 16)
+).cuda()
+
+decoder = Decoder(
+ unet = (unet1, unet2),
+ image_sizes = (128, 256),
+ clip = clip,
+ timesteps = 100,
+ image_cond_drop_prob = 0.1,
+ text_cond_drop_prob = 0.5
+).cuda()
+
+for unet_number in (1, 2):
+ loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
+ loss.backward()
+
+# do above for many steps
+
+dalle2 = DALLE2(
+ prior = diffusion_prior,
+ decoder = decoder
+)
+
+images = dalle2(
+ ['cute puppy chasing after a squirrel'],
+ cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
+)
+
+# save your image (in this example, of size 256x256)
+```
+
+Everything in this readme should run without error
+
+You can also train the decoder on images of greater than the size (say 512x512) at which CLIP was trained (256x256). The images will be resized to CLIP image resolution for the image embeddings
+
+For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
+
+## Training on Preprocessed CLIP Embeddings
+
+It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings`
+
+Working example below
+
+```python
+import torch
+from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
+
+# get trained CLIP from step one
+
+clip = CLIP(
+ dim_text = 512,
+ dim_image = 512,
+ dim_latent = 512,
+ num_text_tokens = 49408,
+ text_enc_depth = 6,
+ text_seq_len = 256,
+ text_heads = 8,
+ visual_enc_depth = 6,
+ visual_image_size = 256,
+ visual_patch_size = 32,
+ visual_heads = 8,
+).cuda()
+
+# setup prior network, which contains an autoregressive transformer
+
+prior_network = DiffusionPriorNetwork(
+ dim = 512,
+ depth = 6,
+ dim_head = 64,
+ heads = 8
+).cuda()
+
+# diffusion prior network, which contains the CLIP and network (with transformer) above
+
+diffusion_prior = DiffusionPrior(
+ net = prior_network,
+ clip = clip,
+ timesteps = 100,
+ cond_drop_prob = 0.2,
+ condition_on_text_encodings = False # this probably should be true, but just to get Laion started
+).cuda()
+
+# mock data
+
+text = torch.randint(0, 49408, (4, 256)).cuda()
+images = torch.randn(4, 3, 256, 256).cuda()
+
+# precompute the text and image embeddings
+# here using the diffusion prior class, but could be done with CLIP alone
+
+clip_image_embeds = diffusion_prior.clip.embed_image(images).image_embed
+clip_text_embeds = diffusion_prior.clip.embed_text(text).text_embed
+
+# feed text and images into diffusion prior network
+
+loss = diffusion_prior(
+ text_embed = clip_text_embeds,
+ image_embed = clip_image_embeds
+)
+
+loss.backward()
+
+# do the above for many many many steps
+# now the diffusion prior can generate image embeddings from the text embeddings
+```
+
+You can also completely go `CLIP`-less, in which case you will need to pass in the `image_embed_dim` into the `DiffusionPrior` on initialization
+
+```python
+import torch
+from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior
+
+# setup prior network, which contains an autoregressive transformer
+
+prior_network = DiffusionPriorNetwork(
+ dim = 512,
+ depth = 6,
+ dim_head = 64,
+ heads = 8
+).cuda()
+
+# diffusion prior network, which contains the CLIP and network (with transformer) above
+
+diffusion_prior = DiffusionPrior(
+ net = prior_network,
+ image_embed_dim = 512, # this needs to be set
+ timesteps = 100,
+ cond_drop_prob = 0.2,
+ condition_on_text_encodings = False # this probably should be true, but just to get Laion started
+).cuda()
+
+# mock data
+
+text = torch.randint(0, 49408, (4, 256)).cuda()
+images = torch.randn(4, 3, 256, 256).cuda()
+
+# precompute the text and image embeddings
+# here using the diffusion prior class, but could be done with CLIP alone
+
+clip_image_embeds = torch.randn(4, 512).cuda()
+clip_text_embeds = torch.randn(4, 512).cuda()
+
+# feed text and images into diffusion prior network
+
+loss = diffusion_prior(
+ text_embed = clip_text_embeds,
+ image_embed = clip_image_embeds
+)
+
+loss.backward()
+
+# do the above for many many many steps
+# now the diffusion prior can generate image embeddings from the text embeddings
+```
+
+## OpenAI CLIP
+
+Although there is the possibility they are using an unreleased, more powerful CLIP, you can use one of the released ones, if you do not wish to train your own CLIP from scratch. This will also allow the community to more quickly validate the conclusions of the paper.
+
+To use a pretrained OpenAI CLIP, simply import `OpenAIClipAdapter` and pass it into the `DiffusionPrior` or `Decoder` like so
+
+```python
+import torch
+from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, OpenAIClipAdapter
+
+# openai pretrained clip - defaults to ViT-B/32
+
+clip = OpenAIClipAdapter()
+
+# mock data
+
+text = torch.randint(0, 49408, (4, 256)).cuda()
+images = torch.randn(4, 3, 256, 256).cuda()
+
+# prior networks (with transformer)
+
+prior_network = DiffusionPriorNetwork(
+ dim = 512,
+ depth = 6,
+ dim_head = 64,
+ heads = 8
+).cuda()
+
+diffusion_prior = DiffusionPrior(
+ net = prior_network,
+ clip = clip,
+ timesteps = 100,
+ cond_drop_prob = 0.2
+).cuda()
+
+loss = diffusion_prior(text, images)
+loss.backward()
+
+# do above for many steps ...
+
+# decoder (with unet)
+
+unet1 = Unet(
+ dim = 128,
+ image_embed_dim = 512,
+ cond_dim = 128,
+ channels = 3,
+ dim_mults=(1, 2, 4, 8),
+ text_embed_dim = 512,
+ cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade)
+).cuda()
+
+unet2 = Unet(
+ dim = 16,
+ image_embed_dim = 512,
+ cond_dim = 128,
+ channels = 3,
+ dim_mults = (1, 2, 4, 8, 16)
+).cuda()
+
+decoder = Decoder(
+ unet = (unet1, unet2),
+ image_sizes = (128, 256),
+ clip = clip,
+ timesteps = 1000,
+ sample_timesteps = (250, 27),
+ image_cond_drop_prob = 0.1,
+ text_cond_drop_prob = 0.5
+).cuda()
+
+for unet_number in (1, 2):
+ loss = decoder(images, text = text, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
+ loss.backward()
+
+# do above for many steps
+
+dalle2 = DALLE2(
+ prior = diffusion_prior,
+ decoder = decoder
+)
+
+images = dalle2(
+ ['a butterfly trying to escape a tornado'],
+ cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
+)
+
+# save your image (in this example, of size 256x256)
+```
+
+Alternatively, you can also use Open Clip
+
+```bash
+$ pip install open-clip-torch
+```
+
+Ex. using the SOTA Open Clip model trained by Romain
+
+```python
+from dalle2_pytorch import OpenClipAdapter
+
+clip = OpenClipAdapter('ViT-H/14')
+```
+
+Now you'll just have to worry about training the Prior and the Decoder!
+
+## Inpainting
+
+Inpainting is also built into the `Decoder`. You simply have to pass in the `inpaint_image` and `inpaint_mask` (boolean tensor where `True` indicates which regions of the inpaint image to keep)
+
+This repository uses the formulation put forth by Lugmayr et al. in Repaint
+
+```python
+import torch
+from dalle2_pytorch import Unet, Decoder, CLIP
+
+# trained clip from step 1
+
+clip = CLIP(
+ dim_text = 512,
+ dim_image = 512,
+ dim_latent = 512,
+ num_text_tokens = 49408,
+ text_enc_depth = 6,
+ text_seq_len = 256,
+ text_heads = 8,
+ visual_enc_depth = 6,
+ visual_image_size = 256,
+ visual_patch_size = 32,
+ visual_heads = 8
+).cuda()
+
+# 2 unets for the decoder (a la cascading DDPM)
+
+unet = Unet(
+ dim = 16,
+ image_embed_dim = 512,
+ cond_dim = 128,
+ channels = 3,
+ dim_mults = (1, 1, 1, 1)
+).cuda()
+
+
+# decoder, which contains the unet(s) and clip
+
+decoder = Decoder(
+ clip = clip,
+ unet = (unet,), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
+ image_sizes = (256,), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
+ timesteps = 1000,
+ image_cond_drop_prob = 0.1,
+ text_cond_drop_prob = 0.5
+).cuda()
+
+# mock images (get a lot of this)
+
+images = torch.randn(4, 3, 256, 256).cuda()
+
+# feed images into decoder, specifying which unet you want to train
+# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
+
+loss = decoder(images, unet_number = 1)
+loss.backward()
+
+# do the above for many steps for both unets
+
+mock_image_embed = torch.randn(1, 512).cuda()
+
+# then to do inpainting
+
+inpaint_image = torch.randn(1, 3, 256, 256).cuda() # (batch, channels, height, width)
+inpaint_mask = torch.ones(1, 256, 256).bool().cuda() # (batch, height, width)
+
+inpainted_images = decoder.sample(
+ image_embed = mock_image_embed,
+ inpaint_image = inpaint_image, # just pass in the inpaint image
+ inpaint_mask = inpaint_mask # and the mask
+)
+
+inpainted_images.shape # (1, 3, 256, 256)
+```
+
+## Experimental
+
+### DALL-E2 with Latent Diffusion
+
+This repository decides to take the next step and offer DALL-E v2 combined with latent diffusion, from Rombach et al.
+
+You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
+
+The repository also comes equipped with all the necessary settings to recreate `ViT-VQGan` from the Improved VQGans paper. Furthermore, the vector quantization library also comes equipped to do residual or multi-headed quantization, which I believe will give an even further boost in performance to the autoencoder.
+
+```python
+import torch
+from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
+
+# trained clip from step 1
+
+clip = CLIP(
+ dim_text = 512,
+ dim_image = 512,
+ dim_latent = 512,
+ num_text_tokens = 49408,
+ text_enc_depth = 1,
+ text_seq_len = 256,
+ text_heads = 8,
+ visual_enc_depth = 1,
+ visual_image_size = 256,
+ visual_patch_size = 32,
+ visual_heads = 8
+)
+
+# 3 unets for the decoder (a la cascading DDPM)
+
+# first two unets are doing latent diffusion
+# vqgan-vae must be trained beforehand
+
+vae1 = VQGanVAE(
+ dim = 32,
+ image_size = 256,
+ layers = 3,
+ layer_mults = (1, 2, 4)
+)
+
+vae2 = VQGanVAE(
+ dim = 32,
+ image_size = 512,
+ layers = 3,
+ layer_mults = (1, 2, 4)
+)
+
+unet1 = Unet(
+ dim = 32,
+ image_embed_dim = 512,
+ cond_dim = 128,
+ channels = 3,
+ sparse_attn = True,
+ sparse_attn_window = 2,
+ dim_mults = (1, 2, 4, 8)
+)
+
+unet2 = Unet(
+ dim = 32,
+ image_embed_dim = 512,
+ channels = 3,
+ dim_mults = (1, 2, 4, 8, 16),
+ cond_on_image_embeds = True,
+ cond_on_text_encodings = False
+)
+
+unet3 = Unet(
+ dim = 32,
+ image_embed_dim = 512,
+ channels = 3,
+ dim_mults = (1, 2, 4, 8, 16),
+ cond_on_image_embeds = True,
+ cond_on_text_encodings = False,
+ attend_at_middle = False
+)
+
+# decoder, which contains the unet(s) and clip
+
+decoder = Decoder(
+ clip = clip,
+ vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
+ unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
+ image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
+ timesteps = 100,
+ image_cond_drop_prob = 0.1,
+ text_cond_drop_prob = 0.5
+).cuda()
+
+# mock images (get a lot of this)
+
+images = torch.randn(1, 3, 1024, 1024).cuda()
+
+# feed images into decoder, specifying which unet you want to train
+# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
+
+with decoder.one_unet_in_gpu(1):
+ loss = decoder(images, unet_number = 1)
+ loss.backward()
+
+with decoder.one_unet_in_gpu(2):
+ loss = decoder(images, unet_number = 2)
+ loss.backward()
+
+with decoder.one_unet_in_gpu(3):
+ loss = decoder(images, unet_number = 3)
+ loss.backward()
+
+# do the above for many steps for both unets
+
+# then it will learn to generate images based on the CLIP image embeddings
+
+# chaining the unets from lowest resolution to highest resolution (thus cascading)
+
+mock_image_embed = torch.randn(1, 512).cuda()
+images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
+```
+
+## Training wrapper
+
+### Decoder Training
+
+Training the `Decoder` may be confusing, as one needs to keep track of an optimizer for each of the `Unet`(s) separately. Each `Unet` will also need its own corresponding exponential moving average. The `DecoderTrainer` hopes to make this simple, as shown below
+
+```python
+import torch
+from dalle2_pytorch import DALLE2, Unet, Decoder, CLIP, DecoderTrainer
+
+clip = CLIP(
+ dim_text = 512,
+ dim_image = 512,
+ dim_latent = 512,
+ num_text_tokens = 49408,
+ text_enc_depth = 6,
+ text_seq_len = 256,
+ text_heads = 8,
+ visual_enc_depth = 6,
+ visual_image_size = 256,
+ visual_patch_size = 32,
+ visual_heads = 8
+).cuda()
+
+# mock data
+
+text = torch.randint(0, 49408, (32, 256)).cuda()
+images = torch.randn(32, 3, 256, 256).cuda()
+
+# decoder (with unet)
+
+unet1 = Unet(
+ dim = 128,
+ image_embed_dim = 512,
+ text_embed_dim = 512,
+ cond_dim = 128,
+ channels = 3,
+ dim_mults=(1, 2, 4, 8),
+ cond_on_text_encodings = True,
+).cuda()
+
+unet2 = Unet(
+ dim = 16,
+ image_embed_dim = 512,
+ cond_dim = 128,
+ channels = 3,
+ dim_mults = (1, 2, 4, 8, 16),
+).cuda()
+
+decoder = Decoder(
+ unet = (unet1, unet2),
+ image_sizes = (128, 256),
+ clip = clip,
+ timesteps = 1000
+).cuda()
+
+decoder_trainer = DecoderTrainer(
+ decoder,
+ lr = 3e-4,
+ wd = 1e-2,
+ ema_beta = 0.99,
+ ema_update_after_step = 1000,
+ ema_update_every = 10,
+)
+
+for unet_number in (1, 2):
+ loss = decoder_trainer(
+ images,
+ text = text,
+ unet_number = unet_number, # which unet to train on
+ max_batch_size = 4 # gradient accumulation - this sets the maximum batch size in which to do forward and backwards pass - for this example 32 / 4 == 8 times
+ )
+
+ decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
+
+# after much training
+# you can sample from the exponentially moving averaged unets as so
+
+mock_image_embed = torch.randn(32, 512).cuda()
+images = decoder_trainer.sample(image_embed = mock_image_embed, text = text) # (4, 3, 256, 256)
+```
+
+### Diffusion Prior Training
+
+Similarly, one can use the `DiffusionPriorTrainer` to automatically instantiate and keep track of an exponential moving averaged prior.
+
+```python
+import torch
+from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, CLIP
+
+clip = CLIP(
+ dim_text = 512,
+ dim_image = 512,
+ dim_latent = 512,
+ num_text_tokens = 49408,
+ text_enc_depth = 6,
+ text_seq_len = 256,
+ text_heads = 8,
+ visual_enc_depth = 6,
+ visual_image_size = 256,
+ visual_patch_size = 32,
+ visual_heads = 8
+).cuda()
+
+# mock data
+
+text = torch.randint(0, 49408, (512, 256)).cuda()
+images = torch.randn(512, 3, 256, 256).cuda()
+
+# prior networks (with transformer)
+
+prior_network = DiffusionPriorNetwork(
+ dim = 512,
+ depth = 6,
+ dim_head = 64,
+ heads = 8
+).cuda()
+
+diffusion_prior = DiffusionPrior(
+ net = prior_network,
+ clip = clip,
+ timesteps = 100,
+ cond_drop_prob = 0.2
+).cuda()
+
+diffusion_prior_trainer = DiffusionPriorTrainer(
+ diffusion_prior,
+ lr = 3e-4,
+ wd = 1e-2,
+ ema_beta = 0.99,
+ ema_update_after_step = 1000,
+ ema_update_every = 10,
+)
+
+loss = diffusion_prior_trainer(text, images, max_batch_size = 4)
+diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior
+
+# after much of the above three lines in a loop
+# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior
+
+image_embeds = diffusion_prior_trainer.sample(text, max_batch_size = 4) # (512, 512) - exponential moving averaged image embeddings
+```
+
+## Bonus
+
+### Unconditional Training
+
+The repository also contains the means to train unconditional DDPM model, or even cascading DDPMs. You simply have to set `unconditional = True` in the `Decoder`
+
+ex.
+
+```python
+import torch
+from dalle2_pytorch import Unet, Decoder, DecoderTrainer
+
+# unet for the cascading ddpm
+
+unet1 = Unet(
+ dim = 128,
+ dim_mults=(1, 2, 4, 8)
+).cuda()
+
+unet2 = Unet(
+ dim = 32,
+ dim_mults = (1, 2, 4, 8, 16)
+).cuda()
+
+# decoder, which contains the unets
+
+decoder = Decoder(
+ unet = (unet1, unet2),
+ image_sizes = (256, 512), # first unet up to 256px, then second to 512px
+ timesteps = 1000,
+ unconditional = True
+).cuda()
+
+# decoder trainer
+
+decoder_trainer = DecoderTrainer(decoder)
+
+# images (get a lot of this)
+
+images = torch.randn(1, 3, 512, 512).cuda()
+
+# feed images into decoder
+
+for i in (1, 2):
+ loss = decoder_trainer(images, unet_number = i)
+ decoder_trainer.update(unet_number = i)
+
+# do the above for many many many many images
+# then it will learn to generate images
+
+images = decoder_trainer.sample(batch_size = 36, max_batch_size = 4) # (36, 3, 512, 512)
+```
+
+## Dataloaders
+
+### Decoder Dataloaders
+
+In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
+
+#### Decoder: Image Embedding Dataset
+
+When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
+
+Generating a dataset of this type:
+1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
+2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
+3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
+
+Usage:
+
+```python
+from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
+
+# Create a dataloader directly.
+dataloader = create_image_embedding_dataloader(
+ tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
+ embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
+ num_workers=4,
+ batch_size=32,
+ shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
+ shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
+ shuffle_shards=True, # Shuffle the order the shards are read in
+ resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
+)
+for img, emb in dataloader:
+ print(img.shape) # torch.Size([32, 3, 256, 256])
+ print(emb["img"].shape) # torch.Size([32, 512])
+ # Train decoder only as shown above
+
+# Or create a dataset without a loader so you can configure it manually
+dataset = ImageEmbeddingDataset(
+ urls="/path/or/url/to/webdataset/{0000..9999}.tar",
+ embedding_folder_url="path/or/url/to/embeddings/folder",
+ shard_width=4,
+ shuffle_shards=True,
+ resample=False
+)
+```
+
+### Scripts
+
+#### `train_diffusion_prior.py`
+
+For detailed information on training the diffusion prior, please refer to the [dedicated readme](prior.md)
+
+## Todo
+
+- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
+- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
+- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
+- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
+- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
+- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
+- [x] add efficient attention in unet
+- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
+- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
+- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
+- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
+- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
+- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
+- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
+- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
+- [x] take care of mixed precision as well as gradient accumulation within decoder trainer
+- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
+- [x] bring in tools to train vqgan-vae
+- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
+- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
+- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
+- [x] offer setting in diffusion prior to split time and image embeddings into multiple tokens, configurable, for more surface area during attention
+- [x] make sure resnet hyperparameters can be configurable across unet depth (groups and expansion factor)
+- [x] pull logic for training diffusion prior into a class DiffusionPriorTrainer, for eventual script based + CLI based training
+- [x] make sure the cascading ddpm in the repository can be trained unconditionally, offer a one-line CLI tool for training on a folder of images
+- [x] bring in cross-scale embedding from iclr paper https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/crossformer.py#L14
+- [x] cross embed layers for downsampling, as an option
+- [x] use an experimental tracker agnostic setup, as done here
+- [x] use pydantic for config drive training
+- [x] for both diffusion prior and decoder, all exponential moving averaged models needs to be saved and restored as well (as well as the step number)
+- [x] offer save / load methods on the trainer classes to automatically take care of state dicts for scalers / optimizers / saving versions and checking for breaking changes
+- [x] allow for creation of diffusion prior model off pydantic config classes - consider the same for tracker configs
+- [x] bring in skip-layer excitations (from lightweight gan paper) to see if it helps for either decoder of unet or vqgan-vae training (doesnt work well)
+- [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine)
+- [x] allow for unet to be able to condition non-cross attention style as well
+- [x] speed up inference, read up on papers (ddim)
+- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
+- [x] add the final combination of upsample feature maps, used in unet squared, seems to have an effect in local experiments
+- [ ] consider elucidated dalle2 https://arxiv.org/abs/2206.00364
+- [ ] add simple outpainting, text-guided 2x size the image for starters
+- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
+
+## Citations
+
+```bibtex
+@misc{ramesh2022,
+ title = {Hierarchical Text-Conditional Image Generation with CLIP Latents},
+ author = {Aditya Ramesh et al},
+ year = {2022}
+}
+```
+
+```bibtex
+@misc{crowson2022,
+ author = {Katherine Crowson},
+ url = {https://twitter.com/rivershavewings}
+}
+```
+
+```bibtex
+@misc{rombach2021highresolution,
+ title = {High-Resolution Image Synthesis with Latent Diffusion Models},
+ author = {Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
+ year = {2021},
+ eprint = {2112.10752},
+ archivePrefix = {arXiv},
+ primaryClass = {cs.CV}
+}
+```
+
+```bibtex
+@article{shen2019efficient,
+ author = {Zhuoran Shen and Mingyuan Zhang and Haiyu Zhao and Shuai Yi and Hongsheng Li},
+ title = {Efficient Attention: Attention with Linear Complexities},
+ journal = {CoRR},
+ year = {2018},
+ url = {http://arxiv.org/abs/1812.01243},
+}
+```
+
+```bibtex
+@article{Yu2021VectorquantizedIM,
+ title = {Vector-quantized Image Modeling with Improved VQGAN},
+ author = {Jiahui Yu and Xin Li and Jing Yu Koh and Han Zhang and Ruoming Pang and James Qin and Alexander Ku and Yuanzhong Xu and Jason Baldridge and Yonghui Wu},
+ journal = {ArXiv},
+ year = {2021},
+ volume = {abs/2110.04627}
+}
+```
+
+```bibtex
+@article{Shleifer2021NormFormerIT,
+ title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
+ author = {Sam Shleifer and Jason Weston and Myle Ott},
+ journal = {ArXiv},
+ year = {2021},
+ volume = {abs/2110.09456}
+}
+```
+
+```bibtex
+@article{Yu2022CoCaCC,
+ title = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
+ author = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
+ journal = {ArXiv},
+ year = {2022},
+ volume = {abs/2205.01917}
+}
+```
+
+```bibtex
+@misc{wang2021crossformer,
+ title = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention},
+ author = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
+ year = {2021},
+ eprint = {2108.00154},
+ archivePrefix = {arXiv},
+ primaryClass = {cs.CV}
+}
+```
+
+```bibtex
+@article{ho2021cascaded,
+ title = {Cascaded Diffusion Models for High Fidelity Image Generation},
+ author = {Ho, Jonathan and Saharia, Chitwan and Chan, William and Fleet, David J and Norouzi, Mohammad and Salimans, Tim},
+ journal = {arXiv preprint arXiv:2106.15282},
+ year = {2021}
+}
+```
+
+```bibtex
+@misc{Saharia2022,
+ title = {Imagen: unprecedented photorealism × deep level of language understanding},
+ author = {Chitwan Saharia*, William Chan*, Saurabh Saxena†, Lala Li†, Jay Whang†, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S. Sara Mahdavi, Rapha Gontijo Lopes, Tim Salimans, Jonathan Ho†, David Fleet†, Mohammad Norouzi*},
+ year = {2022}
+}
+```
+
+```bibtex
+@article{Choi2022PerceptionPT,
+ title = {Perception Prioritized Training of Diffusion Models},
+ author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
+ journal = {ArXiv},
+ year = {2022},
+ volume = {abs/2204.00227}
+}
+```
+
+```bibtex
+@article{Saharia2021PaletteID,
+ title = {Palette: Image-to-Image Diffusion Models},
+ author = {Chitwan Saharia and William Chan and Huiwen Chang and Chris A. Lee and Jonathan Ho and Tim Salimans and David J. Fleet and Mohammad Norouzi},
+ journal = {ArXiv},
+ year = {2021},
+ volume = {abs/2111.05826}
+}
+```
+
+```bibtex
+@article{Lugmayr2022RePaintIU,
+ title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
+ author = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
+ journal = {ArXiv},
+ year = {2022},
+ volume = {abs/2201.09865}
+}
+```
+
+```bibtex
+@misc{chen2022analog,
+ title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
+ author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
+ year = {2022},
+ eprint = {2208.04202},
+ archivePrefix = {arXiv},
+ primaryClass = {cs.CV}
+}
+```
+
+```bibtex
+@article{Qiao2019WeightS,
+ title = {Weight Standardization},
+ author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Loddon Yuille},
+ journal = {ArXiv},
+ year = {2019},
+ volume = {abs/1903.10520}
+}
+```
+
+```bibtex
+@inproceedings{rogozhnikov2022einops,
+ title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
+ author = {Alex Rogozhnikov},
+ booktitle = {International Conference on Learning Representations},
+ year = {2022},
+ url = {https://openreview.net/forum?id=oapKSVM2bcj}
+}
+```
+
+```bibtex
+@article{Sunkara2022NoMS,
+ title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
+ author = {Raja Sunkara and Tie Luo},
+ journal = {ArXiv},
+ year = {2022},
+ volume = {abs/2208.03641}
+}
+```
+
+```bibtex
+@article{Salimans2022ProgressiveDF,
+ title = {Progressive Distillation for Fast Sampling of Diffusion Models},
+ author = {Tim Salimans and Jonathan Ho},
+ journal = {ArXiv},
+ year = {2022},
+ volume = {abs/2202.00512}
+}
+```
+
+*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper
diff --git a/docs/src/configs/README.md b/docs/src/configs/README.md
new file mode 100644
index 00000000..be3cb65d
--- /dev/null
+++ b/docs/src/configs/README.md
@@ -0,0 +1,185 @@
+## DALLE2 Training Configurations
+
+For more complex configuration, we provide the option of using a configuration file instead of command line arguments.
+
+### Decoder Trainer
+
+The decoder trainer has 7 main configuration options. A full example of their use can be found in the [example decoder configuration](train_decoder_config.example.json).
+
+**Unet:**
+
+This is a single unet config, which belongs as an array nested under the decoder config as a list of `unets`
+
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `dim` | Yes | N/A | The starting channels of the unet. |
+| `image_embed_dim` | Yes | N/A | The dimension of the image embeddings. |
+| `dim_mults` | No | `(1, 2, 4, 8)` | The growth factors of the channels. |
+
+Any parameter from the `Unet` constructor can also be given here.
+
+**Decoder:**
+
+Defines the configuration options for the decoder model. The unets defined above will automatically be inserted.
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `unets` | Yes | N/A | A list of unets, using the configuration above |
+| `image_sizes` | Yes | N/A | The resolution of the image after each upsampling step. The length of this array should be the number of unets defined. |
+| `image_size` | Yes | N/A | Not used. Can be any number. |
+| `timesteps` | No | `1000` | The number of diffusion timesteps used for generation. |
+| `loss_type` | No | `l2` | The loss function. Options are `l1`, `huber`, or `l2`. |
+| `beta_schedule` | No | `cosine` | The noising schedule. Options are `cosine`, `linear`, `quadratic`, `jsd`, or `sigmoid`. |
+| `learned_variance` | No | `True` | Whether to learn the variance. |
+| `clip` | No | `None` | The clip model to use if embeddings are being generated on the fly. Takes keys `make` and `model` with defaults `openai` and `ViT-L/14`. |
+
+Any parameter from the `Decoder` constructor can also be given here.
+
+**Data:**
+
+Settings for creation of the dataloaders.
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `webdataset_base_url` | Yes | N/A | The url of a shard in the webdataset with the shard replaced with `{}`[^1]. |
+| `img_embeddings_url` | No | `None` | The url of the folder containing image embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
+| `text_embeddings_url` | No | `None` | The url of the folder containing text embeddings shards. Not required if embeddings are in webdataset or clip is being used. |
+| `num_workers` | No | `4` | The number of workers used in the dataloader. |
+| `batch_size` | No | `64` | The batch size. |
+| `start_shard` | No | `0` | Defines the start of the shard range the dataset will recall. |
+| `end_shard` | No | `9999999` | Defines the end of the shard range the dataset will recall. |
+| `shard_width` | No | `6` | Defines the width of one webdataset shard number[^2]. |
+| `index_width` | No | `4` | Defines the width of the index of a file inside a shard[^3]. |
+| `splits` | No | `{ "train": 0.75, "val": 0.15, "test": 0.1 }` | Defines the proportion of shards that will be allocated to the training, validation, and testing datasets. |
+| `shuffle_train` | No | `True` | Whether to shuffle the shards of the training dataset. |
+| `resample_train` | No | `False` | If true, shards will be randomly sampled with replacement from the datasets making the epoch length infinite if a limit is not set. Cannot be enabled if `shuffle_train` is enabled. |
+| `preprocessing` | No | `{ "ToTensor": True }` | Defines preprocessing applied to images from the datasets. |
+
+[^1]: If your shard files have the paths `protocol://path/to/shard/00104.tar`, then the base url would be `protocol://path/to/shard/{}.tar`. If you are using a protocol like `s3`, you need to pipe the tars. For example `pipe:s3cmd get s3://bucket/path/{}.tar -`.
+
+[^2]: This refers to the string length of the shard number for your webdataset shards. For instance, if your webdataset shard has the filename `00104.tar`, your shard length is 5.
+
+[^3]: Inside the webdataset `tar`, you have files named something like `001045945.jpg`. 5 of these characters refer to the shard, and 4 refer to the index of the file in the webdataset (shard is `001041` and index is `5945`). The `index_width` in this case is 4.
+
+**Train:**
+
+Settings for controlling the training hyperparameters.
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `epochs` | No | `20` | The number of epochs in the training run. |
+| `lr` | No | `1e-4` | The learning rate. |
+| `wd` | No | `0.01` | The weight decay. |
+| `max_grad_norm`| No | `0.5` | The grad norm clipping. |
+| `save_every_n_samples` | No | `100000` | Samples will be generated and a checkpoint will be saved every `save_every_n_samples` samples. |
+| `cond_scale` | No | `1.0` | Conditioning scale to use for sampling. Can also be an array of values, one for each unet. |
+| `device` | No | `cuda:0` | The device to train on. |
+| `epoch_samples` | No | `None` | Limits the number of samples iterated through in each epoch. This must be set if resampling. None means no limit. |
+| `validation_samples` | No | `None` | The number of samples to use for validation. None mean the entire validation set. |
+| `use_ema` | No | `True` | Whether to use exponential moving average models for sampling. |
+| `ema_beta` | No | `0.99` | The ema coefficient. |
+| `unet_training_mask` | No | `None` | A boolean array of the same length as the number of unets. If false, the unet is frozen. A value of `None` trains all unets. |
+
+**Evaluate:**
+
+Defines which evaluation metrics will be used to test the model.
+Each metric can be enabled by setting its configuration. The configuration keys for each metric are defined by the torchmetrics constructors which will be linked.
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `n_evaluation_samples` | No | `1000` | The number of samples to generate to test the model. |
+| `FID` | No | `None` | Setting to an object enables the [Frechet Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/frechet_inception_distance.html) metric.
+| `IS` | No | `None` | Setting to an object enables the [Inception Score](https://torchmetrics.readthedocs.io/en/stable/image/inception_score.html) metric.
+| `KID` | No | `None` | Setting to an object enables the [Kernel Inception Distance](https://torchmetrics.readthedocs.io/en/stable/image/kernel_inception_distance.html) metric. |
+| `LPIPS` | No | `None` | Setting to an object enables the [Learned Perceptual Image Patch Similarity](https://torchmetrics.readthedocs.io/en/stable/image/learned_perceptual_image_patch_similarity.html) metric. |
+
+**Tracker:**
+
+Selects how the experiment will be tracked.
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `data_path` | No | `./.tracker-data` | The path to the folder where temporary tracker data will be saved. |
+| `overwrite_data_path` | No | `False` | If true, the data path will be overwritten. Otherwise, you need to delete it yourself. |
+| `log` | Yes | N/A | Logging configuration. |
+| `load` | No | `None` | Checkpoint loading configuration. |
+| `save` | Yes | N/A | Checkpoint/Model saving configuration. |
+Tracking is split up into three sections:
+* Log: Where to save run metadata and image output. Options are `console` or `wandb`.
+* Load: Where to load a checkpoint from. Options are `local`, `url`, or `wandb`.
+* Save: Where to save a checkpoint to. Options are `local`, `huggingface`, or `wandb`.
+
+**Logging:**
+
+All loggers have the following keys:
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `log_type` | Yes | N/A | The type of logger class to use. |
+| `resume` | No | `False` | For loggers that have the option to resume an old run, resume it using maually input parameters. |
+| `auto_resume` | No | `False` | If true, the logger will attempt to resume an old run using parameters from that previous run. |
+
+If using `console` there is no further configuration than setting `log_type` to `console`.
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `log_type` | Yes | N/A | Must be `console`. |
+
+If using `wandb`
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `log_type` | Yes | N/A | Must be `wandb`. |
+| `wandb_entity` | Yes | N/A | The wandb entity to log to. |
+| `wandb_project` | Yes | N/A | The wandb project save the run to. |
+| `wandb_run_name` | No | `None` | The wandb run name. |
+| `wandb_run_id` | No | `None` | The wandb run id. Used if resuming an old run. |
+
+**Loading:**
+
+All loaders have the following keys:
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `load_from` | Yes | N/A | The type of loader class to use. |
+| `only_auto_resume` | No | `False` | If true, the loader will only load the model if the run is being auto resumed. |
+
+If using `local`
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `load_from` | Yes | N/A | Must be `local`. |
+| `file_path` | Yes | N/A | The path to the checkpoint file. |
+
+If using `url`
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `load_from` | Yes | N/A | Must be `url`. |
+| `url` | Yes | N/A | The url of the checkpoint file. |
+
+If using `wandb`
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `load_from` | Yes | N/A | Must be `wandb`. |
+| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the run that is being resumed. |
+| `wandb_file_path` | Yes | N/A | The path to the checkpoint file in the W&B file system. |
+
+**Saving:**
+Unlike `log` and `load`, `save` may be an array of options so that you can save to different locations in a run.
+
+All save locations have these configuration options
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `save_to` | Yes | N/A | Must be `local`, `huggingface`, or `wandb`. |
+| `save_latest_to` | No | `None` | Sets the relative path to save the latest model to. |
+| `save_best_to` | No | `None` | Sets the relative path to save the best model to every time the model has a lower validation loss than all previous models. |
+| `save_meta_to` | No | `None` | The path to save metadata files in. This includes the config files used to start the training. |
+| `save_type` | No | `checkpoint` | The type of save. `checkpoint` saves a checkpoint, `model` saves a model without any fluff (Saves with ema if ema is enabled). |
+
+If using `local`
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `save_to` | Yes | N/A | Must be `local`. |
+
+If using `huggingface`
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `save_to` | Yes | N/A | Must be `huggingface`. |
+| `huggingface_repo` | Yes | N/A | The huggingface repository to save to. |
+| `token_path` | No | `None` | If logging in with the huggingface cli is not possible, point to a token file instead. |
+
+If using `wandb`
+| Option | Required | Default | Description |
+| ------ | -------- | ------- | ----------- |
+| `save_to` | Yes | N/A | Must be `wandb`. |
+| `wandb_run_path` | No | `None` | The wandb run path. If `None`, uses the current run. You will almost always want this to be `None`. |
diff --git a/docs/src/dalle2_pytorch/__init__.py b/docs/src/dalle2_pytorch/__init__.py
new file mode 100644
index 00000000..53ebb340
--- /dev/null
+++ b/docs/src/dalle2_pytorch/__init__.py
@@ -0,0 +1,7 @@
+from dalle2_pytorch.version import __version__
+from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
+from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter, OpenClipAdapter
+from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
+
+from dalle2_pytorch.vqgan_vae import VQGanVAE
+from x_clip import CLIP
diff --git a/docs/src/dalle2_pytorch/cli.py b/docs/src/dalle2_pytorch/cli.py
new file mode 100644
index 00000000..a2a66504
--- /dev/null
+++ b/docs/src/dalle2_pytorch/cli.py
@@ -0,0 +1,52 @@
+import click
+import torch
+import torchvision.transforms as T
+from functools import reduce
+from pathlib import Path
+
+from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
+
+def safeget(dictionary, keys, default = None):
+ return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
+
+def simple_slugify(text, max_length = 255):
+ return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
+
+def get_pkg_version():
+ from pkg_resources import get_distribution
+ return get_distribution('dalle2_pytorch').version
+
+def main():
+ pass
+
+@click.command()
+@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
+@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
+@click.argument('text')
+def dream(
+ model,
+ cond_scale,
+ text
+):
+ model_path = Path(model)
+ full_model_path = str(model_path.resolve())
+ assert model_path.exists(), f'model not found at {full_model_path}'
+ loaded = torch.load(str(model_path))
+
+ version = safeget(loaded, 'version')
+ print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
+
+ prior_init_params = safeget(loaded, 'init_params.prior')
+ decoder_init_params = safeget(loaded, 'init_params.decoder')
+ model_params = safeget(loaded, 'model_params')
+
+ prior = DiffusionPrior(**prior_init_params)
+ decoder = Decoder(**decoder_init_params)
+
+ dalle2 = DALLE2(prior, decoder)
+ dalle2.load_state_dict(model_params)
+
+ image = dalle2(text, cond_scale = cond_scale)
+
+ pil_image = T.ToPILImage()(image)
+ return pil_image.save(f'./{simple_slugify(text)}.png')
diff --git a/docs/src/dalle2_pytorch/dalle2_pytorch.py b/docs/src/dalle2_pytorch/dalle2_pytorch.py
new file mode 100644
index 00000000..71a6e4c4
--- /dev/null
+++ b/docs/src/dalle2_pytorch/dalle2_pytorch.py
@@ -0,0 +1,3340 @@
+import math
+import random
+from tqdm.auto import tqdm
+from functools import partial, wraps
+from contextlib import contextmanager
+from collections import namedtuple
+from pathlib import Path
+
+import torch
+import torch.nn.functional as F
+from torch.utils.checkpoint import checkpoint
+from torch import nn, einsum
+import torchvision.transforms as T
+
+from einops import rearrange, repeat, reduce, pack, unpack
+from einops.layers.torch import Rearrange
+
+from kornia.filters import gaussian_blur2d
+import kornia.augmentation as K
+
+from dalle2_pytorch.tokenizer import tokenizer
+from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
+
+from resize_right import resize
+
+# rotary embeddings
+
+from rotary_embedding_torch import RotaryEmbedding
+
+# use x-clip
+
+from x_clip import CLIP
+from coca_pytorch import CoCa
+
+# constants
+
+NAT = 1. / math.log(2.)
+
+UnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized'])
+
+# helper functions
+
+def exists(val):
+ return val is not None
+
+def identity(t, *args, **kwargs):
+ return t
+
+def first(arr, d = None):
+ if len(arr) == 0:
+ return d
+ return arr[0]
+
+def maybe(fn):
+ @wraps(fn)
+ def inner(x, *args, **kwargs):
+ if not exists(x):
+ return x
+ return fn(x, *args, **kwargs)
+ return inner
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if callable(d) else d
+
+def cast_tuple(val, length = None, validate = True):
+ if isinstance(val, list):
+ val = tuple(val)
+
+ out = val if isinstance(val, tuple) else ((val,) * default(length, 1))
+
+ if exists(length) and validate:
+ assert len(out) == length
+
+ return out
+
+def module_device(module):
+ if isinstance(module, nn.Identity):
+ return 'cpu' # It doesn't matter
+ return next(module.parameters()).device
+
+def zero_init_(m):
+ nn.init.zeros_(m.weight)
+ if exists(m.bias):
+ nn.init.zeros_(m.bias)
+
+@contextmanager
+def null_context(*args, **kwargs):
+ yield
+
+def eval_decorator(fn):
+ def inner(model, *args, **kwargs):
+ was_training = model.training
+ model.eval()
+ out = fn(model, *args, **kwargs)
+ model.train(was_training)
+ return out
+ return inner
+
+def is_float_dtype(dtype):
+ return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
+
+def is_list_str(x):
+ if not isinstance(x, (list, tuple)):
+ return False
+ return all([type(el) == str for el in x])
+
+def pad_tuple_to_length(t, length, fillvalue = None):
+ remain_length = length - len(t)
+ if remain_length <= 0:
+ return t
+ return (*t, *((fillvalue,) * remain_length))
+
+# checkpointing helper function
+
+def make_checkpointable(fn, **kwargs):
+ if isinstance(fn, nn.ModuleList):
+ return [maybe(make_checkpointable)(el, **kwargs) for el in fn]
+
+ condition = kwargs.pop('condition', None)
+
+ if exists(condition) and not condition(fn):
+ return fn
+
+ @wraps(fn)
+ def inner(*args):
+ input_needs_grad = any([isinstance(el, torch.Tensor) and el.requires_grad for el in args])
+
+ if not input_needs_grad:
+ return fn(*args)
+
+ return checkpoint(fn, *args)
+
+ return inner
+
+# for controlling freezing of CLIP
+
+def set_module_requires_grad_(module, requires_grad):
+ for param in module.parameters():
+ param.requires_grad = requires_grad
+
+def freeze_all_layers_(module):
+ set_module_requires_grad_(module, False)
+
+def unfreeze_all_layers_(module):
+ set_module_requires_grad_(module, True)
+
+def freeze_model_and_make_eval_(model):
+ model.eval()
+ freeze_all_layers_(model)
+
+# tensor helpers
+
+def log(t, eps = 1e-12):
+ return torch.log(t.clamp(min = eps))
+
+def l2norm(t):
+ return F.normalize(t, dim = -1)
+
+def resize_image_to(
+ image,
+ target_image_size,
+ clamp_range = None,
+ nearest = False,
+ **kwargs
+):
+ orig_image_size = image.shape[-1]
+
+ if orig_image_size == target_image_size:
+ return image
+
+ if not nearest:
+ scale_factors = target_image_size / orig_image_size
+ out = resize(image, scale_factors = scale_factors, **kwargs)
+ else:
+ out = F.interpolate(image, target_image_size, mode = 'nearest')
+
+ if exists(clamp_range):
+ out = out.clamp(*clamp_range)
+
+ return out
+
+# image normalization functions
+# ddpms expect images to be in the range of -1 to 1
+# but CLIP may otherwise
+
+def normalize_neg_one_to_one(img):
+ return img * 2 - 1
+
+def unnormalize_zero_to_one(normed_img):
+ return (normed_img + 1) * 0.5
+
+# clip related adapters
+
+EmbeddedText = namedtuple('EmbedTextReturn', ['text_embed', 'text_encodings'])
+EmbeddedImage = namedtuple('EmbedImageReturn', ['image_embed', 'image_encodings'])
+
+class BaseClipAdapter(nn.Module):
+ def __init__(self, clip, **kwargs):
+ super().__init__()
+ self.clip = clip
+ self.overrides = kwargs
+
+ def validate_and_resize_image(self, image):
+ image_size = image.shape[-1]
+ assert image_size >= self.image_size, f'you are passing in an image of size {image_size} but CLIP requires the image size to be at least {self.image_size}'
+ return resize_image_to(image, self.image_size)
+
+ @property
+ def dim_latent(self):
+ raise NotImplementedError
+
+ @property
+ def image_size(self):
+ raise NotImplementedError
+
+ @property
+ def image_channels(self):
+ raise NotImplementedError
+
+ @property
+ def max_text_len(self):
+ raise NotImplementedError
+
+ def embed_text(self, text):
+ raise NotImplementedError
+
+ def embed_image(self, image):
+ raise NotImplementedError
+
+class XClipAdapter(BaseClipAdapter):
+ @property
+ def dim_latent(self):
+ return self.clip.dim_latent
+
+ @property
+ def image_size(self):
+ return self.clip.image_size
+
+ @property
+ def image_channels(self):
+ return self.clip.image_channels
+
+ @property
+ def max_text_len(self):
+ return self.clip.text_seq_len
+
+ @torch.no_grad()
+ def embed_text(self, text):
+ text = text[..., :self.max_text_len]
+ text_mask = text != 0
+ encoder_output = self.clip.text_transformer(text)
+
+ encoder_output_is_cls = encoder_output.ndim == 3
+
+ text_cls, text_encodings = (encoder_output[:, 0], encoder_output[:, 1:]) if encoder_output_is_cls else (encoder_output, None)
+ text_embed = self.clip.to_text_latent(text_cls)
+
+ if exists(text_encodings):
+ text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
+
+ return EmbeddedText(l2norm(text_embed), text_encodings)
+
+ @torch.no_grad()
+ def embed_image(self, image):
+ image = self.validate_and_resize_image(image)
+ encoder_output = self.clip.visual_transformer(image)
+ image_cls, image_encodings = encoder_output[:, 0], encoder_output[:, 1:]
+ image_embed = self.clip.to_visual_latent(image_cls)
+ return EmbeddedImage(l2norm(image_embed), image_encodings)
+
+class CoCaAdapter(BaseClipAdapter):
+ @property
+ def dim_latent(self):
+ return self.clip.dim
+
+ @property
+ def image_size(self):
+ assert 'image_size' in self.overrides
+ return self.overrides['image_size']
+
+ @property
+ def image_channels(self):
+ assert 'image_channels' in self.overrides
+ return self.overrides['image_channels']
+
+ @property
+ def max_text_len(self):
+ assert 'max_text_len' in self.overrides
+ return self.overrides['max_text_len']
+
+ @torch.no_grad()
+ def embed_text(self, text):
+ text = text[..., :self.max_text_len]
+ text_mask = text != 0
+ text_embed, text_encodings = self.clip.embed_text(text)
+ text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
+ return EmbeddedText(text_embed, text_encodings)
+
+ @torch.no_grad()
+ def embed_image(self, image):
+ image = self.validate_and_resize_image(image)
+ image_embed, image_encodings = self.clip.embed_image(image)
+ return EmbeddedImage(image_embed, image_encodings)
+
+class OpenAIClipAdapter(BaseClipAdapter):
+ def __init__(
+ self,
+ name = 'ViT-B/32'
+ ):
+ import clip
+ openai_clip, preprocess = clip.load(name)
+ super().__init__(openai_clip)
+ self.eos_id = 49407 # for handling 0 being also '!'
+
+ text_attention_final = self.find_layer('ln_final')
+
+ self.dim_latent_ = text_attention_final.weight.shape[0]
+ self.handle = text_attention_final.register_forward_hook(self._hook)
+
+ self.clip_normalize = preprocess.transforms[-1]
+ self.cleared = False
+
+ def find_layer(self, layer):
+ modules = dict([*self.clip.named_modules()])
+ return modules.get(layer, None)
+
+ def clear(self):
+ if self.cleared:
+ return
+
+ self.handle()
+
+ def _hook(self, _, inputs, outputs):
+ self.text_encodings = outputs
+
+ @property
+ def dim_latent(self):
+ return self.dim_latent_
+
+ @property
+ def image_size(self):
+ return self.clip.visual.input_resolution
+
+ @property
+ def image_channels(self):
+ return 3
+
+ @property
+ def max_text_len(self):
+ return self.clip.context_length
+
+ @torch.no_grad()
+ def embed_text(self, text):
+ text = text[..., :self.max_text_len]
+
+ is_eos_id = (text == self.eos_id)
+ text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
+ text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
+ text_mask = text_mask & (text != 0)
+ assert not self.cleared
+
+ text_embed = self.clip.encode_text(text)
+ text_encodings = self.text_encodings
+ text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
+ del self.text_encodings
+ return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())
+
+ @torch.no_grad()
+ def embed_image(self, image):
+ assert not self.cleared
+ image = self.validate_and_resize_image(image)
+ image = self.clip_normalize(image)
+ image_embed = self.clip.encode_image(image)
+ return EmbeddedImage(l2norm(image_embed.float()), None)
+
+class OpenClipAdapter(BaseClipAdapter):
+ def __init__(
+ self,
+ name = 'ViT-B/32',
+ pretrained = 'laion400m_e32'
+ ):
+ import open_clip
+ clip, _, preprocess = open_clip.create_model_and_transforms(name, pretrained = pretrained)
+
+ super().__init__(clip)
+ self.eos_id = 49407
+
+ text_attention_final = self.find_layer('ln_final')
+ self._dim_latent = text_attention_final.weight.shape[0]
+
+ self.handle = text_attention_final.register_forward_hook(self._hook)
+ self.clip_normalize = preprocess.transforms[-1]
+ self.cleared = False
+
+ def find_layer(self, layer):
+ modules = dict([*self.clip.named_modules()])
+ return modules.get(layer, None)
+
+ def clear(self):
+ if self.cleared:
+ return
+
+ self.handle()
+
+ def _hook(self, _, inputs, outputs):
+ self.text_encodings = outputs
+
+ @property
+ def dim_latent(self):
+ return self._dim_latent
+
+ @property
+ def image_size(self):
+ image_size = self.clip.visual.image_size
+ if isinstance(image_size, tuple):
+ return max(image_size)
+ return image_size
+
+ @property
+ def image_channels(self):
+ return 3
+
+ @property
+ def max_text_len(self):
+ return self.clip.context_length
+
+ @torch.no_grad()
+ def embed_text(self, text):
+ text = text[..., :self.max_text_len]
+
+ is_eos_id = (text == self.eos_id)
+ text_mask_excluding_eos = is_eos_id.cumsum(dim = -1) == 0
+ text_mask = F.pad(text_mask_excluding_eos, (1, -1), value = True)
+ text_mask = text_mask & (text != 0)
+ assert not self.cleared
+
+ text_embed = self.clip.encode_text(text)
+ text_encodings = self.text_encodings
+ text_encodings = text_encodings.masked_fill(~text_mask[..., None], 0.)
+ del self.text_encodings
+ return EmbeddedText(l2norm(text_embed.float()), text_encodings.float())
+
+ @torch.no_grad()
+ def embed_image(self, image):
+ assert not self.cleared
+ image = self.validate_and_resize_image(image)
+ image = self.clip_normalize(image)
+ image_embed = self.clip.encode_image(image)
+ return EmbeddedImage(l2norm(image_embed.float()), None)
+
+# classifier free guidance functions
+
+def prob_mask_like(shape, prob, device):
+ if prob == 1:
+ return torch.ones(shape, device = device, dtype = torch.bool)
+ elif prob == 0:
+ return torch.zeros(shape, device = device, dtype = torch.bool)
+ else:
+ return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
+
+# gaussian diffusion helper functions
+
+def extract(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+def meanflat(x):
+ return x.mean(dim = tuple(range(1, len(x.shape))))
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2))
+
+def approx_standard_normal_cdf(x):
+ return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3))))
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):
+ assert x.shape == means.shape == log_scales.shape
+
+ # attempting to correct nan gradients when learned variance is turned on
+ # in the setting of deepspeed fp16
+ eps = 1e-12 if x.dtype == torch.float32 else 1e-3
+
+ centered_x = x - means
+ inv_stdv = torch.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1. / 255.)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1. / 255.)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = log(cdf_plus, eps = eps)
+ log_one_minus_cdf_min = log(1. - cdf_min, eps = eps)
+ cdf_delta = cdf_plus - cdf_min
+
+ log_probs = torch.where(x < -thres,
+ log_cdf_plus,
+ torch.where(x > thres,
+ log_one_minus_cdf_min,
+ log(cdf_delta, eps = eps)))
+
+ return log_probs
+
+def cosine_beta_schedule(timesteps, s = 0.008):
+ """
+ cosine schedule
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
+ """
+ steps = timesteps + 1
+ x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
+ alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
+ alphas_cumprod = alphas_cumprod / first(alphas_cumprod)
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
+ return torch.clip(betas, 0, 0.999)
+
+
+def linear_beta_schedule(timesteps):
+ scale = 1000 / timesteps
+ beta_start = scale * 0.0001
+ beta_end = scale * 0.02
+ return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
+
+
+def quadratic_beta_schedule(timesteps):
+ scale = 1000 / timesteps
+ beta_start = scale * 0.0001
+ beta_end = scale * 0.02
+ return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64) ** 2
+
+
+def sigmoid_beta_schedule(timesteps):
+ scale = 1000 / timesteps
+ beta_start = scale * 0.0001
+ beta_end = scale * 0.02
+ betas = torch.linspace(-6, 6, timesteps, dtype = torch.float64)
+ return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
+
+
+class NoiseScheduler(nn.Module):
+ def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1):
+ super().__init__()
+
+ if beta_schedule == "cosine":
+ betas = cosine_beta_schedule(timesteps)
+ elif beta_schedule == "linear":
+ betas = linear_beta_schedule(timesteps)
+ elif beta_schedule == "quadratic":
+ betas = quadratic_beta_schedule(timesteps)
+ elif beta_schedule == "jsd":
+ betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
+ elif beta_schedule == "sigmoid":
+ betas = sigmoid_beta_schedule(timesteps)
+ else:
+ raise NotImplementedError()
+
+ alphas = 1. - betas
+ alphas_cumprod = torch.cumprod(alphas, axis = 0)
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+
+ if loss_type == 'l1':
+ loss_fn = F.l1_loss
+ elif loss_type == 'l2':
+ loss_fn = F.mse_loss
+ elif loss_type == 'huber':
+ loss_fn = F.smooth_l1_loss
+ else:
+ raise NotImplementedError()
+
+ self.loss_type = loss_type
+ self.loss_fn = loss_fn
+
+ # register buffer helper function to cast double back to float
+
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
+
+ register_buffer('betas', betas)
+ register_buffer('alphas_cumprod', alphas_cumprod)
+ register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+
+ register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
+ register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
+ register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
+ register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
+ register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
+
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+
+ register_buffer('posterior_variance', posterior_variance)
+
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+
+ register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
+ register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
+ register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
+
+ # p2 loss reweighting
+
+ self.has_p2_loss_reweighting = p2_loss_weight_gamma > 0.
+ register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
+
+ def sample_random_times(self, batch):
+ return torch.randint(0, self.num_timesteps, (batch,), device = self.betas.device, dtype = torch.long)
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def q_sample(self, x_start, t, noise = None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+
+ return (
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+ )
+
+ def calculate_v(self, x_start, t, noise = None):
+ return (
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
+ )
+
+ def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
+ shape = x_from.shape
+ noise = default(noise, lambda: torch.randn_like(x_from))
+
+ alpha = extract(self.sqrt_alphas_cumprod, from_t, shape)
+ sigma = extract(self.sqrt_one_minus_alphas_cumprod, from_t, shape)
+ alpha_next = extract(self.sqrt_alphas_cumprod, to_t, shape)
+ sigma_next = extract(self.sqrt_one_minus_alphas_cumprod, to_t, shape)
+
+ return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
+
+ def predict_start_from_v(self, x_t, t, v):
+ return (
+ extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def predict_noise_from_start(self, x_t, t, x0):
+ return (
+ (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+ )
+
+ def p2_reweigh_loss(self, loss, times):
+ if not self.has_p2_loss_reweighting:
+ return loss
+ return loss * extract(self.p2_loss_weight, times, loss.shape)
+
+# rearrange image to sequence
+
+class RearrangeToSequence(nn.Module):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, x):
+ x = rearrange(x, 'b c ... -> b ... c')
+ x, ps = pack([x], 'b * c')
+
+ x = self.fn(x)
+
+ x, = unpack(x, ps, 'b * c')
+ x = rearrange(x, 'b ... c -> b c ...')
+ return x
+
+# diffusion prior
+
+class LayerNorm(nn.Module):
+ def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
+ super().__init__()
+ self.eps = eps
+ self.fp16_eps = fp16_eps
+ self.stable = stable
+ self.g = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
+
+ if self.stable:
+ x = x / x.amax(dim = -1, keepdim = True).detach()
+
+ var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
+ mean = torch.mean(x, dim = -1, keepdim = True)
+ return (x - mean) * (var + eps).rsqrt() * self.g
+
+class ChanLayerNorm(nn.Module):
+ def __init__(self, dim, eps = 1e-5, fp16_eps = 1e-3, stable = False):
+ super().__init__()
+ self.eps = eps
+ self.fp16_eps = fp16_eps
+ self.stable = stable
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
+
+ def forward(self, x):
+ eps = self.eps if x.dtype == torch.float32 else self.fp16_eps
+
+ if self.stable:
+ x = x / x.amax(dim = 1, keepdim = True).detach()
+
+ var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
+ mean = torch.mean(x, dim = 1, keepdim = True)
+ return (x - mean) * (var + eps).rsqrt() * self.g
+
+class Residual(nn.Module):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ return self.fn(x, **kwargs) + x
+
+# mlp
+
+class MLP(nn.Module):
+ def __init__(
+ self,
+ dim_in,
+ dim_out,
+ *,
+ expansion_factor = 2.,
+ depth = 2,
+ norm = False,
+ ):
+ super().__init__()
+ hidden_dim = int(expansion_factor * dim_out)
+ norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
+
+ layers = [nn.Sequential(
+ nn.Linear(dim_in, hidden_dim),
+ nn.SiLU(),
+ norm_fn()
+ )]
+
+ for _ in range(depth - 1):
+ layers.append(nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.SiLU(),
+ norm_fn()
+ ))
+
+ layers.append(nn.Linear(hidden_dim, dim_out))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, x):
+ return self.net(x.float())
+
+# relative positional bias for causal transformer
+
+class RelPosBias(nn.Module):
+ def __init__(
+ self,
+ heads = 8,
+ num_buckets = 32,
+ max_distance = 128,
+ ):
+ super().__init__()
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
+
+ @staticmethod
+ def _relative_position_bucket(
+ relative_position,
+ num_buckets = 32,
+ max_distance = 128
+ ):
+ n = -relative_position
+ n = torch.max(n, torch.zeros_like(n))
+
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+ return torch.where(is_small, n, val_if_large)
+
+ def forward(self, i, j, *, device):
+ q_pos = torch.arange(i, dtype = torch.long, device = device)
+ k_pos = torch.arange(j, dtype = torch.long, device = device)
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
+ values = self.relative_attention_bias(rp_bucket)
+ return rearrange(values, 'i j h -> h i j')
+
+# feedforward
+
+class SwiGLU(nn.Module):
+ """ used successfully in https://arxiv.org/abs/2204.0231 """
+ def forward(self, x):
+ x, gate = x.chunk(2, dim = -1)
+ return x * F.silu(gate)
+
+def FeedForward(
+ dim,
+ mult = 4,
+ dropout = 0.,
+ post_activation_norm = False
+):
+ """ post-activation norm https://arxiv.org/abs/2110.09456 """
+
+ inner_dim = int(mult * dim)
+ return nn.Sequential(
+ LayerNorm(dim),
+ nn.Linear(dim, inner_dim * 2, bias = False),
+ SwiGLU(),
+ LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim, bias = False)
+ )
+
+# attention
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ *,
+ dim_head = 64,
+ heads = 8,
+ dropout = 0.,
+ causal = False,
+ rotary_emb = None,
+ cosine_sim = True,
+ cosine_sim_scale = 16
+ ):
+ super().__init__()
+ self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
+ self.cosine_sim = cosine_sim
+
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.causal = causal
+ self.norm = LayerNorm(dim)
+ self.dropout = nn.Dropout(dropout)
+
+ self.null_kv = nn.Parameter(torch.randn(2, dim_head))
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
+ self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
+
+ self.rotary_emb = rotary_emb
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, dim, bias = False),
+ LayerNorm(dim)
+ )
+
+ def forward(self, x, mask = None, attn_bias = None):
+ b, n, device = *x.shape[:2], x.device
+
+ x = self.norm(x)
+ q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
+
+ q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
+ q = q * self.scale
+
+ # rotary embeddings
+
+ if exists(self.rotary_emb):
+ q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
+
+ # add null key / value for classifier free guidance in prior net
+
+ nk, nv = map(lambda t: repeat(t, 'd -> b 1 d', b = b), self.null_kv.unbind(dim = -2))
+ k = torch.cat((nk, k), dim = -2)
+ v = torch.cat((nv, v), dim = -2)
+
+ # whether to use cosine sim
+
+ if self.cosine_sim:
+ q, k = map(l2norm, (q, k))
+
+ q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
+
+ # calculate query / key similarities
+
+ sim = einsum('b h i d, b j d -> b h i j', q, k)
+
+ # relative positional encoding (T5 style)
+
+ if exists(attn_bias):
+ sim = sim + attn_bias
+
+ # masking
+
+ max_neg_value = -torch.finfo(sim.dtype).max
+
+ if exists(mask):
+ mask = F.pad(mask, (1, 0), value = True)
+ mask = rearrange(mask, 'b j -> b 1 1 j')
+ sim = sim.masked_fill(~mask, max_neg_value)
+
+ if self.causal:
+ i, j = sim.shape[-2:]
+ causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
+ sim = sim.masked_fill(causal_mask, max_neg_value)
+
+ # attention
+
+ attn = sim.softmax(dim = -1, dtype = torch.float32)
+ attn = attn.type(sim.dtype)
+
+ attn = self.dropout(attn)
+
+ # aggregate values
+
+ out = einsum('b h i j, b j d -> b h i d', attn, v)
+
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+class CausalTransformer(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ depth,
+ dim_head = 64,
+ heads = 8,
+ ff_mult = 4,
+ norm_in = False,
+ norm_out = True,
+ attn_dropout = 0.,
+ ff_dropout = 0.,
+ final_proj = True,
+ normformer = False,
+ rotary_emb = True
+ ):
+ super().__init__()
+ self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
+
+ self.rel_pos_bias = RelPosBias(heads = heads)
+
+ rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(nn.ModuleList([
+ Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
+ FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
+ ]))
+
+ self.norm = LayerNorm(dim, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
+ self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
+
+ def forward(self, x):
+ n, device = x.shape[1], x.device
+
+ x = self.init_norm(x)
+
+ attn_bias = self.rel_pos_bias(n, n + 1, device = device)
+
+ for attn, ff in self.layers:
+ x = attn(x, attn_bias = attn_bias) + x
+ x = ff(x) + x
+
+ out = self.norm(x)
+ return self.project_out(out)
+
+class DiffusionPriorNetwork(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_timesteps = None,
+ num_time_embeds = 1,
+ num_image_embeds = 1,
+ num_text_embeds = 1,
+ max_text_len = 256,
+ self_cond = False,
+ **kwargs
+ ):
+ super().__init__()
+ self.dim = dim
+
+ self.num_time_embeds = num_time_embeds
+ self.num_image_embeds = num_image_embeds
+ self.num_text_embeds = num_text_embeds
+
+ self.to_text_embeds = nn.Sequential(
+ nn.Linear(dim, dim * num_text_embeds) if num_text_embeds > 1 else nn.Identity(),
+ Rearrange('b (n d) -> b n d', n = num_text_embeds)
+ )
+
+ self.continuous_embedded_time = not exists(num_timesteps)
+
+ self.to_time_embeds = nn.Sequential(
+ nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
+ Rearrange('b (n d) -> b n d', n = num_time_embeds)
+ )
+
+ self.to_image_embeds = nn.Sequential(
+ nn.Linear(dim, dim * num_image_embeds) if num_image_embeds > 1 else nn.Identity(),
+ Rearrange('b (n d) -> b n d', n = num_image_embeds)
+ )
+
+ self.learned_query = nn.Parameter(torch.randn(dim))
+ self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
+
+ # dalle1 learned padding strategy
+
+ self.max_text_len = max_text_len
+
+ self.null_text_encodings = nn.Parameter(torch.randn(1, max_text_len, dim))
+ self.null_text_embeds = nn.Parameter(torch.randn(1, num_text_embeds, dim))
+ self.null_image_embed = nn.Parameter(torch.randn(1, dim))
+
+ # whether to use self conditioning, Hinton's group's new ddpm technique
+
+ self.self_cond = self_cond
+
+ def forward_with_cond_scale(
+ self,
+ *args,
+ cond_scale = 1.,
+ **kwargs
+ ):
+ logits = self.forward(*args, **kwargs)
+
+ if cond_scale == 1:
+ return logits
+
+ null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1, **kwargs)
+ return null_logits + (logits - null_logits) * cond_scale
+
+ def forward(
+ self,
+ image_embed,
+ diffusion_timesteps,
+ *,
+ text_embed,
+ text_encodings = None,
+ self_cond = None,
+ text_cond_drop_prob = 0.,
+ image_cond_drop_prob = 0.
+ ):
+ batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
+
+ num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
+
+ # setup self conditioning
+
+ if self.self_cond:
+ self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))
+ self_cond = rearrange(self_cond, 'b d -> b 1 d')
+
+ # in section 2.2, last paragraph
+ # "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
+
+ text_embed = self.to_text_embeds(text_embed)
+ image_embed = self.to_image_embeds(image_embed)
+
+ # classifier free guidance masks
+
+ text_keep_mask = prob_mask_like((batch,), 1 - text_cond_drop_prob, device = device)
+ text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
+
+ image_keep_mask = prob_mask_like((batch,), 1 - image_cond_drop_prob, device = device)
+ image_keep_mask = rearrange(image_keep_mask, 'b -> b 1 1')
+
+ # make text encodings optional
+ # although the paper seems to suggest it is present <--
+
+ if not exists(text_encodings):
+ text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
+
+ mask = torch.any(text_encodings != 0., dim = -1)
+
+ # replace any padding in the text encodings with learned padding tokens unique across position
+
+ text_encodings = text_encodings[:, :self.max_text_len]
+ mask = mask[:, :self.max_text_len]
+
+ text_len = text_encodings.shape[-2]
+ remainder = self.max_text_len - text_len
+
+ if remainder > 0:
+ text_encodings = F.pad(text_encodings, (0, 0, 0, remainder), value = 0.)
+ mask = F.pad(mask, (0, remainder), value = False)
+
+ # mask out text encodings with null encodings
+
+ null_text_encodings = self.null_text_encodings.to(text_encodings.dtype)
+
+ text_encodings = torch.where(
+ rearrange(mask, 'b n -> b n 1').clone() & text_keep_mask,
+ text_encodings,
+ null_text_encodings
+ )
+
+ # mask out text embeddings with null text embeddings
+
+ null_text_embeds = self.null_text_embeds.to(text_embed.dtype)
+
+ text_embed = torch.where(
+ text_keep_mask,
+ text_embed,
+ null_text_embeds
+ )
+
+ # mask out image embeddings with null image embeddings
+
+ null_image_embed = self.null_image_embed.to(image_embed.dtype)
+
+ image_embed = torch.where(
+ image_keep_mask,
+ image_embed,
+ null_image_embed
+ )
+
+ # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
+ # but let's just do it right
+
+ if self.continuous_embedded_time:
+ diffusion_timesteps = diffusion_timesteps.type(dtype)
+
+ time_embed = self.to_time_embeds(diffusion_timesteps)
+
+ learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
+
+ if self.self_cond:
+ learned_queries = torch.cat((self_cond, learned_queries), dim = -2)
+
+ tokens = torch.cat((
+ text_encodings,
+ text_embed,
+ time_embed,
+ image_embed,
+ learned_queries
+ ), dim = -2)
+
+ # attend
+
+ tokens = self.causal_transformer(tokens)
+
+ # get learned query, which should predict the image embedding (per DDPM timestep)
+
+ pred_image_embed = tokens[..., -1, :]
+
+ return pred_image_embed
+
+class DiffusionPrior(nn.Module):
+ def __init__(
+ self,
+ net,
+ *,
+ clip = None,
+ image_embed_dim = None,
+ image_size = None,
+ image_channels = 3,
+ timesteps = 1000,
+ sample_timesteps = None,
+ cond_drop_prob = 0.,
+ text_cond_drop_prob = None,
+ image_cond_drop_prob = None,
+ loss_type = "l2",
+ predict_x_start = True,
+ predict_v = False,
+ beta_schedule = "cosine",
+ condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
+ sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
+ sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm)
+ training_clamp_l2norm = False,
+ init_image_embed_l2norm = False,
+ image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
+ clip_adapter_overrides = dict()
+ ):
+ super().__init__()
+
+ self.sample_timesteps = sample_timesteps
+
+ self.noise_scheduler = NoiseScheduler(
+ beta_schedule = beta_schedule,
+ timesteps = timesteps,
+ loss_type = loss_type
+ )
+
+ if exists(clip):
+ assert image_channels == clip.image_channels, f'channels of image ({image_channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
+
+ if isinstance(clip, CLIP):
+ clip = XClipAdapter(clip, **clip_adapter_overrides)
+ elif isinstance(clip, CoCa):
+ clip = CoCaAdapter(clip, **clip_adapter_overrides)
+
+ assert isinstance(clip, BaseClipAdapter)
+ freeze_model_and_make_eval_(clip)
+ self.clip = clip
+ else:
+ assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
+ self.clip = None
+
+ self.net = net
+ self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
+
+ assert net.dim == self.image_embed_dim, f'your diffusion prior network has a dimension of {net.dim}, but you set your image embedding dimension (keyword image_embed_dim) on DiffusionPrior to {self.image_embed_dim}'
+ assert not exists(clip) or clip.dim_latent == self.image_embed_dim, f'you passed in a CLIP to the diffusion prior with latent dimensions of {clip.dim_latent}, but your image embedding dimension (keyword image_embed_dim) for the DiffusionPrior was set to {self.image_embed_dim}'
+
+ self.channels = default(image_channels, lambda: clip.image_channels)
+
+ self.text_cond_drop_prob = default(text_cond_drop_prob, cond_drop_prob)
+ self.image_cond_drop_prob = default(image_cond_drop_prob, cond_drop_prob)
+
+ self.can_classifier_guidance = self.text_cond_drop_prob > 0. and self.image_cond_drop_prob > 0.
+ self.condition_on_text_encodings = condition_on_text_encodings
+
+ # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
+
+ self.predict_x_start = predict_x_start
+ self.predict_v = predict_v # takes precedence over predict_x_start
+
+ # @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
+
+ self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5)
+
+ # whether to force an l2norm, similar to clipping denoised, when sampling
+
+ self.sampling_clamp_l2norm = sampling_clamp_l2norm
+ self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm
+
+ self.training_clamp_l2norm = training_clamp_l2norm
+ self.init_image_embed_l2norm = init_image_embed_l2norm
+
+ # device tracker
+
+ self.register_buffer('_dummy', torch.tensor([True]), persistent = False)
+
+ @property
+ def device(self):
+ return self._dummy.device
+
+ def l2norm_clamp_embed(self, image_embed):
+ return l2norm(image_embed) * self.image_embed_scale
+
+ def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
+ assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
+
+ pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
+
+ if self.predict_v:
+ x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)
+ elif self.predict_x_start:
+ x_start = pred
+ else:
+ x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
+
+ if clip_denoised and not self.predict_x_start:
+ x_start.clamp_(-1., 1.)
+
+ if self.predict_x_start and self.sampling_clamp_l2norm:
+ x_start = l2norm(x_start) * self.image_embed_scale
+
+ model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance, x_start
+
+ @torch.no_grad()
+ def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
+ noise = torch.randn_like(x)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+ return pred, x_start
+
+ @torch.no_grad()
+ def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
+ batch, device = shape[0], self.device
+
+ image_embed = torch.randn(shape, device = device)
+ x_start = None # for self-conditioning
+
+ if self.init_image_embed_l2norm:
+ image_embed = l2norm(image_embed) * self.image_embed_scale
+
+ for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
+ times = torch.full((batch,), i, device = device, dtype = torch.long)
+
+ self_cond = x_start if self.net.self_cond else None
+ image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)
+
+ if self.sampling_final_clamp_l2norm and self.predict_x_start:
+ image_embed = self.l2norm_clamp_embed(image_embed)
+
+ return image_embed
+
+ @torch.no_grad()
+ def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.):
+ batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps
+
+ times = torch.linspace(-1., total_timesteps, steps = timesteps + 1)[:-1]
+
+ times = list(reversed(times.int().tolist()))
+ time_pairs = list(zip(times[:-1], times[1:]))
+
+ image_embed = torch.randn(shape, device = device)
+
+ x_start = None # for self-conditioning
+
+ if self.init_image_embed_l2norm:
+ image_embed = l2norm(image_embed) * self.image_embed_scale
+
+ for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
+ alpha = alphas[time]
+ alpha_next = alphas[time_next]
+
+ time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
+
+ self_cond = x_start if self.net.self_cond else None
+
+ pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
+
+ # derive x0
+
+ if self.predict_v:
+ x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)
+ elif self.predict_x_start:
+ x_start = pred
+ else:
+ x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
+
+ # clip x0 before maybe predicting noise
+
+ if not self.predict_x_start:
+ x_start.clamp_(-1., 1.)
+
+ if self.predict_x_start and self.sampling_clamp_l2norm:
+ x_start = self.l2norm_clamp_embed(x_start)
+
+ # predict noise
+
+ pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
+
+ if time_next < 0:
+ image_embed = x_start
+ continue
+
+ c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
+ c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
+ noise = torch.randn_like(image_embed) if time_next > 0 else 0.
+
+ image_embed = x_start * alpha_next.sqrt() + \
+ c1 * noise + \
+ c2 * pred_noise
+
+ if self.predict_x_start and self.sampling_final_clamp_l2norm:
+ image_embed = self.l2norm_clamp_embed(image_embed)
+
+ return image_embed
+
+ @torch.no_grad()
+ def p_sample_loop(self, *args, timesteps = None, **kwargs):
+ timesteps = default(timesteps, self.noise_scheduler.num_timesteps)
+ assert timesteps <= self.noise_scheduler.num_timesteps
+ is_ddim = timesteps < self.noise_scheduler.num_timesteps
+
+ if not is_ddim:
+ normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
+ else:
+ normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
+
+ image_embed = normalized_image_embed / self.image_embed_scale
+ return image_embed
+
+ def p_losses(self, image_embed, times, text_cond, noise = None):
+ noise = default(noise, lambda: torch.randn_like(image_embed))
+
+ image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
+
+ self_cond = None
+ if self.net.self_cond and random.random() < 0.5:
+ with torch.no_grad():
+ self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
+
+ pred = self.net(
+ image_embed_noisy,
+ times,
+ self_cond = self_cond,
+ text_cond_drop_prob = self.text_cond_drop_prob,
+ image_cond_drop_prob = self.image_cond_drop_prob,
+ **text_cond
+ )
+
+ if self.predict_x_start and self.training_clamp_l2norm:
+ pred = self.l2norm_clamp_embed(pred)
+
+ if self.predict_v:
+ target = self.noise_scheduler.calculate_v(image_embed, times, noise)
+ elif self.predict_x_start:
+ target = image_embed
+ else:
+ target = noise
+
+ loss = self.noise_scheduler.loss_fn(pred, target)
+ return loss
+
+ @torch.no_grad()
+ @eval_decorator
+ def sample_batch_size(self, batch_size, text_cond, cond_scale = 1.):
+ device = self.betas.device
+ shape = (batch_size, self.image_embed_dim)
+
+ img = torch.randn(shape, device = device)
+
+ for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = self.noise_scheduler.num_timesteps):
+ img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale)
+ return img
+
+ @torch.no_grad()
+ @eval_decorator
+ def sample(
+ self,
+ text,
+ num_samples_per_batch = 2,
+ cond_scale = 1.,
+ timesteps = None
+ ):
+ timesteps = default(timesteps, self.sample_timesteps)
+
+ # in the paper, what they did was
+ # sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
+ text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
+
+ batch_size = text.shape[0]
+ image_embed_dim = self.image_embed_dim
+
+ text_embed, text_encodings = self.clip.embed_text(text)
+
+ text_cond = dict(text_embed = text_embed)
+
+ if self.condition_on_text_encodings:
+ text_cond = {**text_cond, 'text_encodings': text_encodings}
+
+ image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
+
+ # retrieve original unscaled image embed
+
+ text_embeds = text_cond['text_embed']
+
+ text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
+ image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
+
+ text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
+ top_sim_indices = text_image_sims.topk(k = 1).indices
+
+ top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
+
+ top_image_embeds = image_embeds.gather(1, top_sim_indices)
+ return rearrange(top_image_embeds, 'b 1 d -> b d')
+
+ def forward(
+ self,
+ text = None,
+ image = None,
+ text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
+ image_embed = None,
+ text_encodings = None, # as well as CLIP text encodings
+ *args,
+ **kwargs
+ ):
+ assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
+ assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
+ assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
+
+ if exists(image):
+ image_embed, _ = self.clip.embed_image(image)
+
+ # calculate text conditionings, based on what is passed in
+
+ if exists(text):
+ text_embed, text_encodings = self.clip.embed_text(text)
+
+ text_cond = dict(text_embed = text_embed)
+
+ if self.condition_on_text_encodings:
+ assert exists(text_encodings), 'text encodings must be present for diffusion prior if specified'
+ text_cond = {**text_cond, 'text_encodings': text_encodings}
+
+ # timestep conditioning from ddpm
+
+ batch, device = image_embed.shape[0], image_embed.device
+ times = self.noise_scheduler.sample_random_times(batch)
+
+ # scale image embed (Katherine)
+
+ image_embed *= self.image_embed_scale
+
+ # calculate forward loss
+
+ return self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
+
+# decoder
+
+def NearestUpsample(dim, dim_out = None):
+ dim_out = default(dim_out, dim)
+
+ return nn.Sequential(
+ nn.Upsample(scale_factor = 2, mode = 'nearest'),
+ nn.Conv2d(dim, dim_out, 3, padding = 1)
+ )
+
+class PixelShuffleUpsample(nn.Module):
+ """
+ code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts
+ https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf
+ """
+ def __init__(self, dim, dim_out = None):
+ super().__init__()
+ dim_out = default(dim_out, dim)
+ conv = nn.Conv2d(dim, dim_out * 4, 1)
+
+ self.net = nn.Sequential(
+ conv,
+ nn.SiLU(),
+ nn.PixelShuffle(2)
+ )
+
+ self.init_conv_(conv)
+
+ def init_conv_(self, conv):
+ o, i, h, w = conv.weight.shape
+ conv_weight = torch.empty(o // 4, i, h, w)
+ nn.init.kaiming_uniform_(conv_weight)
+ conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
+
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+ def forward(self, x):
+ return self.net(x)
+
+def Downsample(dim, dim_out = None):
+ # https://arxiv.org/abs/2208.03641 shows this is the most optimal way to downsample
+ # named SP-conv in the paper, but basically a pixel unshuffle
+ dim_out = default(dim_out, dim)
+ return nn.Sequential(
+ Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2),
+ nn.Conv2d(dim * 4, dim_out, 1)
+ )
+
+class WeightStandardizedConv2d(nn.Conv2d):
+ """
+ https://arxiv.org/abs/1903.10520
+ weight standardization purportedly works synergistically with group normalization
+ """
+ def forward(self, x):
+ eps = 1e-5 if x.dtype == torch.float32 else 1e-3
+
+ weight = self.weight
+ flattened_weights = rearrange(weight, 'o ... -> o (...)')
+
+ mean = reduce(weight, 'o ... -> o 1 1 1', 'mean')
+
+ var = torch.var(flattened_weights, dim = -1, unbiased = False)
+ var = rearrange(var, 'o -> o 1 1 1')
+
+ weight = (weight - mean) * (var + eps).rsqrt()
+
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+class SinusoidalPosEmb(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ dtype, device = x.dtype, x.device
+ assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
+
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
+ emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
+ return torch.cat((emb.sin(), emb.cos()), dim = -1).type(dtype)
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_out,
+ groups = 8,
+ weight_standardization = False
+ ):
+ super().__init__()
+ conv_klass = nn.Conv2d if not weight_standardization else WeightStandardizedConv2d
+
+ self.project = conv_klass(dim, dim_out, 3, padding = 1)
+ self.norm = nn.GroupNorm(groups, dim_out)
+ self.act = nn.SiLU()
+
+ def forward(self, x, scale_shift = None):
+ x = self.project(x)
+ x = self.norm(x)
+
+ if exists(scale_shift):
+ scale, shift = scale_shift
+ x = x * (scale + 1) + shift
+
+ x = self.act(x)
+ return x
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_out,
+ *,
+ cond_dim = None,
+ time_cond_dim = None,
+ groups = 8,
+ weight_standardization = False,
+ cosine_sim_cross_attn = False
+ ):
+ super().__init__()
+
+ self.time_mlp = None
+
+ if exists(time_cond_dim):
+ self.time_mlp = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(time_cond_dim, dim_out * 2)
+ )
+
+ self.cross_attn = None
+
+ if exists(cond_dim):
+ self.cross_attn = CrossAttention(
+ dim = dim_out,
+ context_dim = cond_dim,
+ cosine_sim = cosine_sim_cross_attn
+ )
+
+ self.block1 = Block(dim, dim_out, groups = groups, weight_standardization = weight_standardization)
+ self.block2 = Block(dim_out, dim_out, groups = groups, weight_standardization = weight_standardization)
+ self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
+
+ def forward(self, x, time_emb = None, cond = None):
+
+ scale_shift = None
+ if exists(self.time_mlp) and exists(time_emb):
+ time_emb = self.time_mlp(time_emb)
+ time_emb = rearrange(time_emb, 'b c -> b c 1 1')
+ scale_shift = time_emb.chunk(2, dim = 1)
+
+ h = self.block1(x, scale_shift = scale_shift)
+
+ if exists(self.cross_attn):
+ assert exists(cond)
+
+ h = rearrange(h, 'b c ... -> b ... c')
+ h, ps = pack([h], 'b * c')
+
+ h = self.cross_attn(h, context = cond) + h
+
+ h, = unpack(h, ps, 'b * c')
+ h = rearrange(h, 'b ... c -> b c ...')
+
+ h = self.block2(h)
+ return h + self.res_conv(x)
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ *,
+ context_dim = None,
+ dim_head = 64,
+ heads = 8,
+ dropout = 0.,
+ norm_context = False,
+ cosine_sim = False,
+ cosine_sim_scale = 16
+ ):
+ super().__init__()
+ self.cosine_sim = cosine_sim
+ self.scale = cosine_sim_scale if cosine_sim else (dim_head ** -0.5)
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ context_dim = default(context_dim, dim)
+
+ self.norm = LayerNorm(dim)
+ self.norm_context = LayerNorm(context_dim) if norm_context else nn.Identity()
+ self.dropout = nn.Dropout(dropout)
+
+ self.null_kv = nn.Parameter(torch.randn(2, dim_head))
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, dim, bias = False),
+ LayerNorm(dim)
+ )
+
+ def forward(self, x, context, mask = None):
+ b, n, device = *x.shape[:2], x.device
+
+ x = self.norm(x)
+ context = self.norm_context(context)
+
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
+
+ # add null key / value for classifier free guidance in prior net
+
+ nk, nv = map(lambda t: repeat(t, 'd -> b h 1 d', h = self.heads, b = b), self.null_kv.unbind(dim = -2))
+
+ k = torch.cat((nk, k), dim = -2)
+ v = torch.cat((nv, v), dim = -2)
+
+ if self.cosine_sim:
+ q, k = map(l2norm, (q, k))
+
+ q, k = map(lambda t: t * math.sqrt(self.scale), (q, k))
+
+ sim = einsum('b h i d, b h j d -> b h i j', q, k)
+ max_neg_value = -torch.finfo(sim.dtype).max
+
+ if exists(mask):
+ mask = F.pad(mask, (1, 0), value = True)
+ mask = rearrange(mask, 'b j -> b 1 1 j')
+ sim = sim.masked_fill(~mask, max_neg_value)
+
+ attn = sim.softmax(dim = -1, dtype = torch.float32)
+ attn = attn.type(sim.dtype)
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+class LinearAttention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_head = 32,
+ heads = 8,
+ **kwargs
+ ):
+ super().__init__()
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ inner_dim = dim_head * heads
+ self.norm = ChanLayerNorm(dim)
+
+ self.nonlin = nn.GELU()
+ self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
+
+ self.to_out = nn.Sequential(
+ nn.Conv2d(inner_dim, dim, 1, bias = False),
+ ChanLayerNorm(dim)
+ )
+
+ def forward(self, fmap):
+ h, x, y = self.heads, *fmap.shape[-2:]
+ seq_len = x * y
+
+ fmap = self.norm(fmap)
+ q, k, v = self.to_qkv(fmap).chunk(3, dim = 1)
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = h), (q, k, v))
+
+ q = q.softmax(dim = -1)
+ k = k.softmax(dim = -2)
+
+ q = q * self.scale
+ v = l2norm(v)
+
+ k, v = map(lambda t: t / math.sqrt(seq_len), (k, v))
+
+ context = einsum('b n d, b n e -> b d e', k, v)
+ out = einsum('b n d, b d e -> b n e', q, context)
+ out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
+
+ out = self.nonlin(out)
+ return self.to_out(out)
+
+class CrossEmbedLayer(nn.Module):
+ def __init__(
+ self,
+ dim_in,
+ kernel_sizes,
+ dim_out = None,
+ stride = 2
+ ):
+ super().__init__()
+ assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
+ dim_out = default(dim_out, dim_in)
+
+ kernel_sizes = sorted(kernel_sizes)
+ num_scales = len(kernel_sizes)
+
+ # calculate the dimension at each scale
+ dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
+ dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
+
+ self.convs = nn.ModuleList([])
+ for kernel, dim_scale in zip(kernel_sizes, dim_scales):
+ self.convs.append(nn.Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
+
+ def forward(self, x):
+ fmaps = tuple(map(lambda conv: conv(x), self.convs))
+ return torch.cat(fmaps, dim = 1)
+
+class UpsampleCombiner(nn.Module):
+ def __init__(
+ self,
+ dim,
+ *,
+ enabled = False,
+ dim_ins = tuple(),
+ dim_outs = tuple()
+ ):
+ super().__init__()
+ assert len(dim_ins) == len(dim_outs)
+ self.enabled = enabled
+
+ if not self.enabled:
+ self.dim_out = dim
+ return
+
+ self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
+ self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
+
+ def forward(self, x, fmaps = None):
+ target_size = x.shape[-1]
+
+ fmaps = default(fmaps, tuple())
+
+ if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
+ return x
+
+ fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps]
+ outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
+ return torch.cat((x, *outs), dim = 1)
+
+class Unet(nn.Module):
+ def __init__(
+ self,
+ dim,
+ *,
+ image_embed_dim = None,
+ text_embed_dim = None,
+ cond_dim = None,
+ num_image_tokens = 4,
+ num_time_tokens = 2,
+ out_dim = None,
+ dim_mults=(1, 2, 4, 8),
+ channels = 3,
+ channels_out = None,
+ self_attn = False,
+ attn_dim_head = 32,
+ attn_heads = 16,
+ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
+ lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
+ self_cond = False, # set this to True to use the self-conditioning technique from - https://arxiv.org/abs/2208.04202
+ sparse_attn = False,
+ cosine_sim_cross_attn = False,
+ cosine_sim_self_attn = False,
+ attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
+ cond_on_text_encodings = False,
+ max_text_len = 256,
+ cond_on_image_embeds = False,
+ add_image_embeds_to_time = True, # alerted by @mhh0318 to a phrase in the paper - "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and adding CLIP embeddings to the existing timestep embedding"
+ init_dim = None,
+ init_conv_kernel_size = 7,
+ resnet_groups = 8,
+ resnet_weight_standardization = False,
+ num_resnet_blocks = 2,
+ init_cross_embed = True,
+ init_cross_embed_kernel_sizes = (3, 7, 15),
+ cross_embed_downsample = False,
+ cross_embed_downsample_kernel_sizes = (2, 4),
+ memory_efficient = False,
+ scale_skip_connection = False,
+ pixel_shuffle_upsample = True,
+ final_conv_kernel_size = 1,
+ combine_upsample_fmaps = False, # whether to combine the outputs of all upsample blocks, as in unet squared paper
+ checkpoint_during_training = False,
+ **kwargs
+ ):
+ super().__init__()
+ # save locals to take care of some hyperparameters for cascading DDPM
+
+ self._locals = locals()
+ del self._locals['self']
+ del self._locals['__class__']
+
+ # for eventual cascading diffusion
+
+ self.lowres_cond = lowres_cond
+
+ # whether to do self conditioning
+
+ self.self_cond = self_cond
+
+ # determine dimensions
+
+ self.channels = channels
+ self.channels_out = default(channels_out, channels)
+
+ # initial number of channels depends on
+ # (1) low resolution conditioning from cascading ddpm paper, conditioned on previous unet output in the cascade
+ # (2) self conditioning (bit diffusion paper)
+
+ init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
+
+ init_dim = default(init_dim, dim)
+
+ self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
+
+ dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
+ in_out = list(zip(dims[:-1], dims[1:]))
+
+ num_stages = len(in_out)
+
+ # time, image embeddings, and optional text encoding
+
+ cond_dim = default(cond_dim, dim)
+ time_cond_dim = dim * 4
+
+ self.to_time_hiddens = nn.Sequential(
+ SinusoidalPosEmb(dim),
+ nn.Linear(dim, time_cond_dim),
+ nn.GELU()
+ )
+
+ self.to_time_tokens = nn.Sequential(
+ nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
+ Rearrange('b (r d) -> b r d', r = num_time_tokens)
+ )
+
+ self.to_time_cond = nn.Sequential(
+ nn.Linear(time_cond_dim, time_cond_dim)
+ )
+
+ self.image_to_tokens = nn.Sequential(
+ nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
+ Rearrange('b (n d) -> b n d', n = num_image_tokens)
+ ) if cond_on_image_embeds and image_embed_dim != cond_dim else nn.Identity()
+
+ self.to_image_hiddens = nn.Sequential(
+ nn.Linear(image_embed_dim, time_cond_dim),
+ nn.GELU()
+ ) if cond_on_image_embeds and add_image_embeds_to_time else None
+
+ self.norm_cond = nn.LayerNorm(cond_dim)
+ self.norm_mid_cond = nn.LayerNorm(cond_dim)
+
+ # text encoding conditioning (optional)
+
+ self.text_to_cond = None
+ self.text_embed_dim = None
+
+ if cond_on_text_encodings:
+ assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text_encodings is True'
+ self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
+ self.text_embed_dim = text_embed_dim
+
+ # low resolution noise conditiong, based on Imagen's upsampler training technique
+
+ self.lowres_noise_cond = lowres_noise_cond
+
+ self.to_lowres_noise_cond = nn.Sequential(
+ SinusoidalPosEmb(dim),
+ nn.Linear(dim, time_cond_dim),
+ nn.GELU(),
+ nn.Linear(time_cond_dim, time_cond_dim)
+ ) if lowres_noise_cond else None
+
+ # finer control over whether to condition on image embeddings and text encodings
+ # so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
+
+ self.cond_on_text_encodings = cond_on_text_encodings
+ self.cond_on_image_embeds = cond_on_image_embeds
+
+ # for classifier free guidance
+
+ self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
+ self.null_image_hiddens = nn.Parameter(torch.randn(1, time_cond_dim))
+
+ self.max_text_len = max_text_len
+ self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
+
+ # whether to scale skip connection, adopted in Imagen
+
+ self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
+
+ # attention related params
+
+ attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim = cosine_sim_self_attn)
+
+ self_attn = cast_tuple(self_attn, num_stages)
+
+ create_self_attn = lambda dim: RearrangeToSequence(Residual(Attention(dim, **attn_kwargs)))
+
+ # resnet block klass
+
+ resnet_groups = cast_tuple(resnet_groups, num_stages)
+ top_level_resnet_group = first(resnet_groups)
+
+ num_resnet_blocks = cast_tuple(num_resnet_blocks, num_stages)
+
+ # downsample klass
+
+ downsample_klass = Downsample
+ if cross_embed_downsample:
+ downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
+
+ # upsample klass
+
+ upsample_klass = NearestUpsample if not pixel_shuffle_upsample else PixelShuffleUpsample
+
+ # prepare resnet klass
+
+ resnet_block = partial(ResnetBlock, cosine_sim_cross_attn = cosine_sim_cross_attn, weight_standardization = resnet_weight_standardization)
+
+ # give memory efficient unet an initial resnet block
+
+ self.init_resnet_block = resnet_block(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group) if memory_efficient else None
+
+ # layers
+
+ self.downs = nn.ModuleList([])
+ self.ups = nn.ModuleList([])
+ num_resolutions = len(in_out)
+
+ skip_connect_dims = [] # keeping track of skip connection dimensions
+ upsample_combiner_dims = [] # keeping track of dimensions for final upsample feature map combiner
+
+ for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(in_out, resnet_groups, num_resnet_blocks, self_attn)):
+ is_first = ind == 0
+ is_last = ind >= (num_resolutions - 1)
+ layer_cond_dim = cond_dim if not is_first else None
+
+ dim_layer = dim_out if memory_efficient else dim_in
+ skip_connect_dims.append(dim_layer)
+
+ attention = nn.Identity()
+ if layer_self_attn:
+ attention = create_self_attn(dim_layer)
+ elif sparse_attn:
+ attention = Residual(LinearAttention(dim_layer, **attn_kwargs))
+
+ self.downs.append(nn.ModuleList([
+ downsample_klass(dim_in, dim_out = dim_out) if memory_efficient else None,
+ resnet_block(dim_layer, dim_layer, time_cond_dim = time_cond_dim, groups = groups),
+ nn.ModuleList([resnet_block(dim_layer, dim_layer, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
+ attention,
+ downsample_klass(dim_layer, dim_out = dim_out) if not is_last and not memory_efficient else nn.Conv2d(dim_layer, dim_out, 1)
+ ]))
+
+ mid_dim = dims[-1]
+
+ self.mid_block1 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
+ self.mid_attn = create_self_attn(mid_dim)
+ self.mid_block2 = resnet_block(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
+
+ for ind, ((dim_in, dim_out), groups, layer_num_resnet_blocks, layer_self_attn) in enumerate(zip(reversed(in_out), reversed(resnet_groups), reversed(num_resnet_blocks), reversed(self_attn))):
+ is_last = ind >= (len(in_out) - 1)
+ layer_cond_dim = cond_dim if not is_last else None
+
+ skip_connect_dim = skip_connect_dims.pop()
+
+ attention = nn.Identity()
+ if layer_self_attn:
+ attention = create_self_attn(dim_out)
+ elif sparse_attn:
+ attention = Residual(LinearAttention(dim_out, **attn_kwargs))
+
+ upsample_combiner_dims.append(dim_out)
+
+ self.ups.append(nn.ModuleList([
+ resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups),
+ nn.ModuleList([resnet_block(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim, groups = groups) for _ in range(layer_num_resnet_blocks)]),
+ attention,
+ upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else nn.Identity()
+ ]))
+
+ # whether to combine outputs from all upsample blocks for final resnet block
+
+ self.upsample_combiner = UpsampleCombiner(
+ dim = dim,
+ enabled = combine_upsample_fmaps,
+ dim_ins = upsample_combiner_dims,
+ dim_outs = (dim,) * len(upsample_combiner_dims)
+ )
+
+ # a final resnet block
+
+ self.final_resnet_block = resnet_block(self.upsample_combiner.dim_out + dim, dim, time_cond_dim = time_cond_dim, groups = top_level_resnet_group)
+
+ out_dim_in = dim + (channels if lowres_cond else 0)
+
+ self.to_out = nn.Conv2d(out_dim_in, self.channels_out, kernel_size = final_conv_kernel_size, padding = final_conv_kernel_size // 2)
+
+ zero_init_(self.to_out) # since both OpenAI and @crowsonkb are doing it
+
+ # whether to checkpoint during training
+
+ self.checkpoint_during_training = checkpoint_during_training
+
+ # if the current settings for the unet are not correct
+ # for cascading DDPM, then reinit the unet with the right settings
+ def cast_model_parameters(
+ self,
+ *,
+ lowres_cond,
+ lowres_noise_cond,
+ channels,
+ channels_out,
+ cond_on_image_embeds,
+ cond_on_text_encodings,
+ ):
+ if lowres_cond == self.lowres_cond and \
+ channels == self.channels and \
+ cond_on_image_embeds == self.cond_on_image_embeds and \
+ cond_on_text_encodings == self.cond_on_text_encodings and \
+ lowres_noise_cond == self.lowres_noise_cond and \
+ channels_out == self.channels_out:
+ return self
+
+ updated_kwargs = dict(
+ lowres_cond = lowres_cond,
+ channels = channels,
+ channels_out = channels_out,
+ cond_on_image_embeds = cond_on_image_embeds,
+ cond_on_text_encodings = cond_on_text_encodings,
+ lowres_noise_cond = lowres_noise_cond
+ )
+
+ return self.__class__(**{**self._locals, **updated_kwargs})
+
+ def forward_with_cond_scale(
+ self,
+ *args,
+ cond_scale = 1.,
+ **kwargs
+ ):
+ logits = self.forward(*args, **kwargs)
+
+ if cond_scale == 1:
+ return logits
+
+ null_logits = self.forward(*args, text_cond_drop_prob = 1., image_cond_drop_prob = 1., **kwargs)
+ return null_logits + (logits - null_logits) * cond_scale
+
+ def forward(
+ self,
+ x,
+ time,
+ *,
+ image_embed,
+ lowres_cond_img = None,
+ lowres_noise_level = None,
+ text_encodings = None,
+ image_cond_drop_prob = 0.,
+ text_cond_drop_prob = 0.,
+ blur_sigma = None,
+ blur_kernel_size = None,
+ disable_checkpoint = False,
+ self_cond = None
+ ):
+ batch_size, device = x.shape[0], x.device
+
+ # add low resolution conditioning, if present
+
+ assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
+
+ # concat self conditioning, if needed
+
+ if self.self_cond:
+ self_cond = default(self_cond, lambda: torch.zeros_like(x))
+ x = torch.cat((x, self_cond), dim = 1)
+
+ # concat low resolution conditioning
+
+ if exists(lowres_cond_img):
+ x = torch.cat((x, lowres_cond_img), dim = 1)
+
+ # initial convolution
+
+ x = self.init_conv(x)
+ r = x.clone() # final residual
+
+ # time conditioning
+
+ time = time.type_as(x)
+ time_hiddens = self.to_time_hiddens(time)
+
+ time_tokens = self.to_time_tokens(time_hiddens)
+ t = self.to_time_cond(time_hiddens)
+
+ # low res noise conditioning (similar to time above)
+
+ if exists(lowres_noise_level):
+ assert exists(self.to_lowres_noise_cond), 'lowres_noise_cond must be set to True on instantiation of the unet in order to conditiong on lowres noise'
+ lowres_noise_level = lowres_noise_level.type_as(x)
+ t = t + self.to_lowres_noise_cond(lowres_noise_level)
+
+ # conditional dropout
+
+ image_keep_mask = prob_mask_like((batch_size,), 1 - image_cond_drop_prob, device = device)
+ text_keep_mask = prob_mask_like((batch_size,), 1 - text_cond_drop_prob, device = device)
+
+ text_keep_mask = rearrange(text_keep_mask, 'b -> b 1 1')
+
+ # image embedding to be summed to time embedding
+ # discovered by @mhh0318 in the paper
+
+ if exists(image_embed) and exists(self.to_image_hiddens):
+ image_hiddens = self.to_image_hiddens(image_embed)
+ image_keep_mask_hidden = rearrange(image_keep_mask, 'b -> b 1')
+ null_image_hiddens = self.null_image_hiddens.to(image_hiddens.dtype)
+
+ image_hiddens = torch.where(
+ image_keep_mask_hidden,
+ image_hiddens,
+ null_image_hiddens
+ )
+
+ t = t + image_hiddens
+
+ # mask out image embedding depending on condition dropout
+ # for classifier free guidance
+
+ image_tokens = None
+
+ if self.cond_on_image_embeds:
+ image_keep_mask_embed = rearrange(image_keep_mask, 'b -> b 1 1')
+ image_tokens = self.image_to_tokens(image_embed)
+ null_image_embed = self.null_image_embed.to(image_tokens.dtype) # for some reason pytorch AMP not working
+
+ image_tokens = torch.where(
+ image_keep_mask_embed,
+ image_tokens,
+ null_image_embed
+ )
+
+ # take care of text encodings (optional)
+
+ text_tokens = None
+
+ if exists(text_encodings) and self.cond_on_text_encodings:
+ assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {text_encodings.shape} - required batch size is {batch_size}'
+ assert self.text_embed_dim == text_encodings.shape[-1], f'the text encodings you are passing in have a dimension of {text_encodings.shape[-1]}, but the unet was created with text_embed_dim of {self.text_embed_dim}.'
+
+ text_mask = torch.any(text_encodings != 0., dim = -1)
+
+ text_tokens = self.text_to_cond(text_encodings)
+
+ text_tokens = text_tokens[:, :self.max_text_len]
+ text_mask = text_mask[:, :self.max_text_len]
+
+ text_tokens_len = text_tokens.shape[1]
+ remainder = self.max_text_len - text_tokens_len
+
+ if remainder > 0:
+ text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
+ text_mask = F.pad(text_mask, (0, remainder), value = False)
+
+ text_mask = rearrange(text_mask, 'b n -> b n 1')
+
+ assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}. text encoding is of shape {text_encodings.shape}'
+ text_keep_mask = text_mask & text_keep_mask
+
+ null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
+
+ text_tokens = torch.where(
+ text_keep_mask,
+ text_tokens,
+ null_text_embed
+ )
+
+ # main conditioning tokens (c)
+
+ c = time_tokens
+
+ if exists(image_tokens):
+ c = torch.cat((c, image_tokens), dim = -2)
+
+ # text and image conditioning tokens (mid_c)
+ # to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
+
+ mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
+
+ # normalize conditioning tokens
+
+ c = self.norm_cond(c)
+ mid_c = self.norm_mid_cond(mid_c)
+
+ # gradient checkpointing
+
+ can_checkpoint = self.training and self.checkpoint_during_training and not disable_checkpoint
+ apply_checkpoint_fn = make_checkpointable if can_checkpoint else identity
+
+ # make checkpointable modules
+
+ init_resnet_block, mid_block1, mid_attn, mid_block2, final_resnet_block = [maybe(apply_checkpoint_fn)(module) for module in (self.init_resnet_block, self.mid_block1, self.mid_attn, self.mid_block2, self.final_resnet_block)]
+
+ can_checkpoint_cond = lambda m: isinstance(m, ResnetBlock)
+ downs, ups = [maybe(apply_checkpoint_fn)(m, condition = can_checkpoint_cond) for m in (self.downs, self.ups)]
+
+ # initial resnet block
+
+ if exists(init_resnet_block):
+ x = init_resnet_block(x, t)
+
+ # go through the layers of the unet, down and up
+
+ down_hiddens = []
+ up_hiddens = []
+
+ for pre_downsample, init_block, resnet_blocks, attn, post_downsample in downs:
+ if exists(pre_downsample):
+ x = pre_downsample(x)
+
+ x = init_block(x, t, c)
+
+ for resnet_block in resnet_blocks:
+ x = resnet_block(x, t, c)
+ down_hiddens.append(x.contiguous())
+
+ x = attn(x)
+ down_hiddens.append(x.contiguous())
+
+ if exists(post_downsample):
+ x = post_downsample(x)
+
+ x = mid_block1(x, t, mid_c)
+
+ if exists(mid_attn):
+ x = mid_attn(x)
+
+ x = mid_block2(x, t, mid_c)
+
+ connect_skip = lambda fmap: torch.cat((fmap, down_hiddens.pop() * self.skip_connect_scale), dim = 1)
+
+ for init_block, resnet_blocks, attn, upsample in ups:
+ x = connect_skip(x)
+ x = init_block(x, t, c)
+
+ for resnet_block in resnet_blocks:
+ x = connect_skip(x)
+ x = resnet_block(x, t, c)
+
+ x = attn(x)
+
+ up_hiddens.append(x.contiguous())
+ x = upsample(x)
+
+ x = self.upsample_combiner(x, up_hiddens)
+
+ x = torch.cat((x, r), dim = 1)
+
+ x = final_resnet_block(x, t)
+
+ if exists(lowres_cond_img):
+ x = torch.cat((x, lowres_cond_img), dim = 1)
+
+ return self.to_out(x)
+
+class LowresConditioner(nn.Module):
+ def __init__(
+ self,
+ downsample_first = True,
+ use_blur = True,
+ blur_prob = 0.5,
+ blur_sigma = 0.6,
+ blur_kernel_size = 3,
+ use_noise = False,
+ input_image_range = None,
+ normalize_img_fn = identity,
+ unnormalize_img_fn = identity
+ ):
+ super().__init__()
+ self.downsample_first = downsample_first
+ self.input_image_range = input_image_range
+
+ self.use_blur = use_blur
+ self.blur_prob = blur_prob
+ self.blur_sigma = blur_sigma
+ self.blur_kernel_size = blur_kernel_size
+
+ self.use_noise = use_noise
+ self.normalize_img = normalize_img_fn
+ self.unnormalize_img = unnormalize_img_fn
+ self.noise_scheduler = NoiseScheduler(beta_schedule = 'linear', timesteps = 1000, loss_type = 'l2') if use_noise else None
+
+ def noise_image(self, cond_fmap, noise_levels = None):
+ assert exists(self.noise_scheduler)
+
+ batch = cond_fmap.shape[0]
+ cond_fmap = self.normalize_img(cond_fmap)
+
+ random_noise_levels = default(noise_levels, lambda: self.noise_scheduler.sample_random_times(batch))
+ cond_fmap = self.noise_scheduler.q_sample(cond_fmap, t = random_noise_levels, noise = torch.randn_like(cond_fmap))
+
+ cond_fmap = self.unnormalize_img(cond_fmap)
+ return cond_fmap, random_noise_levels
+
+ def forward(
+ self,
+ cond_fmap,
+ *,
+ target_image_size,
+ downsample_image_size = None,
+ should_blur = True,
+ blur_sigma = None,
+ blur_kernel_size = None
+ ):
+ if self.downsample_first and exists(downsample_image_size):
+ cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = True)
+
+ # blur is only applied 50% of the time
+ # section 3.1 in https://arxiv.org/abs/2106.15282
+
+ if self.use_blur and should_blur and random.random() < self.blur_prob:
+
+ # when training, blur the low resolution conditional image
+
+ blur_sigma = default(blur_sigma, self.blur_sigma)
+ blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
+
+ # allow for drawing a random sigma between lo and hi float values
+
+ if isinstance(blur_sigma, tuple):
+ blur_sigma = tuple(map(float, blur_sigma))
+ blur_sigma = random.uniform(*blur_sigma)
+
+ # allow for drawing a random kernel size between lo and hi int values
+
+ if isinstance(blur_kernel_size, tuple):
+ blur_kernel_size = tuple(map(int, blur_kernel_size))
+ kernel_size_lo, kernel_size_hi = blur_kernel_size
+ blur_kernel_size = random.randrange(kernel_size_lo, kernel_size_hi + 1)
+
+ cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
+
+ # resize to target image size
+
+ cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range, nearest = True)
+
+ # noise conditioning, as done in Imagen
+ # as a replacement for the BSR noising, and potentially replace blurring for first stage too
+
+ random_noise_levels = None
+
+ if self.use_noise:
+ cond_fmap, random_noise_levels = self.noise_image(cond_fmap)
+
+ # return conditioning feature map, as well as the augmentation noise levels
+
+ return cond_fmap, random_noise_levels
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ unet,
+ *,
+ clip = None,
+ image_size = None,
+ channels = 3,
+ vae = tuple(),
+ timesteps = 1000,
+ sample_timesteps = None,
+ image_cond_drop_prob = 0.1,
+ text_cond_drop_prob = 0.5,
+ loss_type = 'l2',
+ beta_schedule = None,
+ predict_x_start = False,
+ predict_v = False,
+ predict_x_start_for_latent_diffusion = False,
+ image_sizes = None, # for cascading ddpm, image size at each stage
+ random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
+ use_noise_for_lowres_cond = False, # whether to use Imagen-like noising for low resolution conditioning
+ use_blur_for_lowres_cond = True, # whether to use the blur conditioning used in the original cascading ddpm paper, as well as DALL-E2
+ lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
+ blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time
+ blur_sigma = 0.6, # cascading ddpm - blur sigma
+ blur_kernel_size = 3, # cascading ddpm - blur kernel size
+ lowres_noise_sample_level = 0.2, # in imagen paper, they use a 0.2 noise level at sample time for low resolution conditioning
+ clip_denoised = True,
+ clip_x_start = True,
+ clip_adapter_overrides = dict(),
+ learned_variance = True,
+ learned_variance_constrain_frac = False,
+ vb_loss_weight = 0.001,
+ unconditional = False, # set to True for generating images without conditioning
+ auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
+ use_dynamic_thres = False, # from the Imagen paper
+ dynamic_thres_percentile = 0.95,
+ p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
+ p2_loss_weight_k = 1,
+ ddim_sampling_eta = 0. # can be set to 0. for deterministic sampling afaict
+ ):
+ super().__init__()
+
+ # clip
+
+ self.clip = None
+ if exists(clip):
+ assert not unconditional, 'clip must not be given if doing unconditional image training'
+ assert channels == clip.image_channels, f'channels of image ({channels}) should be equal to the channels that CLIP accepts ({clip.image_channels})'
+
+ if isinstance(clip, CLIP):
+ clip = XClipAdapter(clip, **clip_adapter_overrides)
+ elif isinstance(clip, CoCa):
+ clip = CoCaAdapter(clip, **clip_adapter_overrides)
+
+ freeze_model_and_make_eval_(clip)
+ assert isinstance(clip, BaseClipAdapter)
+
+ self.clip = clip
+
+ # determine image size, with image_size and image_sizes taking precedence
+
+ if exists(image_size) or exists(image_sizes):
+ assert exists(image_size) ^ exists(image_sizes), 'only one of image_size or image_sizes must be given'
+ image_size = default(image_size, lambda: image_sizes[-1])
+ elif exists(clip):
+ image_size = clip.image_size
+ else:
+ raise Error('either image_size, image_sizes, or clip must be given to decoder')
+
+ # channels
+
+ self.channels = channels
+
+
+ # normalize and unnormalize image functions
+
+ self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
+ self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
+
+ # verify conditioning method
+
+ unets = cast_tuple(unet)
+ num_unets = len(unets)
+ self.num_unets = num_unets
+
+ self.unconditional = unconditional
+
+ # automatically take care of ensuring that first unet is unconditional
+ # while the rest of the unets are conditioned on the low resolution image produced by previous unet
+
+ vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
+
+ # whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
+
+ learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False)
+ self.learned_variance = learned_variance
+ self.learned_variance_constrain_frac = learned_variance_constrain_frac # whether to constrain the output of the network (the interpolation fraction) from 0 to 1
+ self.vb_loss_weight = vb_loss_weight
+
+ # default and validate conditioning parameters
+
+ use_noise_for_lowres_cond = cast_tuple(use_noise_for_lowres_cond, num_unets - 1, validate = False)
+ use_blur_for_lowres_cond = cast_tuple(use_blur_for_lowres_cond, num_unets - 1, validate = False)
+
+ if len(use_noise_for_lowres_cond) < num_unets:
+ use_noise_for_lowres_cond = (False, *use_noise_for_lowres_cond)
+
+ if len(use_blur_for_lowres_cond) < num_unets:
+ use_blur_for_lowres_cond = (False, *use_blur_for_lowres_cond)
+
+ assert not use_noise_for_lowres_cond[0], 'first unet will never need low res noise conditioning'
+ assert not use_blur_for_lowres_cond[0], 'first unet will never need low res blur conditioning'
+
+ assert num_unets == 1 or all((use_noise or use_blur) for use_noise, use_blur in zip(use_noise_for_lowres_cond[1:], use_blur_for_lowres_cond[1:]))
+
+ # construct unets and vaes
+
+ self.unets = nn.ModuleList([])
+ self.vaes = nn.ModuleList([])
+
+ for ind, (one_unet, one_vae, one_unet_learned_var, lowres_noise_cond) in enumerate(zip(unets, vaes, learned_variance, use_noise_for_lowres_cond)):
+ assert isinstance(one_unet, Unet)
+ assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
+
+ is_first = ind == 0
+ latent_dim = one_vae.encoded_dim if exists(one_vae) else None
+
+ unet_channels = default(latent_dim, self.channels)
+ unet_channels_out = unet_channels * (1 if not one_unet_learned_var else 2)
+
+ one_unet = one_unet.cast_model_parameters(
+ lowres_cond = not is_first,
+ lowres_noise_cond = lowres_noise_cond,
+ cond_on_image_embeds = not unconditional and is_first,
+ cond_on_text_encodings = not unconditional and one_unet.cond_on_text_encodings,
+ channels = unet_channels,
+ channels_out = unet_channels_out
+ )
+
+ self.unets.append(one_unet)
+ self.vaes.append(one_vae.copy_for_eval())
+
+ # sampling timesteps, defaults to non-ddim with full timesteps sampling
+
+ self.sample_timesteps = cast_tuple(sample_timesteps, num_unets)
+ self.ddim_sampling_eta = ddim_sampling_eta
+
+ # create noise schedulers per unet
+
+ if not exists(beta_schedule):
+ beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1)))
+
+ beta_schedule = cast_tuple(beta_schedule, num_unets)
+ p2_loss_weight_gamma = cast_tuple(p2_loss_weight_gamma, num_unets)
+
+ self.noise_schedulers = nn.ModuleList([])
+
+ for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)):
+ assert not exists(sample_timesteps) or sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}'
+
+ noise_scheduler = NoiseScheduler(
+ beta_schedule = unet_beta_schedule,
+ timesteps = timesteps,
+ loss_type = loss_type,
+ p2_loss_weight_gamma = unet_p2_loss_weight_gamma,
+ p2_loss_weight_k = p2_loss_weight_k
+ )
+
+ self.noise_schedulers.append(noise_scheduler)
+
+ # unet image sizes
+
+ image_sizes = default(image_sizes, (image_size,))
+ image_sizes = tuple(sorted(set(image_sizes)))
+
+ assert self.num_unets == len(image_sizes), f'you did not supply the correct number of u-nets ({self.num_unets}) for resolutions {image_sizes}'
+ self.image_sizes = image_sizes
+ self.sample_channels = cast_tuple(self.channels, len(image_sizes))
+
+ # random crop sizes (for super-resoluting unets at the end of cascade?)
+
+ self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes))
+ assert not exists(self.random_crop_sizes[0]), 'you would not need to randomly crop the image for the base unet'
+
+ # predict x0 config
+
+ self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
+
+ # predict v
+
+ self.predict_v = cast_tuple(predict_v, len(unets))
+
+ # input image range
+
+ self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
+
+ # cascading ddpm related stuff
+
+ lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
+ assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
+
+ self.lowres_conds = nn.ModuleList([])
+
+ for unet_index, use_noise, use_blur in zip(range(num_unets), use_noise_for_lowres_cond, use_blur_for_lowres_cond):
+ if unet_index == 0:
+ self.lowres_conds.append(None)
+ continue
+
+ lowres_cond = LowresConditioner(
+ downsample_first = lowres_downsample_first,
+ use_blur = use_blur,
+ use_noise = use_noise,
+ blur_prob = blur_prob,
+ blur_sigma = blur_sigma,
+ blur_kernel_size = blur_kernel_size,
+ input_image_range = self.input_image_range,
+ normalize_img_fn = self.normalize_img,
+ unnormalize_img_fn = self.unnormalize_img
+ )
+
+ self.lowres_conds.append(lowres_cond)
+
+ self.lowres_noise_sample_level = lowres_noise_sample_level
+
+ # classifier free guidance
+
+ self.image_cond_drop_prob = image_cond_drop_prob
+ self.text_cond_drop_prob = text_cond_drop_prob
+ self.can_classifier_guidance = image_cond_drop_prob > 0. or text_cond_drop_prob > 0.
+
+ # whether to clip when sampling
+
+ self.clip_denoised = clip_denoised
+ self.clip_x_start = clip_x_start
+
+ # dynamic thresholding settings, if clipping denoised during sampling
+
+ self.use_dynamic_thres = use_dynamic_thres
+ self.dynamic_thres_percentile = dynamic_thres_percentile
+
+ # device tracker
+
+ self.register_buffer('_dummy', torch.Tensor([True]), persistent = False)
+
+ @property
+ def device(self):
+ return self._dummy.device
+
+ @property
+ def condition_on_text_encodings(self):
+ return any([unet.cond_on_text_encodings for unet in self.unets if isinstance(unet, Unet)])
+
+ def get_unet(self, unet_number):
+ assert 0 < unet_number <= self.num_unets
+ index = unet_number - 1
+ return self.unets[index]
+
+ def parse_unet_output(self, learned_variance, output):
+ var_interp_frac_unnormalized = None
+
+ if learned_variance:
+ output, var_interp_frac_unnormalized = output.chunk(2, dim = 1)
+
+ return UnetOutput(output, var_interp_frac_unnormalized)
+
+ @contextmanager
+ def one_unet_in_gpu(self, unet_number = None, unet = None):
+ assert exists(unet_number) ^ exists(unet)
+
+ if exists(unet_number):
+ unet = self.get_unet(unet_number)
+
+ # devices
+
+ cuda, cpu = torch.device('cuda'), torch.device('cpu')
+
+ self.cuda()
+
+ devices = [module_device(unet) for unet in self.unets]
+
+ self.unets.to(cpu)
+ unet.to(cuda)
+
+ yield
+
+ for unet, device in zip(self.unets, devices):
+ unet.to(device)
+
+ def dynamic_threshold(self, x):
+ """ proposed in https://arxiv.org/abs/2205.11487 as an improved clamping in the setting of classifier free guidance """
+
+ # s is the threshold amount
+ # static thresholding would just be s = 1
+ s = 1.
+ if self.use_dynamic_thres:
+ s = torch.quantile(
+ rearrange(x, 'b ... -> b (...)').abs(),
+ self.dynamic_thres_percentile,
+ dim = -1
+ )
+
+ s.clamp_(min = 1.)
+ s = s.view(-1, *((1,) * (x.ndim - 1)))
+
+ # clip by threshold, depending on whether static or dynamic
+ x = x.clamp(-s, s) / s
+ return x
+
+ def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
+ assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
+
+ model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
+
+ pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
+
+ if predict_v:
+ x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
+ elif predict_x_start:
+ x_start = pred
+ else:
+ x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
+
+ if clip_denoised:
+ x_start = self.dynamic_threshold(x_start)
+
+ model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
+
+ if learned_variance:
+ # if learned variance, posterio variance and posterior log variance are predicted by the network
+ # by an interpolation of the max and min log beta values
+ # eq 15 - https://arxiv.org/abs/2102.09672
+ min_log = extract(noise_scheduler.posterior_log_variance_clipped, t, x.shape)
+ max_log = extract(torch.log(noise_scheduler.betas), t, x.shape)
+ var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized)
+
+ if self.learned_variance_constrain_frac:
+ var_interp_frac = var_interp_frac.sigmoid()
+
+ posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
+ posterior_variance = posterior_log_variance.exp()
+
+ return model_mean, posterior_variance, posterior_log_variance, x_start
+
+ @torch.no_grad()
+ def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
+ noise = torch.randn_like(x)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+ return pred, x_start
+
+ @torch.no_grad()
+ def p_sample_loop_ddpm(
+ self,
+ unet,
+ shape,
+ image_embed,
+ noise_scheduler,
+ predict_x_start = False,
+ predict_v = False,
+ learned_variance = False,
+ clip_denoised = True,
+ lowres_cond_img = None,
+ text_encodings = None,
+ cond_scale = 1,
+ is_latent_diffusion = False,
+ lowres_noise_level = None,
+ inpaint_image = None,
+ inpaint_mask = None,
+ inpaint_resample_times = 5
+ ):
+ device = self.device
+
+ b = shape[0]
+ img = torch.randn(shape, device = device)
+
+ x_start = None # for self-conditioning
+
+ is_inpaint = exists(inpaint_image)
+ resample_times = inpaint_resample_times if is_inpaint else 1
+
+ if is_inpaint:
+ inpaint_image = self.normalize_img(inpaint_image)
+ inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
+ inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
+ inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)
+ inpaint_mask = inpaint_mask.bool()
+
+ if not is_latent_diffusion:
+ lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
+
+ for time in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps):
+ is_last_timestep = time == 0
+
+ for r in reversed(range(0, resample_times)):
+ is_last_resample_step = r == 0
+
+ times = torch.full((b,), time, device = device, dtype = torch.long)
+
+ if is_inpaint:
+ # following the repaint paper
+ # https://arxiv.org/abs/2201.09865
+ noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times)
+ img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
+
+ self_cond = x_start if unet.self_cond else None
+
+ img, x_start = self.p_sample(
+ unet,
+ img,
+ times,
+ image_embed = image_embed,
+ text_encodings = text_encodings,
+ cond_scale = cond_scale,
+ self_cond = self_cond,
+ lowres_cond_img = lowres_cond_img,
+ lowres_noise_level = lowres_noise_level,
+ predict_x_start = predict_x_start,
+ predict_v = predict_v,
+ noise_scheduler = noise_scheduler,
+ learned_variance = learned_variance,
+ clip_denoised = clip_denoised
+ )
+
+ if is_inpaint and not (is_last_timestep or is_last_resample_step):
+ # in repaint, you renoise and resample up to 10 times every step
+ img = noise_scheduler.q_sample_from_to(img, times - 1, times)
+
+ if is_inpaint:
+ img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
+
+ unnormalize_img = self.unnormalize_img(img)
+ return unnormalize_img
+
+ @torch.no_grad()
+ def p_sample_loop_ddim(
+ self,
+ unet,
+ shape,
+ image_embed,
+ noise_scheduler,
+ timesteps,
+ eta = 1.,
+ predict_x_start = False,
+ predict_v = False,
+ learned_variance = False,
+ clip_denoised = True,
+ lowres_cond_img = None,
+ text_encodings = None,
+ cond_scale = 1,
+ is_latent_diffusion = False,
+ lowres_noise_level = None,
+ inpaint_image = None,
+ inpaint_mask = None,
+ inpaint_resample_times = 5
+ ):
+ batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod, self.ddim_sampling_eta
+
+ times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1]
+
+ times = list(reversed(times.int().tolist()))
+ time_pairs = list(zip(times[:-1], times[1:]))
+ time_pairs = list(filter(lambda t: t[0] > t[1], time_pairs))
+
+ is_inpaint = exists(inpaint_image)
+ resample_times = inpaint_resample_times if is_inpaint else 1
+
+ if is_inpaint:
+ inpaint_image = self.normalize_img(inpaint_image)
+ inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True)
+ inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float()
+ inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True)
+ inpaint_mask = inpaint_mask.bool()
+
+ img = torch.randn(shape, device = device)
+
+ x_start = None # for self-conditioning
+
+ if not is_latent_diffusion:
+ lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
+
+ for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
+ is_last_timestep = time_next == 0
+
+ for r in reversed(range(0, resample_times)):
+ is_last_resample_step = r == 0
+
+ alpha = alphas[time]
+ alpha_next = alphas[time_next]
+
+ time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
+
+ if is_inpaint:
+ # following the repaint paper
+ # https://arxiv.org/abs/2201.09865
+ noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond)
+ img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask)
+
+ self_cond = x_start if unet.self_cond else None
+
+ unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
+
+ pred, _ = self.parse_unet_output(learned_variance, unet_output)
+
+ # predict x0
+
+ if predict_v:
+ x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)
+ elif predict_x_start:
+ x_start = pred
+ else:
+ x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
+
+ # maybe clip x0
+
+ if clip_denoised:
+ x_start = self.dynamic_threshold(x_start)
+
+ # predict noise
+
+ pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
+
+ c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
+ c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
+ noise = torch.randn_like(img) if not is_last_timestep else 0.
+
+ img = x_start * alpha_next.sqrt() + \
+ c1 * noise + \
+ c2 * pred_noise
+
+ if is_inpaint and not (is_last_timestep or is_last_resample_step):
+ # in repaint, you renoise and resample up to 10 times every step
+ time_next_cond = torch.full((batch,), time_next, device = device, dtype = torch.long)
+ img = noise_scheduler.q_sample_from_to(img, time_next_cond, time_cond)
+
+ if exists(inpaint_image):
+ img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask)
+
+ img = self.unnormalize_img(img)
+ return img
+
+ @torch.no_grad()
+ def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs):
+ num_timesteps = noise_scheduler.num_timesteps
+
+ timesteps = default(timesteps, num_timesteps)
+ assert timesteps <= num_timesteps
+ is_ddim = timesteps < num_timesteps
+
+ if not is_ddim:
+ return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **kwargs)
+
+ return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
+
+ def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+
+ # normalize to [-1, 1]
+
+ if not is_latent_diffusion:
+ x_start = self.normalize_img(x_start)
+ lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
+
+ # get x_t
+
+ x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
+
+ # unet kwargs
+
+ unet_kwargs = dict(
+ image_embed = image_embed,
+ text_encodings = text_encodings,
+ lowres_cond_img = lowres_cond_img,
+ lowres_noise_level = lowres_noise_level,
+ )
+
+ # self conditioning
+
+ self_cond = None
+
+ if unet.self_cond and random.random() < 0.5:
+ with torch.no_grad():
+ unet_output = unet(x_noisy, times, **unet_kwargs)
+ self_cond, _ = self.parse_unet_output(learned_variance, unet_output)
+ self_cond = self_cond.detach()
+
+ # forward to get model prediction
+
+ unet_output = unet(
+ x_noisy,
+ times,
+ **unet_kwargs,
+ self_cond = self_cond,
+ image_cond_drop_prob = self.image_cond_drop_prob,
+ text_cond_drop_prob = self.text_cond_drop_prob,
+ )
+
+ pred, _ = self.parse_unet_output(learned_variance, unet_output)
+
+ if predict_v:
+ target = noise_scheduler.calculate_v(x_start, times, noise)
+ elif predict_x_start:
+ target = x_start
+ else:
+ target = noise
+
+ loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
+ loss = reduce(loss, 'b ... -> b (...)', 'mean')
+
+ loss = noise_scheduler.p2_reweigh_loss(loss, times)
+
+ loss = loss.mean()
+
+ if not learned_variance:
+ # return simple loss if not using learned variance
+ return loss
+
+ # most of the code below is transcribed from
+ # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/diffusion_utils_2.py
+ # the Improved DDPM paper then further modified it so that the mean is detached (shown a couple lines before), and weighted to be smaller than the l1 or l2 "simple" loss
+ # it is questionable whether this is really needed, looking at some of the figures in the paper, but may as well stay faithful to their implementation
+
+ # if learning the variance, also include the extra weight kl loss
+
+ true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
+ model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output)
+
+ # kl loss with detached model predicted mean, for stability reasons as in paper
+
+ detached_model_mean = model_mean.detach()
+
+ kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
+ kl = meanflat(kl) * NAT
+
+ decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance)
+ decoder_nll = meanflat(decoder_nll) * NAT
+
+ # at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+
+ vb_losses = torch.where(times == 0, decoder_nll, kl)
+
+ # weight the vb loss smaller, for stability, as in the paper (recommended 0.001)
+
+ vb_loss = vb_losses.mean() * self.vb_loss_weight
+
+ return loss + vb_loss
+
+ @torch.no_grad()
+ @eval_decorator
+ def sample(
+ self,
+ image = None,
+ image_embed = None,
+ text = None,
+ text_encodings = None,
+ batch_size = 1,
+ cond_scale = 1.,
+ start_at_unet_number = 1,
+ stop_at_unet_number = None,
+ distributed = False,
+ inpaint_image = None,
+ inpaint_mask = None,
+ inpaint_resample_times = 5,
+ one_unet_in_gpu_at_time = True
+ ):
+ assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally'
+
+ if not self.unconditional:
+ batch_size = image_embed.shape[0]
+
+ if exists(text) and not exists(text_encodings) and not self.unconditional:
+ assert exists(self.clip)
+ _, text_encodings = self.clip.embed_text(text)
+
+ assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
+ assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
+
+ assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting'
+
+ img = None
+ if start_at_unet_number > 1:
+ # Then we are not generating the first image and one must have been passed in
+ assert exists(image), 'image must be passed in if starting at unet number > 1'
+ assert image.shape[0] == batch_size, 'image must have batch size of {} if starting at unet number > 1'.format(batch_size)
+ prev_unet_output_size = self.image_sizes[start_at_unet_number - 2]
+ img = resize_image_to(image, prev_unet_output_size, nearest = True)
+
+ is_cuda = next(self.parameters()).is_cuda
+
+ num_unets = self.num_unets
+ cond_scale = cast_tuple(cond_scale, num_unets)
+
+ for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
+ if unet_number < start_at_unet_number:
+ continue # It's the easiest way to do it
+
+ context = self.one_unet_in_gpu(unet = unet) if is_cuda and one_unet_in_gpu_at_time else null_context()
+
+ with context:
+ # prepare low resolution conditioning for upsamplers
+
+ lowres_cond_img = lowres_noise_level = None
+ shape = (batch_size, channel, image_size, image_size)
+
+ if unet.lowres_cond:
+ lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = True)
+
+ if lowres_cond.use_noise:
+ lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device)
+ lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level)
+
+ # latent diffusion
+
+ is_latent_diffusion = isinstance(vae, VQGanVAE)
+ image_size = vae.get_encoded_fmap_size(image_size)
+ shape = (batch_size, vae.encoded_dim, image_size, image_size)
+
+ lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
+
+ # denoising loop for image
+
+ img = self.p_sample_loop(
+ unet,
+ shape,
+ image_embed = image_embed,
+ text_encodings = text_encodings,
+ cond_scale = unet_cond_scale,
+ predict_x_start = predict_x_start,
+ predict_v = predict_v,
+ learned_variance = learned_variance,
+ clip_denoised = not is_latent_diffusion,
+ lowres_cond_img = lowres_cond_img,
+ lowres_noise_level = lowres_noise_level,
+ is_latent_diffusion = is_latent_diffusion,
+ noise_scheduler = noise_scheduler,
+ timesteps = sample_timesteps,
+ inpaint_image = inpaint_image,
+ inpaint_mask = inpaint_mask,
+ inpaint_resample_times = inpaint_resample_times
+ )
+
+ img = vae.decode(img)
+
+ if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
+ break
+
+ return img
+
+ def forward(
+ self,
+ image,
+ text = None,
+ image_embed = None,
+ text_encodings = None,
+ unet_number = None,
+ return_lowres_cond_image = False # whether to return the low resolution conditioning images, for debugging upsampler purposes
+ ):
+ assert not (self.num_unets > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {self.num_unets}, if you are training cascading DDPM (multiple unets)'
+ unet_number = default(unet_number, 1)
+ unet_index = unet_number - 1
+
+ unet = self.get_unet(unet_number)
+
+ vae = self.vaes[unet_index]
+ noise_scheduler = self.noise_schedulers[unet_index]
+ lowres_conditioner = self.lowres_conds[unet_index]
+ target_image_size = self.image_sizes[unet_index]
+ predict_x_start = self.predict_x_start[unet_index]
+ predict_v = self.predict_v[unet_index]
+ random_crop_size = self.random_crop_sizes[unet_index]
+ learned_variance = self.learned_variance[unet_index]
+ b, c, h, w, device, = *image.shape, image.device
+
+ assert image.shape[1] == self.channels
+ assert h >= target_image_size and w >= target_image_size
+
+ times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
+
+ if not exists(image_embed) and not self.unconditional:
+ assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init'
+ image_embed, _ = self.clip.embed_image(image)
+
+ if exists(text) and not exists(text_encodings) and not self.unconditional:
+ assert exists(self.clip), 'if you are passing in raw text, you need to supply `clip` to the decoder'
+ _, text_encodings = self.clip.embed_text(text)
+
+ assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
+ assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
+
+ lowres_cond_img, lowres_noise_level = lowres_conditioner(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if exists(lowres_conditioner) else (None, None)
+ image = resize_image_to(image, target_image_size, nearest = True)
+
+ if exists(random_crop_size):
+ aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
+
+ # make sure low res conditioner and image both get augmented the same way
+ # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
+ image = aug(image)
+ lowres_cond_img = aug(lowres_cond_img, params = aug._params)
+
+ is_latent_diffusion = not isinstance(vae, NullVQGanVAE)
+
+ vae.eval()
+ with torch.no_grad():
+ image = vae.encode(image)
+ lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
+
+ losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
+
+ if not return_lowres_cond_image:
+ return losses
+
+ return losses, lowres_cond_img
+
+# main class
+
+class DALLE2(nn.Module):
+ def __init__(
+ self,
+ *,
+ prior,
+ decoder,
+ prior_num_samples = 2
+ ):
+ super().__init__()
+ assert isinstance(prior, DiffusionPrior)
+ assert isinstance(decoder, Decoder)
+ self.prior = prior
+ self.decoder = decoder
+
+ self.prior_num_samples = prior_num_samples
+ self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
+
+ self.to_pil = T.ToPILImage()
+
+ @torch.no_grad()
+ @eval_decorator
+ def forward(
+ self,
+ text,
+ cond_scale = 1.,
+ prior_cond_scale = 1.,
+ return_pil_images = False
+ ):
+ device = module_device(self)
+ one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
+
+ if isinstance(text, str) or is_list_str(text):
+ text = [text] if not isinstance(text, (list, tuple)) else text
+ text = tokenizer.tokenize(text).to(device)
+
+ image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
+
+ text_cond = text if self.decoder_need_text_cond else None
+ images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)
+
+ if return_pil_images:
+ images = list(map(self.to_pil, images.unbind(dim = 0)))
+
+ if one_text:
+ return first(images)
+
+ return images
diff --git a/docs/src/dalle2_pytorch/dataloaders/README.md b/docs/src/dalle2_pytorch/dataloaders/README.md
new file mode 100644
index 00000000..66f3c90e
--- /dev/null
+++ b/docs/src/dalle2_pytorch/dataloaders/README.md
@@ -0,0 +1,75 @@
+## Dataloaders
+In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network.
+
+### Decoder: Image Embedding Dataset
+When training the decoder (and up samplers if training together) in isolation, you will need to load images and corresponding image embeddings. This dataset can read two similar types of datasets. First, it can read a [webdataset](https://github.com/webdataset/webdataset) that contains `.jpg` and `.npy` files in the `.tar`s that contain the images and associated image embeddings respectively. Alternatively, you can also specify a source for the embeddings outside of the webdataset. In this case, the path to the embeddings should contain `.npy` files with the same shard numbers as the webdataset and there should be a correspondence between the filename of the `.jpg` and the index of the embedding in the `.npy`. So, for example, `0001.tar` from the webdataset with image `00010509.jpg` (the first 4 digits are the shard number and the last 4 are the index) in it should be paralleled by a `img_emb_0001.npy` which contains a NumPy array with the embedding at index 509.
+
+Generating a dataset of this type:
+1. Use [img2dataset](https://github.com/rom1504/img2dataset) to generate a webdataset.
+2. Use [clip-retrieval](https://github.com/rom1504/clip-retrieval) to convert the images to embeddings.
+3. Use [embedding-dataset-reordering](https://github.com/Veldrovive/embedding-dataset-reordering) to reorder the embeddings into the expected format.
+
+Usage:
+```python
+from dalle2_pytorch.dataloaders import ImageEmbeddingDataset, create_image_embedding_dataloader
+
+# Create a dataloader directly.
+dataloader = create_image_embedding_dataloader(
+ tar_url="/path/or/url/to/webdataset/{0000..9999}.tar", # Uses bracket expanding notation. This specifies to read all tars from 0000.tar to 9999.tar
+ embeddings_url="path/or/url/to/embeddings/folder", # Included if .npy files are not in webdataset. Left out or set to None otherwise
+ num_workers=4,
+ batch_size=32,
+ shard_width=4, # If a file in the webdataset shard 3 is named 0003039.jpg, we know the shard width is 4 and the last three digits are the index
+ shuffle_num=200, # Does a shuffle of the data with a buffer size of 200
+ shuffle_shards=True, # Shuffle the order the shards are read in
+ resample_shards=False, # Sample shards with replacement. If true, an epoch will be infinite unless stopped manually
+)
+for img, emb in dataloader:
+ print(img.shape) # torch.Size([32, 3, 256, 256])
+ print(emb.shape) # torch.Size([32, 512])
+ # Train decoder only as shown above
+
+# Or create a dataset without a loader so you can configure it manually
+dataset = ImageEmbeddingDataset(
+ urls="/path/or/url/to/webdataset/{0000..9999}.tar",
+ embedding_folder_url="path/or/url/to/embeddings/folder",
+ shard_width=4,
+ shuffle_shards=True,
+ resample=False
+)
+```
+
+### Diffusion Prior: Prior Embedding Dataset
+When training the prior it is much more efficient to work with pre-computed embeddings. The `PriorEmbeddingDataset` class enables you to leverage the same script (with minimal modification) for both embedding-only and text-conditioned prior training. This saves you from having to worry about a lot of the boilerplate code.
+
+To utilize the `PriorEmbeddingDataset`, all you need to do is make a single call to `get_reader()` which will create `EmbeddingReader` object(s) for you. Afterwards, you can utilize `make_splits()` to cleanly create DataLoader objects from for your training run.
+
+If you are training in a distributed manner, `make_splits()` accepts `rank` and `world_size` arguments to properly distribute to each process. The defaults for these values are `rank=0` and `world_size=1`, so single-process training can safely ignore these parameters.
+
+Usage:
+```python
+from dalle2_pytorch.dataloaders import get_reader, make_splits
+
+# grab embeddings from some specified location
+IMG_URL = "data/img_emb/"
+META_URL = "data/meta/"
+
+reader = get_reader(text_conditioned=True, img_url=IMG_URL, meta_url=META_URL)
+
+# some config for training
+TRAIN_ARGS = {
+ "world_size": 3,
+ "text_conditioned": True,
+ "start": 0,
+ "num_data_points": 10000,
+ "batch_size": 2,
+ "train_split": 0.5,
+ "eval_split": 0.25,
+ "image_reader": reader,
+}
+
+# specifying a rank will handle allocation internally
+rank0_train, rank0_eval, rank0_test = make_splits(rank=0, **TRAIN_ARGS)
+rank1_train, rank1_eval, rank1_test = make_splits(rank=1, **TRAIN_ARGS)
+rank2_train, rank2_eval, rank2_test = make_splits(rank=2, **TRAIN_ARGS)
+```
diff --git a/docs/src/dalle2_pytorch/dataloaders/__init__.py b/docs/src/dalle2_pytorch/dataloaders/__init__.py
new file mode 100644
index 00000000..72af534e
--- /dev/null
+++ b/docs/src/dalle2_pytorch/dataloaders/__init__.py
@@ -0,0 +1,2 @@
+from dalle2_pytorch.dataloaders.decoder_loader import ImageEmbeddingDataset, create_image_embedding_dataloader
+from dalle2_pytorch.dataloaders.prior_loader import make_splits, get_reader, PriorEmbeddingDataset
diff --git a/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py b/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py
new file mode 100644
index 00000000..6b679e64
--- /dev/null
+++ b/docs/src/dalle2_pytorch/dataloaders/decoder_loader.py
@@ -0,0 +1,266 @@
+import os
+import webdataset as wds
+import torch
+from torch.utils.data import DataLoader
+import numpy as np
+import fsspec
+import shutil
+
+def get_shard(filename):
+ """
+ Filenames with shards in them have a consistent structure that we can take advantage of
+ Standard structure: path/to/file/prefix_string_00001.ext
+ """
+ try:
+ return filename.split("_")[-1].split(".")[0]
+ except ValueError:
+ raise RuntimeError(f"Could not find shard for filename {filename}")
+
+def get_example_file(fs, path, file_format):
+ """
+ Given a file system and a file extension, return the example file
+ """
+ return fs.glob(os.path.join(path, f"*.{file_format}"))[0]
+
+def embedding_inserter(samples, embeddings_url, index_width, sample_key='npy', handler=wds.handlers.reraise_exception):
+ """Given a datum of {"__key__": str, "__url__": str, ...} adds the cooresponding embedding and yields"""
+ previous_tar_url = None
+ current_embeddings = None
+ # Get a reference to an abstract file system where the embeddings are stored
+ embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
+ example_embedding_file = get_example_file(embeddings_fs, embeddings_path, "npy")
+ example_embedding_shard = get_shard(example_embedding_file)
+ emb_shard_width = len(example_embedding_shard)
+ # Easier to get the basename without the shard once than search through for the correct file every time
+ embedding_file_basename = '_'.join(example_embedding_file.split("_")[:-1]) + "_"
+
+ def load_corresponding_embeds(tar_url):
+ """Finds and reads the npy files that contains embeddings for the given webdataset tar"""
+ shard = int(tar_url.split("/")[-1].split(".")[0])
+ embedding_url = embedding_file_basename + str(shard).zfill(emb_shard_width) + '.npy'
+ with embeddings_fs.open(embedding_url) as f:
+ data = np.load(f)
+ return torch.from_numpy(data)
+
+ for sample in samples:
+ try:
+ tar_url = sample["__url__"]
+ key = sample["__key__"]
+ if tar_url != previous_tar_url:
+ # If the tar changed, we need to download new embeddings
+ # This means if we shuffle before inserting it will load many more files than we expect and be very inefficient.
+ previous_tar_url = tar_url
+ current_embeddings = load_corresponding_embeds(tar_url)
+
+ embedding_index = int(key[-index_width:])
+ embedding = current_embeddings[embedding_index]
+ # We need to check if this sample is nonzero. If it is, this embedding is not valid and we should continue to the next loop
+ if torch.count_nonzero(embedding) == 0:
+ raise RuntimeError(f"Webdataset had a sample, but no embedding was found. ImgShard: {key[:-index_width]} - Index: {key[-index_width:]}")
+ sample[sample_key] = embedding
+ yield sample
+ except Exception as exn: # From wds implementation
+ if handler(exn):
+ continue
+ else:
+ break
+insert_embedding = wds.filters.pipelinefilter(embedding_inserter)
+
+def unassociated_shard_skipper(tarfiles, embeddings_url, handler=wds.handlers.reraise_exception):
+ """Finds if the is a corresponding embedding for the tarfile at { url: [URL] }"""
+ embeddings_fs, embeddings_path = fsspec.core.url_to_fs(embeddings_url)
+ embedding_files = embeddings_fs.ls(embeddings_path)
+ get_embedding_shard = lambda embedding_file: int(embedding_file.split("_")[-1].split(".")[0])
+ embedding_shards = set([get_embedding_shard(filename) for filename in embedding_files]) # Sets have O(1) check for member
+
+ get_tar_shard = lambda tar_file: int(tar_file.split("/")[-1].split(".")[0])
+ for tarfile in tarfiles:
+ try:
+ webdataset_shard = get_tar_shard(tarfile["url"])
+ # If this shard has an associated embeddings file, we pass it through. Otherwise we iterate until we do have one
+ if webdataset_shard in embedding_shards:
+ yield tarfile
+ except Exception as exn: # From wds implementation
+ if handler(exn):
+ continue
+ else:
+ break
+skip_unassociated_shards = wds.filters.pipelinefilter(unassociated_shard_skipper)
+
+def join_embeddings(samples, handler=wds.handlers.reraise_exception):
+ """
+ Takes the img_emb and text_emb keys and turns them into one key "emb": { "text": text_emb, "img": img_emb }
+ either or both of text_emb and img_emb may not be in the sample so we only add the ones that exist
+ """
+ for sample in samples:
+ try:
+ sample['emb'] = {}
+ if 'text_emb' in sample:
+ sample['emb']['text'] = sample['text_emb']
+ if 'img_emb' in sample:
+ sample['emb']['img'] = sample['img_emb']
+ yield sample
+ except Exception as exn: # From wds implementation
+ if handler(exn):
+ continue
+ else:
+ break
+
+def verify_keys(samples, required_keys, handler=wds.handlers.reraise_exception):
+ """
+ Requires that both the image and embedding are present in the sample
+ This is important to do as a user may forget they do not have embeddings in their webdataset and neglect to add them using the embedding_folder_url parameter.
+ """
+ for sample in samples:
+ try:
+ for key in required_keys:
+ assert key in sample, f"Sample {sample['__key__']} missing {key}. Has keys {sample.keys()}"
+ yield sample
+ except Exception as exn: # From wds implementation
+ if handler(exn):
+ continue
+ else:
+ break
+key_verifier = wds.filters.pipelinefilter(verify_keys)
+
+class ImageEmbeddingDataset(wds.DataPipeline, wds.compat.FluidInterface):
+ """
+ A fluid interface wrapper for DataPipline that returns image embedding pairs
+ Reads embeddings as npy files from the webdataset if they exist. If embedding_folder_url is set, they will be inserted in from the alternate source.
+ """
+
+ def __init__(
+ self,
+ urls,
+ img_embedding_folder_url=None,
+ text_embedding_folder_url=None,
+ index_width=None,
+ img_preproc=None,
+ extra_keys=[],
+ handler=wds.handlers.reraise_exception,
+ resample=False,
+ shuffle_shards=True
+ ):
+ """
+ Modeled directly off of the WebDataset constructor
+
+ :param urls: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
+ :param embedding_folder_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
+ Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
+ :param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
+ For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
+ :param img_preproc: This function is run on the img before it is batched and returned. Useful for data augmentation or converting to torch tensor.
+ :param handler: A webdataset handler.
+ :param resample: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
+ :param shuffle_shards: If true, shuffle the shards before resampling. This cannot be true if resample is true.
+
+
+ """
+ super().__init__()
+ keys = ["jpg", "emb"] + extra_keys
+ # if img_embedding_folder_url is not None:
+ # keys.append("img_emb")
+ # if text_embedding_folder_url is not None:
+ # keys.append("text_emb")
+ # keys.extend(extra_keys)
+ self.key_map = {key: i for i, key in enumerate(keys)}
+ self.resampling = resample
+ self.img_preproc = img_preproc
+ # If s3, check if s3fs is installed and s3cmd is installed and check if the data is piped instead of straight up
+ if (isinstance(urls, str) and "s3:" in urls) or (isinstance(urls, list) and any(["s3:" in url for url in urls])):
+ # Then this has an s3 link for the webdataset and we need extra packages
+ if shutil.which("s3cmd") is None:
+ raise RuntimeError("s3cmd is required for s3 webdataset")
+ if (img_embedding_folder_url is not None and "s3:" in img_embedding_folder_url) or (text_embedding_folder_url is not None and "s3:" in text_embedding_folder_url):
+ # Then the embeddings are being loaded from s3 and fsspec requires s3fs
+ try:
+ import s3fs
+ except ImportError:
+ raise RuntimeError("s3fs is required to load embeddings from s3")
+ # Add the shardList and randomize or resample if requested
+ if resample:
+ assert not shuffle_shards, "Cannot both resample and shuffle"
+ self.append(wds.ResampledShards(urls))
+ else:
+ self.append(wds.SimpleShardList(urls))
+ if shuffle_shards:
+ self.append(wds.filters.shuffle(1000))
+
+ if img_embedding_folder_url is not None:
+ # There may be webdataset shards that do not have a embedding shard associated with it. If we do not skip these, they would cause issues.
+ self.append(skip_unassociated_shards(embeddings_url=img_embedding_folder_url, handler=handler))
+ if text_embedding_folder_url is not None:
+ self.append(skip_unassociated_shards(embeddings_url=text_embedding_folder_url, handler=handler))
+
+ self.append(wds.tarfile_to_samples(handler=handler))
+ self.append(wds.decode("pilrgb", handler=handler))
+ if img_embedding_folder_url is not None:
+ # Then we are loading image embeddings for a remote source
+ assert index_width is not None, "Reading embeddings separately requires index width length to be given"
+ self.append(insert_embedding(embeddings_url=img_embedding_folder_url, index_width=index_width, sample_key='img_emb', handler=handler))
+ if text_embedding_folder_url is not None:
+ # Then we are loading image embeddings for a remote source
+ assert index_width is not None, "Reading embeddings separately requires index width length to be given"
+ self.append(insert_embedding(embeddings_url=text_embedding_folder_url, index_width=index_width, sample_key='text_emb', handler=handler))
+ self.append(join_embeddings)
+ self.append(key_verifier(required_keys=keys, handler=handler))
+ # Apply preprocessing
+ self.append(wds.map(self.preproc))
+ self.append(wds.to_tuple(*keys))
+
+ def preproc(self, sample):
+ """Applies the preprocessing for images"""
+ if self.img_preproc is not None:
+ sample["jpg"] = self.img_preproc(sample["jpg"])
+ return sample
+
+def create_image_embedding_dataloader(
+ tar_url,
+ num_workers,
+ batch_size,
+ img_embeddings_url=None,
+ text_embeddings_url=None,
+ index_width=None,
+ shuffle_num = None,
+ shuffle_shards = True,
+ resample_shards = False,
+ img_preproc=None,
+ extra_keys=[],
+ handler=wds.handlers.reraise_exception#warn_and_continue
+):
+ """
+ Convenience function to create an image embedding dataseta and dataloader in one line
+
+ :param tar_url: A url pointing to the tar files of the webdataset formatted as /path/to/webdataset/{0000..9999}.tar
+ :param num_workers: The number of workers to use for the dataloader
+ :param batch_size: The batch size to use for the dataloader
+ :param embeddings_url: Required if webdataset does not contain embeddings. A url pointing to the npy files of the embeddings. Should have the same number of shards as the webdataset.
+ Webdataset image keys should align with the index of the embedding. This means missing image indices must have a corresponding embedding of all zeros.
+ :param index_width: The number of digits in the index. This is used to align the embedding index with the image index.
+ For example, if a file in the webdataset shard 3 is named 0003039.jpg, we know the shard is 4 digits and the last 3 digits are the index_width.
+ :param shuffle_num: If not None, shuffle the dataset with this size buffer after sampling.
+ :param shuffle_shards: If true, shuffle the shards before sampling. This cannot be true if resample is true.
+ :param resample_shards: If true, resample webdataset shards with replacement. You need to set your own epoch size if this is true since it will resample infinitely.
+ :param handler: A webdataset handler.
+ """
+ ds = ImageEmbeddingDataset(
+ tar_url,
+ img_embedding_folder_url=img_embeddings_url,
+ text_embedding_folder_url=text_embeddings_url,
+ index_width=index_width,
+ shuffle_shards=shuffle_shards,
+ resample=resample_shards,
+ extra_keys=extra_keys,
+ img_preproc=img_preproc,
+ handler=handler
+ )
+ if shuffle_num is not None and shuffle_num > 0:
+ ds.shuffle(1000)
+ return DataLoader(
+ ds,
+ num_workers=num_workers,
+ batch_size=batch_size,
+ prefetch_factor=2, # This might be good to have high so the next npy file is prefetched
+ pin_memory=True,
+ shuffle=False
+ )
diff --git a/docs/src/dalle2_pytorch/dataloaders/prior_loader.py b/docs/src/dalle2_pytorch/dataloaders/prior_loader.py
new file mode 100644
index 00000000..f612653e
--- /dev/null
+++ b/docs/src/dalle2_pytorch/dataloaders/prior_loader.py
@@ -0,0 +1,282 @@
+from math import ceil
+from clip import tokenize
+from embedding_reader import EmbeddingReader
+from torch import from_numpy
+from torch.utils.data import IterableDataset, DataLoader
+
+
+class PriorEmbeddingDataset(IterableDataset):
+ """
+ PriorEmbeddingDataset is a wrapper of EmbeddingReader.
+
+ It enables one to simplify the logic necessary to yield samples from
+ the different EmbeddingReader configurations available.
+ """
+
+ def __init__(
+ self,
+ text_conditioned: bool,
+ batch_size: int,
+ start: int,
+ stop: int,
+ image_reader,
+ text_reader: EmbeddingReader = None,
+ ) -> None:
+ super(PriorEmbeddingDataset).__init__()
+
+ self.text_conditioned = text_conditioned
+
+ if not self.text_conditioned:
+ self.text_reader = text_reader
+
+ self.image_reader = image_reader
+ self.start = start
+ self.stop = stop
+ self.batch_size = batch_size
+
+ def __len__(self):
+ return self.stop - self.start
+
+ def __iter__(self):
+ # D.R.Y loader args
+ loader_args = dict(
+ batch_size=self.batch_size,
+ start=self.start,
+ end=self.stop,
+ show_progress=False,
+ )
+
+ # if the data requested is text conditioned, only load images
+ if self.text_conditioned:
+ self.loader = self.image_reader(**loader_args)
+ # otherwise, include text embeddings and bypass metadata
+ else:
+ self.loader = zip(
+ self.image_reader(**loader_args), self.text_reader(**loader_args)
+ )
+
+ # return the data loader in its formatted state
+ return self
+
+ def __next__(self):
+ try:
+ return self.get_sample()
+ except StopIteration:
+ raise StopIteration
+
+ def __str__(self):
+ return f""
+
+ def set_start(self, start):
+ """
+ Adjust the starting point within the reader, useful for resuming an epoch
+ """
+ self.start = start
+
+ def get_start(self):
+ return self.start
+
+ def get_sample(self):
+ """
+ pre-proocess data from either reader into a common format
+ """
+ if self.text_conditioned:
+ image_embedding, caption = next(self.loader)
+
+ image_embedding = from_numpy(image_embedding)
+ tokenized_caption = tokenize(caption["caption"].to_list(), truncate=True)
+
+ return image_embedding, tokenized_caption
+
+ else:
+ (image_embedding, _), (text_embedding, _) = next(self.loader)
+
+ image_embedding = from_numpy(image_embedding)
+ text_embedding = from_numpy(text_embedding)
+
+ return image_embedding, text_embedding
+
+
+# helper functions
+
+
+def distribute_to_rank(start, stop, rank, world_size):
+ """
+ Distribute data to each rank given the world size.
+
+ Return:
+ - New start and stop points for this rank.
+ """
+ num_samples = int(stop - start)
+
+ per_rank = int(ceil((num_samples) / float(world_size)))
+
+ assert (
+ per_rank > 0
+ ), f"Number of samples per rank must be larger than 0, (found: {per_rank})"
+
+ rank_start = start + rank * per_rank
+
+ rank_stop = min(rank_start + per_rank, stop)
+
+ new_length = rank_stop - rank_start
+
+ assert (
+ new_length > 0
+ ), "Calculated start and stop points result in a length of zero for this rank."
+
+ return rank_start, rank_stop
+
+
+def get_reader(
+ text_conditioned: bool, img_url: str, meta_url: str = None, txt_url: str = None
+):
+ """
+ Create an EmbeddingReader object from the specified URLs
+
+ get_reader() will always expect a url to image embeddings.
+
+ If text-conditioned, it will also expect a meta_url for the captions.
+ Otherwise, it will need txt_url for the matching text embeddings.
+
+ Returns an image_reader object if text-conditioned.
+ Otherwise it returns both an image_reader and a text_reader
+ """
+
+ assert img_url is not None, "Must supply a image url"
+
+ if text_conditioned:
+ assert meta_url is not None, "Must supply meta url if text-conditioned"
+
+ image_reader = EmbeddingReader(
+ embeddings_folder=img_url,
+ file_format="parquet_npy",
+ # will assume the caption column exists and is the only one requested
+ meta_columns=["caption"],
+ metadata_folder=meta_url,
+ )
+
+ return image_reader
+
+ # otherwise we will require text embeddings as well and return two readers
+ assert (
+ txt_url is not None
+ ), "Must supply text embedding url if not text-conditioning"
+
+ image_reader = EmbeddingReader(img_url, file_format="npy")
+ text_reader = EmbeddingReader(txt_url, file_format="npy")
+
+ return image_reader, text_reader
+
+
+def make_splits(
+ text_conditioned: bool,
+ batch_size: int,
+ num_data_points: int,
+ train_split: float,
+ eval_split: float,
+ image_reader: EmbeddingReader,
+ text_reader: EmbeddingReader = None,
+ start=0,
+ rank=0,
+ world_size=1,
+):
+ """
+ Split an embedding reader object as needed.
+
+ NOTE: make_splits() will infer the test set size from your train and eval.
+
+ Input:
+ - text_conditioned: whether to prepare text-conditioned training data
+ - batch_size: the batch size for a single gpu
+ - num_data_points: the total number of data points you wish to train on
+ - train_split: the percentage of data you wish to train on
+ - eval_split: the percentage of data you wish to validate on
+ - image_reader: the image_reader you wish to split
+ - text_reader: the text_reader you want to split (if !text_conditioned)
+ - start: the starting point within your dataset
+ - rank: the rank of your worker
+ - world_size: the total world size of your distributed training run
+
+ Returns:
+ - PyTorch Dataloaders that yield tuples of (img, txt) data.
+ """
+
+ assert start < image_reader.count, "start position cannot exceed reader count."
+
+ # verify that the num_data_points does not exceed the max points
+ if num_data_points > (image_reader.count - start):
+ print(
+ "Specified count is larger than what's available...defaulting to reader's count."
+ )
+ num_data_points = image_reader.count
+
+ # compute split points
+ train_set_size = int(train_split * num_data_points)
+ eval_set_size = int(eval_split * num_data_points)
+ eval_start = train_set_size
+ eval_stop = int(eval_start + eval_set_size)
+
+ assert (
+ train_split + eval_split
+ ) < 1.0, "Specified train and eval split is too large to infer a test split."
+
+ # distribute to rank
+ rank_train_start, rank_train_stop = distribute_to_rank(
+ start, train_set_size, rank, world_size
+ )
+ rank_eval_start, rank_eval_stop = distribute_to_rank(
+ train_set_size, eval_stop, rank, world_size
+ )
+ rank_test_start, rank_test_stop = distribute_to_rank(
+ eval_stop, num_data_points, rank, world_size
+ )
+
+ # wrap up splits into a dict
+ train_split_args = dict(
+ start=rank_train_start, stop=rank_train_stop, batch_size=batch_size
+ )
+ eval_split_args = dict(
+ start=rank_eval_start, stop=rank_eval_stop, batch_size=batch_size
+ )
+ test_split_args = dict(
+ start=rank_test_start, stop=rank_test_stop, batch_size=batch_size
+ )
+
+ if text_conditioned:
+ # add the text-conditioned args to a unified dict
+ reader_args = dict(
+ text_conditioned=text_conditioned,
+ image_reader=image_reader,
+ )
+
+ train_split_args = dict(**reader_args, **train_split_args)
+ eval_split_args = dict(**reader_args, **eval_split_args)
+ test_split_args = dict(**reader_args, **test_split_args)
+
+ train = PriorEmbeddingDataset(**train_split_args)
+ val = PriorEmbeddingDataset(**eval_split_args)
+ test = PriorEmbeddingDataset(**test_split_args)
+
+ else:
+ # add the non-conditioned args to a unified dict
+ reader_args = dict(
+ text_conditioned=text_conditioned,
+ image_reader=image_reader,
+ text_reader=text_reader,
+ )
+
+ train_split_args = dict(**reader_args, **train_split_args)
+ eval_split_args = dict(**reader_args, **eval_split_args)
+ test_split_args = dict(**reader_args, **test_split_args)
+
+ train = PriorEmbeddingDataset(**train_split_args)
+ val = PriorEmbeddingDataset(**eval_split_args)
+ test = PriorEmbeddingDataset(**test_split_args)
+
+ # true batch size is specifed in the PriorEmbeddingDataset
+ train_loader = DataLoader(train, batch_size=None)
+ eval_loader = DataLoader(val, batch_size=None)
+ test_loader = DataLoader(test, batch_size=None)
+
+ return train_loader, eval_loader, test_loader
diff --git a/docs/src/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py b/docs/src/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py
new file mode 100644
index 00000000..1418c945
--- /dev/null
+++ b/docs/src/dalle2_pytorch/dataloaders/simple_image_only_dataloader.py
@@ -0,0 +1,59 @@
+from pathlib import Path
+
+import torch
+from torch.utils import data
+from torchvision import transforms, utils
+
+from PIL import Image
+
+# helpers functions
+
+def cycle(dl):
+ while True:
+ for data in dl:
+ yield data
+
+# dataset and dataloader
+
+class Dataset(data.Dataset):
+ def __init__(
+ self,
+ folder,
+ image_size,
+ exts = ['jpg', 'jpeg', 'png']
+ ):
+ super().__init__()
+ self.folder = folder
+ self.image_size = image_size
+ self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
+
+ self.transform = transforms.Compose([
+ transforms.Resize(image_size),
+ transforms.RandomHorizontalFlip(),
+ transforms.CenterCrop(image_size),
+ transforms.ToTensor()
+ ])
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, index):
+ path = self.paths[index]
+ img = Image.open(path)
+ return self.transform(img)
+
+def get_images_dataloader(
+ folder,
+ *,
+ batch_size,
+ image_size,
+ shuffle = True,
+ cycle_dl = True,
+ pin_memory = True
+):
+ ds = Dataset(folder, image_size)
+ dl = data.DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
+
+ if cycle_dl:
+ dl = cycle(dl)
+ return dl
diff --git a/docs/src/dalle2_pytorch/optimizer.py b/docs/src/dalle2_pytorch/optimizer.py
new file mode 100644
index 00000000..2df2d48d
--- /dev/null
+++ b/docs/src/dalle2_pytorch/optimizer.py
@@ -0,0 +1,34 @@
+from torch.optim import AdamW, Adam
+
+def separate_weight_decayable_params(params):
+ wd_params, no_wd_params = [], []
+ for param in params:
+ param_list = no_wd_params if param.ndim < 2 else wd_params
+ param_list.append(param)
+ return wd_params, no_wd_params
+
+def get_optimizer(
+ params,
+ lr = 1e-4,
+ wd = 1e-2,
+ betas = (0.9, 0.99),
+ eps = 1e-8,
+ filter_by_requires_grad = False,
+ group_wd_params = True,
+ **kwargs
+):
+ if filter_by_requires_grad:
+ params = list(filter(lambda t: t.requires_grad, params))
+
+ if wd == 0:
+ return Adam(params, lr = lr, betas = betas, eps = eps)
+
+ if group_wd_params:
+ wd_params, no_wd_params = separate_weight_decayable_params(params)
+
+ params = [
+ {'params': wd_params},
+ {'params': no_wd_params, 'weight_decay': 0},
+ ]
+
+ return AdamW(params, lr = lr, weight_decay = wd, betas = betas, eps = eps)
diff --git a/docs/src/dalle2_pytorch/tokenizer.py b/docs/src/dalle2_pytorch/tokenizer.py
new file mode 100644
index 00000000..7c010089
--- /dev/null
+++ b/docs/src/dalle2_pytorch/tokenizer.py
@@ -0,0 +1,191 @@
+# take from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
+# to give users a quick easy start to training DALL-E without doing BPE
+
+import torch
+
+import html
+import os
+import ftfy
+import regex as re
+from functools import lru_cache
+from pathlib import Path
+
+from dalle2_pytorch.utils import import_or_print_error
+
+# OpenAI simple tokenizer
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "data/bpe_simple_vocab_16e6.txt")
+
+@lru_cache()
+def bytes_to_unicode():
+ bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
+ cs = bs[:]
+ n = 0
+ for b in range(2 ** 8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2 ** 8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+def get_pairs(word):
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path = default_bpe()):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = Path(bpe_path).read_text(encoding='utf8').split('\n')
+ merges = merges[1:49152 - 256 - 2 + 1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v + '' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+
+ self.vocab_size = 49408
+
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+ self.pat = re.compile(
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+ re.IGNORECASE)
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + (token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ''
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens, remove_start_end = True, pad_tokens = set()):
+ if torch.is_tensor(tokens):
+ tokens = tokens.tolist()
+
+ if remove_start_end:
+ tokens = [token for token in tokens if token not in (49406, 40407, 0)]
+ text = ''.join([self.decoder[token] for token in tokens if token not in pad_tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
+
+ def tokenize(self, texts, context_length = 256, truncate_text = False):
+ if isinstance(texts, str):
+ texts = [texts]
+
+ all_tokens = [self.encode(text) for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ if truncate_text:
+ tokens = tokens[:context_length]
+ else:
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
+
+tokenizer = SimpleTokenizer()
+
+# YTTM tokenizer
+
+class YttmTokenizer:
+ def __init__(self, bpe_path = None):
+ bpe_path = Path(bpe_path)
+ assert bpe_path.exists(), f'BPE json path {str(bpe_path)} does not exist'
+
+ self.yttm = import_or_print_error('youtokentome', 'you need to install youtokentome by `pip install youtokentome`')
+
+ tokenizer = self.yttm.BPE(model = str(bpe_path))
+ self.tokenizer = tokenizer
+ self.vocab_size = tokenizer.vocab_size()
+
+ def decode(self, tokens, pad_tokens = set()):
+ if torch.is_tensor(tokens):
+ tokens = tokens.tolist()
+
+ return self.tokenizer.decode(tokens, ignore_ids = pad_tokens.union({0}))
+
+ def encode(self, texts):
+ encoded = self.tokenizer.encode(texts, output_type = self.yttm.OutputType.ID)
+ return list(map(torch.tensor, encoded))
+
+ def tokenize(self, texts, context_length = 256, truncate_text = False):
+ if isinstance(texts, str):
+ texts = [texts]
+
+ all_tokens = self.encode(texts)
+
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ if truncate_text:
+ tokens = tokens[:context_length]
+ else:
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
diff --git a/docs/src/dalle2_pytorch/trackers.py b/docs/src/dalle2_pytorch/trackers.py
new file mode 100644
index 00000000..a83f6d42
--- /dev/null
+++ b/docs/src/dalle2_pytorch/trackers.py
@@ -0,0 +1,601 @@
+import urllib.request
+import os
+import json
+from pathlib import Path
+import shutil
+from itertools import zip_longest
+from typing import Any, Optional, List, Union
+from pydantic import BaseModel
+
+import torch
+from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
+from dalle2_pytorch.utils import import_or_print_error
+from dalle2_pytorch.trainer import DecoderTrainer, DiffusionPriorTrainer
+from dalle2_pytorch.version import __version__
+from packaging import version
+
+# constants
+
+DEFAULT_DATA_PATH = './.tracker-data'
+
+# helper functions
+
+def exists(val):
+ return val is not None
+
+class BaseLogger:
+ """
+ An abstract class representing an object that can log data.
+ Parameters:
+ data_path (str): A file path for storing temporary data.
+ verbose (bool): Whether of not to always print logs to the console.
+ """
+ def __init__(self, data_path: str, resume: bool = False, auto_resume: bool = False, verbose: bool = False, **kwargs):
+ self.data_path = Path(data_path)
+ self.resume = resume
+ self.auto_resume = auto_resume
+ self.verbose = verbose
+
+ def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
+ """
+ Initializes the logger.
+ Errors if the logger is invalid.
+ full_config is the config file dict while extra_config is anything else from the script that is not defined the config file.
+ """
+ raise NotImplementedError
+
+ def log(self, log, **kwargs) -> None:
+ raise NotImplementedError
+
+ def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
+ raise NotImplementedError
+
+ def log_file(self, file_path, **kwargs) -> None:
+ raise NotImplementedError
+
+ def log_error(self, error_string, **kwargs) -> None:
+ raise NotImplementedError
+
+ def get_resume_data(self, **kwargs) -> dict:
+ """
+ Sets tracker attributes that along with { "resume": True } will be used to resume training.
+ It is assumed that after init is called this data will be complete.
+ If the logger does not have any resume functionality, it should return an empty dict.
+ """
+ raise NotImplementedError
+
+class ConsoleLogger(BaseLogger):
+ def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
+ print("Logging to console")
+
+ def log(self, log, **kwargs) -> None:
+ print(log)
+
+ def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
+ pass
+
+ def log_file(self, file_path, **kwargs) -> None:
+ pass
+
+ def log_error(self, error_string, **kwargs) -> None:
+ print(error_string)
+
+ def get_resume_data(self, **kwargs) -> dict:
+ return {}
+
+class WandbLogger(BaseLogger):
+ """
+ Logs to a wandb run.
+ Parameters:
+ data_path (str): A file path for storing temporary data.
+ wandb_entity (str): The wandb entity to log to.
+ wandb_project (str): The wandb project to log to.
+ wandb_run_id (str): The wandb run id to resume.
+ wandb_run_name (str): The wandb run name to use.
+ """
+ def __init__(self,
+ data_path: str,
+ wandb_entity: str,
+ wandb_project: str,
+ wandb_run_id: Optional[str] = None,
+ wandb_run_name: Optional[str] = None,
+ **kwargs
+ ):
+ super().__init__(data_path, **kwargs)
+ self.entity = wandb_entity
+ self.project = wandb_project
+ self.run_id = wandb_run_id
+ self.run_name = wandb_run_name
+
+ def init(self, full_config: BaseModel, extra_config: dict, **kwargs) -> None:
+ assert self.entity is not None, "wandb_entity must be specified for wandb logger"
+ assert self.project is not None, "wandb_project must be specified for wandb logger"
+ self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
+ os.environ["WANDB_SILENT"] = "true"
+ # Initializes the wandb run
+ init_object = {
+ "entity": self.entity,
+ "project": self.project,
+ "config": {**full_config.dict(), **extra_config}
+ }
+ if self.run_name is not None:
+ init_object['name'] = self.run_name
+ if self.resume:
+ assert self.run_id is not None, '`wandb_run_id` must be provided if `wandb_resume` is True'
+ if self.run_name is not None:
+ print("You are renaming a run. I hope that is what you intended.")
+ init_object['resume'] = 'must'
+ init_object['id'] = self.run_id
+
+ self.wandb.init(**init_object)
+ print(f"Logging to wandb run {self.wandb.run.path}-{self.wandb.run.name}")
+
+ def log(self, log, **kwargs) -> None:
+ if self.verbose:
+ print(log)
+ self.wandb.log(log, **kwargs)
+
+ def log_images(self, images, captions=[], image_section="images", **kwargs) -> None:
+ """
+ Takes a tensor of images and a list of captions and logs them to wandb.
+ """
+ wandb_images = [self.wandb.Image(image, caption=caption) for image, caption in zip_longest(images, captions)]
+ self.wandb.log({ image_section: wandb_images }, **kwargs)
+
+ def log_file(self, file_path, base_path: Optional[str] = None, **kwargs) -> None:
+ if base_path is None:
+ # Then we take the basepath as the parent of the file_path
+ base_path = Path(file_path).parent
+ self.wandb.save(str(file_path), base_path = str(base_path))
+
+ def log_error(self, error_string, step=None, **kwargs) -> None:
+ if self.verbose:
+ print(error_string)
+ self.wandb.log({"error": error_string, **kwargs}, step=step)
+
+ def get_resume_data(self, **kwargs) -> dict:
+ # In order to resume, we need wandb_entity, wandb_project, and wandb_run_id
+ return {
+ "entity": self.entity,
+ "project": self.project,
+ "run_id": self.wandb.run.id
+ }
+
+logger_type_map = {
+ 'console': ConsoleLogger,
+ 'wandb': WandbLogger,
+}
+def create_logger(logger_type: str, data_path: str, **kwargs) -> BaseLogger:
+ if logger_type == 'custom':
+ raise NotImplementedError('Custom loggers are not supported yet. Please use a different logger type.')
+ try:
+ logger_class = logger_type_map[logger_type]
+ except KeyError:
+ raise ValueError(f'Unknown logger type: {logger_type}. Must be one of {list(logger_type_map.keys())}')
+ return logger_class(data_path, **kwargs)
+
+class BaseLoader:
+ """
+ An abstract class representing an object that can load a model checkpoint.
+ Parameters:
+ data_path (str): A file path for storing temporary data.
+ """
+ def __init__(self, data_path: str, only_auto_resume: bool = False, **kwargs):
+ self.data_path = Path(data_path)
+ self.only_auto_resume = only_auto_resume
+
+ def init(self, logger: BaseLogger, **kwargs) -> None:
+ raise NotImplementedError
+
+ def recall() -> dict:
+ raise NotImplementedError
+
+class UrlLoader(BaseLoader):
+ """
+ A loader that downloads the file from a url and loads it
+ Parameters:
+ data_path (str): A file path for storing temporary data.
+ url (str): The url to download the file from.
+ """
+ def __init__(self, data_path: str, url: str, **kwargs):
+ super().__init__(data_path, **kwargs)
+ self.url = url
+
+ def init(self, logger: BaseLogger, **kwargs) -> None:
+ # Makes sure the file exists to be downloaded
+ pass # TODO: Actually implement that
+
+ def recall(self) -> dict:
+ # Download the file
+ save_path = self.data_path / 'loaded_checkpoint.pth'
+ urllib.request.urlretrieve(self.url, str(save_path))
+ # Load the file
+ return torch.load(str(save_path), map_location='cpu')
+
+
+class LocalLoader(BaseLoader):
+ """
+ A loader that loads a file from a local path
+ Parameters:
+ data_path (str): A file path for storing temporary data.
+ file_path (str): The path to the file to load.
+ """
+ def __init__(self, data_path: str, file_path: str, **kwargs):
+ super().__init__(data_path, **kwargs)
+ self.file_path = Path(file_path)
+
+ def init(self, logger: BaseLogger, **kwargs) -> None:
+ # Makes sure the file exists to be loaded
+ if not self.file_path.exists() and not self.only_auto_resume:
+ raise FileNotFoundError(f'Model not found at {self.file_path}')
+
+ def recall(self) -> dict:
+ # Load the file
+ return torch.load(str(self.file_path), map_location='cpu')
+
+class WandbLoader(BaseLoader):
+ """
+ A loader that loads a model from an existing wandb run
+ """
+ def __init__(self, data_path: str, wandb_file_path: str, wandb_run_path: Optional[str] = None, **kwargs):
+ super().__init__(data_path, **kwargs)
+ self.run_path = wandb_run_path
+ self.file_path = wandb_file_path
+
+ def init(self, logger: BaseLogger, **kwargs) -> None:
+ self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb recall function')
+ # Make sure the file can be downloaded
+ if self.wandb.run is not None and self.run_path is None:
+ self.run_path = self.wandb.run.path
+ assert self.run_path is not None, 'wandb run was not found to load from. If not using the wandb logger must specify the `wandb_run_path`.'
+ assert self.run_path is not None, '`wandb_run_path` must be provided for the wandb loader'
+ assert self.file_path is not None, '`wandb_file_path` must be provided for the wandb loader'
+
+ os.environ["WANDB_SILENT"] = "true"
+ pass # TODO: Actually implement that
+
+ def recall(self) -> dict:
+ file_reference = self.wandb.restore(self.file_path, run_path=self.run_path)
+ return torch.load(file_reference.name, map_location='cpu')
+
+loader_type_map = {
+ 'url': UrlLoader,
+ 'local': LocalLoader,
+ 'wandb': WandbLoader,
+}
+def create_loader(loader_type: str, data_path: str, **kwargs) -> BaseLoader:
+ if loader_type == 'custom':
+ raise NotImplementedError('Custom loaders are not supported yet. Please use a different loader type.')
+ try:
+ loader_class = loader_type_map[loader_type]
+ except KeyError:
+ raise ValueError(f'Unknown loader type: {loader_type}. Must be one of {list(loader_type_map.keys())}')
+ return loader_class(data_path, **kwargs)
+
+class BaseSaver:
+ def __init__(self,
+ data_path: str,
+ save_latest_to: Optional[Union[str, bool]] = None,
+ save_best_to: Optional[Union[str, bool]] = None,
+ save_meta_to: Optional[str] = None,
+ save_type: str = 'checkpoint',
+ **kwargs
+ ):
+ self.data_path = Path(data_path)
+ self.save_latest_to = save_latest_to
+ self.saving_latest = save_latest_to is not None and save_latest_to is not False
+ self.save_best_to = save_best_to
+ self.saving_best = save_best_to is not None and save_best_to is not False
+ self.save_meta_to = save_meta_to
+ self.saving_meta = save_meta_to is not None
+ self.save_type = save_type
+ assert save_type in ['checkpoint', 'model'], '`save_type` must be one of `checkpoint` or `model`'
+ assert self.saving_latest or self.saving_best or self.saving_meta, 'At least one saving option must be specified'
+
+ def init(self, logger: BaseLogger, **kwargs) -> None:
+ raise NotImplementedError
+
+ def save_file(self, local_path: Path, save_path: str, is_best=False, is_latest=False, **kwargs) -> None:
+ """
+ Save a general file under save_meta_to
+ """
+ raise NotImplementedError
+
+class LocalSaver(BaseSaver):
+ def __init__(self,
+ data_path: str,
+ **kwargs
+ ):
+ super().__init__(data_path, **kwargs)
+
+ def init(self, logger: BaseLogger, **kwargs) -> None:
+ # Makes sure the directory exists to be saved to
+ print(f"Saving {self.save_type} locally")
+ if not self.data_path.exists():
+ self.data_path.mkdir(parents=True)
+
+ def save_file(self, local_path: str, save_path: str, **kwargs) -> None:
+ # Copy the file to save_path
+ save_path_file_name = Path(save_path).name
+ # Make sure parent directory exists
+ save_path_parent = Path(save_path).parent
+ if not save_path_parent.exists():
+ save_path_parent.mkdir(parents=True)
+ print(f"Saving {save_path_file_name} {self.save_type} to local path {save_path}")
+ shutil.copy(local_path, save_path)
+
+class WandbSaver(BaseSaver):
+ def __init__(self, data_path: str, wandb_run_path: Optional[str] = None, **kwargs):
+ super().__init__(data_path, **kwargs)
+ self.run_path = wandb_run_path
+
+ def init(self, logger: BaseLogger, **kwargs) -> None:
+ self.wandb = import_or_print_error('wandb', '`pip install wandb` to use the wandb logger')
+ os.environ["WANDB_SILENT"] = "true"
+ # Makes sure that the user can upload tot his run
+ if self.run_path is not None:
+ entity, project, run_id = self.run_path.split("/")
+ self.run = self.wandb.init(entity=entity, project=project, id=run_id)
+ else:
+ assert self.wandb.run is not None, 'You must be using the wandb logger if you are saving to wandb and have not set `wandb_run_path`'
+ self.run = self.wandb.run
+ # TODO: Now actually check if upload is possible
+ print(f"Saving to wandb run {self.run.path}-{self.run.name}")
+
+ def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
+ # In order to log something in the correct place in wandb, we need to have the same file structure here
+ save_path_file_name = Path(save_path).name
+ print(f"Saving {save_path_file_name} {self.save_type} to wandb run {self.run.path}-{self.run.name}")
+ save_path = Path(self.data_path) / save_path
+ save_path.parent.mkdir(parents=True, exist_ok=True)
+ shutil.copy(local_path, save_path)
+ self.run.save(str(save_path), base_path = str(self.data_path), policy='now')
+
+class HuggingfaceSaver(BaseSaver):
+ def __init__(self, data_path: str, huggingface_repo: str, token_path: Optional[str] = None, **kwargs):
+ super().__init__(data_path, **kwargs)
+ self.huggingface_repo = huggingface_repo
+ self.token_path = token_path
+
+ def init(self, logger: BaseLogger, **kwargs):
+ # Makes sure this user can upload to the repo
+ self.hub = import_or_print_error('huggingface_hub', '`pip install huggingface_hub` to use the huggingface saver')
+ try:
+ identity = self.hub.whoami() # Errors if not logged in
+ # Then we are logged in
+ except:
+ # We are not logged in. Use the token_path to set the token.
+ if not os.path.exists(self.token_path):
+ raise Exception("Not logged in to huggingface and no token_path specified. Please login with `huggingface-cli login` or if that does not work set the token_path.")
+ with open(self.token_path, "r") as f:
+ token = f.read().strip()
+ self.hub.HfApi.set_access_token(token)
+ identity = self.hub.whoami()
+ print(f"Saving to huggingface repo {self.huggingface_repo}")
+
+ def save_file(self, local_path: Path, save_path: str, **kwargs) -> None:
+ # Saving to huggingface is easy, we just need to upload the file with the correct name
+ save_path_file_name = Path(save_path).name
+ print(f"Saving {save_path_file_name} {self.save_type} to huggingface repo {self.huggingface_repo}")
+ self.hub.upload_file(
+ path_or_fileobj=str(local_path),
+ path_in_repo=str(save_path),
+ repo_id=self.huggingface_repo
+ )
+
+saver_type_map = {
+ 'local': LocalSaver,
+ 'wandb': WandbSaver,
+ 'huggingface': HuggingfaceSaver
+}
+def create_saver(saver_type: str, data_path: str, **kwargs) -> BaseSaver:
+ if saver_type == 'custom':
+ raise NotImplementedError('Custom savers are not supported yet. Please use a different saver type.')
+ try:
+ saver_class = saver_type_map[saver_type]
+ except KeyError:
+ raise ValueError(f'Unknown saver type: {saver_type}. Must be one of {list(saver_type_map.keys())}')
+ return saver_class(data_path, **kwargs)
+
+
+class Tracker:
+ def __init__(self, data_path: Optional[str] = DEFAULT_DATA_PATH, overwrite_data_path: bool = False, dummy_mode: bool = False):
+ self.data_path = Path(data_path)
+ if not dummy_mode:
+ if not overwrite_data_path:
+ assert not self.data_path.exists(), f'Data path {self.data_path} already exists. Set overwrite_data_path to True to overwrite.'
+ if not self.data_path.exists():
+ self.data_path.mkdir(parents=True)
+ self.logger: BaseLogger = None
+ self.loader: Optional[BaseLoader] = None
+ self.savers: List[BaseSaver]= []
+ self.dummy_mode = dummy_mode
+
+ def _load_auto_resume(self) -> bool:
+ # If the file does not exist, we return False. If autoresume is enabled we print a warning so that the user can know that this is the first run.
+ if not self.auto_resume_path.exists():
+ if self.logger.auto_resume:
+ print("Auto_resume is enabled but no auto_resume.json file exists. Assuming this is the first run.")
+ return False
+
+ # Now we know that the autoresume file exists, but if we are not auto resuming we should remove it so that we don't accidentally load it next time
+ if not self.logger.auto_resume:
+ print(f'Removing auto_resume.json because auto_resume is not enabled in the config')
+ self.auto_resume_path.unlink()
+ return False
+
+ # Otherwise we read the json into a dictionary will will override parts of logger.__dict__
+ with open(self.auto_resume_path, 'r') as f:
+ auto_resume_dict = json.load(f)
+ # Check if the logger is of the same type as the autoresume save
+ if auto_resume_dict["logger_type"] != self.logger.__class__.__name__:
+ raise Exception(f'The logger type in the auto_resume file is {auto_resume_dict["logger_type"]} but the current logger is {self.logger.__class__.__name__}. Either use the original logger type, set `auto_resume` to `False`, or delete your existing tracker-data folder.')
+ # Then we are ready to override the logger with the autoresume save
+ self.logger.__dict__["resume"] = True
+ print(f"Updating {self.logger.__dict__} with {auto_resume_dict}")
+ self.logger.__dict__.update(auto_resume_dict)
+ return True
+
+ def _save_auto_resume(self):
+ # Gets the autoresume dict from the logger and adds "logger_type" to it then saves it to the auto_resume file
+ auto_resume_dict = self.logger.get_resume_data()
+ auto_resume_dict['logger_type'] = self.logger.__class__.__name__
+ with open(self.auto_resume_path, 'w') as f:
+ json.dump(auto_resume_dict, f)
+
+ def init(self, full_config: BaseModel, extra_config: dict):
+ self.auto_resume_path = self.data_path / 'auto_resume.json'
+ # Check for resuming the run
+ self.did_auto_resume = self._load_auto_resume()
+ if self.did_auto_resume:
+ print(f'\n\nWARNING: RUN HAS BEEN AUTO-RESUMED WITH THE LOGGER TYPE {self.logger.__class__.__name__}.\nIf this was not your intention, stop this run and set `auto_resume` to `False` in the config.\n\n')
+ print(f"New logger config: {self.logger.__dict__}")
+
+ self.save_metadata = dict(
+ version = version.parse(__version__)
+ ) # Data that will be saved alongside the checkpoint or model
+ self.blacklisted_checkpoint_metadata_keys = ['scaler', 'optimizer', 'model', 'version', 'step', 'steps'] # These keys would cause us to error if we try to save them as metadata
+
+ assert self.logger is not None, '`logger` must be set before `init` is called'
+ if self.dummy_mode:
+ # The only thing we need is a loader
+ if self.loader is not None:
+ self.loader.init(self.logger)
+ return
+ assert len(self.savers) > 0, '`savers` must be set before `init` is called'
+
+ self.logger.init(full_config, extra_config)
+ if self.loader is not None:
+ self.loader.init(self.logger)
+ for saver in self.savers:
+ saver.init(self.logger)
+
+ if self.logger.auto_resume:
+ # Then we need to save the autoresume file. It is assumed after logger.init is called that the logger is ready to be saved.
+ self._save_auto_resume()
+
+ def add_logger(self, logger: BaseLogger):
+ self.logger = logger
+
+ def add_loader(self, loader: BaseLoader):
+ self.loader = loader
+
+ def add_saver(self, saver: BaseSaver):
+ self.savers.append(saver)
+
+ def log(self, *args, **kwargs):
+ if self.dummy_mode:
+ return
+ self.logger.log(*args, **kwargs)
+
+ def log_images(self, *args, **kwargs):
+ if self.dummy_mode:
+ return
+ self.logger.log_images(*args, **kwargs)
+
+ def log_file(self, *args, **kwargs):
+ if self.dummy_mode:
+ return
+ self.logger.log_file(*args, **kwargs)
+
+ def save_config(self, current_config_path: str, config_name = 'config.json'):
+ if self.dummy_mode:
+ return
+ # Save the config under config_name in the root folder of data_path
+ shutil.copy(current_config_path, self.data_path / config_name)
+ for saver in self.savers:
+ if saver.saving_meta:
+ remote_path = Path(saver.save_meta_to) / config_name
+ saver.save_file(current_config_path, str(remote_path))
+
+ def add_save_metadata(self, state_dict_key: str, metadata: Any):
+ """
+ Adds a new piece of metadata that will be saved along with the model or decoder.
+ """
+ self.save_metadata[state_dict_key] = metadata
+
+ def _save_state_dict(self, trainer: Union[DiffusionPriorTrainer, DecoderTrainer], save_type: str, file_path: str, **kwargs) -> Path:
+ """
+ Gets the state dict to be saved and writes it to file_path.
+ If save_type is 'checkpoint', we save the entire trainer state dict.
+ If save_type is 'model', we save only the model state dict.
+ """
+ assert save_type in ['checkpoint', 'model']
+ if save_type == 'checkpoint':
+ # Create a metadata dict without the blacklisted keys so we do not error when we create the state dict
+ metadata = {k: v for k, v in self.save_metadata.items() if k not in self.blacklisted_checkpoint_metadata_keys}
+ trainer.save(file_path, overwrite=True, **kwargs, **metadata)
+ elif save_type == 'model':
+ if isinstance(trainer, DiffusionPriorTrainer):
+ prior = trainer.ema_diffusion_prior.ema_model if trainer.use_ema else trainer.diffusion_prior
+ prior: DiffusionPrior = trainer.accelerator.unwrap_model(prior)
+ # Remove CLIP if it is part of the model
+ original_clip = prior.clip
+ prior.clip = None
+ model_state_dict = prior.state_dict()
+ prior.clip = original_clip
+ elif isinstance(trainer, DecoderTrainer):
+ decoder: Decoder = trainer.accelerator.unwrap_model(trainer.decoder)
+ # Remove CLIP if it is part of the model
+ original_clip = decoder.clip
+ decoder.clip = None
+ if trainer.use_ema:
+ trainable_unets = decoder.unets
+ decoder.unets = trainer.unets # Swap EMA unets in
+ model_state_dict = decoder.state_dict()
+ decoder.unets = trainable_unets # Swap back
+ else:
+ model_state_dict = decoder.state_dict()
+ decoder.clip = original_clip
+ else:
+ raise NotImplementedError('Saving this type of model with EMA mode enabled is not yet implemented. Actually, how did you get here?')
+ state_dict = {
+ **self.save_metadata,
+ 'model': model_state_dict
+ }
+ torch.save(state_dict, file_path)
+ return Path(file_path)
+
+ def save(self, trainer, is_best: bool, is_latest: bool, **kwargs):
+ if self.dummy_mode:
+ return
+ if not is_best and not is_latest:
+ # Nothing to do
+ return
+ # Save the checkpoint and model to data_path
+ checkpoint_path = self.data_path / 'checkpoint.pth'
+ self._save_state_dict(trainer, 'checkpoint', checkpoint_path, **kwargs)
+ model_path = self.data_path / 'model.pth'
+ self._save_state_dict(trainer, 'model', model_path, **kwargs)
+ print("Saved cached models")
+ # Call the save methods on the savers
+ for saver in self.savers:
+ local_path = checkpoint_path if saver.save_type == 'checkpoint' else model_path
+ if saver.saving_latest and is_latest:
+ latest_checkpoint_path = saver.save_latest_to.format(**kwargs)
+ try:
+ saver.save_file(local_path, latest_checkpoint_path, is_latest=True, **kwargs)
+ except Exception as e:
+ self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
+ print(f'Error saving checkpoint: {e}')
+ if saver.saving_best and is_best:
+ best_checkpoint_path = saver.save_best_to.format(**kwargs)
+ try:
+ saver.save_file(local_path, best_checkpoint_path, is_best=True, **kwargs)
+ except Exception as e:
+ self.logger.log_error(f'Error saving checkpoint: {e}', **kwargs)
+ print(f'Error saving checkpoint: {e}')
+
+ @property
+ def can_recall(self):
+ # Defines whether a recall can be performed.
+ return self.loader is not None and (not self.loader.only_auto_resume or self.did_auto_resume)
+
+ def recall(self):
+ if self.can_recall:
+ return self.loader.recall()
+ else:
+ raise ValueError('Tried to recall, but no loader was set or auto-resume was not performed.')
+
+
+
\ No newline at end of file
diff --git a/docs/src/dalle2_pytorch/train_configs.py b/docs/src/dalle2_pytorch/train_configs.py
new file mode 100644
index 00000000..fb72a1df
--- /dev/null
+++ b/docs/src/dalle2_pytorch/train_configs.py
@@ -0,0 +1,382 @@
+import json
+from torchvision import transforms as T
+from pydantic import BaseModel, validator, model_validator
+from typing import List, Optional, Union, Tuple, Dict, Any, TypeVar
+
+from x_clip import CLIP as XCLIP
+from open_clip import list_pretrained
+from coca_pytorch import CoCa
+
+from dalle2_pytorch.dalle2_pytorch import (
+ CoCaAdapter,
+ OpenAIClipAdapter,
+ OpenClipAdapter,
+ Unet,
+ Decoder,
+ DiffusionPrior,
+ DiffusionPriorNetwork,
+ XClipAdapter
+)
+from dalle2_pytorch.trackers import Tracker, create_loader, create_logger, create_saver
+
+# helper functions
+
+def exists(val):
+ return val is not None
+
+def default(val, d):
+ return val if exists(val) else d
+
+InnerType = TypeVar('InnerType')
+ListOrTuple = Union[List[InnerType], Tuple[InnerType]]
+SingularOrIterable = Union[InnerType, ListOrTuple[InnerType]]
+
+# general pydantic classes
+
+class TrainSplitConfig(BaseModel):
+ train: float = 0.75
+ val: float = 0.15
+ test: float = 0.1
+
+ @model_validator(mode = 'after')
+ def validate_all(self, m):
+ actual_sum = sum([*dict(self).values()])
+ if actual_sum != 1.:
+ raise ValueError(f'{dict(self).keys()} must sum to 1.0. Found: {actual_sum}')
+ return self
+
+class TrackerLogConfig(BaseModel):
+ log_type: str = 'console'
+ resume: bool = False # For logs that are saved to unique locations, resume a previous run
+ auto_resume: bool = False # If the process crashes and restarts, resume from the run that crashed
+ verbose: bool = False
+
+ class Config:
+ # Each individual log type has it's own arguments that will be passed through the config
+ extra = "allow"
+
+ def create(self, data_path: str):
+ kwargs = self.dict()
+ return create_logger(self.log_type, data_path, **kwargs)
+
+
+class TrackerLoadConfig(BaseModel):
+ load_from: Optional[str] = None
+ only_auto_resume: bool = False # Only attempt to load if the logger is auto-resuming
+
+ class Config:
+ extra = "allow"
+
+ def create(self, data_path: str):
+ kwargs = self.dict()
+ if self.load_from is None:
+ return None
+ return create_loader(self.load_from, data_path, **kwargs)
+
+class TrackerSaveConfig(BaseModel):
+ save_to: str = 'local'
+ save_all: bool = False
+ save_latest: bool = True
+ save_best: bool = True
+
+ class Config:
+ extra = "allow"
+
+ def create(self, data_path: str):
+ kwargs = self.dict()
+ return create_saver(self.save_to, data_path, **kwargs)
+
+class TrackerConfig(BaseModel):
+ data_path: str = '.tracker_data'
+ overwrite_data_path: bool = False
+ log: TrackerLogConfig
+ load: Optional[TrackerLoadConfig] = None
+ save: Union[List[TrackerSaveConfig], TrackerSaveConfig]
+
+ def create(self, full_config: BaseModel, extra_config: dict, dummy_mode: bool = False) -> Tracker:
+ tracker = Tracker(self.data_path, dummy_mode=dummy_mode, overwrite_data_path=self.overwrite_data_path)
+ # Add the logger
+ tracker.add_logger(self.log.create(self.data_path))
+ # Add the loader
+ if self.load is not None:
+ tracker.add_loader(self.load.create(self.data_path))
+ # Add the saver or savers
+ if isinstance(self.save, list):
+ for save_config in self.save:
+ tracker.add_saver(save_config.create(self.data_path))
+ else:
+ tracker.add_saver(self.save.create(self.data_path))
+ # Initialize all the components and verify that all data is valid
+ tracker.init(full_config, extra_config)
+ return tracker
+
+# diffusion prior pydantic classes
+
+class AdapterConfig(BaseModel):
+ make: str = "openai"
+ model: str = "ViT-L/14"
+ base_model_kwargs: Optional[Dict[str, Any]] = None
+
+ def create(self):
+ if self.make == "openai":
+ return OpenAIClipAdapter(self.model)
+ elif self.make == "open_clip":
+ pretrained = dict(list_pretrained())
+ checkpoint = pretrained[self.model]
+ return OpenClipAdapter(name=self.model, pretrained=checkpoint)
+ elif self.make == "x-clip":
+ return XClipAdapter(XCLIP(**self.base_model_kwargs))
+ elif self.make == "coca":
+ return CoCaAdapter(CoCa(**self.base_model_kwargs))
+ else:
+ raise AttributeError("No adapter with that name is available.")
+
+class DiffusionPriorNetworkConfig(BaseModel):
+ dim: int
+ depth: int
+ max_text_len: Optional[int] = None
+ num_timesteps: Optional[int] = None
+ num_time_embeds: int = 1
+ num_image_embeds: int = 1
+ num_text_embeds: int = 1
+ dim_head: int = 64
+ heads: int = 8
+ ff_mult: int = 4
+ norm_in: bool = False
+ norm_out: bool = True
+ attn_dropout: float = 0.
+ ff_dropout: float = 0.
+ final_proj: bool = True
+ normformer: bool = False
+ rotary_emb: bool = True
+
+ class Config:
+ extra = "allow"
+
+ def create(self):
+ kwargs = self.dict()
+ return DiffusionPriorNetwork(**kwargs)
+
+class DiffusionPriorConfig(BaseModel):
+ clip: Optional[AdapterConfig] = None
+ net: DiffusionPriorNetworkConfig
+ image_embed_dim: int
+ image_size: int
+ image_channels: int = 3
+ timesteps: int = 1000
+ sample_timesteps: Optional[int] = None
+ cond_drop_prob: float = 0.
+ loss_type: str = 'l2'
+ predict_x_start: bool = True
+ beta_schedule: str = 'cosine'
+ condition_on_text_encodings: bool = True
+
+ class Config:
+ extra = "allow"
+
+ def create(self):
+ kwargs = self.dict()
+
+ has_clip = exists(kwargs.pop('clip'))
+ kwargs.pop('net')
+
+ clip = None
+ if has_clip:
+ clip = self.clip.create()
+
+ diffusion_prior_network = self.net.create()
+ return DiffusionPrior(net = diffusion_prior_network, clip = clip, **kwargs)
+
+class DiffusionPriorTrainConfig(BaseModel):
+ epochs: int = 1
+ lr: float = 1.1e-4
+ wd: float = 6.02e-2
+ max_grad_norm: float = 0.5
+ use_ema: bool = True
+ ema_beta: float = 0.99
+ amp: bool = False
+ warmup_steps: Optional[int] = None # number of warmup steps
+ save_every_seconds: int = 3600 # how often to save
+ eval_timesteps: List[int] = [64] # which sampling timesteps to evaluate with
+ best_validation_loss: float = 1e9 # the current best valudation loss observed
+ current_epoch: int = 0 # the current epoch
+ num_samples_seen: int = 0 # the current number of samples seen
+ random_seed: int = 0 # manual seed for torch
+
+class DiffusionPriorDataConfig(BaseModel):
+ image_url: str # path to embeddings folder
+ meta_url: str # path to metadata (captions) for images
+ splits: TrainSplitConfig # define train, validation, test splits for your dataset
+ batch_size: int # per-gpu batch size used to train the model
+ num_data_points: int = 25e7 # total number of datapoints to train on
+ eval_every_seconds: int = 3600 # validation statistics will be performed this often
+
+class TrainDiffusionPriorConfig(BaseModel):
+ prior: DiffusionPriorConfig
+ data: DiffusionPriorDataConfig
+ train: DiffusionPriorTrainConfig
+ tracker: TrackerConfig
+
+ @classmethod
+ def from_json_path(cls, json_path):
+ with open(json_path) as f:
+ config = json.load(f)
+ return cls(**config)
+
+# decoder pydantic classes
+
+class UnetConfig(BaseModel):
+ dim: int
+ dim_mults: ListOrTuple[int]
+ image_embed_dim: Optional[int] = None
+ text_embed_dim: Optional[int] = None
+ cond_on_text_encodings: Optional[bool] = None
+ cond_dim: Optional[int] = None
+ channels: int = 3
+ self_attn: SingularOrIterable[bool] = False
+ attn_dim_head: int = 32
+ attn_heads: int = 16
+ init_cross_embed: bool = True
+
+ class Config:
+ extra = "allow"
+
+class DecoderConfig(BaseModel):
+ unets: ListOrTuple[UnetConfig]
+ image_size: Optional[int] = None
+ image_sizes: ListOrTuple[int] = None
+ clip: Optional[AdapterConfig] = None # The clip model to use if embeddings are not provided
+ channels: int = 3
+ timesteps: int = 1000
+ sample_timesteps: Optional[SingularOrIterable[Optional[int]]] = None
+ loss_type: str = 'l2'
+ beta_schedule: Optional[ListOrTuple[str]] = None # None means all cosine
+ learned_variance: SingularOrIterable[bool] = True
+ image_cond_drop_prob: float = 0.1
+ text_cond_drop_prob: float = 0.5
+
+ def create(self):
+ decoder_kwargs = self.dict()
+
+ unet_configs = decoder_kwargs.pop('unets')
+ unets = [Unet(**config) for config in unet_configs]
+
+ has_clip = exists(decoder_kwargs.pop('clip'))
+ clip = None
+ if has_clip:
+ clip = self.clip.create()
+
+ return Decoder(unets, clip=clip, **decoder_kwargs)
+
+ @validator('image_sizes')
+ def check_image_sizes(cls, image_sizes, values):
+ if exists(values.get('image_size')) ^ exists(image_sizes):
+ return image_sizes
+ raise ValueError('either image_size or image_sizes is required, but not both')
+
+ class Config:
+ extra = "allow"
+
+class DecoderDataConfig(BaseModel):
+ webdataset_base_url: str # path to a webdataset with jpg images
+ img_embeddings_url: Optional[str] = None # path to .npy files with embeddings
+ text_embeddings_url: Optional[str] = None # path to .npy files with embeddings
+ num_workers: int = 4
+ batch_size: int = 64
+ start_shard: int = 0
+ end_shard: int = 9999999
+ shard_width: int = 6
+ index_width: int = 4
+ splits: TrainSplitConfig
+ shuffle_train: bool = True
+ resample_train: bool = False
+ preprocessing: Dict[str, Any] = {'ToTensor': True}
+
+ @property
+ def img_preproc(self):
+ def _get_transformation(transformation_name, **kwargs):
+ if transformation_name == "RandomResizedCrop":
+ return T.RandomResizedCrop(**kwargs)
+ elif transformation_name == "RandomHorizontalFlip":
+ return T.RandomHorizontalFlip()
+ elif transformation_name == "ToTensor":
+ return T.ToTensor()
+
+ transforms = []
+ for transform_name, transform_kwargs_or_bool in self.preprocessing.items():
+ transform_kwargs = {} if not isinstance(transform_kwargs_or_bool, dict) else transform_kwargs_or_bool
+ transforms.append(_get_transformation(transform_name, **transform_kwargs))
+ return T.Compose(transforms)
+
+class DecoderTrainConfig(BaseModel):
+ epochs: int = 20
+ lr: SingularOrIterable[float] = 1e-4
+ wd: SingularOrIterable[float] = 0.01
+ warmup_steps: Optional[SingularOrIterable[int]] = None
+ find_unused_parameters: bool = True
+ static_graph: bool = True
+ max_grad_norm: SingularOrIterable[float] = 0.5
+ save_every_n_samples: int = 100000
+ n_sample_images: int = 6 # The number of example images to produce when sampling the train and test dataset
+ cond_scale: Union[float, List[float]] = 1.0
+ device: str = 'cuda:0'
+ epoch_samples: Optional[int] = None # Limits the number of samples per epoch. None means no limit. Required if resample_train is true as otherwise the number of samples per epoch is infinite.
+ validation_samples: Optional[int] = None # Same as above but for validation.
+ save_immediately: bool = False
+ use_ema: bool = True
+ ema_beta: float = 0.999
+ amp: bool = False
+ unet_training_mask: Optional[ListOrTuple[bool]] = None # If None, use all unets
+
+class DecoderEvaluateConfig(BaseModel):
+ n_evaluation_samples: int = 1000
+ FID: Optional[Dict[str, Any]] = None
+ IS: Optional[Dict[str, Any]] = None
+ KID: Optional[Dict[str, Any]] = None
+ LPIPS: Optional[Dict[str, Any]] = None
+
+class TrainDecoderConfig(BaseModel):
+ decoder: DecoderConfig
+ data: DecoderDataConfig
+ train: DecoderTrainConfig
+ evaluate: DecoderEvaluateConfig
+ tracker: TrackerConfig
+ seed: int = 0
+
+ @classmethod
+ def from_json_path(cls, json_path):
+ with open(json_path) as f:
+ config = json.load(f)
+ print(config)
+ return cls(**config)
+
+ @model_validator(mode = 'after')
+ def check_has_embeddings(self, m):
+ # Makes sure that enough information is provided to get the embeddings specified for training
+ values = dict(self)
+
+ data_config, decoder_config = values.get('data'), values.get('decoder')
+
+ if not exists(data_config) or not exists(decoder_config):
+ # Then something else errored and we should just pass through
+ return values
+
+ using_text_embeddings = any([unet.cond_on_text_encodings for unet in decoder_config.unets])
+ using_clip = exists(decoder_config.clip)
+ img_emb_url = data_config.img_embeddings_url
+ text_emb_url = data_config.text_embeddings_url
+
+ if using_text_embeddings:
+ # Then we need some way to get the embeddings
+ assert using_clip or exists(text_emb_url), 'If text conditioning, either clip or text_embeddings_url must be provided'
+
+ if using_clip:
+ if using_text_embeddings:
+ assert not exists(text_emb_url) or not exists(img_emb_url), 'Loaded clip, but also provided text_embeddings_url and img_embeddings_url. This is redundant. Remove the clip model or the text embeddings'
+ else:
+ assert not exists(img_emb_url), 'Loaded clip, but also provided img_embeddings_url. This is redundant. Remove the clip model or the embeddings'
+
+ if text_emb_url:
+ assert using_text_embeddings, "Text embeddings are being loaded, but text embeddings are not being conditioned on. This will slow down the dataloader for no reason."
+
+ return m
diff --git a/docs/src/dalle2_pytorch/trainer.py b/docs/src/dalle2_pytorch/trainer.py
new file mode 100644
index 00000000..b20b2065
--- /dev/null
+++ b/docs/src/dalle2_pytorch/trainer.py
@@ -0,0 +1,742 @@
+import time
+import copy
+from pathlib import Path
+from math import ceil
+from functools import partial, wraps
+from contextlib import nullcontext
+from collections.abc import Iterable
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
+from torch.cuda.amp import autocast, GradScaler
+
+from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
+from dalle2_pytorch.optimizer import get_optimizer
+from dalle2_pytorch.version import __version__
+from packaging import version
+
+import pytorch_warmup as warmup
+
+from ema_pytorch import EMA
+
+from accelerate import Accelerator, DistributedType
+
+import numpy as np
+
+# helper functions
+
+def exists(val):
+ return val is not None
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if callable(d) else d
+
+def cast_tuple(val, length = 1):
+ return val if isinstance(val, tuple) else ((val,) * length)
+
+def pick_and_pop(keys, d):
+ values = list(map(lambda key: d.pop(key), keys))
+ return dict(zip(keys, values))
+
+def group_dict_by_key(cond, d):
+ return_val = [dict(),dict()]
+ for key in d.keys():
+ match = bool(cond(key))
+ ind = int(not match)
+ return_val[ind][key] = d[key]
+ return (*return_val,)
+
+def string_begins_with(prefix, str):
+ return str.startswith(prefix)
+
+def group_by_key_prefix(prefix, d):
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
+
+def groupby_prefix_and_trim(prefix, d):
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
+ return kwargs_without_prefix, kwargs
+
+def num_to_groups(num, divisor):
+ groups = num // divisor
+ remainder = num % divisor
+ arr = [divisor] * groups
+ if remainder > 0:
+ arr.append(remainder)
+ return arr
+
+# decorators
+
+def cast_torch_tensor(fn):
+ @wraps(fn)
+ def inner(model, *args, **kwargs):
+ device = kwargs.pop('_device', next(model.parameters()).device)
+ cast_device = kwargs.pop('_cast_device', True)
+ cast_deepspeed_precision = kwargs.pop('_cast_deepspeed_precision', True)
+
+ kwargs_keys = kwargs.keys()
+ all_args = (*args, *kwargs.values())
+ split_kwargs_index = len(all_args) - len(kwargs_keys)
+ all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
+
+ if cast_device:
+ all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
+
+ if cast_deepspeed_precision:
+ try:
+ accelerator = model.accelerator
+ if accelerator is not None and accelerator.distributed_type == DistributedType.DEEPSPEED:
+ cast_type_map = {
+ "fp16": torch.half,
+ "bf16": torch.bfloat16,
+ "no": torch.float
+ }
+ precision_type = cast_type_map[accelerator.mixed_precision]
+ all_args = tuple(map(lambda t: t.to(precision_type) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
+ except AttributeError:
+ # Then this model doesn't have an accelerator
+ pass
+
+ args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
+ kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
+
+ out = fn(model, *args, **kwargs)
+ return out
+ return inner
+
+# gradient accumulation functions
+
+def split_iterable(it, split_size):
+ accum = []
+ for ind in range(ceil(len(it) / split_size)):
+ start_index = ind * split_size
+ accum.append(it[start_index: (start_index + split_size)])
+ return accum
+
+def split(t, split_size = None):
+ if not exists(split_size):
+ return t
+
+ if isinstance(t, torch.Tensor):
+ return t.split(split_size, dim = 0)
+
+ if isinstance(t, Iterable):
+ return split_iterable(t, split_size)
+
+ return TypeError
+
+def find_first(cond, arr):
+ for el in arr:
+ if cond(el):
+ return el
+ return None
+
+def split_args_and_kwargs(*args, split_size = None, **kwargs):
+ all_args = (*args, *kwargs.values())
+ len_all_args = len(all_args)
+ first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
+ assert exists(first_tensor)
+
+ batch_size = len(first_tensor)
+ split_size = default(split_size, batch_size)
+ num_chunks = ceil(batch_size / split_size)
+
+ dict_len = len(kwargs)
+ dict_keys = kwargs.keys()
+ split_kwargs_index = len_all_args - dict_len
+
+ split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
+ chunk_sizes = tuple(map(len, split_all_args[0]))
+
+ for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
+ chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
+ chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
+ chunk_size_frac = chunk_size / batch_size
+ yield chunk_size_frac, (chunked_args, chunked_kwargs)
+
+# diffusion prior trainer
+
+def prior_sample_in_chunks(fn):
+ @wraps(fn)
+ def inner(self, *args, max_batch_size = None, **kwargs):
+ if not exists(max_batch_size):
+ return fn(self, *args, **kwargs)
+
+ outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
+ return torch.cat(outputs, dim = 0)
+ return inner
+
+class DiffusionPriorTrainer(nn.Module):
+ def __init__(
+ self,
+ diffusion_prior,
+ accelerator = None,
+ use_ema = True,
+ lr = 3e-4,
+ wd = 1e-2,
+ eps = 1e-6,
+ max_grad_norm = None,
+ group_wd_params = True,
+ warmup_steps = None,
+ cosine_decay_max_steps = None,
+ **kwargs
+ ):
+ super().__init__()
+ assert isinstance(diffusion_prior, DiffusionPrior)
+
+ ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
+ accelerator_kwargs, kwargs = groupby_prefix_and_trim('accelerator_', kwargs)
+
+ if not exists(accelerator):
+ accelerator = Accelerator(**accelerator_kwargs)
+
+ # assign some helpful member vars
+
+ self.accelerator = accelerator
+ self.text_conditioned = diffusion_prior.condition_on_text_encodings
+
+ # setting the device
+
+ self.device = accelerator.device
+ diffusion_prior.to(self.device)
+
+ # save model
+
+ self.diffusion_prior = diffusion_prior
+
+ # mixed precision checks
+
+ if (
+ exists(self.accelerator)
+ and self.accelerator.distributed_type == DistributedType.DEEPSPEED
+ and self.diffusion_prior.clip is not None
+ ):
+ # Then we need to make sure clip is using the correct precision or else deepspeed will error
+ cast_type_map = {
+ "fp16": torch.half,
+ "bf16": torch.bfloat16,
+ "no": torch.float
+ }
+ precision_type = cast_type_map[accelerator.mixed_precision]
+ assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
+ self.diffusion_prior.clip.to(precision_type)
+
+ # optimizer stuff
+
+ self.optim_kwargs = dict(lr=lr, wd=wd, eps=eps, group_wd_params=group_wd_params)
+
+ self.optimizer = get_optimizer(
+ self.diffusion_prior.parameters(),
+ **self.optim_kwargs,
+ **kwargs
+ )
+
+ if exists(cosine_decay_max_steps):
+ self.scheduler = CosineAnnealingLR(self.optimizer, T_max = cosine_decay_max_steps)
+ else:
+ self.scheduler = LambdaLR(self.optimizer, lr_lambda = lambda _: 1.0)
+
+ self.warmup_scheduler = warmup.LinearWarmup(self.optimizer, warmup_period = warmup_steps) if exists(warmup_steps) else None
+
+ # distribute the model if using HFA
+
+ self.diffusion_prior, self.optimizer, self.scheduler = self.accelerator.prepare(self.diffusion_prior, self.optimizer, self.scheduler)
+
+ # exponential moving average stuff
+
+ self.use_ema = use_ema
+
+ if self.use_ema:
+ self.ema_diffusion_prior = EMA(self.accelerator.unwrap_model(self.diffusion_prior), **ema_kwargs)
+
+ # gradient clipping if needed
+
+ self.max_grad_norm = max_grad_norm
+
+ # track steps internally
+
+ self.register_buffer('step', torch.tensor([0], device = self.device))
+
+ # utility
+
+ def save(self, path, overwrite = True, **kwargs):
+
+ # only save on the main process
+ if self.accelerator.is_main_process:
+ print(f"Saving checkpoint at step: {self.step.item()}")
+ path = Path(path)
+ assert not (path.exists() and not overwrite)
+ path.parent.mkdir(parents = True, exist_ok = True)
+
+ # FIXME: LambdaLR can't be saved due to pickling issues
+ save_obj = dict(
+ optimizer = self.optimizer.state_dict(),
+ scheduler = self.scheduler.state_dict(),
+ warmup_scheduler = self.warmup_scheduler,
+ model = self.accelerator.unwrap_model(self.diffusion_prior).state_dict(),
+ version = version.parse(__version__),
+ step = self.step,
+ **kwargs
+ )
+
+ if self.use_ema:
+ save_obj = {
+ **save_obj,
+ 'ema': self.ema_diffusion_prior.state_dict(),
+ 'ema_model': self.ema_diffusion_prior.ema_model.state_dict() # save the ema model specifically for easy ema-only reload
+ }
+
+ torch.save(save_obj, str(path))
+
+ def load(self, path_or_state, overwrite_lr = True, strict = True):
+ """
+ Load a checkpoint of a diffusion prior trainer.
+
+ Will load the entire trainer, including the optimizer and EMA.
+
+ Params:
+ - path_or_state (str | torch): a path to the DiffusionPriorTrainer checkpoint file
+ - overwrite_lr (bool): wether or not to overwrite the stored LR with the LR specified in the new trainer
+ - strict (bool): kwarg for `torch.nn.Module.load_state_dict`, will force an exact checkpoint match
+
+ Returns:
+ loaded_obj (dict): The loaded checkpoint dictionary
+ """
+
+ # all processes need to load checkpoint. no restriction here
+ if isinstance(path_or_state, str):
+ path = Path(path_or_state)
+ assert path.exists()
+ loaded_obj = torch.load(str(path), map_location=self.device)
+
+ elif isinstance(path_or_state, dict):
+ loaded_obj = path_or_state
+
+ if version.parse(__version__) != loaded_obj['version']:
+ print(f'loading saved diffusion prior at version {loaded_obj["version"]} but current package version is at {__version__}')
+
+ # unwrap the model when loading from checkpoint
+ self.accelerator.unwrap_model(self.diffusion_prior).load_state_dict(loaded_obj['model'], strict = strict)
+ self.step.copy_(torch.ones_like(self.step, device=self.device) * loaded_obj['step'].to(self.device))
+
+ self.optimizer.load_state_dict(loaded_obj['optimizer'])
+ self.scheduler.load_state_dict(loaded_obj['scheduler'])
+
+ # set warmupstep
+ if exists(self.warmup_scheduler):
+ self.warmup_scheduler.last_step = self.step.item()
+
+ # ensure new lr is used if different from old one
+ if overwrite_lr:
+ new_lr = self.optim_kwargs["lr"]
+
+ for group in self.optimizer.param_groups:
+ group["lr"] = new_lr if group["lr"] > 0.0 else 0.0
+
+ if self.use_ema:
+ assert 'ema' in loaded_obj
+ self.ema_diffusion_prior.load_state_dict(loaded_obj['ema'], strict = strict)
+ # below might not be necessary, but I had a suspicion that this wasn't being loaded correctly
+ self.ema_diffusion_prior.ema_model.load_state_dict(loaded_obj["ema_model"])
+
+ return loaded_obj
+
+ # model functionality
+
+ def update(self):
+
+ if exists(self.max_grad_norm):
+ self.accelerator.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
+
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+
+ # accelerator will ocassionally skip optimizer steps in a "dynamic loss scaling strategy"
+ if not self.accelerator.optimizer_step_was_skipped:
+ sched_context = self.warmup_scheduler.dampening if exists(self.warmup_scheduler) else nullcontext
+ with sched_context():
+ self.scheduler.step()
+
+ if self.use_ema:
+ self.ema_diffusion_prior.update()
+
+ self.step += 1
+
+ @torch.no_grad()
+ @cast_torch_tensor
+ @prior_sample_in_chunks
+ def p_sample_loop(self, *args, **kwargs):
+ model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
+ return model.p_sample_loop(*args, **kwargs)
+
+ @torch.no_grad()
+ @cast_torch_tensor
+ @prior_sample_in_chunks
+ def sample(self, *args, **kwargs):
+ model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
+ return model.sample(*args, **kwargs)
+
+ @torch.no_grad()
+ def sample_batch_size(self, *args, **kwargs):
+ model = self.ema_diffusion_prior.ema_model if self.use_ema else self.diffusion_prior
+ return model.sample_batch_size(*args, **kwargs)
+
+ @torch.no_grad()
+ @cast_torch_tensor
+ @prior_sample_in_chunks
+ def embed_text(self, *args, **kwargs):
+ return self.accelerator.unwrap_model(self.diffusion_prior).clip.embed_text(*args, **kwargs)
+
+ @cast_torch_tensor
+ def forward(
+ self,
+ *args,
+ max_batch_size = None,
+ **kwargs
+ ):
+ total_loss = 0.
+
+ for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
+ with self.accelerator.autocast():
+ loss = self.diffusion_prior(*chunked_args, **chunked_kwargs)
+ loss = loss * chunk_size_frac
+
+ total_loss += loss.item()
+
+ if self.training:
+ self.accelerator.backward(loss)
+
+ return total_loss
+
+# decoder trainer
+
+def decoder_sample_in_chunks(fn):
+ @wraps(fn)
+ def inner(self, *args, max_batch_size = None, **kwargs):
+ if not exists(max_batch_size):
+ return fn(self, *args, **kwargs)
+
+ if self.decoder.unconditional:
+ batch_size = kwargs.get('batch_size')
+ batch_sizes = num_to_groups(batch_size, max_batch_size)
+ outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
+ else:
+ outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
+
+ return torch.cat(outputs, dim = 0)
+ return inner
+
+class DecoderTrainer(nn.Module):
+ def __init__(
+ self,
+ decoder,
+ accelerator = None,
+ dataloaders = None,
+ use_ema = True,
+ lr = 1e-4,
+ wd = 1e-2,
+ eps = 1e-8,
+ warmup_steps = None,
+ cosine_decay_max_steps = None,
+ max_grad_norm = 0.5,
+ amp = False,
+ group_wd_params = True,
+ **kwargs
+ ):
+ super().__init__()
+ assert isinstance(decoder, Decoder)
+ ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
+
+ self.accelerator = default(accelerator, Accelerator)
+
+ self.num_unets = len(decoder.unets)
+
+ self.use_ema = use_ema
+ self.ema_unets = nn.ModuleList([])
+
+ self.amp = amp
+
+ # be able to finely customize learning rate, weight decay
+ # per unet
+
+ lr, wd, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, wd, eps, warmup_steps, cosine_decay_max_steps))
+
+ assert all([unet_lr <= 1e-2 for unet_lr in lr]), 'your learning rate is too high, recommend sticking with 1e-4, at most 5e-4'
+
+ optimizers = []
+ schedulers = []
+ warmup_schedulers = []
+
+ for unet, unet_lr, unet_wd, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps in zip(decoder.unets, lr, wd, eps, warmup_steps, cosine_decay_max_steps):
+ if isinstance(unet, nn.Identity):
+ optimizers.append(None)
+ schedulers.append(None)
+ warmup_schedulers.append(None)
+ else:
+ optimizer = get_optimizer(
+ unet.parameters(),
+ lr = unet_lr,
+ wd = unet_wd,
+ eps = unet_eps,
+ group_wd_params = group_wd_params,
+ **kwargs
+ )
+
+ optimizers.append(optimizer)
+
+ if exists(unet_cosine_decay_max_steps):
+ scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
+ else:
+ scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
+
+ warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) if exists(unet_warmup_steps) else None
+ warmup_schedulers.append(warmup_scheduler)
+
+ schedulers.append(scheduler)
+
+ if self.use_ema:
+ self.ema_unets.append(EMA(unet, **ema_kwargs))
+
+ # gradient clipping if needed
+
+ self.max_grad_norm = max_grad_norm
+
+ self.register_buffer('steps', torch.tensor([0] * self.num_unets))
+
+ if self.accelerator.distributed_type == DistributedType.DEEPSPEED and decoder.clip is not None:
+ # Then we need to make sure clip is using the correct precision or else deepspeed will error
+ cast_type_map = {
+ "fp16": torch.half,
+ "bf16": torch.bfloat16,
+ "no": torch.float
+ }
+ precision_type = cast_type_map[accelerator.mixed_precision]
+ assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip"
+ clip = decoder.clip
+ clip.to(precision_type)
+
+ decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers))
+
+ self.decoder = decoder
+
+ # prepare dataloaders
+
+ train_loader = val_loader = None
+ if exists(dataloaders):
+ train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"])
+
+ self.train_loader = train_loader
+ self.val_loader = val_loader
+
+ # store optimizers
+
+ for opt_ind, optimizer in zip(range(len(optimizers)), optimizers):
+ setattr(self, f'optim{opt_ind}', optimizer)
+
+ # store schedulers
+
+ for sched_ind, scheduler in zip(range(len(schedulers)), schedulers):
+ setattr(self, f'sched{sched_ind}', scheduler)
+
+ # store warmup schedulers
+
+ self.warmup_schedulers = warmup_schedulers
+
+ def validate_and_return_unet_number(self, unet_number = None):
+ if self.num_unets == 1:
+ unet_number = default(unet_number, 1)
+
+ assert exists(unet_number) and 1 <= unet_number <= self.num_unets
+ return unet_number
+
+ def num_steps_taken(self, unet_number = None):
+ unet_number = self.validate_and_return_unet_number(unet_number)
+ return self.steps[unet_number - 1].item()
+
+ def save(self, path, overwrite = True, **kwargs):
+ path = Path(path)
+ assert not (path.exists() and not overwrite)
+ path.parent.mkdir(parents = True, exist_ok = True)
+
+ save_obj = dict(
+ model = self.accelerator.unwrap_model(self.decoder).state_dict(),
+ version = __version__,
+ steps = self.steps.cpu(),
+ **kwargs
+ )
+
+ for ind in range(0, self.num_unets):
+ optimizer_key = f'optim{ind}'
+ scheduler_key = f'sched{ind}'
+
+ optimizer = getattr(self, optimizer_key)
+ scheduler = getattr(self, scheduler_key)
+
+ optimizer_state_dict = optimizer.state_dict() if exists(optimizer) else None
+ scheduler_state_dict = scheduler.state_dict() if exists(scheduler) else None
+
+ save_obj = {**save_obj, optimizer_key: optimizer_state_dict, scheduler_key: scheduler_state_dict}
+
+ if self.use_ema:
+ save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
+
+ self.accelerator.save(save_obj, str(path))
+
+ def load_state_dict(self, loaded_obj, only_model = False, strict = True):
+ if version.parse(__version__) != version.parse(loaded_obj['version']):
+ self.accelerator.print(f'loading saved decoder at version {loaded_obj["version"]}, but current package version is {__version__}')
+
+ self.accelerator.unwrap_model(self.decoder).load_state_dict(loaded_obj['model'], strict = strict)
+ self.steps.copy_(loaded_obj['steps'])
+
+ if only_model:
+ return loaded_obj
+
+ for ind, last_step in zip(range(0, self.num_unets), self.steps.tolist()):
+
+ optimizer_key = f'optim{ind}'
+ optimizer = getattr(self, optimizer_key)
+
+ scheduler_key = f'sched{ind}'
+ scheduler = getattr(self, scheduler_key)
+
+ warmup_scheduler = self.warmup_schedulers[ind]
+
+ if exists(optimizer):
+ optimizer.load_state_dict(loaded_obj[optimizer_key])
+
+ if exists(scheduler):
+ scheduler.load_state_dict(loaded_obj[scheduler_key])
+
+ if exists(warmup_scheduler):
+ warmup_scheduler.last_step = last_step
+
+ if self.use_ema:
+ assert 'ema' in loaded_obj
+ self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
+
+ def load(self, path, only_model = False, strict = True):
+ path = Path(path)
+ assert path.exists()
+
+ loaded_obj = torch.load(str(path), map_location = 'cpu')
+
+ self.load_state_dict(loaded_obj, only_model = only_model, strict = strict)
+
+ return loaded_obj
+
+ @property
+ def unets(self):
+ return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
+
+ def increment_step(self, unet_number):
+ assert 1 <= unet_number <= self.num_unets
+
+ unet_index_tensor = torch.tensor(unet_number - 1, device = self.steps.device)
+ self.steps += F.one_hot(unet_index_tensor, num_classes = len(self.steps))
+
+ def update(self, unet_number = None):
+ unet_number = self.validate_and_return_unet_number(unet_number)
+ index = unet_number - 1
+
+ optimizer = getattr(self, f'optim{index}')
+ scheduler = getattr(self, f'sched{index}')
+
+ if exists(self.max_grad_norm):
+ self.accelerator.clip_grad_norm_(self.decoder.parameters(), self.max_grad_norm) # Automatically unscales gradients
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ warmup_scheduler = self.warmup_schedulers[index]
+ scheduler_context = warmup_scheduler.dampening if exists(warmup_scheduler) else nullcontext
+
+ with scheduler_context():
+ scheduler.step()
+
+ if self.use_ema:
+ ema_unet = self.ema_unets[index]
+ ema_unet.update()
+
+ self.increment_step(unet_number)
+
+ @torch.no_grad()
+ @cast_torch_tensor
+ @decoder_sample_in_chunks
+ def sample(self, *args, **kwargs):
+ distributed = self.accelerator.num_processes > 1
+ base_decoder = self.accelerator.unwrap_model(self.decoder)
+
+ was_training = base_decoder.training
+ base_decoder.eval()
+
+ if kwargs.pop('use_non_ema', False) or not self.use_ema:
+ out = base_decoder.sample(*args, **kwargs, distributed = distributed)
+ base_decoder.train(was_training)
+ return out
+
+ trainable_unets = self.accelerator.unwrap_model(self.decoder).unets
+ base_decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
+
+ output = base_decoder.sample(*args, **kwargs, distributed = distributed)
+
+ base_decoder.unets = trainable_unets # restore original training unets
+
+ # cast the ema_model unets back to original device
+ for ema in self.ema_unets:
+ ema.restore_ema_model_device()
+
+ base_decoder.train(was_training)
+ return output
+
+ @torch.no_grad()
+ @cast_torch_tensor
+ @prior_sample_in_chunks
+ def embed_text(self, *args, **kwargs):
+ return self.accelerator.unwrap_model(self.decoder).clip.embed_text(*args, **kwargs)
+
+ @torch.no_grad()
+ @cast_torch_tensor
+ @prior_sample_in_chunks
+ def embed_image(self, *args, **kwargs):
+ return self.accelerator.unwrap_model(self.decoder).clip.embed_image(*args, **kwargs)
+
+ @cast_torch_tensor
+ def forward(
+ self,
+ *args,
+ unet_number = None,
+ max_batch_size = None,
+ return_lowres_cond_image=False,
+ **kwargs
+ ):
+ unet_number = self.validate_and_return_unet_number(unet_number)
+
+ total_loss = 0.
+ cond_images = []
+ for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
+ with self.accelerator.autocast():
+ loss_obj = self.decoder(*chunked_args, unet_number = unet_number, return_lowres_cond_image=return_lowres_cond_image, **chunked_kwargs)
+ # loss_obj may be a tuple with loss and cond_image
+ if return_lowres_cond_image:
+ loss, cond_image = loss_obj
+ else:
+ loss = loss_obj
+ cond_image = None
+ loss = loss * chunk_size_frac
+ if cond_image is not None:
+ cond_images.append(cond_image)
+
+ total_loss += loss.item()
+
+ if self.training:
+ self.accelerator.backward(loss)
+
+ if return_lowres_cond_image:
+ return total_loss, torch.stack(cond_images)
+ else:
+ return total_loss
diff --git a/docs/src/dalle2_pytorch/utils.py b/docs/src/dalle2_pytorch/utils.py
new file mode 100644
index 00000000..447b88f8
--- /dev/null
+++ b/docs/src/dalle2_pytorch/utils.py
@@ -0,0 +1,35 @@
+import time
+import importlib
+
+# helper functions
+
+def exists(val):
+ return val is not None
+
+# time helpers
+
+class Timer:
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.last_time = time.time()
+
+ def elapsed(self):
+ return time.time() - self.last_time
+
+# print helpers
+
+def print_ribbon(s, symbol = '=', repeat = 40):
+ flank = symbol * repeat
+ return f'{flank} {s} {flank}'
+
+# import helpers
+
+def import_or_print_error(pkg_name, err_str = None):
+ try:
+ return importlib.import_module(pkg_name)
+ except ModuleNotFoundError as e:
+ if exists(err_str):
+ print(err_str)
+ exit()
diff --git a/docs/src/dalle2_pytorch/version.py b/docs/src/dalle2_pytorch/version.py
new file mode 100644
index 00000000..d061b62a
--- /dev/null
+++ b/docs/src/dalle2_pytorch/version.py
@@ -0,0 +1 @@
+__version__ = '1.15.6'
diff --git a/docs/src/dalle2_pytorch/vqgan_vae.py b/docs/src/dalle2_pytorch/vqgan_vae.py
new file mode 100644
index 00000000..01d72586
--- /dev/null
+++ b/docs/src/dalle2_pytorch/vqgan_vae.py
@@ -0,0 +1,764 @@
+import copy
+import math
+from math import sqrt
+from functools import partial, wraps
+
+from vector_quantize_pytorch import VectorQuantize as VQ
+
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+from torch.autograd import grad as torch_grad
+import torchvision
+
+from einops import rearrange, reduce, repeat, pack, unpack
+from einops.layers.torch import Rearrange
+
+# constants
+
+MList = nn.ModuleList
+
+# helper functions
+
+def exists(val):
+ return val is not None
+
+def default(val, d):
+ return val if exists(val) else d
+
+# decorators
+
+def eval_decorator(fn):
+ def inner(model, *args, **kwargs):
+ was_training = model.training
+ model.eval()
+ out = fn(model, *args, **kwargs)
+ model.train(was_training)
+ return out
+ return inner
+
+def remove_vgg(fn):
+ @wraps(fn)
+ def inner(self, *args, **kwargs):
+ has_vgg = hasattr(self, 'vgg')
+ if has_vgg:
+ vgg = self.vgg
+ delattr(self, 'vgg')
+
+ out = fn(self, *args, **kwargs)
+
+ if has_vgg:
+ self.vgg = vgg
+
+ return out
+ return inner
+
+# keyword argument helpers
+
+def pick_and_pop(keys, d):
+ values = list(map(lambda key: d.pop(key), keys))
+ return dict(zip(keys, values))
+
+def group_dict_by_key(cond, d):
+ return_val = [dict(),dict()]
+ for key in d.keys():
+ match = bool(cond(key))
+ ind = int(not match)
+ return_val[ind][key] = d[key]
+ return (*return_val,)
+
+def string_begins_with(prefix, string_input):
+ return string_input.startswith(prefix)
+
+def group_by_key_prefix(prefix, d):
+ return group_dict_by_key(partial(string_begins_with, prefix), d)
+
+def groupby_prefix_and_trim(prefix, d):
+ kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
+ kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
+ return kwargs_without_prefix, kwargs
+
+# tensor helper functions
+
+def log(t, eps = 1e-10):
+ return torch.log(t + eps)
+
+def gradient_penalty(images, output, weight = 10):
+ batch_size = images.shape[0]
+ gradients = torch_grad(outputs = output, inputs = images,
+ grad_outputs = torch.ones(output.size(), device = images.device),
+ create_graph = True, retain_graph = True, only_inputs = True)[0]
+
+ gradients = rearrange(gradients, 'b ... -> b (...)')
+ return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
+
+def l2norm(t):
+ return F.normalize(t, dim = -1)
+
+def leaky_relu(p = 0.1):
+ return nn.LeakyReLU(0.1)
+
+def stable_softmax(t, dim = -1, alpha = 32 ** 2):
+ t = t / alpha
+ t = t - torch.amax(t, dim = dim, keepdim = True).detach()
+ return (t * alpha).softmax(dim = dim)
+
+def safe_div(numer, denom, eps = 1e-8):
+ return numer / (denom + eps)
+
+# gan losses
+
+def hinge_discr_loss(fake, real):
+ return (F.relu(1 + fake) + F.relu(1 - real)).mean()
+
+def hinge_gen_loss(fake):
+ return -fake.mean()
+
+def bce_discr_loss(fake, real):
+ return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
+
+def bce_gen_loss(fake):
+ return -log(torch.sigmoid(fake)).mean()
+
+def grad_layer_wrt_loss(loss, layer):
+ return torch_grad(
+ outputs = loss,
+ inputs = layer,
+ grad_outputs = torch.ones_like(loss),
+ retain_graph = True
+ )[0].detach()
+
+# vqgan vae
+
+class LayerNormChan(nn.Module):
+ def __init__(
+ self,
+ dim,
+ eps = 1e-5
+ ):
+ super().__init__()
+ self.eps = eps
+ self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
+
+ def forward(self, x):
+ var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
+ mean = torch.mean(x, dim = 1, keepdim = True)
+ return (x - mean) / (var + self.eps).sqrt() * self.gamma
+
+# discriminator
+
+class Discriminator(nn.Module):
+ def __init__(
+ self,
+ dims,
+ channels = 3,
+ groups = 16,
+ init_kernel_size = 5
+ ):
+ super().__init__()
+ dim_pairs = zip(dims[:-1], dims[1:])
+
+ self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])
+
+ for dim_in, dim_out in dim_pairs:
+ self.layers.append(nn.Sequential(
+ nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
+ nn.GroupNorm(groups, dim_out),
+ leaky_relu()
+ ))
+
+ dim = dims[-1]
+ self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
+ nn.Conv2d(dim, dim, 1),
+ leaky_relu(),
+ nn.Conv2d(dim, 1, 4)
+ )
+
+ def forward(self, x):
+ for net in self.layers:
+ x = net(x)
+
+ return self.to_logits(x)
+
+# positional encoding
+
+class ContinuousPositionBias(nn.Module):
+ """ from https://arxiv.org/abs/2111.09883 """
+
+ def __init__(self, *, dim, heads, layers = 2):
+ super().__init__()
+ self.net = MList([])
+ self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))
+
+ for _ in range(layers - 1):
+ self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
+
+ self.net.append(nn.Linear(dim, heads))
+ self.register_buffer('rel_pos', None, persistent = False)
+
+ def forward(self, x):
+ n, device = x.shape[-1], x.device
+ fmap_size = int(sqrt(n))
+
+ if not exists(self.rel_pos):
+ pos = torch.arange(fmap_size, device = device)
+ grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
+ grid = rearrange(grid, 'c i j -> (i j) c')
+ rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
+ rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
+ self.register_buffer('rel_pos', rel_pos, persistent = False)
+
+ rel_pos = self.rel_pos.float()
+
+ for layer in self.net:
+ rel_pos = layer(rel_pos)
+
+ bias = rearrange(rel_pos, 'i j h -> h i j')
+ return x + bias
+
+# resnet encoder / decoder
+
+class ResnetEncDec(nn.Module):
+ def __init__(
+ self,
+ dim,
+ *,
+ channels = 3,
+ layers = 4,
+ layer_mults = None,
+ num_resnet_blocks = 1,
+ resnet_groups = 16,
+ first_conv_kernel_size = 5,
+ use_attn = True,
+ attn_dim_head = 64,
+ attn_heads = 8,
+ attn_dropout = 0.,
+ ):
+ super().__init__()
+ assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
+
+ self.layers = layers
+
+ self.encoders = MList([])
+ self.decoders = MList([])
+
+ layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
+ assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
+
+ layer_dims = [dim * mult for mult in layer_mults]
+ dims = (dim, *layer_dims)
+
+ self.encoded_dim = dims[-1]
+
+ dim_pairs = zip(dims[:-1], dims[1:])
+
+ append = lambda arr, t: arr.append(t)
+ prepend = lambda arr, t: arr.insert(0, t)
+
+ if not isinstance(num_resnet_blocks, tuple):
+ num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
+
+ if not isinstance(use_attn, tuple):
+ use_attn = (*((False,) * (layers - 1)), use_attn)
+
+ assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
+ assert len(use_attn) == layers
+
+ for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
+ append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
+ prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
+
+ if layer_use_attn:
+ prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
+
+ for _ in range(layer_num_resnet_blocks):
+ append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
+ prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
+
+ if layer_use_attn:
+ append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
+
+ prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
+ append(self.decoders, nn.Conv2d(dim, channels, 1))
+
+ def get_encoded_fmap_size(self, image_size):
+ return image_size // (2 ** self.layers)
+
+ @property
+ def last_dec_layer(self):
+ return self.decoders[-1].weight
+
+ def encode(self, x):
+ for enc in self.encoders:
+ x = enc(x)
+ return x
+
+ def decode(self, x):
+ for dec in self.decoders:
+ x = dec(x)
+ return x
+
+class GLUResBlock(nn.Module):
+ def __init__(self, chan, groups = 16):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Conv2d(chan, chan * 2, 3, padding = 1),
+ nn.GLU(dim = 1),
+ nn.GroupNorm(groups, chan),
+ nn.Conv2d(chan, chan * 2, 3, padding = 1),
+ nn.GLU(dim = 1),
+ nn.GroupNorm(groups, chan),
+ nn.Conv2d(chan, chan, 1)
+ )
+
+ def forward(self, x):
+ return self.net(x) + x
+
+class ResBlock(nn.Module):
+ def __init__(self, chan, groups = 16):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Conv2d(chan, chan, 3, padding = 1),
+ nn.GroupNorm(groups, chan),
+ leaky_relu(),
+ nn.Conv2d(chan, chan, 3, padding = 1),
+ nn.GroupNorm(groups, chan),
+ leaky_relu(),
+ nn.Conv2d(chan, chan, 1)
+ )
+
+ def forward(self, x):
+ return self.net(x) + x
+
+# vqgan attention layer
+
+class VQGanAttention(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ dim_head = 64,
+ heads = 8,
+ dropout = 0.
+ ):
+ super().__init__()
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+ inner_dim = heads * dim_head
+
+ self.dropout = nn.Dropout(dropout)
+ self.pre_norm = LayerNormChan(dim)
+
+ self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
+ self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
+ self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
+
+ def forward(self, x):
+ h = self.heads
+ height, width, residual = *x.shape[-2:], x.clone()
+
+ x = self.pre_norm(x)
+
+ q, k, v = self.to_qkv(x).chunk(3, dim = 1)
+
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))
+
+ sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale
+
+ sim = self.cpb(sim)
+
+ attn = stable_softmax(sim, dim = -1)
+ attn = self.dropout(attn)
+
+ out = einsum('b h i j, b h c j -> b h c i', attn, v)
+ out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
+ out = self.to_out(out)
+
+ return out + residual
+
+# ViT encoder / decoder
+
+class RearrangeImage(nn.Module):
+ def forward(self, x):
+ n = x.shape[1]
+ w = h = int(sqrt(n))
+ return rearrange(x, 'b (h w) ... -> b h w ...', h = h, w = w)
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ *,
+ heads = 8,
+ dim_head = 32
+ ):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.heads = heads
+ self.scale = dim_head ** -0.5
+ inner_dim = dim_head * heads
+
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
+ self.to_out = nn.Linear(inner_dim, dim)
+
+ def forward(self, x):
+ h = self.heads
+
+ x = self.norm(x)
+
+ q, k, v = self.to_qkv(x).chunk(3, dim = -1)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
+
+ q = q * self.scale
+ sim = einsum('b h i d, b h j d -> b h i j', q, k)
+
+ sim = sim - sim.amax(dim = -1, keepdim = True).detach()
+ attn = sim.softmax(dim = -1)
+
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
+
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+def FeedForward(dim, mult = 4):
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, dim * mult, bias = False),
+ nn.GELU(),
+ nn.Linear(dim * mult, dim, bias = False)
+ )
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ dim,
+ *,
+ layers,
+ dim_head = 32,
+ heads = 8,
+ ff_mult = 4
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(layers):
+ self.layers.append(nn.ModuleList([
+ Attention(dim = dim, dim_head = dim_head, heads = heads),
+ FeedForward(dim = dim, mult = ff_mult)
+ ]))
+
+ self.norm = nn.LayerNorm(dim)
+
+ def forward(self, x):
+ for attn, ff in self.layers:
+ x = attn(x) + x
+ x = ff(x) + x
+
+ return self.norm(x)
+
+class ViTEncDec(nn.Module):
+ def __init__(
+ self,
+ dim,
+ channels = 3,
+ layers = 4,
+ patch_size = 8,
+ dim_head = 32,
+ heads = 8,
+ ff_mult = 4
+ ):
+ super().__init__()
+ self.encoded_dim = dim
+ self.patch_size = patch_size
+
+ input_dim = channels * (patch_size ** 2)
+
+ self.encoder = nn.Sequential(
+ Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
+ nn.Linear(input_dim, dim),
+ Transformer(
+ dim = dim,
+ dim_head = dim_head,
+ heads = heads,
+ ff_mult = ff_mult,
+ layers = layers
+ ),
+ RearrangeImage(),
+ Rearrange('b h w c -> b c h w')
+ )
+
+ self.decoder = nn.Sequential(
+ Rearrange('b c h w -> b (h w) c'),
+ Transformer(
+ dim = dim,
+ dim_head = dim_head,
+ heads = heads,
+ ff_mult = ff_mult,
+ layers = layers
+ ),
+ nn.Sequential(
+ nn.Linear(dim, dim * 4, bias = False),
+ nn.Tanh(),
+ nn.Linear(dim * 4, input_dim, bias = False),
+ ),
+ RearrangeImage(),
+ Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)
+ )
+
+ def get_encoded_fmap_size(self, image_size):
+ return image_size // self.patch_size
+
+ @property
+ def last_dec_layer(self):
+ return self.decoder[-3][-1].weight
+
+ def encode(self, x):
+ return self.encoder(x)
+
+ def decode(self, x):
+ return self.decoder(x)
+
+# main vqgan-vae classes
+
+class NullVQGanVAE(nn.Module):
+ def __init__(
+ self,
+ *,
+ channels
+ ):
+ super().__init__()
+ self.encoded_dim = channels
+ self.layers = 0
+
+ def get_encoded_fmap_size(self, size):
+ return size
+
+ def copy_for_eval(self):
+ return self
+
+ def encode(self, x):
+ return x
+
+ def decode(self, x):
+ return x
+
+class VQGanVAE(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ image_size,
+ channels = 3,
+ layers = 4,
+ l2_recon_loss = False,
+ use_hinge_loss = True,
+ vgg = None,
+ vq_codebook_dim = 256,
+ vq_codebook_size = 512,
+ vq_decay = 0.8,
+ vq_commitment_weight = 1.,
+ vq_kmeans_init = True,
+ vq_use_cosine_sim = True,
+ use_vgg_and_gan = True,
+ vae_type = 'resnet',
+ discr_layers = 4,
+ **kwargs
+ ):
+ super().__init__()
+ vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
+ encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)
+
+ self.image_size = image_size
+ self.channels = channels
+ self.codebook_size = vq_codebook_size
+
+ if vae_type == 'resnet':
+ enc_dec_klass = ResnetEncDec
+ elif vae_type == 'vit':
+ enc_dec_klass = ViTEncDec
+ else:
+ raise ValueError(f'{vae_type} not valid')
+
+ self.enc_dec = enc_dec_klass(
+ dim = dim,
+ channels = channels,
+ layers = layers,
+ **encdec_kwargs
+ )
+
+ self.vq = VQ(
+ dim = self.enc_dec.encoded_dim,
+ codebook_dim = vq_codebook_dim,
+ codebook_size = vq_codebook_size,
+ decay = vq_decay,
+ commitment_weight = vq_commitment_weight,
+ accept_image_fmap = True,
+ kmeans_init = vq_kmeans_init,
+ use_cosine_sim = vq_use_cosine_sim,
+ **vq_kwargs
+ )
+
+ # reconstruction loss
+
+ self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
+
+ # turn off GAN and perceptual loss if grayscale
+
+ self.vgg = None
+ self.discr = None
+ self.use_vgg_and_gan = use_vgg_and_gan
+
+ if not use_vgg_and_gan:
+ return
+
+ # preceptual loss
+
+ if exists(vgg):
+ self.vgg = vgg
+ else:
+ self.vgg = torchvision.models.vgg16(pretrained = True)
+ self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
+
+ # gan related losses
+
+ layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
+ layer_dims = [dim * mult for mult in layer_mults]
+ dims = (dim, *layer_dims)
+
+ self.discr = Discriminator(dims = dims, channels = channels)
+
+ self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
+ self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
+
+ @property
+ def encoded_dim(self):
+ return self.enc_dec.encoded_dim
+
+ def get_encoded_fmap_size(self, image_size):
+ return self.enc_dec.get_encoded_fmap_size(image_size)
+
+ def copy_for_eval(self):
+ device = next(self.parameters()).device
+ vae_copy = copy.deepcopy(self.cpu())
+
+ if vae_copy.use_vgg_and_gan:
+ del vae_copy.discr
+ del vae_copy.vgg
+
+ vae_copy.eval()
+ return vae_copy.to(device)
+
+ @remove_vgg
+ def state_dict(self, *args, **kwargs):
+ return super().state_dict(*args, **kwargs)
+
+ @remove_vgg
+ def load_state_dict(self, *args, **kwargs):
+ return super().load_state_dict(*args, **kwargs)
+
+ @property
+ def codebook(self):
+ return self.vq.codebook
+
+ def encode(self, fmap):
+ fmap = self.enc_dec.encode(fmap)
+ return fmap
+
+ def decode(self, fmap, return_indices_and_loss = False):
+ fmap, indices, commit_loss = self.vq(fmap)
+
+ fmap = self.enc_dec.decode(fmap)
+
+ if not return_indices_and_loss:
+ return fmap
+
+ return fmap, indices, commit_loss
+
+ def forward(
+ self,
+ img,
+ return_loss = False,
+ return_discr_loss = False,
+ return_recons = False,
+ add_gradient_penalty = True
+ ):
+ batch, channels, height, width, device = *img.shape, img.device
+ assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
+ assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
+
+ fmap = self.encode(img)
+
+ fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
+
+ if not return_loss and not return_discr_loss:
+ return fmap
+
+ assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
+
+ # whether to return discriminator loss
+
+ if return_discr_loss:
+ assert exists(self.discr), 'discriminator must exist to train it'
+
+ fmap.detach_()
+ img.requires_grad_()
+
+ fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
+
+ discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
+
+ if add_gradient_penalty:
+ gp = gradient_penalty(img, img_discr_logits)
+ loss = discr_loss + gp
+
+ if return_recons:
+ return loss, fmap
+
+ return loss
+
+ # reconstruction loss
+
+ recon_loss = self.recon_loss_fn(fmap, img)
+
+ # early return if training on grayscale
+
+ if not self.use_vgg_and_gan:
+ if return_recons:
+ return recon_loss, fmap
+
+ return recon_loss
+
+ # perceptual loss
+
+ img_vgg_input = img
+ fmap_vgg_input = fmap
+
+ if img.shape[1] == 1:
+ # handle grayscale for vgg
+ img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))
+
+ img_vgg_feats = self.vgg(img_vgg_input)
+ recon_vgg_feats = self.vgg(fmap_vgg_input)
+ perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
+
+ # generator loss
+
+ gen_loss = self.gen_loss(self.discr(fmap))
+
+ # calculate adaptive weight
+
+ last_dec_layer = self.enc_dec.last_dec_layer
+
+ norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
+ norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
+
+ adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
+ adaptive_weight.clamp_(max = 1e4)
+
+ # combine losses
+
+ loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
+
+ if return_recons:
+ return loss, fmap
+
+ return loss
diff --git a/docs/src/dalle2_pytorch/vqgan_vae_trainer.py b/docs/src/dalle2_pytorch/vqgan_vae_trainer.py
new file mode 100644
index 00000000..40479806
--- /dev/null
+++ b/docs/src/dalle2_pytorch/vqgan_vae_trainer.py
@@ -0,0 +1,278 @@
+from math import sqrt
+import copy
+from random import choice
+from pathlib import Path
+from shutil import rmtree
+from PIL import Image
+
+import torch
+from torch import nn
+from torch.cuda.amp import autocast, GradScaler
+from torch.utils.data import Dataset, DataLoader, random_split
+
+import torchvision.transforms as T
+from torchvision.datasets import ImageFolder
+from torchvision.utils import make_grid, save_image
+
+from einops import rearrange
+
+from dalle2_pytorch.vqgan_vae import VQGanVAE
+from dalle2_pytorch.optimizer import get_optimizer
+
+from ema_pytorch import EMA
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+def noop(*args, **kwargs):
+ pass
+
+def cycle(dl):
+ while True:
+ for data in dl:
+ yield data
+
+def cast_tuple(t):
+ return t if isinstance(t, (tuple, list)) else (t,)
+
+def yes_or_no(question):
+ answer = input(f'{question} (y/n) ')
+ return answer.lower() in ('yes', 'y')
+
+def accum_log(log, new_logs):
+ for key, new_value in new_logs.items():
+ old_value = log.get(key, 0.)
+ log[key] = old_value + new_value
+ return log
+
+# classes
+
+class ImageDataset(Dataset):
+ def __init__(
+ self,
+ folder,
+ image_size,
+ exts = ['jpg', 'jpeg', 'png']
+ ):
+ super().__init__()
+ self.folder = folder
+ self.image_size = image_size
+ self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
+
+ print(f'{len(self.paths)} training samples found at {folder}')
+
+ self.transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.Resize(image_size),
+ T.RandomHorizontalFlip(),
+ T.CenterCrop(image_size),
+ T.ToTensor()
+ ])
+
+ def __len__(self):
+ return len(self.paths)
+
+ def __getitem__(self, index):
+ path = self.paths[index]
+ img = Image.open(path)
+ return self.transform(img)
+
+# main trainer class
+
+class VQGanVAETrainer(nn.Module):
+ def __init__(
+ self,
+ vae,
+ *,
+ num_train_steps,
+ lr,
+ batch_size,
+ folder,
+ grad_accum_every,
+ wd = 0.,
+ save_results_every = 100,
+ save_model_every = 1000,
+ results_folder = './results',
+ valid_frac = 0.05,
+ random_split_seed = 42,
+ ema_beta = 0.995,
+ ema_update_after_step = 500,
+ ema_update_every = 10,
+ apply_grad_penalty_every = 4,
+ amp = False
+ ):
+ super().__init__()
+ assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
+ image_size = vae.image_size
+
+ self.vae = vae
+ self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
+
+ self.register_buffer('steps', torch.Tensor([0]))
+
+ self.num_train_steps = num_train_steps
+ self.batch_size = batch_size
+ self.grad_accum_every = grad_accum_every
+
+ all_parameters = set(vae.parameters())
+ discr_parameters = set(vae.discr.parameters())
+ vae_parameters = all_parameters - discr_parameters
+
+ self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
+ self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
+
+ self.amp = amp
+ self.scaler = GradScaler(enabled = amp)
+ self.discr_scaler = GradScaler(enabled = amp)
+
+ # create dataset
+
+ self.ds = ImageDataset(folder, image_size = image_size)
+
+ # split for validation
+
+ if valid_frac > 0:
+ train_size = int((1 - valid_frac) * len(self.ds))
+ valid_size = len(self.ds) - train_size
+ self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
+ print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
+ else:
+ self.valid_ds = self.ds
+ print(f'training with shared training and valid dataset of {len(self.ds)} samples')
+
+ # dataloader
+
+ self.dl = cycle(DataLoader(
+ self.ds,
+ batch_size = batch_size,
+ shuffle = True
+ ))
+
+ self.valid_dl = cycle(DataLoader(
+ self.valid_ds,
+ batch_size = batch_size,
+ shuffle = True
+ ))
+
+ self.save_model_every = save_model_every
+ self.save_results_every = save_results_every
+
+ self.apply_grad_penalty_every = apply_grad_penalty_every
+
+ self.results_folder = Path(results_folder)
+
+ if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
+ rmtree(str(self.results_folder))
+
+ self.results_folder.mkdir(parents = True, exist_ok = True)
+
+ def train_step(self):
+ device = next(self.vae.parameters()).device
+ steps = int(self.steps.item())
+ apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
+
+ self.vae.train()
+
+ # logs
+
+ logs = {}
+
+ # update vae (generator)
+
+ for _ in range(self.grad_accum_every):
+ img = next(self.dl)
+ img = img.to(device)
+
+ with autocast(enabled = self.amp):
+ loss = self.vae(
+ img,
+ return_loss = True,
+ apply_grad_penalty = apply_grad_penalty
+ )
+
+
+ self.scaler.scale(loss / self.grad_accum_every).backward()
+
+ accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
+
+ self.scaler.step(self.optim)
+ self.scaler.update()
+ self.optim.zero_grad()
+
+ # update discriminator
+
+ if exists(self.vae.discr):
+ discr_loss = 0
+ for _ in range(self.grad_accum_every):
+ img = next(self.dl)
+ img = img.to(device)
+
+ with autocast(enabled = self.amp):
+ loss = self.vae(img, return_discr_loss = True)
+
+ self.discr_scaler.scale(loss / self.grad_accum_every).backward()
+
+ accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
+
+ self.discr_scaler.step(self.discr_optim)
+ self.discr_scaler.update()
+ self.discr_optim.zero_grad()
+
+ # log
+
+ print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
+
+ # update exponential moving averaged generator
+
+ self.ema_vae.update()
+
+ # sample results every so often
+
+ if not (steps % self.save_results_every):
+ for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
+ model.eval()
+
+ imgs = next(self.dl)
+ imgs = imgs.to(device)
+
+ recons = model(imgs)
+ nrows = int(sqrt(self.batch_size))
+
+ imgs_and_recons = torch.stack((imgs, recons), dim = 0)
+ imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
+
+ imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
+ grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
+
+ logs['reconstructions'] = grid
+
+ save_image(grid, str(self.results_folder / f'{filename}.png'))
+
+ print(f'{steps}: saving to {str(self.results_folder)}')
+
+ # save model every so often
+
+ if not (steps % self.save_model_every):
+ state_dict = self.vae.state_dict()
+ model_path = str(self.results_folder / f'vae.{steps}.pt')
+ torch.save(state_dict, model_path)
+
+ ema_state_dict = self.ema_vae.state_dict()
+ model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
+ torch.save(ema_state_dict, model_path)
+
+ print(f'{steps}: saving model to {str(self.results_folder)}')
+
+ self.steps += 1
+ return logs
+
+ def train(self, log_fn = noop):
+ device = next(self.vae.parameters()).device
+
+ while self.steps < self.num_train_steps:
+ logs = self.train_step()
+ log_fn(logs)
+
+ print('training complete')
diff --git a/docs/src/prior.md b/docs/src/prior.md
new file mode 100644
index 00000000..e4f4841f
--- /dev/null
+++ b/docs/src/prior.md
@@ -0,0 +1,183 @@
+# Diffusion Prior
+This readme serves as an introduction to the diffusion prior.
+
+## Intro
+
+A properly trained prior will allow you to translate between two embedding spaces. If you know *a priori* that two embeddings are connected some way—then ability the translate between them could extremely helpful.
+
+### Motivation
+
+Before we dive into the model, let’s look at a quick example of where the model may be helpful.
+
+For demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder.
+
+> [CLIP](https://openai.com/blog/clip/) is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are ***close*** the image and text embeddings occupy two disjoint sets.
+
+```python
+# Load Models
+clip_model = clip.load("ViT-L/14")
+decoder = Decoder(checkpoint="best.pth") # A decoder trained on CLIP Image embeddings
+
+# Retrieve prompt from user and encode with CLIP
+prompt = "A corgi wearing sunglasses"
+tokenized_text = tokenize(prompt)
+text_embedding = clip_model.encode_text(tokenized_text)
+
+# Now, pass the text embedding to the decoder
+predicted_image = decoder.sample(text_embedding)
+```
+
+> **Question**: *Can you spot the issue here?*
+>
+> **Answer**: *We’re trying to generate an image from a text embedding!*
+
+Unfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution
+
+```python
+# Load Models
+prior= Prior(checkpoint="prior.pth") # A decoder trained to go from: text-> clip text emb -> clip img emb
+decoder = Decoder(checkpoint="decoder.pth") # A decoder trained on CLIP Image embeddings
+
+# Retrieve prompt from user and encode with a prior
+prompt = "A corgi wearing sunglasses"
+tokenized_text = tokenize(prompt)
+text_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images!
+
+# Now, pass the predicted image embedding to the decoder
+predicted_image = decoder.sample(text_embedding)
+```
+
+With the prior we are able to successfully generate embeddings *within* CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data.
+
+> **You may be asking yourself the following question:**
+>
+> *"Why don't you just train the decoder on clip text embeddings instead of image embeddings?"*
+>
+> OpenAI covers this topic in their [DALLE-2 paper](https://arxiv.org/abs/2204.06125). The TL;DR is *"it doesn't work as well as decoders trained on image embeddings"*...also...its just an example :smile:
+
+## Usage
+
+To utilize a pre-trained prior, it’s quite simple.
+
+### Loading Checkpoints
+```python
+import torch
+from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
+from dalle2_pytorch.trainer import DiffusionPriorTrainer
+
+def load_diffusion_model(dprior_path):
+
+ prior_network = DiffusionPriorNetwork(
+ dim=768,
+ depth=24,
+ dim_head=64,
+ heads=32,
+ normformer=True,
+ attn_dropout=5e-2,
+ ff_dropout=5e-2,
+ num_time_embeds=1,
+ num_image_embeds=1,
+ num_text_embeds=1,
+ num_timesteps=1000,
+ ff_mult=4
+ )
+
+ diffusion_prior = DiffusionPrior(
+ net=prior_network,
+ clip=OpenAIClipAdapter("ViT-L/14"),
+ image_embed_dim=768,
+ timesteps=1000,
+ cond_drop_prob=0.1,
+ loss_type="l2",
+ condition_on_text_encodings=True,
+
+ )
+
+ trainer = DiffusionPriorTrainer(
+ diffusion_prior=diffusion_prior,
+ lr=1.1e-4,
+ wd=6.02e-2,
+ max_grad_norm=0.5,
+ amp=False,
+ group_wd_params=True,
+ use_ema=True,
+ device=device,
+ accelerator=None,
+ )
+
+ trainer.load(dprior_path)
+
+ return trainer
+```
+
+ Here we instantiate a model matches the configuration it was trained with, and then load the weights (*just like any other PyTorch model!*)
+
+### Sampling
+Once we have a pre-trained model, generating embeddings is quite simple!
+```python
+# tokenize the text
+tokenized_text = clip.tokenize("")
+# predict an embedding
+predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)
+```
+
+The resulting tensor returned from `.sample()` is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on `ViT-L/14` embeddings will predict an embedding of shape (1, 768).
+
+> For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text().
+
+**Some things to note:**
+* It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is `n=2`). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt.
+* You may specify a higher conditioning scale than the default (`1.0`). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than `1.0` but *ymmv*.
+
+---
+
+## Training
+
+### Overview
+
+Training the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration
+
+## Dataset
+
+To train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage [img2datset](https://github.com/rom1504/img2dataset) to pull images from a list of URLs and [clip_retrieval](https://github.com/rom1504/clip-retrieval#clip-inference) for generating the actual embeddings that can be used in the prior's dataloader.
+
+## Configuration
+
+The configuration file allows for you to easily track and reproduce experiments. It is a simple JSON file that will specify the architecture, dataset, and training parameters. For more information and specifics please see the configuration README.
+
+## Distributed Training
+
+If you would like to train in a distributed manner we have opted to leverage huggingface’ new Accelerate library. HFA makes it extremely simple to distribute work across multiple GPU’s and nodes. All that is required of you is to follow the simple CLI configuration tool [more information here](https://huggingface.co/docs/accelerate/accelerator).
+
+## Evaluation
+
+There are a variety of metrics available to you when training the prior. You can read a brief description of each in the table below:
+| Metric | Description | Comments |
+| ----------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| Online Model Validation | The validation loss associated with your online model. | Ideally validation loss will be as low as possible. Using L2 loss, values as low as `0.1` and lower are possible after around 1 Billion samples seen. |
+| EMA Validation | This metric measures the validation loss associated with your EMA model. | This will likely lag behind your "online" model's validation loss, but should outperform in the long-term. |
+| Baseline Similarity | Baseline similarity refers to the similarity between your dataset's prompts and associated image embeddings. This will serve as a guide for your prior's performance in cosine similarity. | Generally `0.3` is considered a good cosine similarity for caption similarity. |
+| Similarity With Original Image | This metric will measure the cosine similarity between your prior's predicted image embedding and the actual image that the caption was associated with. This is useful for determining wether your prior is generating images with the right contents. | Values around `0.75`+ are obtainable. This metric should improve rapidly in the early stages of training and plateau with diminishing increases over time. If it takes hundreds of millions of samples to reach above `0.5`/`0.6` similarity--then you likely are suffering from some kind of training error or inefficiency (i.e. not using EMA) |
+| Difference From Baseline Similarity | Sometimes its useful to visualize a metric in another light. This metric will show you how your prior's predicted image embeddings match up with the baseline similarity measured in your dataset. | This value should float around `0.0` with some room for variation. After a billion samples seen, values are within `0.01`+/- of `0.0`. If this climbs to high, (~>`0.02`) then this may be a sign that your model is overfitting somehow. |
+| Similarity With Text | This metric is your bread and butter cosine similarity between the predicted image embedding and the original caption given to the prior. Monitoring this metric will be on of your main focuses and is probably the second most important behind your loss. | As mentioned, this value should be close to baseline similarity. We have observed early rapid increase with diminishing returns as the prior learns to generate valid image embeddings. If this value increases too far beyond the baseline similarity--it could be an indication that your model is overfitting. |
+| Similarity With Unrelated Caption | This metric will attempt to exposed an overfit prior by feeding it arbitrary prompts (from your dataset) and then measure the similarity of this predicted embedding with some other image. | Early on we found that a poorly trained/modeled prior could effectively fool CLIP into believing that the cosine similarity between two images were high (when in fact the caption and image were completely unrelated). With this in mind--a low value is ideal, anything below `0.1` is probably safe. |
+
+## Launching the script
+
+Now that you’ve done all the prep it’s time for the easy part! 🚀
+
+To actually launch the script, you will either use `accelerate launch train_diffusion_prior.py --config_path ` to launch with distributed training & huggingface accelerate or `python train_diffusion_prior.py` if you would like to train on your gpu/cpu without huggingface accelerate.
+
+## Checkpointing
+
+Checkpoints will be saved to the directory specified in your configuration file.
+
+Additionally, a final checkpoint is saved before running the test split. This file will be saved to the same directory and titled “latest.pth”. This is to avoid problems where your `save_every` configuration does not overlap with the number of steps required to do a complete pass through the data.
+
+## Things To Keep In Mind
+
+The prior has not been trained for tasks other than the traditional CLIP embedding translation…at least yet.
+
+As we finalize the replication of unCLIP, there will almost assuredly be experiments attempting to apply the prior network to other tasks.
+
+With that in mind, you are more or less a pioneer in embedding-translation if you are reading this and attempting something you don’t see documentation for!
diff --git a/docs/src/setup.py b/docs/src/setup.py
new file mode 100644
index 00000000..d48a7cb9
--- /dev/null
+++ b/docs/src/setup.py
@@ -0,0 +1,59 @@
+from setuptools import setup, find_packages
+exec(open('dalle2_pytorch/version.py').read())
+
+setup(
+ name = 'dalle2-pytorch',
+ packages = find_packages(exclude=[]),
+ include_package_data = True,
+ entry_points={
+ 'console_scripts': [
+ 'dalle2_pytorch = dalle2_pytorch.cli:main',
+ 'dream = dalle2_pytorch.cli:dream'
+ ],
+ },
+ version = __version__,
+ license='MIT',
+ description = 'DALL-E 2',
+ author = 'Phil Wang',
+ author_email = 'lucidrains@gmail.com',
+ long_description_content_type = 'text/markdown',
+ url = 'https://github.com/lucidrains/dalle2-pytorch',
+ keywords = [
+ 'artificial intelligence',
+ 'deep learning',
+ 'text to image'
+ ],
+ install_requires=[
+ 'accelerate',
+ 'click',
+ 'open-clip-torch>=2.0.0,<3.0.0',
+ 'clip-anytorch>=2.5.2',
+ 'coca-pytorch>=0.0.5',
+ 'ema-pytorch>=0.0.7',
+ 'einops>=0.7.0',
+ 'embedding-reader',
+ 'kornia>=0.5.4',
+ 'numpy',
+ 'packaging',
+ 'pillow',
+ 'pydantic>=2',
+ 'pytorch-warmup',
+ 'resize-right>=0.0.2',
+ 'rotary-embedding-torch',
+ 'torch>=1.10',
+ 'torchvision',
+ 'tqdm',
+ 'vector-quantize-pytorch',
+ 'x-clip>=0.4.4',
+ 'webdataset>=0.2.5',
+ 'fsspec>=2022.1.0',
+ 'torchmetrics[image]>=0.8.0'
+ ],
+ classifiers=[
+ 'Development Status :: 4 - Beta',
+ 'Intended Audience :: Developers',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ 'License :: OSI Approved :: MIT License',
+ 'Programming Language :: Python :: 3.6',
+ ],
+)
diff --git a/docs/src/train_decoder.py b/docs/src/train_decoder.py
new file mode 100644
index 00000000..249a093a
--- /dev/null
+++ b/docs/src/train_decoder.py
@@ -0,0 +1,651 @@
+from pathlib import Path
+from typing import List
+from datetime import timedelta
+
+from dalle2_pytorch.trainer import DecoderTrainer
+from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
+from dalle2_pytorch.trackers import Tracker
+from dalle2_pytorch.train_configs import DecoderConfig, TrainDecoderConfig
+from dalle2_pytorch.utils import Timer, print_ribbon
+from dalle2_pytorch.dalle2_pytorch import Decoder, resize_image_to
+from clip import tokenize
+
+import torchvision
+import torch
+from torch import nn
+from torchmetrics.image.fid import FrechetInceptionDistance
+from torchmetrics.image.inception import InceptionScore
+from torchmetrics.image.kid import KernelInceptionDistance
+from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
+from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs
+from accelerate.utils import dataclasses as accelerate_dataclasses
+import webdataset as wds
+import click
+
+# constants
+
+TRAIN_CALC_LOSS_EVERY_ITERS = 10
+VALID_CALC_LOSS_EVERY_ITERS = 10
+
+# helpers functions
+
+def exists(val):
+ return val is not None
+
+# main functions
+
+def create_dataloaders(
+ available_shards,
+ webdataset_base_url,
+ img_embeddings_url=None,
+ text_embeddings_url=None,
+ shard_width=6,
+ num_workers=4,
+ batch_size=32,
+ n_sample_images=6,
+ shuffle_train=True,
+ resample_train=False,
+ img_preproc = None,
+ index_width=4,
+ train_prop = 0.75,
+ val_prop = 0.15,
+ test_prop = 0.10,
+ seed = 0,
+ **kwargs
+):
+ """
+ Randomly splits the available shards into train, val, and test sets and returns a dataloader for each
+ """
+ assert train_prop + test_prop + val_prop == 1
+ num_train = round(train_prop*len(available_shards))
+ num_test = round(test_prop*len(available_shards))
+ num_val = len(available_shards) - num_train - num_test
+ assert num_train + num_test + num_val == len(available_shards), f"{num_train} + {num_test} + {num_val} = {num_train + num_test + num_val} != {len(available_shards)}"
+ train_split, test_split, val_split = torch.utils.data.random_split(available_shards, [num_train, num_test, num_val], generator=torch.Generator().manual_seed(seed))
+
+ # The shard number in the webdataset file names has a fixed width. We zero pad the shard numbers so they correspond to a filename.
+ train_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in train_split]
+ test_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in test_split]
+ val_urls = [webdataset_base_url.format(str(shard).zfill(shard_width)) for shard in val_split]
+
+ create_dataloader = lambda tar_urls, shuffle=False, resample=False, for_sampling=False: create_image_embedding_dataloader(
+ tar_url=tar_urls,
+ num_workers=num_workers,
+ batch_size=batch_size if not for_sampling else n_sample_images,
+ img_embeddings_url=img_embeddings_url,
+ text_embeddings_url=text_embeddings_url,
+ index_width=index_width,
+ shuffle_num = None,
+ extra_keys= ["txt"],
+ shuffle_shards = shuffle,
+ resample_shards = resample,
+ img_preproc=img_preproc,
+ handler=wds.handlers.warn_and_continue
+ )
+
+ train_dataloader = create_dataloader(train_urls, shuffle=shuffle_train, resample=resample_train)
+ train_sampling_dataloader = create_dataloader(train_urls, shuffle=False, for_sampling=True)
+ val_dataloader = create_dataloader(val_urls, shuffle=False)
+ test_dataloader = create_dataloader(test_urls, shuffle=False)
+ test_sampling_dataloader = create_dataloader(test_urls, shuffle=False, for_sampling=True)
+ return {
+ "train": train_dataloader,
+ "train_sampling": train_sampling_dataloader,
+ "val": val_dataloader,
+ "test": test_dataloader,
+ "test_sampling": test_sampling_dataloader
+ }
+
+def get_dataset_keys(dataloader):
+ """
+ It is sometimes neccesary to get the keys the dataloader is returning. Since the dataset is burried in the dataloader, we need to do a process to recover it.
+ """
+ # If the dataloader is actually a WebLoader, we need to extract the real dataloader
+ if isinstance(dataloader, wds.WebLoader):
+ dataloader = dataloader.pipeline[0]
+ return dataloader.dataset.key_map
+
+def get_example_data(dataloader, device, n=5):
+ """
+ Samples the dataloader and returns a zipped list of examples
+ """
+ images = []
+ img_embeddings = []
+ text_embeddings = []
+ captions = []
+ for img, emb, txt in dataloader:
+ img_emb, text_emb = emb.get('img'), emb.get('text')
+ if img_emb is not None:
+ img_emb = img_emb.to(device=device, dtype=torch.float)
+ img_embeddings.extend(list(img_emb))
+ else:
+ # Then we add None img.shape[0] times
+ img_embeddings.extend([None]*img.shape[0])
+ if text_emb is not None:
+ text_emb = text_emb.to(device=device, dtype=torch.float)
+ text_embeddings.extend(list(text_emb))
+ else:
+ # Then we add None img.shape[0] times
+ text_embeddings.extend([None]*img.shape[0])
+ img = img.to(device=device, dtype=torch.float)
+ images.extend(list(img))
+ captions.extend(list(txt))
+ if len(images) >= n:
+ break
+ return list(zip(images[:n], img_embeddings[:n], text_embeddings[:n], captions[:n]))
+
+def generate_samples(trainer, example_data, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend="", match_image_size=True):
+ """
+ Takes example data and generates images from the embeddings
+ Returns three lists: real images, generated images, and captions
+ """
+ real_images, img_embeddings, text_embeddings, txts = zip(*example_data)
+ sample_params = {}
+ if img_embeddings[0] is None:
+ # Generate image embeddings from clip
+ imgs_tensor = torch.stack(real_images)
+ assert clip is not None, "clip is None, but img_embeddings is None"
+ imgs_tensor.to(device=device)
+ img_embeddings, img_encoding = clip.embed_image(imgs_tensor)
+ sample_params["image_embed"] = img_embeddings
+ else:
+ # Then we are using precomputed image embeddings
+ img_embeddings = torch.stack(img_embeddings)
+ sample_params["image_embed"] = img_embeddings
+ if condition_on_text_encodings:
+ if text_embeddings[0] is None:
+ # Generate text embeddings from text
+ assert clip is not None, "clip is None, but text_embeddings is None"
+ tokenized_texts = tokenize(txts, truncate=True).to(device=device)
+ text_embed, text_encodings = clip.embed_text(tokenized_texts)
+ sample_params["text_encodings"] = text_encodings
+ else:
+ # Then we are using precomputed text embeddings
+ text_embeddings = torch.stack(text_embeddings)
+ sample_params["text_encodings"] = text_embeddings
+ sample_params["start_at_unet_number"] = start_unet
+ sample_params["stop_at_unet_number"] = end_unet
+ if start_unet > 1:
+ # If we are only training upsamplers
+ sample_params["image"] = torch.stack(real_images)
+ if device is not None:
+ sample_params["_device"] = device
+ samples = trainer.sample(**sample_params, _cast_deepspeed_precision=False) # At sampling time we don't want to cast to FP16
+ generated_images = list(samples)
+ captions = [text_prepend + txt for txt in txts]
+ if match_image_size:
+ generated_image_size = generated_images[0].shape[-1]
+ real_images = [resize_image_to(image, generated_image_size, clamp_range=(0, 1)) for image in real_images]
+ return real_images, generated_images, captions
+
+def generate_grid_samples(trainer, examples, clip=None, start_unet=1, end_unet=None, condition_on_text_encodings=False, cond_scale=1.0, device=None, text_prepend=""):
+ """
+ Generates samples and uses torchvision to put them in a side by side grid for easy viewing
+ """
+ real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, device, text_prepend)
+ grid_images = [torchvision.utils.make_grid([original_image, generated_image]) for original_image, generated_image in zip(real_images, generated_images)]
+ return grid_images, captions
+
+def evaluate_trainer(trainer, dataloader, device, start_unet, end_unet, clip=None, condition_on_text_encodings=False, cond_scale=1.0, inference_device=None, n_evaluation_samples=1000, FID=None, IS=None, KID=None, LPIPS=None):
+ """
+ Computes evaluation metrics for the decoder
+ """
+ metrics = {}
+ # Prepare the data
+ examples = get_example_data(dataloader, device, n_evaluation_samples)
+ if len(examples) == 0:
+ print("No data to evaluate. Check that your dataloader has shards.")
+ return metrics
+ real_images, generated_images, captions = generate_samples(trainer, examples, clip, start_unet, end_unet, condition_on_text_encodings, cond_scale, inference_device)
+ real_images = torch.stack(real_images).to(device=device, dtype=torch.float)
+ generated_images = torch.stack(generated_images).to(device=device, dtype=torch.float)
+ # Convert from [0, 1] to [0, 255] and from torch.float to torch.uint8
+ int_real_images = real_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
+ int_generated_images = generated_images.mul(255).add(0.5).clamp(0, 255).type(torch.uint8)
+
+ def null_sync(t, *args, **kwargs):
+ return [t]
+
+ if exists(FID):
+ fid = FrechetInceptionDistance(**FID, dist_sync_fn=null_sync)
+ fid.to(device=device)
+ fid.update(int_real_images, real=True)
+ fid.update(int_generated_images, real=False)
+ metrics["FID"] = fid.compute().item()
+ if exists(IS):
+ inception = InceptionScore(**IS, dist_sync_fn=null_sync)
+ inception.to(device=device)
+ inception.update(int_real_images)
+ is_mean, is_std = inception.compute()
+ metrics["IS_mean"] = is_mean.item()
+ metrics["IS_std"] = is_std.item()
+ if exists(KID):
+ kernel_inception = KernelInceptionDistance(**KID, dist_sync_fn=null_sync)
+ kernel_inception.to(device=device)
+ kernel_inception.update(int_real_images, real=True)
+ kernel_inception.update(int_generated_images, real=False)
+ kid_mean, kid_std = kernel_inception.compute()
+ metrics["KID_mean"] = kid_mean.item()
+ metrics["KID_std"] = kid_std.item()
+ if exists(LPIPS):
+ # Convert from [0, 1] to [-1, 1]
+ renorm_real_images = real_images.mul(2).sub(1).clamp(-1,1)
+ renorm_generated_images = generated_images.mul(2).sub(1).clamp(-1,1)
+ lpips = LearnedPerceptualImagePatchSimilarity(**LPIPS, dist_sync_fn=null_sync)
+ lpips.to(device=device)
+ lpips.update(renorm_real_images, renorm_generated_images)
+ metrics["LPIPS"] = lpips.compute().item()
+
+ if trainer.accelerator.num_processes > 1:
+ # Then we should sync the metrics
+ metrics_order = sorted(metrics.keys())
+ metrics_tensor = torch.zeros(1, len(metrics), device=device, dtype=torch.float)
+ for i, metric_name in enumerate(metrics_order):
+ metrics_tensor[0, i] = metrics[metric_name]
+ metrics_tensor = trainer.accelerator.gather(metrics_tensor)
+ metrics_tensor = metrics_tensor.mean(dim=0)
+ for i, metric_name in enumerate(metrics_order):
+ metrics[metric_name] = metrics_tensor[i].item()
+ return metrics
+
+def save_trainer(tracker: Tracker, trainer: DecoderTrainer, epoch: int, sample: int, next_task: str, validation_losses: List[float], samples_seen: int, is_latest=True, is_best=False):
+ """
+ Logs the model with an appropriate method depending on the tracker
+ """
+ tracker.save(trainer, is_best=is_best, is_latest=is_latest, epoch=epoch, sample=sample, next_task=next_task, validation_losses=validation_losses, samples_seen=samples_seen)
+
+def recall_trainer(tracker: Tracker, trainer: DecoderTrainer):
+ """
+ Loads the model with an appropriate method depending on the tracker
+ """
+ trainer.accelerator.print(print_ribbon(f"Loading model from {type(tracker.loader).__name__}"))
+ state_dict = tracker.recall()
+ trainer.load_state_dict(state_dict, only_model=False, strict=True)
+ return state_dict.get("epoch", 0), state_dict.get("validation_losses", []), state_dict.get("next_task", "train"), state_dict.get("sample", 0), state_dict.get("samples_seen", 0)
+
+def train(
+ dataloaders,
+ decoder: Decoder,
+ accelerator: Accelerator,
+ tracker: Tracker,
+ inference_device,
+ clip=None,
+ evaluate_config=None,
+ epoch_samples = None, # If the training dataset is resampling, we have to manually stop an epoch
+ validation_samples = None,
+ save_immediately=False,
+ epochs = 20,
+ n_sample_images = 5,
+ save_every_n_samples = 100000,
+ unet_training_mask=None,
+ condition_on_text_encodings=False,
+ cond_scale=1.0,
+ **kwargs
+):
+ """
+ Trains a decoder on a dataset.
+ """
+ is_master = accelerator.process_index == 0
+
+ if not exists(unet_training_mask):
+ # Then the unet mask should be true for all unets in the decoder
+ unet_training_mask = [True] * len(decoder.unets)
+ assert len(unet_training_mask) == len(decoder.unets), f"The unet training mask should be the same length as the number of unets in the decoder. Got {len(unet_training_mask)} and {trainer.num_unets}"
+ trainable_unet_numbers = [i+1 for i, trainable in enumerate(unet_training_mask) if trainable]
+ first_trainable_unet = trainable_unet_numbers[0]
+ last_trainable_unet = trainable_unet_numbers[-1]
+ def move_unets(unet_training_mask):
+ for i in range(len(decoder.unets)):
+ if not unet_training_mask[i]:
+ # Replace the unet from the module list with a nn.Identity(). This training script never uses unets that aren't being trained so this is fine.
+ decoder.unets[i] = nn.Identity().to(inference_device)
+ # Remove non-trainable unets
+ move_unets(unet_training_mask)
+
+ trainer = DecoderTrainer(
+ decoder=decoder,
+ accelerator=accelerator,
+ dataloaders=dataloaders,
+ **kwargs
+ )
+
+ # Set up starting model and parameters based on a recalled state dict
+ start_epoch = 0
+ validation_losses = []
+ next_task = 'train'
+ sample = 0
+ samples_seen = 0
+ val_sample = 0
+ step = lambda: int(trainer.num_steps_taken(unet_number=first_trainable_unet))
+
+ if tracker.can_recall:
+ start_epoch, validation_losses, next_task, recalled_sample, samples_seen = recall_trainer(tracker, trainer)
+ if next_task == 'train':
+ sample = recalled_sample
+ if next_task == 'val':
+ val_sample = recalled_sample
+ accelerator.print(f"Loaded model from {type(tracker.loader).__name__} on epoch {start_epoch} having seen {samples_seen} samples with minimum validation loss {min(validation_losses) if len(validation_losses) > 0 else 'N/A'}")
+ accelerator.print(f"Starting training from task {next_task} at sample {sample} and validation sample {val_sample}")
+ trainer.to(device=inference_device)
+
+ accelerator.print(print_ribbon("Generating Example Data", repeat=40))
+ accelerator.print("This can take a while to load the shard lists...")
+ if is_master:
+ train_example_data = get_example_data(dataloaders["train_sampling"], inference_device, n_sample_images)
+ accelerator.print("Generated training examples")
+ test_example_data = get_example_data(dataloaders["test_sampling"], inference_device, n_sample_images)
+ accelerator.print("Generated testing examples")
+
+ send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
+
+ sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
+ unet_losses_tensor = torch.zeros(TRAIN_CALC_LOSS_EVERY_ITERS, trainer.num_unets, dtype=torch.float, device=inference_device)
+ for epoch in range(start_epoch, epochs):
+ accelerator.print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
+
+ timer = Timer()
+ last_sample = sample
+ last_snapshot = sample
+
+ if next_task == 'train':
+ for i, (img, emb, txt) in enumerate(dataloaders["train"]):
+ # We want to count the total number of samples across all processes
+ sample_length_tensor[0] = len(img)
+ all_samples = accelerator.gather(sample_length_tensor) # TODO: accelerator.reduce is broken when this was written. If it is fixed replace this.
+ total_samples = all_samples.sum().item()
+ sample += total_samples
+ samples_seen += total_samples
+ img_emb = emb.get('img')
+ has_img_embedding = img_emb is not None
+ if has_img_embedding:
+ img_emb, = send_to_device((img_emb,))
+ text_emb = emb.get('text')
+ has_text_embedding = text_emb is not None
+ if has_text_embedding:
+ text_emb, = send_to_device((text_emb,))
+ img, = send_to_device((img,))
+
+ trainer.train()
+ for unet in range(1, trainer.num_unets+1):
+ # Check if this is a unet we are training
+ if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
+ continue
+
+ forward_params = {}
+ if has_img_embedding:
+ forward_params['image_embed'] = img_emb
+ else:
+ # Forward pass automatically generates embedding
+ assert clip is not None
+ img_embed, img_encoding = clip.embed_image(img)
+ forward_params['image_embed'] = img_embed
+ if condition_on_text_encodings:
+ if has_text_embedding:
+ forward_params['text_encodings'] = text_emb
+ else:
+ # Then we need to pass the text instead
+ assert clip is not None
+ tokenized_texts = tokenize(txt, truncate=True).to(inference_device)
+ assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
+ text_embed, text_encodings = clip.embed_text(tokenized_texts)
+ forward_params['text_encodings'] = text_encodings
+ loss = trainer.forward(img, **forward_params, unet_number=unet, _device=inference_device)
+ trainer.update(unet_number=unet)
+ unet_losses_tensor[i % TRAIN_CALC_LOSS_EVERY_ITERS, unet-1] = loss
+
+ samples_per_sec = (sample - last_sample) / timer.elapsed()
+ timer.reset()
+ last_sample = sample
+
+ if i % TRAIN_CALC_LOSS_EVERY_ITERS == 0:
+ # We want to average losses across all processes
+ unet_all_losses = accelerator.gather(unet_losses_tensor)
+ mask = unet_all_losses != 0
+ unet_average_loss = (unet_all_losses * mask).sum(dim=0) / mask.sum(dim=0)
+ loss_map = { f"Unet {index} Training Loss": loss.item() for index, loss in enumerate(unet_average_loss) if unet_training_mask[index] }
+
+ # gather decay rate on each UNet
+ ema_decay_list = {f"Unet {index} EMA Decay": ema_unet.get_current_decay() for index, ema_unet in enumerate(trainer.ema_unets) if unet_training_mask[index]}
+
+ log_data = {
+ "Epoch": epoch,
+ "Sample": sample,
+ "Step": i,
+ "Samples per second": samples_per_sec,
+ "Samples Seen": samples_seen,
+ **ema_decay_list,
+ **loss_map
+ }
+
+ if is_master:
+ tracker.log(log_data, step=step())
+
+ if is_master and (last_snapshot + save_every_n_samples < sample or (save_immediately and i == 0)): # This will miss by some amount every time, but it's not a big deal... I hope
+ # It is difficult to gather this kind of info on the accelerator, so we have to do it on the master
+ print("Saving snapshot")
+ last_snapshot = sample
+ # We need to know where the model should be saved
+ save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen)
+ if exists(n_sample_images) and n_sample_images > 0:
+ trainer.eval()
+ train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
+ tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
+
+ if epoch_samples is not None and sample >= epoch_samples:
+ break
+ next_task = 'val'
+ sample = 0
+
+ all_average_val_losses = None
+ if next_task == 'val':
+ trainer.eval()
+ accelerator.print(print_ribbon(f"Starting Validation {epoch}", repeat=40))
+ last_val_sample = val_sample
+ val_sample_length_tensor = torch.zeros(1, dtype=torch.int, device=inference_device)
+ average_val_loss_tensor = torch.zeros(1, trainer.num_unets, dtype=torch.float, device=inference_device)
+ timer = Timer()
+ accelerator.wait_for_everyone()
+ i = 0
+ for i, (img, emb, txt) in enumerate(dataloaders['val']): # Use the accelerate prepared loader
+ val_sample_length_tensor[0] = len(img)
+ all_samples = accelerator.gather(val_sample_length_tensor)
+ total_samples = all_samples.sum().item()
+ val_sample += total_samples
+ img_emb = emb.get('img')
+ has_img_embedding = img_emb is not None
+ if has_img_embedding:
+ img_emb, = send_to_device((img_emb,))
+ text_emb = emb.get('text')
+ has_text_embedding = text_emb is not None
+ if has_text_embedding:
+ text_emb, = send_to_device((text_emb,))
+ img, = send_to_device((img,))
+
+ for unet in range(1, len(decoder.unets)+1):
+ if not unet_training_mask[unet-1]: # Unet index is the unet number - 1
+ # No need to evaluate an unchanging unet
+ continue
+
+ forward_params = {}
+ if has_img_embedding:
+ forward_params['image_embed'] = img_emb.float()
+ else:
+ # Forward pass automatically generates embedding
+ assert clip is not None
+ img_embed, img_encoding = clip.embed_image(img)
+ forward_params['image_embed'] = img_embed
+ if condition_on_text_encodings:
+ if has_text_embedding:
+ forward_params['text_encodings'] = text_emb.float()
+ else:
+ # Then we need to pass the text instead
+ assert clip is not None
+ tokenized_texts = tokenize(txt, truncate=True).to(device=inference_device)
+ assert tokenized_texts.shape[0] == len(img), f"The number of texts ({tokenized_texts.shape[0]}) should be the same as the number of images ({len(img)})"
+ text_embed, text_encodings = clip.embed_text(tokenized_texts)
+ forward_params['text_encodings'] = text_encodings
+ loss = trainer.forward(img.float(), **forward_params, unet_number=unet, _device=inference_device)
+ average_val_loss_tensor[0, unet-1] += loss
+
+ if i % VALID_CALC_LOSS_EVERY_ITERS == 0:
+ samples_per_sec = (val_sample - last_val_sample) / timer.elapsed()
+ timer.reset()
+ last_val_sample = val_sample
+ accelerator.print(f"Epoch {epoch}/{epochs} Val Step {i} - Sample {val_sample} - {samples_per_sec:.2f} samples/sec")
+ accelerator.print(f"Loss: {(average_val_loss_tensor / (i+1))}")
+ accelerator.print("")
+
+ if validation_samples is not None and val_sample >= validation_samples:
+ break
+ print(f"Rank {accelerator.state.process_index} finished validation after {i} steps")
+ accelerator.wait_for_everyone()
+ average_val_loss_tensor /= i+1
+ # Gather all the average loss tensors
+ all_average_val_losses = accelerator.gather(average_val_loss_tensor)
+ if is_master:
+ unet_average_val_loss = all_average_val_losses.mean(dim=0)
+ val_loss_map = { f"Unet {index} Validation Loss": loss.item() for index, loss in enumerate(unet_average_val_loss) if loss != 0 }
+ tracker.log(val_loss_map, step=step())
+ next_task = 'eval'
+
+ if next_task == 'eval':
+ if exists(evaluate_config):
+ accelerator.print(print_ribbon(f"Starting Evaluation {epoch}", repeat=40))
+ evaluation = evaluate_trainer(trainer, dataloaders["val"], inference_device, first_trainable_unet, last_trainable_unet, clip=clip, inference_device=inference_device, **evaluate_config.model_dump(), condition_on_text_encodings=condition_on_text_encodings, cond_scale=cond_scale)
+ if is_master:
+ tracker.log(evaluation, step=step())
+ next_task = 'sample'
+ val_sample = 0
+
+ if next_task == 'sample':
+ if is_master:
+ # Generate examples and save the model if we are the master
+ # Generate sample images
+ print(print_ribbon(f"Sampling Set {epoch}", repeat=40))
+ test_images, test_captions = generate_grid_samples(trainer, test_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Test: ")
+ train_images, train_captions = generate_grid_samples(trainer, train_example_data, clip, first_trainable_unet, last_trainable_unet, condition_on_text_encodings, cond_scale, inference_device, "Train: ")
+ tracker.log_images(test_images, captions=test_captions, image_section="Test Samples", step=step())
+ tracker.log_images(train_images, captions=train_captions, image_section="Train Samples", step=step())
+
+ print(print_ribbon(f"Starting Saving {epoch}", repeat=40))
+ is_best = False
+ if all_average_val_losses is not None:
+ average_loss = all_average_val_losses.mean(dim=0).sum() / sum(unet_training_mask)
+ if len(validation_losses) == 0 or average_loss < min(validation_losses):
+ is_best = True
+ validation_losses.append(average_loss)
+ save_trainer(tracker, trainer, epoch, sample, next_task, validation_losses, samples_seen, is_best=is_best)
+ next_task = 'train'
+
+def create_tracker(accelerator: Accelerator, config: TrainDecoderConfig, config_path: str, dummy: bool = False) -> Tracker:
+ tracker_config = config.tracker
+ accelerator_config = {
+ "Distributed": accelerator.distributed_type != accelerate_dataclasses.DistributedType.NO,
+ "DistributedType": accelerator.distributed_type,
+ "NumProcesses": accelerator.num_processes,
+ "MixedPrecision": accelerator.mixed_precision
+ }
+ accelerator.wait_for_everyone() # If nodes arrive at this point at different times they might try to autoresume the current run which makes no sense and will cause errors
+ tracker: Tracker = tracker_config.create(config, accelerator_config, dummy_mode=dummy)
+ tracker.save_config(config_path, config_name='decoder_config.json')
+ tracker.add_save_metadata(state_dict_key='config', metadata=config.model_dump())
+ return tracker
+
+def initialize_training(config: TrainDecoderConfig, config_path):
+ # Make sure if we are not loading, distributed models are initialized to the same values
+ torch.manual_seed(config.seed)
+
+ # Set up accelerator for configurable distributed training
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=config.train.find_unused_parameters, static_graph=config.train.static_graph)
+ init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=60*60))
+ accelerator = Accelerator(kwargs_handlers=[ddp_kwargs, init_kwargs])
+
+ if accelerator.num_processes > 1:
+ # We are using distributed training and want to immediately ensure all can connect
+ accelerator.print("Waiting for all processes to connect...")
+ accelerator.wait_for_everyone()
+ accelerator.print("All processes online and connected")
+
+ # If we are in deepspeed fp16 mode, we must ensure learned variance is off
+ if accelerator.mixed_precision == "fp16" and accelerator.distributed_type == accelerate_dataclasses.DistributedType.DEEPSPEED and config.decoder.learned_variance:
+ raise ValueError("DeepSpeed fp16 mode does not support learned variance")
+
+ # Set up data
+ all_shards = list(range(config.data.start_shard, config.data.end_shard + 1))
+ world_size = accelerator.num_processes
+ rank = accelerator.process_index
+ shards_per_process = len(all_shards) // world_size
+ assert shards_per_process > 0, "Not enough shards to split evenly"
+ my_shards = all_shards[rank * shards_per_process: (rank + 1) * shards_per_process]
+
+ dataloaders = create_dataloaders (
+ available_shards=my_shards,
+ img_preproc = config.data.img_preproc,
+ train_prop = config.data.splits.train,
+ val_prop = config.data.splits.val,
+ test_prop = config.data.splits.test,
+ n_sample_images=config.train.n_sample_images,
+ **config.data.model_dump(),
+ rank = rank,
+ seed = config.seed,
+ )
+
+ # If clip is in the model, we need to remove it for compatibility with deepspeed
+ clip = None
+ if config.decoder.clip is not None:
+ clip = config.decoder.clip.create() # Of course we keep it to use it during training, just not in the decoder as that causes issues
+ config.decoder.clip = None
+ # Create the decoder model and print basic info
+ decoder = config.decoder.create()
+ get_num_parameters = lambda model, only_training=False: sum(p.numel() for p in model.parameters() if (p.requires_grad or not only_training))
+
+ # Create and initialize the tracker if we are the master
+ tracker = create_tracker(accelerator, config, config_path, dummy = rank!=0)
+
+ has_img_embeddings = config.data.img_embeddings_url is not None
+ has_text_embeddings = config.data.text_embeddings_url is not None
+ conditioning_on_text = any([unet.cond_on_text_encodings for unet in config.decoder.unets])
+
+ has_clip_model = clip is not None
+ data_source_string = ""
+
+ if has_img_embeddings:
+ data_source_string += "precomputed image embeddings"
+ elif has_clip_model:
+ data_source_string += "clip image embeddings generation"
+ else:
+ raise ValueError("No image embeddings source specified")
+ if conditioning_on_text:
+ if has_text_embeddings:
+ data_source_string += " and precomputed text embeddings"
+ elif has_clip_model:
+ data_source_string += " and clip text encoding generation"
+ else:
+ raise ValueError("No text embeddings source specified")
+
+ accelerator.print(print_ribbon("Loaded Config", repeat=40))
+ accelerator.print(f"Running training with {accelerator.num_processes} processes and {accelerator.distributed_type} distributed training")
+ accelerator.print(f"Training using {data_source_string}. {'conditioned on text' if conditioning_on_text else 'not conditioned on text'}")
+ accelerator.print(f"Number of parameters: {get_num_parameters(decoder)} total; {get_num_parameters(decoder, only_training=True)} training")
+ for i, unet in enumerate(decoder.unets):
+ accelerator.print(f"Unet {i} has {get_num_parameters(unet)} total; {get_num_parameters(unet, only_training=True)} training")
+
+ train(dataloaders, decoder, accelerator,
+ clip=clip,
+ tracker=tracker,
+ inference_device=accelerator.device,
+ evaluate_config=config.evaluate,
+ condition_on_text_encodings=conditioning_on_text,
+ **config.train.model_dump(),
+ )
+
+# Create a simple click command line interface to load the config and start the training
+@click.command()
+@click.option("--config_file", default="./train_decoder_config.json", help="Path to config file")
+def main(config_file):
+ config_file_path = Path(config_file)
+ config = TrainDecoderConfig.from_json_path(str(config_file_path))
+ initialize_training(config, config_path=config_file_path)
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/src/train_diffusion_prior.py b/docs/src/train_diffusion_prior.py
new file mode 100644
index 00000000..0887956c
--- /dev/null
+++ b/docs/src/train_diffusion_prior.py
@@ -0,0 +1,770 @@
+import click
+import torch
+
+from torch import nn
+from typing import List
+from accelerate import Accelerator
+from accelerate.utils import set_seed
+from torch.utils.data import DataLoader
+from embedding_reader import EmbeddingReader
+from accelerate.utils import dataclasses as accelerate_dataclasses
+
+from dalle2_pytorch.utils import Timer
+from dalle2_pytorch.trackers import Tracker
+from dalle2_pytorch import DiffusionPriorTrainer
+from dalle2_pytorch.dataloaders import get_reader, make_splits
+from dalle2_pytorch.train_configs import (
+ DiffusionPriorConfig,
+ DiffusionPriorTrainConfig,
+ TrainDiffusionPriorConfig,
+)
+
+
+# helpers
+
+
+cos = nn.CosineSimilarity(dim=1, eps=1e-6)
+
+
+def exists(val):
+ return val is not None
+
+
+def all_between(values: list, lower_bound, upper_bound):
+ for value in values:
+ if value < lower_bound or value > upper_bound:
+ return False
+
+ return True
+
+
+def make_model(
+ prior_config: DiffusionPriorConfig,
+ train_config: DiffusionPriorTrainConfig,
+ device: str = None,
+ accelerator: Accelerator = None,
+):
+ # create model from config
+ diffusion_prior = prior_config.create()
+
+ # instantiate the trainer
+ trainer = DiffusionPriorTrainer(
+ diffusion_prior=diffusion_prior,
+ lr=train_config.lr,
+ wd=train_config.wd,
+ max_grad_norm=train_config.max_grad_norm,
+ amp=train_config.amp,
+ use_ema=train_config.use_ema,
+ device=device,
+ accelerator=accelerator,
+ warmup_steps=train_config.warmup_steps,
+ )
+
+ return trainer
+
+
+def create_tracker(
+ accelerator: Accelerator,
+ config: TrainDiffusionPriorConfig,
+ config_path: str,
+ dummy: bool = False,
+) -> Tracker:
+ tracker_config = config.tracker
+
+ accelerator_config = {
+ "Distributed": accelerator.distributed_type
+ != accelerate_dataclasses.DistributedType.NO,
+ "DistributedType": accelerator.distributed_type,
+ "NumProcesses": accelerator.num_processes,
+ "MixedPrecision": accelerator.mixed_precision,
+ }
+
+ tracker: Tracker = tracker_config.create(
+ config, accelerator_config, dummy_mode=dummy
+ )
+
+ tracker.save_config(config_path, config_name="prior_config.json")
+
+ return tracker
+
+
+def pad_gather_reduce(trainer: DiffusionPriorTrainer, x, method="mean"):
+ """
+ pad a value or tensor across all processes and gather
+
+ params:
+ - trainer: a trainer that carries an accelerator object
+ - x: a number or torch tensor to reduce
+ - method: "mean", "sum", "max", "min"
+
+ return:
+ - the average tensor after maskin out 0's
+ - None if the gather resulted in an empty tensor
+ """
+
+ assert method in [
+ "mean",
+ "sum",
+ "max",
+ "min",
+ ], "This function has limited capabilities [sum, mean, max, min]"
+ assert type(x) is not None, "Cannot reduce a None type object"
+
+ # wait for everyone to arrive here before gathering
+
+ if type(x) is not torch.Tensor:
+ x = torch.tensor([x])
+
+ # verify that the tensor is on the proper device
+ x = x.to(trainer.device)
+
+ # pad across processes
+ padded_x = trainer.accelerator.pad_across_processes(x, dim=0)
+
+ # gather across all procesess
+ gathered_x = trainer.accelerator.gather(padded_x)
+
+ # mask out zeros
+ masked_x = gathered_x[gathered_x != 0]
+
+ # if the tensor is empty, warn and return None
+ if len(masked_x) == 0:
+ click.secho(
+ f"The call to this method resulted in an empty tensor after masking out zeros. The gathered tensor was this: {gathered_x} and the original value passed was: {x}.",
+ fg="red",
+ )
+ return None
+
+ if method == "mean":
+ return torch.mean(masked_x)
+ elif method == "sum":
+ return torch.sum(masked_x)
+ elif method == "max":
+ return torch.max(masked_x)
+ elif method == "min":
+ return torch.min(masked_x)
+
+
+def save_trainer(
+ tracker: Tracker,
+ trainer: DiffusionPriorTrainer,
+ is_latest: bool,
+ is_best: bool,
+ epoch: int,
+ samples_seen: int,
+ best_validation_loss: float,
+):
+ """
+ Logs the model with an appropriate method depending on the tracker
+ """
+ trainer.accelerator.wait_for_everyone()
+
+ if trainer.accelerator.is_main_process:
+ click.secho(
+ f"RANK:{trainer.accelerator.process_index} | Saving Model | Best={is_best} | Latest={is_latest}",
+ fg="magenta",
+ )
+
+ tracker.save(
+ trainer=trainer,
+ is_best=is_best,
+ is_latest=is_latest,
+ epoch=int(epoch),
+ samples_seen=int(samples_seen),
+ best_validation_loss=best_validation_loss,
+ )
+
+
+def recall_trainer(tracker: Tracker, trainer: DiffusionPriorTrainer):
+ """
+ Loads the model with an appropriate method depending on the tracker
+ """
+
+ if trainer.accelerator.is_main_process:
+ click.secho(f"Loading model from {type(tracker.loader).__name__}", fg="yellow")
+
+ state_dict = tracker.recall()
+
+ trainer.load(state_dict, strict=True)
+
+ return (
+ int(state_dict.get("epoch", 0)),
+ state_dict.get("best_validation_loss", 0),
+ int(state_dict.get("samples_seen", 0)),
+ )
+
+
+# eval functions
+
+
+def report_validation_loss(
+ trainer: DiffusionPriorTrainer,
+ dataloader: DataLoader,
+ text_conditioned: bool,
+ use_ema: bool,
+ tracker: Tracker,
+ split: str,
+ tracker_folder: str,
+ loss_type: str,
+):
+ """
+ Compute the validation loss on a given subset of data.
+ """
+
+ if trainer.accelerator.is_main_process:
+ click.secho(
+ f"Measuring performance on {use_ema}-{split} split",
+ fg="green",
+ blink=True,
+ )
+
+ total_loss = torch.zeros(1, dtype=torch.float, device=trainer.device)
+
+ for image_embeddings, text_data in dataloader:
+ image_embeddings = image_embeddings.to(trainer.device)
+ text_data = text_data.to(trainer.device)
+
+ input_args = dict(image_embed=image_embeddings)
+
+ if text_conditioned:
+ input_args = dict(**input_args, text=text_data)
+ else:
+ input_args = dict(**input_args, text_embed=text_data)
+
+ if use_ema:
+ loss = trainer.ema_diffusion_prior(**input_args)
+ else:
+ loss = trainer(**input_args)
+
+ total_loss += loss
+
+ # compute the average loss across all processes
+
+ avg_loss = pad_gather_reduce(trainer, total_loss, method="mean")
+ stats = {f"{tracker_folder}/{loss_type}-loss": avg_loss}
+
+ # print and log results on main process
+ tracker.log(stats, step=trainer.step.item() + 1)
+
+ return avg_loss
+
+
+def report_cosine_sims(
+ trainer: DiffusionPriorTrainer,
+ dataloader: DataLoader,
+ text_conditioned: bool,
+ tracker: Tracker,
+ split: str,
+ timesteps: int,
+ tracker_folder: str,
+):
+ trainer.eval()
+ if trainer.accelerator.is_main_process:
+ click.secho(
+ f"Measuring Cosine-Similarity on {split} split with {timesteps} timesteps",
+ fg="green",
+ blink=True,
+ )
+
+ for test_image_embeddings, text_data in dataloader:
+ test_image_embeddings = test_image_embeddings.to(trainer.device)
+ text_data = text_data.to(trainer.device)
+
+ # we are text conditioned, we produce an embedding from the tokenized text
+ if text_conditioned:
+ text_embedding, text_encodings = trainer.embed_text(text_data)
+ text_cond = dict(text_embed=text_embedding, text_encodings=text_encodings)
+ else:
+ text_embedding = text_data
+ text_cond = dict(text_embed=text_embedding)
+
+ # make a copy of the text embeddings for shuffling
+ text_embed_shuffled = text_embedding.clone()
+
+ # roll the text to simulate "unrelated" captions
+ rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
+ text_embed_shuffled = text_embed_shuffled[rolled_idx]
+ text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm(
+ dim=1, keepdim=True
+ )
+
+ if text_conditioned:
+ text_encodings_shuffled = text_encodings[rolled_idx]
+ else:
+ text_encodings_shuffled = None
+
+ text_cond_shuffled = dict(
+ text_embed=text_embed_shuffled, text_encodings=text_encodings_shuffled
+ )
+
+ # prepare the text embedding
+ text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
+
+ # prepare image embeddings
+ test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(
+ dim=1, keepdim=True
+ )
+
+ # predict on the unshuffled text embeddings
+ predicted_image_embeddings = trainer.p_sample_loop(
+ test_image_embeddings.shape,
+ text_cond,
+ timesteps=timesteps,
+ )
+
+ predicted_image_embeddings = (
+ predicted_image_embeddings
+ / predicted_image_embeddings.norm(dim=1, keepdim=True)
+ )
+
+ # predict on the shuffled embeddings
+ predicted_unrelated_embeddings = trainer.p_sample_loop(
+ test_image_embeddings.shape,
+ text_cond_shuffled,
+ timesteps=timesteps,
+ )
+
+ predicted_unrelated_embeddings = (
+ predicted_unrelated_embeddings
+ / predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
+ )
+
+ # calculate similarities
+ orig_sim = pad_gather_reduce(
+ trainer, cos(text_embed, test_image_embeddings), method="mean"
+ )
+ pred_sim = pad_gather_reduce(
+ trainer, cos(text_embed, predicted_image_embeddings), method="mean"
+ )
+ unrel_sim = pad_gather_reduce(
+ trainer, cos(text_embed, predicted_unrelated_embeddings), method="mean"
+ )
+ pred_img_sim = pad_gather_reduce(
+ trainer,
+ cos(test_image_embeddings, predicted_image_embeddings),
+ method="mean",
+ )
+
+ stats = {
+ f"{tracker_folder}/baseline similarity [steps={timesteps}]": orig_sim,
+ f"{tracker_folder}/similarity with text [steps={timesteps}]": pred_sim,
+ f"{tracker_folder}/similarity with original image [steps={timesteps}]": pred_img_sim,
+ f"{tracker_folder}/similarity with unrelated caption [steps={timesteps}]": unrel_sim,
+ f"{tracker_folder}/difference from baseline similarity [steps={timesteps}]": pred_sim
+ - orig_sim,
+ }
+
+ tracker.log(stats, step=trainer.step.item() + 1)
+
+
+def eval_model(
+ trainer: DiffusionPriorTrainer,
+ dataloader: DataLoader,
+ text_conditioned: bool,
+ split: str,
+ tracker: Tracker,
+ use_ema: bool,
+ report_cosine: bool,
+ report_loss: bool,
+ timesteps: List[int],
+ loss_type: str = None,
+):
+ """
+ Run evaluation on a model and track metrics
+
+ returns: loss if requested
+ """
+ trainer.eval()
+
+ use_ema = "ema" if use_ema else "online"
+ tracker_folder = f"metrics/{use_ema}-{split}"
+
+ # detemine if valid timesteps are passed
+
+ min_timesteps = trainer.accelerator.unwrap_model(
+ trainer.diffusion_prior
+ ).sample_timesteps
+ max_timesteps = trainer.accelerator.unwrap_model(
+ trainer.diffusion_prior
+ ).noise_scheduler.num_timesteps
+
+ assert all_between(
+ timesteps, lower_bound=min_timesteps, upper_bound=max_timesteps
+ ), f"all timesteps values must be between {min_timesteps} and {max_timesteps}: got {timesteps}"
+
+ # measure cosine metrics across various eta and timesteps
+
+ if report_cosine:
+ for timestep in timesteps:
+ report_cosine_sims(
+ trainer,
+ dataloader=dataloader,
+ text_conditioned=text_conditioned,
+ tracker=tracker,
+ split=split,
+ timesteps=timestep,
+ tracker_folder=tracker_folder,
+ )
+
+ # measure loss on a seperate split of data
+
+ if report_loss:
+ loss = report_validation_loss(
+ trainer=trainer,
+ dataloader=dataloader,
+ text_conditioned=text_conditioned,
+ use_ema=use_ema,
+ tracker=tracker,
+ split=split,
+ tracker_folder=tracker_folder,
+ loss_type=loss_type,
+ )
+
+ return loss
+
+
+# training script
+
+
+def train(
+ trainer: DiffusionPriorTrainer,
+ tracker: Tracker,
+ train_loader: DataLoader,
+ eval_loader: DataLoader,
+ test_loader: DataLoader,
+ config: DiffusionPriorTrainConfig,
+):
+ # init timers
+ save_timer = Timer() # when to save
+ samples_timer = Timer() # samples/sec
+ validation_profiler = Timer() # how long is validation taking
+ validation_countdown = Timer() # when to perform evalutation
+
+ # keep track of best validation loss
+
+ best_validation_loss = config.train.best_validation_loss
+ samples_seen = config.train.num_samples_seen
+
+ # do training
+
+ start_epoch = config.train.current_epoch
+
+ for epoch in range(start_epoch, config.train.epochs):
+ # if we finished out an old epoch, reset the distribution to be a full epoch
+ tracker.log({"tracking/epoch": epoch}, step=trainer.step.item())
+
+ if train_loader.dataset.get_start() > 0 and epoch == start_epoch+1:
+ if trainer.accelerator.is_main_process:
+ click.secho(f"Finished resumed epoch...resetting dataloader.")
+ train_loader.dataset.set_start(0)
+
+ for img, txt in train_loader:
+ # setup things every step
+
+ trainer.train()
+ current_step = trainer.step.item()
+ samples_timer.reset()
+
+ # place data on device
+
+ img = img.to(trainer.device)
+ txt = txt.to(trainer.device)
+
+ # pass to model
+
+ loss = trainer(text=txt, image_embed=img)
+
+ # perform backprop & apply EMA updates
+
+ trainer.update()
+
+ # gather info about training step
+
+ all_loss = pad_gather_reduce(trainer, loss, method="mean")
+ num_samples = pad_gather_reduce(trainer, len(txt), method="sum")
+ samples_per_sec = num_samples / samples_timer.elapsed()
+ samples_seen += num_samples
+ ema_decay = trainer.ema_diffusion_prior.get_current_decay()
+
+ # log
+
+ tracker.log(
+ {
+ "tracking/samples-sec": samples_per_sec,
+ "tracking/samples-seen": samples_seen,
+ "tracking/ema-decay": ema_decay,
+ f"tracking/training-{config.prior.loss_type}": all_loss,
+ },
+ step=current_step,
+ )
+
+ # Metric Tracking @ Timed Intervals
+
+ eval_delta = pad_gather_reduce(
+ trainer, validation_countdown.elapsed(), method="min"
+ )
+
+ if eval_delta != None and eval_delta > config.data.eval_every_seconds:
+ # begin timing how long this takes
+
+ validation_profiler.reset()
+
+ # package kwargs for evaluation
+
+ eval_kwargs = {
+ "trainer": trainer,
+ "tracker": tracker,
+ "text_conditioned": config.prior.condition_on_text_encodings,
+ "timesteps": config.train.eval_timesteps,
+ }
+
+ # ONLINE MODEL : COSINE : LOSS : VALIDATION SPLIT
+
+ eval_model(
+ dataloader=eval_loader,
+ loss_type=config.prior.loss_type,
+ split="validation",
+ use_ema=False,
+ report_cosine=False,
+ report_loss=True,
+ **eval_kwargs,
+ )
+
+ # EMA MODEL : COSINE : LOSS : VALIDATION DATA
+
+ ema_val_loss = eval_model(
+ dataloader=eval_loader,
+ loss_type=config.prior.loss_type,
+ split="validation",
+ use_ema=True,
+ report_cosine=True,
+ report_loss=True,
+ **eval_kwargs,
+ )
+
+ tracker.log(
+ {
+ "tracking/validation length (minutes)": validation_profiler.elapsed()
+ / 60
+ }
+ )
+
+ # check if the ema validation is the lowest seen yet
+
+ if ema_val_loss < best_validation_loss:
+ best_validation_loss = ema_val_loss
+
+ # go save the model as best
+
+ save_trainer(
+ trainer=trainer,
+ tracker=tracker,
+ is_best=True,
+ is_latest=False,
+ samples_seen=samples_seen,
+ epoch=epoch,
+ best_validation_loss=best_validation_loss,
+ )
+
+ # reset timer for validaiton
+
+ validation_countdown.reset()
+
+ elif eval_delta is None:
+ click.secho(
+ f"Error occured reading the eval time on rank: {trainer.device}",
+ fg="yellow",
+ )
+
+ # save as latest model on schedule
+
+ save_delta = pad_gather_reduce(trainer, save_timer.elapsed(), method="min")
+
+ if save_delta != None and save_delta >= config.train.save_every_seconds:
+ save_trainer(
+ trainer=trainer,
+ tracker=tracker,
+ is_best=False,
+ is_latest=True,
+ samples_seen=samples_seen,
+ epoch=epoch,
+ best_validation_loss=best_validation_loss,
+ )
+
+ save_timer.reset()
+
+ elif save_delta is None:
+ click.secho(
+ f"Error occured reading the save time on rank: {trainer.device}",
+ fg="yellow",
+ )
+
+ # evaluate on test data
+
+ if trainer.accelerator.is_main_process:
+ click.secho(f"Starting Test", fg="red")
+
+ # save one last time as latest before beginning validation
+
+ save_trainer(
+ tracker=tracker,
+ trainer=trainer,
+ is_best=False,
+ is_latest=True,
+ samples_seen=samples_seen,
+ epoch=epoch,
+ best_validation_loss=best_validation_loss,
+ )
+
+ test_loss = eval_model(
+ trainer=trainer,
+ dataloader=test_loader,
+ text_conditioned=config.prior.condition_on_text_encodings,
+ split="test",
+ tracker=tracker,
+ use_ema=True,
+ report_cosine=False,
+ report_loss=True,
+ timesteps=config.train.eval_timesteps,
+ loss_type=config.prior.loss_type,
+ )
+
+ if test_loss < best_validation_loss:
+ best_validation_loss = test_loss
+
+ # go save the model as best
+
+ save_trainer(
+ trainer=trainer,
+ tracker=tracker,
+ is_best=True,
+ is_latest=False,
+ samples_seen=samples_seen,
+ epoch=epoch,
+ best_validation_loss=test_loss,
+ )
+
+
+def initialize_training(config_file, accelerator):
+ """
+ Parse the configuration file, and prepare everything necessary for training
+ """
+ # load the configuration file
+ if accelerator.is_main_process:
+ click.secho(f"Loading configuration from {config_file}", fg="green")
+
+ config = TrainDiffusionPriorConfig.from_json_path(config_file)
+
+ # seed
+
+ set_seed(config.train.random_seed)
+
+ # get a device
+
+ device = accelerator.device
+
+ # make the trainer (will automatically distribute if possible & configured)
+
+ trainer: DiffusionPriorTrainer = make_model(
+ config.prior, config.train, device, accelerator
+ ).to(device)
+
+ # create a tracker
+
+ tracker = create_tracker(
+ accelerator, config, config_file, dummy=accelerator.process_index != 0
+ )
+
+ # reload from chcekpoint
+
+ if tracker.can_recall:
+ current_epoch, best_validation_loss, samples_seen = recall_trainer(
+ tracker=tracker, trainer=trainer
+ )
+
+ # display best values
+ if trainer.accelerator.is_main_process:
+ click.secho(f"Current Epoch: {current_epoch} | Best Val Loss: {best_validation_loss} | Samples Seen: {samples_seen}", fg="yellow")
+
+ # update config to reflect recalled values
+ config.train.num_samples_seen = samples_seen
+ config.train.current_epoch = current_epoch
+ config.train.best_validation_loss = best_validation_loss
+
+ # fetch and prepare data
+
+ if trainer.accelerator.is_main_process:
+ click.secho("Grabbing data...", fg="blue", blink=True)
+
+ trainer.accelerator.wait_for_everyone()
+ img_reader = get_reader(
+ text_conditioned=trainer.text_conditioned,
+ img_url=config.data.image_url,
+ meta_url=config.data.meta_url,
+ )
+
+ # calculate start point within epoch
+
+ trainer.accelerator.wait_for_everyone()
+
+ train_loader, eval_loader, test_loader = make_splits(
+ text_conditioned=trainer.text_conditioned,
+ batch_size=config.data.batch_size,
+ num_data_points=config.data.num_data_points,
+ train_split=config.data.splits.train,
+ eval_split=config.data.splits.val,
+ image_reader=img_reader,
+ rank=accelerator.state.process_index,
+ world_size=accelerator.state.num_processes,
+ start=0,
+ )
+
+ # update the start point to finish out the epoch on a resumed run
+
+ if tracker.can_recall:
+ samples_seen = config.train.num_samples_seen
+ length = (
+ config.data.num_data_points
+ if samples_seen <= img_reader.count
+ else img_reader.count
+ )
+ scaled_samples = length * config.train.current_epoch
+ start_point = (
+ scaled_samples - samples_seen if scaled_samples > samples_seen else samples_seen
+ )
+
+ if trainer.accelerator.is_main_process:
+ click.secho(f"Resuming at sample: {start_point}", fg="yellow")
+
+ train_loader.dataset.set_start(start_point)
+
+ # start training
+
+ if trainer.accelerator.is_main_process:
+ click.secho(
+ f"Beginning Prior Training : Distributed={accelerator.state.distributed_type != accelerate_dataclasses.DistributedType.NO}",
+ fg="yellow",
+ )
+
+ train(
+ trainer=trainer,
+ tracker=tracker,
+ train_loader=train_loader,
+ eval_loader=eval_loader,
+ test_loader=test_loader,
+ config=config,
+ )
+
+
+@click.command()
+@click.option("--config_file", default="configs/train_prior_config.example.json")
+def main(config_file):
+ # start HFA
+ accelerator = Accelerator()
+
+ # setup training
+ initialize_training(config_file, accelerator)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/tree.html b/docs/tree.html
new file mode 100644
index 00000000..9f225012
--- /dev/null
+++ b/docs/tree.html
@@ -0,0 +1,158 @@
+
+
+
+
+
+
+
+
+ Project structure of: lucidrains/DALLE2-pytorch
+
+
+
+
+
+
+
+
Project structure of: lucidrains/DALLE2-pytorch
+
+
DALLE2-pytorchDALL-E 2: Image generation & testing, with diffusion prior model training code.
+
configs
+
README.mdConfigure DALLE2 model training options in PyTorch
+
+
+
dalle2_pytorchTrains VAE models with DALL-E 2 image library.