diff --git a/.gitignore b/.gitignore index 17e8732..32fdff7 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,5 @@ __pycache__ *.pth # Ignore pickle files -*.pkl \ No newline at end of file +*.pkldata/ +data/ diff --git a/config/default.yaml b/config/default.yaml index 7547712..c7b2048 100644 --- a/config/default.yaml +++ b/config/default.yaml @@ -7,10 +7,10 @@ diffusion_params: beta_end : 0.02 model_params: - im_channels : 1 - im_size : 28 - down_channels : [32, 64, 128, 256] - mid_channels : [256, 256, 128] + im_channels : 3 + im_size : 32 + down_channels: [64, 128, 256, 512] + mid_channels: [512, 512, 256] down_sample : [True, True, False] time_emb_dim : 128 num_down_layers : 2 diff --git a/tools/train_ddpm.py b/tools/train_ddpm.py index 21611a4..8b8af5e 100644 --- a/tools/train_ddpm.py +++ b/tools/train_ddpm.py @@ -5,12 +5,22 @@ import numpy as np from tqdm import tqdm from torch.optim import Adam -from dataset.mnist_dataset import MnistDataset + from torch.utils.data import DataLoader from models.unet_base import Unet from scheduler.linear_noise_scheduler import LinearNoiseScheduler device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +from torchvision import datasets, transforms + + +transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Resize((32, 32)), +]) + +dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) +train_loader = DataLoader(dataset, batch_size=64, shuffle=True) def train(args): @@ -20,11 +30,10 @@ def train(args): config = yaml.safe_load(file) except yaml.YAMLError as exc: print(exc) - print(config) + #print(config) ######################## diffusion_config = config['diffusion_params'] - dataset_config = config['dataset_params'] model_config = config['model_params'] train_config = config['train_params'] @@ -33,10 +42,8 @@ def train(args): beta_start=diffusion_config['beta_start'], beta_end=diffusion_config['beta_end']) - # Create the dataset - mnist = MnistDataset('train', im_path=dataset_config['im_path']) - mnist_loader = DataLoader(mnist, batch_size=train_config['batch_size'], shuffle=True, num_workers=4) + # Instantiate the model model = Unet(model_config).to(device) model.train() @@ -58,7 +65,7 @@ def train(args): # Run training for epoch_idx in range(num_epochs): losses = [] - for im in tqdm(mnist_loader): + for im, _ in train_loader: optimizer.zero_grad() im = im.float().to(device) @@ -92,3 +99,4 @@ def train(args): default='config/default.yaml', type=str) args = parser.parse_args() train(args) +