-
Notifications
You must be signed in to change notification settings - Fork 7
Latent PARC #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
chengxinlun
wants to merge
6
commits into
main
Choose a base branch
from
latent_parc
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Latent PARC #59
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
dec4001
Test for latentparc
16a5fcd
Merge branch 'main' of https://github.com/baeklab/PARCtorch into late…
53b7f98
Merge branch 'main' into latent_parc
zgrayblue 9f6911e
added autoencoder modules
8a6ee7b
Merge branch 'latent_parc' of https://github.com/baeklab/PARCtorch in…
1d50031
adressing naming convention, imports, module location
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
| import torch.optim as optim | ||
| from tqdm import tqdm | ||
| from torchvision.utils import make_grid | ||
| import numpy as np | ||
| import matplotlib.pyplot as plt | ||
| import json | ||
| import torch.nn.functional as F | ||
|
|
||
| # Define Modules # | ||
|
|
||
| class MLPEncoder(nn.Module): | ||
| def __init__(self, layers, latent_dim, act_fn=nn.ReLU()): | ||
| ''' | ||
| note: layers and latent dim are treated the same as for convolutional autoencoder at the input level, but under the hood the MLP model flattens each layer as it should | ||
|
|
||
| layers: list, channel values excluding latent dim channels e.g [3, 8], should be the same as layers for decoder | ||
| latent_dim: int, number of channels to have in the latent bottlneck layer | ||
| act_fn: activation function to be used throughout entire model | ||
| ''' | ||
| super().__init__() | ||
| modules = [] | ||
| in_dim = layers[0] | ||
| for dim in layers[1:]: | ||
| modules.append(nn.Linear(in_dim, dim)) | ||
| modules.append(act_fn) | ||
| in_dim = dim | ||
| modules.append(nn.Linear(in_dim, latent_dim)) # Bottleneck layer | ||
| self.net = nn.Sequential(*modules) | ||
|
|
||
| def forward(self, x): | ||
| # Flatten input except batch dimension | ||
| x = x.view(x.size(0), -1) | ||
| return self.net(x) | ||
|
|
||
|
|
||
| class MLPDecoder(nn.Module): | ||
| def __init__(self, layers, latent_dim, output_shape=(3, 128, 256), act_fn=nn.ReLU()): | ||
| ''' | ||
| note: layers and latent dim are treated the same as for convolutional autoencoder at the input level, but under the hood the MLP model flattens each layer as it should | ||
|
|
||
| layers: list, channel values excluding latent dim channels e.g [3, 8], should be the same as layers for decoder | ||
| latent_dim: int, number of channels to have in the latent bottlneck layer | ||
| output_shape: tuple, used to reshape the flattened vector correctly upon output (n_channels, height, width) | ||
| act_fn: activation function to be used throughout entire model | ||
| ''' | ||
| super().__init__() | ||
| self.output_shape = output_shape | ||
| modules = [] | ||
| in_dim = latent_dim | ||
| for dim in reversed(layers): | ||
| modules.append(nn.Linear(in_dim, dim)) | ||
| modules.append(act_fn) | ||
| in_dim = dim | ||
| self.net = nn.Sequential(*modules) | ||
|
|
||
| def forward(self, x): | ||
| x = self.net(x) | ||
| batch_size = x.size(0) | ||
| return x.view(batch_size, *self.output_shape) | ||
|
|
||
|
|
||
| # Convolutional AE | ||
| class Encoder(nn.Module): | ||
| def __init__(self, layers, latent_dim=16, act_fn=nn.ReLU()): | ||
| ''' | ||
| layers: list, channel values excluding latent dim channels e.g [3, 8], should be the same as layers for decoder | ||
| latent_dim: int, number of channels to have in the latent bottlneck layer | ||
| act_fn: activation function to be used throughout entire model | ||
| ''' | ||
| super().__init__() | ||
| modules = [] | ||
| in_channels = layers[0] | ||
| for out_channels in layers[1:]: | ||
| modules.append( | ||
| nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) | ||
| ) # Keep padding=1 for same-sized convolutions | ||
| modules.append(act_fn) | ||
| in_channels = out_channels | ||
| modules.append( | ||
| nn.Conv2d(layers[-1], latent_dim, kernel_size=3, stride = 2, padding=1) | ||
| ) # Bottleneck layer | ||
| self.net = nn.Sequential(*modules) | ||
|
|
||
| def forward(self, x): | ||
| return self.net(x) | ||
|
|
||
|
|
||
| class Decoder(nn.Module): # no deconv | ||
| def __init__(self, layers, latent_dim=16, act_fn=nn.ReLU()): | ||
| ''' | ||
| layers: list, channel values excluding latent dim channels e.g [3, 8], should be the same as layers for encoder | ||
| latent_dim: int, number of channels to have in the latent bottlneck layer | ||
| act_fn: activation function to be used throughout entire model | ||
| ''' | ||
| super().__init__() | ||
|
|
||
| self.in_channels = layers[-1] | ||
| self.latent_dim = latent_dim | ||
|
|
||
| modules = [] | ||
| in_channels = latent_dim | ||
|
|
||
| # Iteratively create resize-convolution layers | ||
| for out_channels in reversed(layers): | ||
| modules.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)) # Resizing | ||
| modules.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)) # Convolution | ||
| modules.append(act_fn) # Activation function | ||
| in_channels = out_channels | ||
|
|
||
| # modules.pop() # final activation linear | ||
|
|
||
| self.conv = nn.Sequential(*modules) | ||
|
|
||
| def forward(self, x): | ||
| return self.conv(x) | ||
|
|
||
|
|
||
| # Defining the full autoencoder | ||
|
|
||
| class Autoencoder(nn.Module): | ||
| ''' | ||
| Wrapper for autoencoder with 1 encoder and 1 decoder handling all data channels together | ||
| ''' | ||
| def __init__(self, encoder, decoder): | ||
| super().__init__() | ||
| self.encoder = encoder | ||
| self.decoder = decoder | ||
|
|
||
| def forward(self, x): | ||
| encoded = self.encoder(x) | ||
| decoded = self.decoder(encoded) | ||
| return decoded | ||
|
|
||
| class AutoencoderSeparate(nn.Module): | ||
| def __init__(self, encoder_T, encoder_P, encoder_M, decoder_T, decoder_P, decoder_M): | ||
| ''' | ||
| Wrapper for autoencoder with 3 encoders and 3 decoders handling all data channels separate. Right now this is hard coded to be just for 3 channels corresponding to EM data (though any data with 3 channels would be compatible). | ||
| ''' | ||
| super().__init__() | ||
| self.encoderT = encoder_T | ||
| self.encoderP = encoder_P | ||
| self.encoderM = encoder_M | ||
| self.decoderT = decoder_T | ||
| self.decoderP = decoder_P | ||
| self.decoderM = decoder_M | ||
|
|
||
| def forward(self, x): | ||
| z_t = self.encoderT(x[:, 0:1, :, :]) # only T channel | ||
| z_p = self.encoderP(x[:, 1:2, :, :]) # only P channel | ||
| z_m = self.encoderM(x[:, 2:3, :, :]) # only M channel | ||
|
|
||
| decoded_t = self.decoderT(z_t) # decode T | ||
| decoded_p = self.decoderP(z_p) # decode P | ||
| decoded_m = self.decoderM(z_m) # decode M | ||
| decoded = torch.cat((decoded_t, decoded_p, decoded_m), dim=1) # concat for output | ||
|
|
||
| return decoded | ||
|
|
||
| class ConvolutionalAutoencoder: | ||
| def __init__(self, autoencoder, optimizer, device, save_path=None, weights_name=None): | ||
| self.network = autoencoder.to(device) | ||
| self.optimizer = optimizer | ||
| self.device = device | ||
| self.save_path = save_path | ||
| self.weights_name = weights_name | ||
|
|
||
| def autoencode(self, x): | ||
| return self.network(x) | ||
|
|
||
| def encode(self, x): | ||
| return self.network.encoder(x) | ||
|
|
||
| def decode(self, x): | ||
| return self.network.decoder(x) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| import torch | ||
| import numpy as np | ||
| from tqdm import tqdm | ||
|
|
||
| from autoencoder import * | ||
| from utils import save_model, save_log, add_random_noise | ||
|
|
||
| def train_autoencoder(model, optimizer, loss_function, train_loader, val_loader, | ||
| device, epochs=10, image_size=(64, 64), n_channels=3, | ||
| scheduler=None, noise_fn=None, initial_max_noise=0.16, | ||
| n_reduce_factor=0.5, reduce_on=1000, save_path=None, weights_name=None): | ||
| """ Train an autoencoder with optional noise injection. """ | ||
|
|
||
| log_dict = {'training_loss_per_epoch': [], 'validation_loss_per_epoch': []} | ||
|
|
||
| model.to(device) | ||
|
|
||
| max_noise = initial_max_noise # Initial noise level | ||
|
|
||
| for epoch in range(epochs): | ||
| print(f"\nEpoch {epoch + 1}/{epochs}") | ||
|
|
||
| # Reduce noise periodically | ||
| if (epoch + 1) % reduce_on == 0: | ||
| max_noise *= n_reduce_factor | ||
|
|
||
| # --- Training --- | ||
| model.train() | ||
| train_losses = [] | ||
| for images in tqdm(train_loader, desc="Training"): | ||
| optimizer.zero_grad() | ||
| images = images[0][:, 0:n_channels, ...].to(device) | ||
|
|
||
| # Apply noise if function is provided | ||
| noisy_images = noise_fn(images, max_val=max_noise) if noise_fn else images | ||
|
|
||
| # Forward pass | ||
| output = model(noisy_images) | ||
| loss = loss_function(output, images.view(-1, n_channels, *image_size)) | ||
| loss.backward() | ||
| optimizer.step() | ||
| train_losses.append(loss.item()) | ||
|
|
||
| avg_train_loss = np.mean(train_losses) | ||
| log_dict['training_loss_per_epoch'].append(avg_train_loss) | ||
|
|
||
| # --- Validation --- | ||
| model.eval() | ||
| val_losses = [] | ||
| with torch.no_grad(): | ||
| for val_images in tqdm(val_loader, desc="Validating"): | ||
| val_images = val_images[0][:, 0:n_channels, ...].to(device) | ||
|
|
||
| # Forward pass | ||
| output = model(val_images) | ||
| val_loss = loss_function(output, val_images.view(-1, n_channels, *image_size)) | ||
| val_losses.append(val_loss.item()) | ||
|
|
||
| avg_val_loss = np.mean(val_losses) | ||
| log_dict['validation_loss_per_epoch'].append(avg_val_loss) | ||
|
|
||
| print(f"Epoch {epoch+1}: Training Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f}") | ||
|
|
||
| if scheduler: | ||
| scheduler.step(avg_val_loss) | ||
|
|
||
| # Save model and logs | ||
| save_model(model, save_path, weights_name, epochs) | ||
| save_log(log_dict, save_path, weights_name, epochs) | ||
|
|
||
| return log_dict | ||
|
|
||
|
|
||
| def train_individual_autoencoder(model, optimizer, loss_function, train_loader, val_loader, | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as before. Please add functionalities for resuming training from checkpoint. |
||
| device, epochs=10, image_size=(64, 64), channel_index=0, | ||
| scheduler=None, noise_fn=None, initial_max_noise=0.16, | ||
| n_reduce_factor=0.8, reduce_on=1000, save_path=None, weights_name=None): | ||
|
|
||
| """ Train an autoencoder on just one channel at a time with optional noise injection. """ | ||
|
|
||
| log_dict = {'training_loss_per_epoch': [], 'validation_loss_per_epoch': []} | ||
|
|
||
| model.to(device) | ||
|
|
||
| max_noise = initial_max_noise # Initial noise level | ||
|
|
||
| for epoch in range(epochs): | ||
| print(f"\nEpoch {epoch + 1}/{epochs}") | ||
|
|
||
| # Reduce noise periodically | ||
| if (epoch + 1) % reduce_on == 0: | ||
| max_noise *= n_reduce_factor | ||
|
|
||
| # --- Training --- | ||
| model.train() | ||
| train_losses = [] | ||
| for images in tqdm(train_loader, desc="Training"): | ||
| optimizer.zero_grad() | ||
| images = images[0][:, channel_index:channel_index+1, ...].to(device) | ||
|
|
||
| # Apply noise if function is provided | ||
| noisy_images = noise_fn(images, max_val=max_noise) if noise_fn else images | ||
|
|
||
| # Forward pass | ||
| output = model(noisy_images) | ||
| loss = loss_function(output, images.view(-1, 1, *image_size)) | ||
| loss.backward() | ||
| optimizer.step() | ||
| train_losses.append(loss.item()) | ||
|
|
||
| avg_train_loss = np.mean(train_losses) | ||
| log_dict['training_loss_per_epoch'].append(avg_train_loss) | ||
|
|
||
| # --- Validation --- | ||
| model.eval() | ||
| val_losses = [] | ||
| with torch.no_grad(): | ||
| for val_images in tqdm(val_loader, desc="Validating"): | ||
| val_images = val_images[0][:, channel_index:channel_index+1, ...].to(device) | ||
|
|
||
| # Forward pass | ||
| output = model(val_images) | ||
| val_loss = loss_function(output, val_images.view(-1, 1, *image_size)) | ||
| val_losses.append(val_loss.item()) | ||
|
|
||
| avg_val_loss = np.mean(val_losses) | ||
| log_dict['validation_loss_per_epoch'].append(avg_val_loss) | ||
|
|
||
| print(f"Epoch {epoch+1}: Training Loss: {avg_train_loss:.4f} | Validation Loss: {avg_val_loss:.4f}") | ||
|
|
||
| if scheduler: | ||
| scheduler.step(avg_val_loss) | ||
|
|
||
| # Save model and logs | ||
| save_model(model, save_path, weights_name, epochs) | ||
| save_log(log_dict, save_path, weights_name, epochs) | ||
|
|
||
| return log_dict | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| import torch | ||
| import json | ||
|
|
||
| def save_model(model, save_path, weights_name, epochs): | ||
| """ Save model weights to file. """ | ||
| if save_path: | ||
| save_file = f"{save_path}/{weights_name}_{epochs}.pth" | ||
| torch.save(model.state_dict(), save_file) | ||
| print(f"Model weights saved to {save_file}") | ||
|
|
||
| def save_log(log_dict, save_path, weights_name, epochs): | ||
| """ Save training logs as JSON. """ | ||
| if save_path: | ||
| log_file = f"{save_path}/{weights_name}_{epochs}.json" | ||
| with open(log_file, 'w') as f: | ||
| json.dump(log_dict, f) | ||
|
|
||
| def add_random_noise(images, min_val=0.0, max_val=0.1): | ||
| """ | ||
| Add random (uniform) noise to the images. | ||
|
|
||
| Parameters: | ||
| images: Tensor of input images. | ||
| min_val: Minimum value of the noise. | ||
| max_val: Maximum value of the noise. | ||
|
|
||
| Returns: | ||
| Noisy images. | ||
| """ | ||
| noise = torch.rand_like(images) * (max_val - min_val) + min_val | ||
| noisy_images = images + noise | ||
| return torch.clamp(noisy_images, 0.0, 1.0) # Keep pixel values in [0, 1] | ||
|
|
||
| class LpLoss(nn.Module): | ||
| def __init__(self, p=10, reduction='mean'): | ||
| super(LpLoss, self).__init__() | ||
| self.p = p | ||
| if reduction not in ('none', 'mean', 'sum'): | ||
| raise ValueError(f"Invalid reduction mode: {reduction}") | ||
| self.reduction = reduction | ||
|
|
||
| def forward(self, input, target): | ||
| diff = torch.abs(input - target) ** self.p | ||
| loss = torch.sum(diff, dim=tuple(range(1, diff.ndim))) ** (1 / self.p) # norm per sample | ||
|
|
||
| if self.reduction == 'mean': | ||
| return torch.mean(loss) | ||
| elif self.reduction == 'sum': | ||
| return torch.sum(loss) | ||
| else: # 'none' | ||
| return loss | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| from PARCtorch.LatentPARC import LatentPARC | ||
| from PARCtorch.model import PARC | ||
| from PARCtorch.utilities.autoencoder import Encoder, Decoder | ||
| import torch | ||
|
|
||
|
|
||
| def test_latentparc_default(): | ||
| # This test requires a GPU to run | ||
| model = LatentPARC().cuda() | ||
| assert issubclass(model, PARC) | ||
| assert isinstance(model.encoder, Encoder) | ||
| assert isinstance(model.decoder, Decoder) | ||
| assert issubclass(model.differentiator, torch.nn.Module) | ||
| assert issubclass(model.integrator, torch.nn.Module) | ||
| # freeze_encoder_decoder functionalities | ||
| model.freeze_encoder_decoder() | ||
| for p in model.encoder.paramters(): | ||
| assert p.reqires_grad is False | ||
| for p in model.decoder.paramters(): | ||
| assert p.requires_grad is False | ||
| # Forward pass | ||
| # TODO: complete upon submission of draft to avoid leaking implementation details | ||
| ic = torch.randn(4, 5, 128, 256, dtype=torch.float32, device="cuda") | ||
| t0 = torch.tensor(0.0, dtype=torch.float32, device="cuda") | ||
| t1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device="cuda") | ||
| gt = torch.randn(3, 4, 5, 128, 256, dtype=torch.float32, device="cuda") | ||
| # Backward pass | ||
| # TODO: complete upon submission of draft to avoid leaking implementation details |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add functionalities for resuming training from checkpoint.