diff --git a/PARCtorch/autoencoder/autoencoder.py b/PARCtorch/autoencoder/autoencoder.py new file mode 100644 index 0000000..75e4258 --- /dev/null +++ b/PARCtorch/autoencoder/autoencoder.py @@ -0,0 +1,176 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm +from torchvision.utils import make_grid +import numpy as np +import matplotlib.pyplot as plt +import json +import torch.nn.functional as F + +# Define Modules # + +class MLPEncoder(nn.Module): + def __init__(self, layers, latent_dim, act_fn=nn.ReLU()): + ''' + note: layers and latent dim are treated the same as for convolutional autoencoder at the input level, but under the hood the MLP model flattens each layer as it should + + layers: list, channel values excluding latent dim channels e.g [3, 8], should be the same as layers for decoder + latent_dim: int, number of channels to have in the latent bottlneck layer + act_fn: activation function to be used throughout entire model + ''' + super().__init__() + modules = [] + in_dim = layers[0] + for dim in layers[1:]: + modules.append(nn.Linear(in_dim, dim)) + modules.append(act_fn) + in_dim = dim + modules.append(nn.Linear(in_dim, latent_dim)) # Bottleneck layer + self.net = nn.Sequential(*modules) + + def forward(self, x): + # Flatten input except batch dimension + x = x.view(x.size(0), -1) + return self.net(x) + + +class MLPDecoder(nn.Module): + def __init__(self, layers, latent_dim, output_shape=(3, 128, 256), act_fn=nn.ReLU()): + ''' + note: layers and latent dim are treated the same as for convolutional autoencoder at the input level, but under the hood the MLP model flattens each layer as it should + + layers: list, channel values excluding latent dim channels e.g [3, 8], should be the same as layers for decoder + latent_dim: int, number of channels to have in the latent bottlneck layer + output_shape: tuple, used to reshape the flattened vector correctly upon output (n_channels, height, width) + act_fn: activation function to be used throughout entire model + ''' + super().__init__() + self.output_shape = output_shape + modules = [] + in_dim = latent_dim + for dim in reversed(layers): + modules.append(nn.Linear(in_dim, dim)) + modules.append(act_fn) + in_dim = dim + self.net = nn.Sequential(*modules) + + def forward(self, x): + x = self.net(x) + batch_size = x.size(0) + return x.view(batch_size, *self.output_shape) + + +# Convolutional AE +class Encoder(nn.Module): + def __init__(self, layers, latent_dim=16, act_fn=nn.ReLU()): + ''' + layers: list, channel values excluding latent dim channels e.g [3, 8], should be the same as layers for decoder + latent_dim: int, number of channels to have in the latent bottlneck layer + act_fn: activation function to be used throughout entire model + ''' + super().__init__() + modules = [] + in_channels = layers[0] + for out_channels in layers[1:]: + modules.append( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) + ) # Keep padding=1 for same-sized convolutions + modules.append(act_fn) + in_channels = out_channels + modules.append( + nn.Conv2d(layers[-1], latent_dim, kernel_size=3, stride = 2, padding=1) + ) # Bottleneck layer + self.net = nn.Sequential(*modules) + + def forward(self, x): + return self.net(x) + + +class Decoder(nn.Module): # no deconv + def __init__(self, layers, latent_dim=16, act_fn=nn.ReLU()): + ''' + layers: list, channel values excluding latent dim channels e.g [3, 8], should be the same as layers for encoder + latent_dim: int, number of channels to have in the latent bottlneck layer + act_fn: activation function to be used throughout entire model + ''' + super().__init__() + + self.in_channels = layers[-1] + self.latent_dim = latent_dim + + modules = [] + in_channels = latent_dim + + # Iteratively create resize-convolution layers + for out_channels in reversed(layers): + modules.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)) # Resizing + modules.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)) # Convolution + modules.append(act_fn) # Activation function + in_channels = out_channels + + # modules.pop() # final activation linear + + self.conv = nn.Sequential(*modules) + + def forward(self, x): + return self.conv(x) + + +# Defining the full autoencoder + +class Autoencoder(nn.Module): + ''' + Wrapper for autoencoder with 1 encoder and 1 decoder handling all data channels together + ''' + def __init__(self, encoder, decoder): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, x): + encoded = self.encoder(x) + decoded = self.decoder(encoded) + return decoded + +class AutoencoderSeparate(nn.Module): + def __init__(self, encoder_T, encoder_P, encoder_M, decoder_T, decoder_P, decoder_M): + ''' + Wrapper for autoencoder with 3 encoders and 3 decoders handling all data channels separate. Right now this is hard coded to be just for 3 channels corresponding to EM data (though any data with 3 channels would be compatible). + ''' + super().__init__() + self.encoderT = encoder_T + self.encoderP = encoder_P + self.encoderM = encoder_M + self.decoderT = decoder_T + self.decoderP = decoder_P + self.decoderM = decoder_M + + def forward(self, x): + z_t = self.encoderT(x[:, 0:1, :, :]) # only T channel + z_p = self.encoderP(x[:, 1:2, :, :]) # only P channel + z_m = self.encoderM(x[:, 2:3, :, :]) # only M channel + + decoded_t = self.decoderT(z_t) # decode T + decoded_p = self.decoderP(z_p) # decode P + decoded_m = self.decoderM(z_m) # decode M + decoded = torch.cat((decoded_t, decoded_p, decoded_m), dim=1) # concat for output + + return decoded + +class ConvolutionalAutoencoder: + def __init__(self, autoencoder, optimizer, device, save_path=None, weights_name=None): + self.network = autoencoder.to(device) + self.optimizer = optimizer + self.device = device + self.save_path = save_path + self.weights_name = weights_name + + def autoencode(self, x): + return self.network(x) + + def encode(self, x): + return self.network.encoder(x) + + def decode(self, x): + return self.network.decoder(x) diff --git a/PARCtorch/autoencoder/train.py b/PARCtorch/autoencoder/train.py new file mode 100644 index 0000000..03c066e --- /dev/null +++ b/PARCtorch/autoencoder/train.py @@ -0,0 +1,138 @@ +import torch +import numpy as np +from tqdm import tqdm + +from autoencoder import * +from utils import save_model, save_log, add_random_noise + +def train_autoencoder(model, optimizer, loss_function, train_loader, val_loader, + device, epochs=10, image_size=(64, 64), n_channels=3, + scheduler=None, noise_fn=None, initial_max_noise=0.16, + n_reduce_factor=0.5, reduce_on=1000, save_path=None, weights_name=None): + """ Train an autoencoder with optional noise injection. """ + + log_dict = {'training_loss_per_epoch': [], 'validation_loss_per_epoch': []} + + model.to(device) + + max_noise = initial_max_noise # Initial noise level + + for epoch in range(epochs): + print(f"\nEpoch {epoch + 1}/{epochs}") + + # Reduce noise periodically + if (epoch + 1) % reduce_on == 0: + max_noise *= n_reduce_factor + + # --- Training --- + model.train() + train_losses = [] + for images in tqdm(train_loader, desc="Training"): + optimizer.zero_grad() + images = images[0][:, 0:n_channels, ...].to(device) + + # Apply noise if function is provided + noisy_images = noise_fn(images, max_val=max_noise) if noise_fn else images + + # Forward pass + output = model(noisy_images) + loss = loss_function(output, images.view(-1, n_channels, *image_size)) + loss.backward() + optimizer.step() + train_losses.append(loss.item()) + + avg_train_loss = np.mean(train_losses) + log_dict['training_loss_per_epoch'].append(avg_train_loss) + + # --- Validation --- + model.eval() + val_losses = [] + with torch.no_grad(): + for val_images in tqdm(val_loader, desc="Validating"): + val_images = val_images[0][:, 0:n_channels, ...].to(device) + + # Forward pass + output = model(val_images) + val_loss = loss_function(output, val_images.view(-1, n_channels, *image_size)) + val_losses.append(val_loss.item()) + + avg_val_loss = np.mean(val_losses) + log_dict['validation_loss_per_epoch'].append(avg_val_loss) + + print(f"Epoch {epoch+1}: Training Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f}") + + if scheduler: + scheduler.step(avg_val_loss) + + # Save model and logs + save_model(model, save_path, weights_name, epochs) + save_log(log_dict, save_path, weights_name, epochs) + + return log_dict + + +def train_individual_autoencoder(model, optimizer, loss_function, train_loader, val_loader, + device, epochs=10, image_size=(64, 64), channel_index=0, + scheduler=None, noise_fn=None, initial_max_noise=0.16, + n_reduce_factor=0.8, reduce_on=1000, save_path=None, weights_name=None): + + """ Train an autoencoder on just one channel at a time with optional noise injection. """ + + log_dict = {'training_loss_per_epoch': [], 'validation_loss_per_epoch': []} + + model.to(device) + + max_noise = initial_max_noise # Initial noise level + + for epoch in range(epochs): + print(f"\nEpoch {epoch + 1}/{epochs}") + + # Reduce noise periodically + if (epoch + 1) % reduce_on == 0: + max_noise *= n_reduce_factor + + # --- Training --- + model.train() + train_losses = [] + for images in tqdm(train_loader, desc="Training"): + optimizer.zero_grad() + images = images[0][:, channel_index:channel_index+1, ...].to(device) + + # Apply noise if function is provided + noisy_images = noise_fn(images, max_val=max_noise) if noise_fn else images + + # Forward pass + output = model(noisy_images) + loss = loss_function(output, images.view(-1, 1, *image_size)) + loss.backward() + optimizer.step() + train_losses.append(loss.item()) + + avg_train_loss = np.mean(train_losses) + log_dict['training_loss_per_epoch'].append(avg_train_loss) + + # --- Validation --- + model.eval() + val_losses = [] + with torch.no_grad(): + for val_images in tqdm(val_loader, desc="Validating"): + val_images = val_images[0][:, channel_index:channel_index+1, ...].to(device) + + # Forward pass + output = model(val_images) + val_loss = loss_function(output, val_images.view(-1, 1, *image_size)) + val_losses.append(val_loss.item()) + + avg_val_loss = np.mean(val_losses) + log_dict['validation_loss_per_epoch'].append(avg_val_loss) + + print(f"Epoch {epoch+1}: Training Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f}") + + if scheduler: + scheduler.step(avg_val_loss) + + # Save model and logs + save_model(model, save_path, weights_name, epochs) + save_log(log_dict, save_path, weights_name, epochs) + + return log_dict \ No newline at end of file diff --git a/PARCtorch/autoencoder/utils.py b/PARCtorch/autoencoder/utils.py new file mode 100644 index 0000000..db732a5 --- /dev/null +++ b/PARCtorch/autoencoder/utils.py @@ -0,0 +1,51 @@ +import torch +import json + +def save_model(model, save_path, weights_name, epochs): + """ Save model weights to file. """ + if save_path: + save_file = f"{save_path}/{weights_name}_{epochs}.pth" + torch.save(model.state_dict(), save_file) + print(f"Model weights saved to {save_file}") + +def save_log(log_dict, save_path, weights_name, epochs): + """ Save training logs as JSON. """ + if save_path: + log_file = f"{save_path}/{weights_name}_{epochs}.json" + with open(log_file, 'w') as f: + json.dump(log_dict, f) + +def add_random_noise(images, min_val=0.0, max_val=0.1): + """ + Add random (uniform) noise to the images. + + Parameters: + images: Tensor of input images. + min_val: Minimum value of the noise. + max_val: Maximum value of the noise. + + Returns: + Noisy images. + """ + noise = torch.rand_like(images) * (max_val - min_val) + min_val + noisy_images = images + noise + return torch.clamp(noisy_images, 0.0, 1.0) # Keep pixel values in [0, 1] + +class LpLoss(nn.Module): + def __init__(self, p=10, reduction='mean'): + super(LpLoss, self).__init__() + self.p = p + if reduction not in ('none', 'mean', 'sum'): + raise ValueError(f"Invalid reduction mode: {reduction}") + self.reduction = reduction + + def forward(self, input, target): + diff = torch.abs(input - target) ** self.p + loss = torch.sum(diff, dim=tuple(range(1, diff.ndim))) ** (1 / self.p) # norm per sample + + if self.reduction == 'mean': + return torch.mean(loss) + elif self.reduction == 'sum': + return torch.sum(loss) + else: # 'none' + return loss diff --git a/tests/test_latentparc.py b/tests/test_latentparc.py new file mode 100644 index 0000000..8181d04 --- /dev/null +++ b/tests/test_latentparc.py @@ -0,0 +1,28 @@ +from PARCtorch.LatentPARC import LatentPARC +from PARCtorch.model import PARC +from PARCtorch.utilities.autoencoder import Encoder, Decoder +import torch + + +def test_latentparc_default(): + # This test requires a GPU to run + model = LatentPARC().cuda() + assert issubclass(model, PARC) + assert isinstance(model.encoder, Encoder) + assert isinstance(model.decoder, Decoder) + assert issubclass(model.differentiator, torch.nn.Module) + assert issubclass(model.integrator, torch.nn.Module) + # freeze_encoder_decoder functionalities + model.freeze_encoder_decoder() + for p in model.encoder.paramters(): + assert p.reqires_grad is False + for p in model.decoder.paramters(): + assert p.requires_grad is False + # Forward pass + # TODO: complete upon submission of draft to avoid leaking implementation details + ic = torch.randn(4, 5, 128, 256, dtype=torch.float32, device="cuda") + t0 = torch.tensor(0.0, dtype=torch.float32, device="cuda") + t1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device="cuda") + gt = torch.randn(3, 4, 5, 128, 256, dtype=torch.float32, device="cuda") + # Backward pass + # TODO: complete upon submission of draft to avoid leaking implementation details