From 33f9cae826c827f9ca4f0addf42d04fb0cb70ff7 Mon Sep 17 00:00:00 2001 From: Aaina Jain <2022uee1052@mnit.ac.in> Date: Mon, 2 Jun 2025 15:49:08 +0530 Subject: [PATCH 01/10] Update train_ddpm.py Changed the dataset to CIFAR-10 --- tools/train_ddpm.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/tools/train_ddpm.py b/tools/train_ddpm.py index 21611a4..c53369b 100644 --- a/tools/train_ddpm.py +++ b/tools/train_ddpm.py @@ -5,12 +5,22 @@ import numpy as np from tqdm import tqdm from torch.optim import Adam -from dataset.mnist_dataset import MnistDataset + from torch.utils.data import DataLoader from models.unet_base import Unet from scheduler.linear_noise_scheduler import LinearNoiseScheduler device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +from torchvision import datasets, transforms + + +transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((32, 32)), +]) + +dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) +train_loader = DataLoader(dataset, batch_size=64, shuffle=True) def train(args): @@ -33,10 +43,9 @@ def train(args): beta_start=diffusion_config['beta_start'], beta_end=diffusion_config['beta_end']) - # Create the dataset - mnist = MnistDataset('train', im_path=dataset_config['im_path']) - mnist_loader = DataLoader(mnist, batch_size=train_config['batch_size'], shuffle=True, num_workers=4) - + # Before training + print("Sample input shape:", im.shape) + # Instantiate the model model = Unet(model_config).to(device) model.train() @@ -58,7 +67,7 @@ def train(args): # Run training for epoch_idx in range(num_epochs): losses = [] - for im in tqdm(mnist_loader): + for im in tqdm(train_loader): optimizer.zero_grad() im = im.float().to(device) From ae692d74eb230068fe1a6189864d619dac3a4370 Mon Sep 17 00:00:00 2001 From: Aaina Jain <2022uee1052@mnit.ac.in> Date: Mon, 2 Jun 2025 16:34:47 +0530 Subject: [PATCH 02/10] Update default.yaml --- config/default.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/config/default.yaml b/config/default.yaml index 7547712..c7b2048 100644 --- a/config/default.yaml +++ b/config/default.yaml @@ -7,10 +7,10 @@ diffusion_params: beta_end : 0.02 model_params: - im_channels : 1 - im_size : 28 - down_channels : [32, 64, 128, 256] - mid_channels : [256, 256, 128] + im_channels : 3 + im_size : 32 + down_channels: [64, 128, 256, 512] + mid_channels: [512, 512, 256] down_sample : [True, True, False] time_emb_dim : 128 num_down_layers : 2 From 0036dc412c947b61991f8b562b8fa0cb21875da5 Mon Sep 17 00:00:00 2001 From: Aaina Jain <2022uee1052@mnit.ac.in> Date: Mon, 2 Jun 2025 17:36:49 +0530 Subject: [PATCH 03/10] Update train_ddpm.py --- tools/train_ddpm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tools/train_ddpm.py b/tools/train_ddpm.py index c53369b..f37a612 100644 --- a/tools/train_ddpm.py +++ b/tools/train_ddpm.py @@ -43,8 +43,7 @@ def train(args): beta_start=diffusion_config['beta_start'], beta_end=diffusion_config['beta_end']) - # Before training - print("Sample input shape:", im.shape) + # Instantiate the model model = Unet(model_config).to(device) @@ -68,6 +67,8 @@ def train(args): for epoch_idx in range(num_epochs): losses = [] for im in tqdm(train_loader): + # Before training + print("Sample input shape:", im.shape) optimizer.zero_grad() im = im.float().to(device) From 376226979d82036bf992f12df606594433966e28 Mon Sep 17 00:00:00 2001 From: Aaina Jain <2022uee1052@mnit.ac.in> Date: Tue, 3 Jun 2025 16:38:44 +0530 Subject: [PATCH 04/10] Update train_ddpm.py --- tools/train_ddpm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tools/train_ddpm.py b/tools/train_ddpm.py index f37a612..6e4f33b 100644 --- a/tools/train_ddpm.py +++ b/tools/train_ddpm.py @@ -67,8 +67,6 @@ def train(args): for epoch_idx in range(num_epochs): losses = [] for im in tqdm(train_loader): - # Before training - print("Sample input shape:", im.shape) optimizer.zero_grad() im = im.float().to(device) From 4271250b4715db9453ed8301eaec397209ced67d Mon Sep 17 00:00:00 2001 From: Aaina Jain <2022uee1052@mnit.ac.in> Date: Tue, 3 Jun 2025 17:05:28 +0530 Subject: [PATCH 05/10] Update train_ddpm.py --- tools/train_ddpm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tools/train_ddpm.py b/tools/train_ddpm.py index 6e4f33b..9b89566 100644 --- a/tools/train_ddpm.py +++ b/tools/train_ddpm.py @@ -30,11 +30,10 @@ def train(args): config = yaml.safe_load(file) except yaml.YAMLError as exc: print(exc) - print(config) + #print(config) ######################## diffusion_config = config['diffusion_params'] - dataset_config = config['dataset_params'] model_config = config['model_params'] train_config = config['train_params'] @@ -66,7 +65,7 @@ def train(args): # Run training for epoch_idx in range(num_epochs): losses = [] - for im in tqdm(train_loader): + for im, _ in train_loader: optimizer.zero_grad() im = im.float().to(device) From 10d1c0a8f67b3531a3d4ab7a6e4138d7ee2cf120 Mon Sep 17 00:00:00 2001 From: aainaaa <2022uee1052@mnit.ac.in> Date: Tue, 3 Jun 2025 11:59:27 +0000 Subject: [PATCH 06/10] WIP: local edits before merge --- tools/train_ddpm.py | 103 +------------------------------------------- 1 file changed, 1 insertion(+), 102 deletions(-) diff --git a/tools/train_ddpm.py b/tools/train_ddpm.py index 6e4f33b..4d69d29 100644 --- a/tools/train_ddpm.py +++ b/tools/train_ddpm.py @@ -1,102 +1 @@ -import torch -import yaml -import argparse -import os -import numpy as np -from tqdm import tqdm -from torch.optim import Adam - -from torch.utils.data import DataLoader -from models.unet_base import Unet -from scheduler.linear_noise_scheduler import LinearNoiseScheduler - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -from torchvision import datasets, transforms - - -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Resize((32, 32)), -]) - -dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) -train_loader = DataLoader(dataset, batch_size=64, shuffle=True) - - -def train(args): - # Read the config file # - with open(args.config_path, 'r') as file: - try: - config = yaml.safe_load(file) - except yaml.YAMLError as exc: - print(exc) - print(config) - ######################## - - diffusion_config = config['diffusion_params'] - dataset_config = config['dataset_params'] - model_config = config['model_params'] - train_config = config['train_params'] - - # Create the noise scheduler - scheduler = LinearNoiseScheduler(num_timesteps=diffusion_config['num_timesteps'], - beta_start=diffusion_config['beta_start'], - beta_end=diffusion_config['beta_end']) - - - - # Instantiate the model - model = Unet(model_config).to(device) - model.train() - - # Create output directories - if not os.path.exists(train_config['task_name']): - os.mkdir(train_config['task_name']) - - # Load checkpoint if found - if os.path.exists(os.path.join(train_config['task_name'],train_config['ckpt_name'])): - print('Loading checkpoint as found one') - model.load_state_dict(torch.load(os.path.join(train_config['task_name'], - train_config['ckpt_name']), map_location=device)) - # Specify training parameters - num_epochs = train_config['num_epochs'] - optimizer = Adam(model.parameters(), lr=train_config['lr']) - criterion = torch.nn.MSELoss() - - # Run training - for epoch_idx in range(num_epochs): - losses = [] - for im in tqdm(train_loader): - optimizer.zero_grad() - im = im.float().to(device) - - # Sample random noise - noise = torch.randn_like(im).to(device) - - # Sample timestep - t = torch.randint(0, diffusion_config['num_timesteps'], (im.shape[0],)).to(device) - - # Add noise to images according to timestep - noisy_im = scheduler.add_noise(im, noise, t) - noise_pred = model(noisy_im, t) - - loss = criterion(noise_pred, noise) - losses.append(loss.item()) - loss.backward() - optimizer.step() - print('Finished epoch:{} | Loss : {:.4f}'.format( - epoch_idx + 1, - np.mean(losses), - )) - torch.save(model.state_dict(), os.path.join(train_config['task_name'], - train_config['ckpt_name'])) - - print('Done Training ...') - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Arguments for ddpm training') - parser.add_argument('--config', dest='config_path', - default='config/default.yaml', type=str) - args = parser.parse_args() - train(args) + \ No newline at end of file From 436af9d0561bf7ff45a340b6edf77bd3fc73fe2b Mon Sep 17 00:00:00 2001 From: Aaina Jain <2022uee1052@mnit.ac.in> Date: Mon, 9 Jun 2025 10:46:18 +0530 Subject: [PATCH 07/10] Update train_ddpm.py --- tools/train_ddpm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/train_ddpm.py b/tools/train_ddpm.py index 2ba647b..a4c2d13 100644 --- a/tools/train_ddpm.py +++ b/tools/train_ddpm.py @@ -102,4 +102,4 @@ def train(args): default='config/default.yaml', type=str) args = parser.parse_args() train(args) ->>>>>>> 4271250b4715db9453ed8301eaec397209ced67d + From cb628da9ea0c93982e2154a802b88d51d898c86d Mon Sep 17 00:00:00 2001 From: Aaina Jain <2022uee1052@mnit.ac.in> Date: Mon, 9 Jun 2025 10:54:38 +0530 Subject: [PATCH 08/10] Update train_ddpm.py --- tools/train_ddpm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tools/train_ddpm.py b/tools/train_ddpm.py index a4c2d13..8b8af5e 100644 --- a/tools/train_ddpm.py +++ b/tools/train_ddpm.py @@ -1,6 +1,3 @@ -<<<<<<< HEAD - -======= import torch import yaml import argparse From 25db718641c196d5d654d51f93f6bbfb023b1d2c Mon Sep 17 00:00:00 2001 From: aainaaa <2022uee1052@mnit.ac.in> Date: Mon, 9 Jun 2025 06:08:59 +0000 Subject: [PATCH 09/10] Ignore dataset directory --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 17e8732..d512813 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,4 @@ __pycache__ *.pth # Ignore pickle files -*.pkl \ No newline at end of file +*.pkldata/ From e8504a7ff7e045da268febf8098eb505ed9e94b8 Mon Sep 17 00:00:00 2001 From: aainaaa <2022uee1052@mnit.ac.in> Date: Mon, 9 Jun 2025 06:14:41 +0000 Subject: [PATCH 10/10] Ignore dataset directory --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index d512813..32fdff7 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ __pycache__ # Ignore pickle files *.pkldata/ +data/