diff --git a/examples/burgers1d.yml b/examples/burgers1d.yml index dc43164..4b6c821 100644 --- a/examples/burgers1d.yml +++ b/examples/burgers1d.yml @@ -1,12 +1,12 @@ lasdi: type: gplasdi gplasdi: - # device: mps + device: cuda n_samples: 20 lr: 0.001 max_iter: 28000 n_iter: 2000 - max_greedy_iter: 28000 + max_greedy_iter: 28000 ld_weight: 0.1 coef_weight: 1.e-6 path_checkpoint: checkpoint @@ -67,12 +67,21 @@ latent_space: ae: hidden_units: [100] latent_dimension: 5 + activation: softplus latent_dynamics: type: sindy sindy: fd_type: sbp12 coef_norm_order: fro + higher_order_terms: 1 + extra_functions: [] + type: edmd + edmd: + fd_type: sbp12 + coef_norm_order: fro + higher_order_terms: 0 + extra_functions: [] physics: type: burgers1d diff --git a/examples/train_model1.py b/examples/train_model1.py new file mode 100644 index 0000000..4309cde --- /dev/null +++ b/examples/train_model1.py @@ -0,0 +1,402 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import sys + +sys.path.append('/p/gpfs1/khurana1/test/GPLaSDI') +import numpy as np +import torch +import time +from src.lasdi.physics import Physics +import os +import matplotlib.pyplot as plt +import torch.nn as nn +import pickle +from src.lasdi.FNO.fno import FNO +from src.lasdi.gp import eval_gp + +torch.manual_seed(0) +np.random.rand(0) + +#For movies + +# import matplotlib.animation as manimation +# FFMpegWriter = manimation.writers['ffmpeg'] +# metadata = dict(title='Movie Test', artist='Matplotlib', +# comment='Blast') +# writer = FFMpegWriter(fps=15) + +# name = 'steelruns' + +date = time.localtime() +date_str = "{month:02d}_{day:02d}_{year:04d}_{hour:02d}_{minute:02d}" +date_str = date_str.format(month = date.tm_mon, day = date.tm_mday, year = date.tm_year, hour = date.tm_hour + 3, minute = date.tm_min) + +name = date_str + +#Initialize training params +dt = 0.001 +ae_weight = 1e2 +sindy_weight = 1.e-1 +coef_weight= 1.e-6 +lr = 1e-1 + +path_data = 'data/' + + +path_data = 'data/' +data_train = np.load(path_data + 'data_train.npy', allow_pickle = True).item() +data_test = np.load(path_data + 'data_test.npy', allow_pickle = True).item() + +# These can be any data, as long as they are in the form [num_simulations, num_timesteps, num_spatial_nodes] +X_train = torch.Tensor(data_train['X_train']) +X_test = (data_test['X_test']) +param_train = data_train['param_train'] +param_test = data_test['param_grid'] + +sol_dim = 1 +grid_dim = [data_test['param_grid'].shape[0]] +time_dim = X_train.shape[1] +space_dim = X_train.shape[-1] + + +t_grid = np.linspace(0, (time_dim - 1)*dt, time_dim) + +#This physics class will also be able to train with any data. +#However, we will not be able to adaptively sample/generate new data +class CustomPhysicsModel(Physics): + def __init__(self): + self.dim = 1 + self.nt = time_dim + self.dt = dt + self.qdim = sol_dim + self.qgrid_size = [space_dim] + self.t_grid = t_grid + return + + ''' See lasdi.physics.Physics class for necessary subroutines ''' + +physics = CustomPhysicsModel() + +# Autoencoder Definition +hidden_units = [100] + +#latent dim (number of modes to keep) +n_z = 16 +n_layers = 3 +n_iter = 1000 + +# Tranposing the train data to match the shape for 1d fno! +model = FNO(n_modes=(1001,n_z), + in_channels=1, + out_channels=1, + hidden_channels=1, + projection_channel_ratio=2, use_channel_mlp=False, + positional_embedding=None, physics = physics, n_layers = n_layers, test_flag=False) + +best_loss = np.inf + +optimizer = torch.optim.Adam(model.parameters(), lr = lr) +MSE = torch.nn.MSELoss() + +cuda = torch.cuda.is_available() +if cuda: + device = 'cuda' +else: + device = 'cpu' + +# device = 'mps' + +print(device) + +model = model.to(device) + +# Adding extra dimension to the X_train +X_train = X_train.unsqueeze(1) +d_Xtrain = X_train.to(device) + +n_train = X_train.shape[0] +save_interval = 1 +hist_file = '%s.loss_history.txt' % name + +loss_hist = np.zeros([n_iter, 4]) +grad_hist = np.zeros([n_iter, 4]) + +tic_start = time.time() +max_coef = np.zeros((n_iter, 2)) +total_loss = [] +loss_ae_list = [] +loss_sindy_list = [] +loss_sindy_coefs_list = [] + +for iter in range(n_iter): + optimizer.zero_grad() + # model = ae.to(device) + d_Xpred = model(d_Xtrain) + + loss_ae = MSE(d_Xtrain, d_Xpred) #Reconstruction loss + + max_coef[iter][0] = torch.max(torch.abs(model.coefs[0])) + + loss = ae_weight * loss_ae + sindy_weight * sum(model.sindy_loss) / n_train + coef_weight * sum(model.loss_coefs) / n_train + + # Append to the lists for plotting purposes! + # Append to list for plotting! + total_loss.append(loss.detach().cpu().numpy()) + loss_sindy_list.append(sindy_weight * sum(model.sindy_loss).detach().cpu().numpy() / n_train) + loss_ae_list.append(ae_weight * loss_ae.detach().cpu().numpy()) + loss_sindy_coefs_list.append(coef_weight * sum(model.loss_coefs).detach().cpu().numpy() / n_train) + + loss_hist[iter] = [loss.item(), loss_ae.item(), sum(model.sindy_loss).item(), sum(model.loss_coefs).item()] + + loss.backward() + + optimizer.step() + + if ((loss.item() < best_loss) and (iter % save_interval == 0)): + os.makedirs(os.path.dirname('checkpoint/' + './%s_checkpoint.pt' % name), exist_ok=True) + torch.save(model.cpu().state_dict(), 'checkpoint/%s_checkpoint.pt' % name) + model = model.to(device) + best_loss = loss.item() + best_coefs = model.coefs + + print("Iter: %05d/%d, Loss: %.5e, Loss AE: %.5e, Loss SI: %.5e, Loss COEF: %.5e" + % (iter + 1, n_iter, loss.item(), loss_ae.item(), sum(model.sindy_loss).item(), sum(model.loss_coefs).item())) + +tic_end = time.time() +total_time = tic_end - tic_start + +## save results + +os.makedirs(os.path.dirname('losses/' + './%s_checkpoint.pt' % name), exist_ok=True) +np.savetxt('losses/%s.loss_history.txt' % name, loss_hist) + +if (loss.item() < best_loss): + torch.save(model.cpu().state_dict(), 'checkpoint/%s_checkpoint.pt' % name) + best_loss = loss.item() +else: + model.cpu().load_state_dict(torch.load('checkpoint/%s_checkpoint.pt' % name, weights_only=False)) + +bglasdi_results = {'autoencoder_param': model.cpu().state_dict(), 'final_param_train': param_train, + 'lr': lr, 'n_iter': n_iter, + 'sindy_weight': sindy_weight, 'coef_weight': coef_weight, + 't_grid' : t_grid, 'dt' : dt,'total_time' : total_time} + +os.makedirs(os.path.dirname('results/'), exist_ok=True) +np.save('results/bglasdi_' + date_str + '_lr' + str(lr) + '_sw' + str(sindy_weight) + '_cw'+str(coef_weight)+'_nt' + str(time_dim) + '_niter' + str(n_iter) + '_nz' + str(n_z) +'.npy', bglasdi_results) + + +## Plotting the loss and stuff! + +plt.figure() +plt.plot(max_coef[:,0])#, label = 'Real weight') +# plt.plot(max_coef[:,1], label = 'Imaginary weight') +plt.title('Max SINDY coefficient') +plt.ylabel('Magnitude') +plt.xlabel('iterations') +plt.grid() +plt.savefig('plots/max_weight.png') + +plt.figure(figsize=(12, 8)) +plt.suptitle('Training Loss v/s Iterations (LOG SCALE)') + +plt.subplot(2, 2, 1) +plt.plot(total_loss) +plt.xlabel('Iterations') +plt.ylabel('Total Loss') +plt.grid() +plt.yscale('log') + +plt.subplot(2, 2, 2) +plt.plot(loss_ae_list) +plt.xlabel('Iterations') +plt.ylabel('Loss AE') +plt.grid() +plt.yscale('log') + +plt.subplot(2, 2, 3) +plt.plot(loss_sindy_list) +plt.xlabel('Iterations') +plt.ylabel('Sindy loss') +plt.grid() +plt.yscale('log') + +plt.subplot(2, 2, 4) +plt.plot(loss_sindy_coefs_list) +plt.xlabel('Iterations') +plt.ylabel('Sindy coefficient loss') +plt.grid() +plt.yscale('log') + +plt.savefig('plots/loss.png') +# %% Plotting + +# autoencoder_param = model.cpu().state_dict() +# AEparams = [value for value in autoencoder_param.items()] +# AEparamnum = 0 +# for i in range(len(AEparams)): +# AEparamnum = AEparamnum + (AEparams[i][1].detach().numpy()).size + +from lasdi.gp import fit_gps +from lasdi.gplasdi import sample_roms, average_rom +import numpy.linalg as LA + +def IC(param): + return param[0] * np.exp(- np.linspace(-3, 3, 1001) ** 2 / 2 / param[1] / param[1]) + +#Need to know IC explicitly. If you only have IC which corresponds to certain param value, +#then you can do something like + +# def IC(param): +# for i in range(X_test.shape[0]): +# if np.abs(param[0] - all_param[i]) <1E-8: +# initcond = X_test[i,0,:] +# break +# return initcond + +physics.initial_condition = IC +model = model.cpu() + +# Get the model predictions! +""" +Deviating from the gausian process here! GP Stuff needs more thought! +""" + +# Get the initial conditions for the entire param grid! +n_param = param_test.shape[0] +z0 = np.zeros((n_param,1,1,1001), dtype=np.float32) + +sol_shape = [1, 1] + physics.qgrid_size + +for i in range(n_param): + u0 = physics.initial_condition(param_test[i]) + u0 = u0.reshape(sol_shape) + z0[i] = u0 + +# Interpolate the sindy coefficients! +gp_dictionnary_real = fit_gps(param_train, best_coefs[0]) +# gp_dictionnary_imag = fit_gps(param_train, best_coefs[1]) +pred_mean_real, _ = eval_gp(gp_dictionnary_real, param_test) +# pred_mean_imag,_ = eval_gp(gp_dictionnary_imag, param_test) + +### Plotting the latent space trajectories! +test_param = X_train[0] +test_coefs = [pred_mean_real[0], 0] + +# Plotting model! +plot_model = FNO(n_modes=(1001,n_z), + in_channels=1, + out_channels=1, + hidden_channels=1, + projection_channel_ratio=2, use_channel_mlp=False, + positional_embedding=None, physics = physics, n_layers = n_layers, + test_flag=False, best_coefs = test_coefs, plot_latent=True) +# Load the weights into the model! +plot_model.cpu().load_state_dict(torch.load('checkpoint/%s_checkpoint.pt' % name, weights_only=False)) + +# Get the plots! +plot_model(torch.unsqueeze(test_param,0)) + +#### Getting the entire testing! +# Get the pred_stuff! +X_pred_mean = np.zeros(X_test.shape) + +for i in range(param_test.shape[0]): + + # Get the interpolated coeff: + coefs = [pred_mean_real[i],0]# pred_mean_imag[i]] + + # Defining the test model! + test_model = FNO(n_modes=(1001,n_z), + in_channels=1, + out_channels=1, + hidden_channels=1, + projection_channel_ratio=2, use_channel_mlp=False, + positional_embedding=None, physics = physics, n_layers = n_layers, + test_flag=True, best_coefs=coefs, plot_latent=False) + + # Load the weights into the model! + test_model.cpu().load_state_dict(torch.load('checkpoint/%s_checkpoint.pt' % name, weights_only=False)) + + # Get the prediction! + X_pred_mean[i] = test_model(torch.unsqueeze(torch.from_numpy(z0[i]),0)).detach().numpy() + + +# %% + +#Plot a single value + +param_ind = 2 +param = np.array([param_test[param_ind]]) + + +true = X_test[param_ind,:,:] + +## Heatmap of errors + +param_grid = param_test +avg_rel_err = LA.norm(X_pred_mean - X_test,axis=2)/LA.norm(X_test,axis=2) +max_rel_err = np.max(avg_rel_err, axis = 1) + +figsize=(12, 12) + +n_a_grid = 21 +n_w_grid = 21 + +a_grid = param_grid[:21,0] +w_grid = param_grid[::21,1] + + +n_p1 = n_a_grid +n_p2 = n_w_grid +p1_grid = param_grid[:21,0] +p2_grid = param_grid[::21,1] + +fig, ax = plt.subplots(1, 1, figsize = figsize) +values = max_rel_err.T.reshape(21,21)*100 + +n_init = len(param_train) + + +from matplotlib.colors import LinearSegmentedColormap +cmap = LinearSegmentedColormap.from_list('rg', ['C0', 'w', 'C3'], N = 256) + +im = ax.imshow(values, cmap = cmap) +fig.colorbar(im, ax = ax, fraction = 0.04) + +ax.set_xticks(np.arange(0, n_a_grid, 2), labels=np.round(a_grid[::2], 2)) +ax.set_yticks(np.arange(0, n_w_grid, 2), labels=np.round(w_grid[::2], 2)) + +for i in range(n_p1): + for j in range(n_p2): + ax.text(j, i, round(values[i, j], 1), ha='center', va='center', color='k') + +grid_square_x = np.arange(-0.5, n_p1, 1) +grid_square_y = np.arange(-0.5, n_p2, 1) + +n_train = param_train.shape[0] +for i in range(n_train): + p1_index = np.sum((p1_grid < param_train[i, 0]) * 1) + p2_index = np.sum((p2_grid < param_train[i, 1]) * 1) + + if i < n_init: + color = 'r' + else: + color = 'k' + + ax.plot([grid_square_x[p1_index], grid_square_x[p1_index]], [grid_square_y[p2_index], grid_square_y[p2_index] + 1], + c=color, linewidth=2) + ax.plot([grid_square_x[p1_index] + 1, grid_square_x[p1_index] + 1], + [grid_square_y[p2_index], grid_square_y[p2_index] + 1], c=color, linewidth=2) + ax.plot([grid_square_x[p1_index], grid_square_x[p1_index] + 1], [grid_square_y[p2_index], grid_square_y[p2_index]], + c=color, linewidth=2) + ax.plot([grid_square_x[p1_index], grid_square_x[p1_index] + 1], + [grid_square_y[p2_index] + 1, grid_square_y[p2_index] + 1], c=color, linewidth=2) + +ax.set_xlabel('$a$', fontsize=25) +ax.set_ylabel('$w$', fontsize=25) +ax.set_title('Relative Error (%)', fontsize=30) +plt.savefig('plots/relative_error.png') +plt.show() + +# Print the total time! +print(f'Total Training Time: {total_time}') diff --git a/src/lasdi/FNO/fno.py b/src/lasdi/FNO/fno.py new file mode 100644 index 0000000..cd29b1b --- /dev/null +++ b/src/lasdi/FNO/fno.py @@ -0,0 +1,654 @@ +from functools import partialmethod +from typing import Tuple, List, Union, Literal + +Number = Union[float, int] + +import torch +import torch.nn as nn +import torch.nn.functional as F +#from ..layers.embeddings import GridEmbeddingND, GridEmbedding2D +from .spectral_convolution import SpectralConv +#from ..layers.padding import DomainPadding +from .fno_block import FNOBlocks +from .channel_mlp import ChannelMLP +from .base_model import BaseModel + +class FNO(BaseModel, name='FNO'): + """N-Dimensional Fourier Neural Operator. The FNO learns a mapping between + spaces of functions discretized over regular grids using Fourier convolutions, + as described in [1]_. + + The key component of an FNO is its SpectralConv layer (see + ``neuralop.layers.spectral_convolution``), which is similar to a standard CNN + conv layer but operates in the frequency domain. + + For a deeper dive into the FNO architecture, refer to :ref:`fno_intro`. + + Parameters + ---------- + n_modes : Tuple[int] + number of modes to keep in Fourier Layer, along each dimension + The dimensionality of the FNO is inferred from ``len(n_modes)`` + in_channels : int + Number of channels in input function + out_channels : int + Number of channels in output function + hidden_channels : int + width of the FNO (i.e. number of channels) + n_layers : int, optional + Number of Fourier Layers, by default 4 + + Documentation for more advanced parameters is below. + + Other parameters + ------------------ + lifting_channel_ratio : int, optional + ratio of lifting channels to hidden_channels, by default 2 + The number of liting channels in the lifting block of the FNO is + lifting_channel_ratio * hidden_channels (e.g. default 2 * hidden_channels) + projection_channel_ratio : int, optional + ratio of projection channels to hidden_channels, by default 2 + The number of projection channels in the projection block of the FNO is + projection_channel_ratio * hidden_channels (e.g. default 2 * hidden_channels) + positional_embedding : Union[str, nn.Module], optional + Positional embedding to apply to last channels of raw input + before being passed through the FNO. Defaults to "grid" + + * If "grid", appends a grid positional embedding with default settings to + the last channels of raw input. Assumes the inputs are discretized + over a grid with entry [0,0,...] at the origin and side lengths of 1. + + * If an initialized GridEmbedding module, uses this module directly + See :mod:`neuralop.embeddings.GridEmbeddingND` for details. + + * If None, does nothing + + non_linearity : nn.Module, optional + Non-Linear activation function module to use, by default F.gelu + norm : Literal ["ada_in", "group_norm", "instance_norm"], optional + Normalization layer to use, by default None + complex_data : bool, optional + Whether data is complex-valued (default False) + if True, initializes complex-valued modules. + use_channel_mlp : bool, optional + Whether to use an MLP layer after each FNO block, by default True + channel_mlp_dropout : float, optional + dropout parameter for ChannelMLP in FNO Block, by default 0 + channel_mlp_expansion : float, optional + expansion parameter for ChannelMLP in FNO Block, by default 0.5 + channel_mlp_skip : Literal['linear', 'identity', 'soft-gating'], optional + Type of skip connection to use in channel-mixing mlp, by default 'soft-gating' + fno_skip : Literal['linear', 'identity', 'soft-gating'], optional + Type of skip connection to use in FNO layers, by default 'linear' + resolution_scaling_factor : Union[Number, List[Number]], optional + layer-wise factor by which to scale the domain resolution of function, by default None + + * If a single number n, scales resolution by n at each layer + + * if a list of numbers [n_0, n_1,...] scales layer i's resolution by n_i. + domain_padding : Union[Number, List[Number]], optional + If not None, percentage of padding to use, by default None + To vary the percentage of padding used along each input dimension, + pass in a list of percentages e.g. [p1, p2, ..., pN] such that + p1 corresponds to the percentage of padding along dim 1, etc. + domain_padding_mode : Literal ['symmetric', 'one-sided'], optional + How to perform domain padding, by default 'symmetric' + fno_block_precision : str {'full', 'half', 'mixed'}, optional + precision mode in which to perform spectral convolution, by default "full" + stabilizer : str {'tanh'} | None, optional + whether to use a tanh stabilizer in FNO block, by default None + + Note: stabilizer greatly improves performance in the case + `fno_block_precision='mixed'`. + + max_n_modes : Tuple[int] | None, optional + + * If not None, this allows to incrementally increase the number of + modes in Fourier domain during training. Has to verify n <= N + for (n, m) in zip(max_n_modes, n_modes). + + * If None, all the n_modes are used. + + This can be updated dynamically during training. + factorization : str, optional + Tensor factorization of the FNO layer weights to use, by default None. + + * If None, a dense tensor parametrizes the Spectral convolutions + + * Otherwise, the specified tensor factorization is used. + rank : float, optional + tensor rank to use in above factorization, by default 1.0 + fixed_rank_modes : bool, optional + Modes to not factorize, by default False + implementation : str {'factorized', 'reconstructed'}, optional + + * If 'factorized', implements tensor contraction with the individual factors of the decomposition + + * If 'reconstructed', implements with the reconstructed full tensorized weight. + decomposition_kwargs : dict, optional + extra kwargs for tensor decomposition (see `tltorch.FactorizedTensor`), by default dict() + separable : bool, optional (**DEACTIVATED**) + if True, use a depthwise separable spectral convolution, by default False + preactivation : bool, optional (**DEACTIVATED**) + whether to compute FNO forward pass with resnet-style preactivation, by default False + conv_module : nn.Module, optional + module to use for FNOBlock's convolutions, by default SpectralConv + + Examples + --------- + + >>> from neuralop.models import FNO + >>> model = FNO(n_modes=(12,12), in_channels=1, out_channels=1, hidden_channels=64) + >>> model + FNO( + (positional_embedding): GridEmbeddingND() + (fno_blocks): FNOBlocks( + (convs): SpectralConv( + (weight): ModuleList( + (0-3): 4 x DenseTensor(shape=torch.Size([64, 64, 12, 7]), rank=None) + ) + ) + ... torch.nn.Module printout truncated ... + + References + ----------- + .. [1] : + + Li, Z. et al. "Fourier Neural Operator for Parametric Partial Differential + Equations" (2021). ICLR 2021, https://arxiv.org/pdf/2010.08895. + + """ + + def __init__( + self, + n_modes: Tuple[int], + in_channels: int, + out_channels: int, + hidden_channels: int, + n_layers: int=4, + lifting_channel_ratio: Number=2, + projection_channel_ratio: Number=2, + positional_embedding: Union[str, nn.Module]="grid", + non_linearity: nn.Module=F.gelu, + norm: Literal ["ada_in", "group_norm", "instance_norm"]=None, + complex_data: bool=False, + use_channel_mlp: bool=True, + channel_mlp_dropout: float=0, + channel_mlp_expansion: float=0.5, + channel_mlp_skip: Literal['linear', 'identity', 'soft-gating']="soft-gating", + fno_skip: Literal['linear', 'identity', 'soft-gating']="linear", + resolution_scaling_factor: Union[Number, List[Number]]=None, + domain_padding: Union[Number, List[Number]]=None, + domain_padding_mode: Literal['symmetric', 'one-sided']="symmetric", + fno_block_precision: str="full", + stabilizer: str=None, + max_n_modes: Tuple[int]=None, + factorization: str=None, + rank: float=1.0, + fixed_rank_modes: bool=False, + implementation: str="factorized", + decomposition_kwargs: dict=dict(), + separable: bool=False, + preactivation: bool=False, + conv_module: nn.Module=SpectralConv, + physics = None, + test_flag = False, + best_coefs = None, + plot_latent = False, + **kwargs + ): + + super().__init__() + if type(n_modes) == int: + self.n_dim = 1 + else: + self.n_dim = len(n_modes) + + # n_modes is a special property - see the class' property for underlying mechanism + # When updated, change should be reflected in fno blocks + self._n_modes = n_modes + + self.hidden_channels = hidden_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.n_layers = n_layers + + # init lifting and projection channels using ratios w.r.t hidden channels + self.lifting_channel_ratio = lifting_channel_ratio + self.lifting_channels = int(lifting_channel_ratio * self.hidden_channels) + + self.projection_channel_ratio = projection_channel_ratio + self.projection_channels = int(projection_channel_ratio * self.hidden_channels) + + self.non_linearity = non_linearity + self.rank = rank + self.factorization = factorization + self.fixed_rank_modes = fixed_rank_modes + self.decomposition_kwargs = decomposition_kwargs + self.fno_skip = (fno_skip,) + self.channel_mlp_skip = (channel_mlp_skip,) + self.implementation = implementation + self.separable = separable + self.preactivation = preactivation + self.fno_block_precision = fno_block_precision + + # Adding terms for the latent dynamics losses! + self.sindy_loss = None + self.coefs = None + self.loss_coefs = None + + + if domain_padding is not None and ( + (isinstance(domain_padding, list) and sum(domain_padding) > 0) + or (isinstance(domain_padding, (float, int)) and domain_padding > 0) + ): + self.domain_padding = DomainPadding( + domain_padding=domain_padding, + padding_mode=domain_padding_mode, + resolution_scaling_factor=resolution_scaling_factor, + ) + else: + self.domain_padding = None + self.positional_embedding = None + self.domain_padding_mode = domain_padding_mode + + if resolution_scaling_factor is not None: + if isinstance(resolution_scaling_factor, (float, int)): + resolution_scaling_factor = [resolution_scaling_factor] * self.n_layers + self.resolution_scaling_factor = resolution_scaling_factor + + self.fno_blocks = FNOBlocks( + in_channels=hidden_channels, + out_channels=hidden_channels, + n_modes=self.n_modes, + physics = physics, + resolution_scaling_factor=resolution_scaling_factor, + use_channel_mlp=use_channel_mlp, + channel_mlp_dropout=channel_mlp_dropout, + channel_mlp_expansion=channel_mlp_expansion, + non_linearity=non_linearity, + stabilizer=stabilizer, + norm=norm, + preactivation=preactivation, + fno_skip=fno_skip, + channel_mlp_skip=channel_mlp_skip, + complex_data=complex_data, + max_n_modes=max_n_modes, + fno_block_precision=fno_block_precision, + rank=rank, + fixed_rank_modes=fixed_rank_modes, + implementation=implementation, + separable=separable, + factorization=factorization, + decomposition_kwargs=decomposition_kwargs, + conv_module=conv_module, + n_layers=n_layers, + test_flag = test_flag, + best_coefs = best_coefs, + plot_latent = plot_latent, + **kwargs + ) + + # if adding a positional embedding, add those channels to lifting + lifting_in_channels = self.in_channels + if self.positional_embedding is not None: + lifting_in_channels += self.n_dim + # if lifting_channels is passed, make lifting a Channel-Mixing MLP + # with a hidden layer of size lifting_channels + if self.lifting_channels: + self.lifting = ChannelMLP( + in_channels=lifting_in_channels, + out_channels=self.hidden_channels, + hidden_channels=self.lifting_channels, + n_layers=2, + n_dim=self.n_dim, + non_linearity=non_linearity + ) + # otherwise, make it a linear layer + else: + self.lifting = ChannelMLP( + in_channels=lifting_in_channels, + hidden_channels=self.hidden_channels, + out_channels=self.hidden_channels, + n_layers=1, + n_dim=self.n_dim, + non_linearity=non_linearity + ) + + self.projection = ChannelMLP( + in_channels=self.hidden_channels, + out_channels=out_channels, + hidden_channels=self.projection_channels, + n_layers=2, + n_dim=self.n_dim, + non_linearity=non_linearity, + ) + + def forward(self, x, output_shape=None, **kwargs): + """FNO's forward pass + + 1. Applies optional positional encoding + + 2. Sends inputs through a lifting layer to a high-dimensional latent space + + 3. Applies optional domain padding to high-dimensional intermediate function representation + + 4. Applies `n_layers` Fourier/FNO layers in sequence (SpectralConvolution + skip connections, nonlinearity) + + 5. If domain padding was applied, domain padding is removed + + 6. Projection of intermediate function representation to the output channels + + Parameters + ---------- + x : tensor + input tensor + + output_shape : {tuple, tuple list, None}, default is None + Gives the option of specifying the exact output shape for odd shaped inputs. + + * If None, don't specify an output shape + + * If tuple, specifies the output-shape of the **last** FNO Block + + * If tuple list, specifies the exact output-shape of each FNO Block + """ + + if output_shape is None: + output_shape = [None]*self.n_layers + elif isinstance(output_shape, tuple): + output_shape = [None]*(self.n_layers - 1) + [output_shape] + + # append spatial pos embedding if set + if self.positional_embedding is not None: + x = self.positional_embedding(x) + + # x = self.lifting(x) + + # if self.domain_padding is not None: + # x = self.domain_padding.pad(x) + for layer_idx in range(self.n_layers): + x = self.fno_blocks(x, layer_idx, output_shape=output_shape[layer_idx]) + + # Append the loss terms for the latent dynamics! + self.sindy_loss = self.fno_blocks.sindy_loss + self.coefs = self.fno_blocks.coefs + self.loss_coefs = self.fno_blocks.loss_coef + # if self.domain_padding is not None: + # x = self.domain_padding.unpad(x) + + # x = self.projection(x) + + return x + + @property + def n_modes(self): + return self._n_modes + + @n_modes.setter + def n_modes(self, n_modes): + self.fno_blocks.n_modes = n_modes + self._n_modes = n_modes + + +class FNO1d(FNO): + """1D Fourier Neural Operator + + For the full list of parameters, see :class:`neuralop.models.FNO`. + + Parameters + ---------- + modes_height : int + number of Fourier modes to keep along the height + """ + + def __init__( + self, + n_modes_height, + hidden_channels, + in_channels=3, + out_channels=1, + lifting_channels=256, + projection_channels=256, + max_n_modes=None, + n_layers=4, + resolution_scaling_factor=None, + non_linearity=F.gelu, + stabilizer=None, + complex_data=False, + fno_block_precision="full", + channel_mlp_dropout=0, + channel_mlp_expansion=0.5, + norm=None, + skip="soft-gating", + separable=False, + preactivation=False, + factorization=None, + rank=1.0, + fixed_rank_modes=False, + implementation="factorized", + decomposition_kwargs=dict(), + domain_padding=None, + domain_padding_mode="symmetric", + **kwargs + ): + super().__init__( + n_modes=(n_modes_height,), + hidden_channels=hidden_channels, + in_channels=in_channels, + out_channels=out_channels, + lifting_channels=lifting_channels, + projection_channels=projection_channels, + n_layers=n_layers, + resolution_scaling_factor=resolution_scaling_factor, + non_linearity=non_linearity, + stabilizer=stabilizer, + complex_data=complex_data, + fno_block_precision=fno_block_precision, + channel_mlp_dropout=channel_mlp_dropout, + channel_mlp_expansion=channel_mlp_expansion, + max_n_modes=max_n_modes, + norm=norm, + skip=skip, + separable=separable, + preactivation=preactivation, + factorization=factorization, + rank=rank, + fixed_rank_modes=fixed_rank_modes, + implementation=implementation, + decomposition_kwargs=decomposition_kwargs, + domain_padding=domain_padding, + domain_padding_mode=domain_padding_mode, + ) + self.n_modes_height = n_modes_height + + +class FNO2d(FNO): + """2D Fourier Neural Operator + + For the full list of parameters, see :class:`neuralop.models.FNO`. + + Parameters + ---------- + n_modes_width : int + number of modes to keep in Fourier Layer, along the width + n_modes_height : int + number of Fourier modes to keep along the height + """ + + def __init__( + self, + n_modes_height, + n_modes_width, + hidden_channels, + in_channels=3, + out_channels=1, + lifting_channels=256, + projection_channels=256, + n_layers=4, + resolution_scaling_factor=None, + max_n_modes=None, + non_linearity=F.gelu, + stabilizer=None, + complex_data=False, + fno_block_precision="full", + channel_mlp_dropout=0, + channel_mlp_expansion=0.5, + norm=None, + skip="soft-gating", + separable=False, + preactivation=False, + factorization=None, + rank=1.0, + fixed_rank_modes=False, + implementation="factorized", + decomposition_kwargs=dict(), + domain_padding=None, + domain_padding_mode="symmetric", + **kwargs + ): + super().__init__( + n_modes=(n_modes_height, n_modes_width), + hidden_channels=hidden_channels, + in_channels=in_channels, + out_channels=out_channels, + lifting_channels=lifting_channels, + projection_channels=projection_channels, + n_layers=n_layers, + resolution_scaling_factor=resolution_scaling_factor, + non_linearity=non_linearity, + stabilizer=stabilizer, + complex_data=complex_data, + fno_block_precision=fno_block_precision, + channel_mlp_dropout=channel_mlp_dropout, + channel_mlp_expansion=channel_mlp_expansion, + max_n_modes=max_n_modes, + norm=norm, + skip=skip, + separable=separable, + preactivation=preactivation, + factorization=factorization, + rank=rank, + fixed_rank_modes=fixed_rank_modes, + implementation=implementation, + decomposition_kwargs=decomposition_kwargs, + domain_padding=domain_padding, + domain_padding_mode=domain_padding_mode, + ) + self.n_modes_height = n_modes_height + self.n_modes_width = n_modes_width + + +class FNO3d(FNO): + """3D Fourier Neural Operator + + For the full list of parameters, see :class:`neuralop.models.FNO`. + + Parameters + ---------- + modes_width : int + number of modes to keep in Fourier Layer, along the width + modes_height : int + number of Fourier modes to keep along the height + modes_depth : int + number of Fourier modes to keep along the depth + """ + + def __init__( + self, + n_modes_height, + n_modes_width, + n_modes_depth, + hidden_channels, + in_channels=3, + out_channels=1, + lifting_channels=256, + projection_channels=256, + n_layers=4, + resolution_scaling_factor=None, + max_n_modes=None, + non_linearity=F.gelu, + stabilizer=None, + complex_data=False, + fno_block_precision="full", + channel_mlp_dropout=0, + channel_mlp_expansion=0.5, + norm=None, + skip="soft-gating", + separable=False, + preactivation=False, + factorization=None, + rank=1.0, + fixed_rank_modes=False, + implementation="factorized", + decomposition_kwargs=dict(), + domain_padding=None, + domain_padding_mode="symmetric", + **kwargs + ): + super().__init__( + n_modes=(n_modes_height, n_modes_width, n_modes_depth), + hidden_channels=hidden_channels, + in_channels=in_channels, + out_channels=out_channels, + lifting_channels=lifting_channels, + projection_channels=projection_channels, + n_layers=n_layers, + resolution_scaling_factor=resolution_scaling_factor, + non_linearity=non_linearity, + stabilizer=stabilizer, + complex_data=complex_data, + fno_block_precision=fno_block_precision, + max_n_modes=max_n_modes, + channel_mlp_dropout=channel_mlp_dropout, + channel_mlp_expansion=channel_mlp_expansion, + norm=norm, + skip=skip, + separable=separable, + preactivation=preactivation, + factorization=factorization, + rank=rank, + fixed_rank_modes=fixed_rank_modes, + implementation=implementation, + decomposition_kwargs=decomposition_kwargs, + domain_padding=domain_padding, + domain_padding_mode=domain_padding_mode, + ) + self.n_modes_height = n_modes_height + self.n_modes_width = n_modes_width + self.n_modes_depth = n_modes_depth + + +def partialclass(new_name, cls, *args, **kwargs): + """Create a new class with different default values + + Notes + ----- + An obvious alternative would be to use functools.partial + >>> new_class = partial(cls, **kwargs) + + The issue is twofold: + 1. the class doesn't have a name, so one would have to set it explicitly: + >>> new_class.__name__ = new_name + + 2. the new class will be a functools object and one cannot inherit from it. + + Instead, here, we define dynamically a new class, inheriting from the existing one. + """ + __init__ = partialmethod(cls.__init__, *args, **kwargs) + new_class = type( + new_name, + (cls,), + { + "__init__": __init__, + "__doc__": cls.__doc__, + "forward": cls.forward, + }, + ) + return new_class + + +TFNO = partialclass("TFNO", FNO, factorization="Tucker") +TFNO1d = partialclass("TFNO1d", FNO1d, factorization="Tucker") +TFNO2d = partialclass("TFNO2d", FNO2d, factorization="Tucker") +TFNO3d = partialclass("TFNO3d", FNO3d, factorization="Tucker") diff --git a/src/lasdi/FNO/fno_block.py b/src/lasdi/FNO/fno_block.py new file mode 100644 index 0000000..f0ccefd --- /dev/null +++ b/src/lasdi/FNO/fno_block.py @@ -0,0 +1,415 @@ +from typing import List, Optional, Union + +import torch +from torch import nn +import torch.nn.functional as F + +from .channel_mlp import ChannelMLP +from .skip_connections import skip_connection +from .spectral_convolution import SpectralConv +import time + +Number = Union[int, float] + + +class FNOBlocks(nn.Module): + """FNOBlocks implements a sequence of Fourier layers, the operations of which + are first described in [1]_. The exact implementation details of the Fourier + layer architecture are discussed in [2]_. + + Parameters + ---------- + in_channels : int + input channels to Fourier layers + out_channels : int + output channels after Fourier layers + n_modes : int, List[int] + number of modes to keep along each dimension + in frequency space. Can either be specified as + an int (for all dimensions) or an iterable with one + number per dimension + resolution_scaling_factor : Optional[Union[Number, List[Number]]], optional + factor by which to scale outputs for super-resolution, by default None + n_layers : int, optional + number of Fourier layers to apply in sequence, by default 1 + max_n_modes : int, List[int], optional + maximum number of modes to keep along each dimension, by default None + fno_block_precision : str, optional + floating point precision to use for computations, by default "full" + use_channel_mlp : bool, optional + Whether to use an MLP layer after each FNO block, by default True + channel_mlp_dropout : int, optional + dropout parameter for self.channel_mlp, by default 0 + channel_mlp_expansion : float, optional + expansion parameter for self.channel_mlp, by default 0.5 + non_linearity : torch.nn.F module, optional + nonlinear activation function to use between layers, by default F.gelu + stabilizer : Literal["tanh"], optional + stabilizing module to use between certain layers, by default None + if "tanh", use tanh + norm : Literal["ada_in", "group_norm", "instance_norm", "batch_norm"], optional + Normalization layer to use, by default None + ada_in_features : int, optional + number of features for adaptive instance norm above, by default None + preactivation : bool, optional + whether to call forward pass with pre-activation, by default False + if True, call nonlinear activation and norm before Fourier convolution + if False, call activation and norms after Fourier convolutions + fno_skip : str, optional + module to use for FNO skip connections, by default "linear" + see layers.skip_connections for more details + channel_mlp_skip : str, optional + module to use for ChannelMLP skip connections, by default "soft-gating" + see layers.skip_connections for more details + + Other Parameters + ------------------- + complex_data : bool, optional + whether the FNO's data takes on complex values in space, by default False + separable : bool, optional + separable parameter for SpectralConv, by default False + factorization : str, optional + factorization parameter for SpectralConv, by default None + rank : float, optional + rank parameter for SpectralConv, by default 1.0 + conv_module : BaseConv, optional + module to use for convolutions in FNO block, by default SpectralConv + joint_factorization : bool, optional + whether to factorize all spectralConv weights as one tensor, by default False + fixed_rank_modes : bool, optional + fixed_rank_modes parameter for SpectralConv, by default False + implementation : str, optional + implementation parameter for SpectralConv, by default "factorized" + decomposition_kwargs : _type_, optional + kwargs for tensor decomposition in SpectralConv, by default dict() + + References + ----------- + .. [1] Li, Z. et al. "Fourier Neural Operator for Parametric Partial Differential + Equations" (2021). ICLR 2021, https://arxiv.org/pdf/2010.08895. + .. [2] Kossaifi, J., Kovachki, N., Azizzadenesheli, K., Anandkumar, A. "Multi-Grid + Tensorized Fourier Neural Operator for High-Resolution PDEs" (2024). + TMLR 2024, https://openreview.net/pdf?id=AWiDlO63bH. + """ + def __init__( + self, + in_channels, + out_channels, + n_modes, + resolution_scaling_factor=None, + n_layers=1, + physics = None, + max_n_modes=None, + fno_block_precision="full", + use_channel_mlp=True, + channel_mlp_dropout=0, + channel_mlp_expansion=0.5, + non_linearity=F.softplus, + stabilizer=None, + norm='None', + ada_in_features=None, + preactivation=False, + fno_skip="linear", + channel_mlp_skip="soft-gating", + complex_data=False, + separable=False, + factorization=None, + rank=1.0, + conv_module=SpectralConv, + fixed_rank_modes=False, #undoc + implementation="factorized", #undoc + decomposition_kwargs=dict(), + test_flag = False, + best_coefs = None, + plot_latent = False, + **kwargs, + ): + super().__init__() + if isinstance(n_modes, int): + n_modes = [n_modes] + self._n_modes = n_modes + self.n_dim = len(n_modes) + + self.resolution_scaling_factor = None + self.max_n_modes = max_n_modes + self.fno_block_precision = fno_block_precision + self.in_channels = in_channels + self.out_channels = out_channels + self.n_layers = n_layers + self.stabilizer = stabilizer + self.rank = rank + self.factorization = factorization + self.fixed_rank_modes = fixed_rank_modes + self.decomposition_kwargs = decomposition_kwargs + self.fno_skip = fno_skip + self.channel_mlp_skip = channel_mlp_skip + self.complex_data = complex_data + + self.use_channel_mlp = use_channel_mlp + self.channel_mlp_expansion = channel_mlp_expansion + self.channel_mlp_dropout = channel_mlp_dropout + self.implementation = implementation + self.separable = separable + self.preactivation = preactivation + self.ada_in_features = ada_in_features + + # Adding features to store data from the latent dynamics evaluation! + self.sindy_loss = None + self.loss_coef = None + self.coefs = None + + # apply real nonlin if data is real, otherwise CGELU + self.non_linearity = non_linearity + + self.convs = nn.ModuleList([ + conv_module( + self.in_channels, + self.out_channels, + self.n_modes, + physics, + resolution_scaling_factor=None if resolution_scaling_factor is None else self.resolution_scaling_factor[i], + max_n_modes=max_n_modes, + rank=rank, + fixed_rank_modes=fixed_rank_modes, + implementation=implementation, + separable=separable, + factorization=factorization, + fno_block_precision=fno_block_precision, + decomposition_kwargs=decomposition_kwargs, + complex_data=complex_data, + high_order_terms = 0, + rand_functions = [], + test_flag = test_flag, + best_coefs = best_coefs, + plot_latent = plot_latent + ) + for i in range(n_layers)]) + + self.fno_skips = nn.ModuleList( + [ + skip_connection( + self.in_channels, + self.out_channels, + skip_type=fno_skip, + n_dim=self.n_dim, + ) + for _ in range(n_layers) + ] + ) + + if self.use_channel_mlp: + self.channel_mlp = nn.ModuleList( + [ + ChannelMLP( + in_channels=self.out_channels, + hidden_channels=round(self.out_channels * channel_mlp_expansion), + dropout=channel_mlp_dropout, + n_dim=self.n_dim, + ) + for _ in range(n_layers) + ] + ) + + self.channel_mlp_skips = nn.ModuleList( + [ + skip_connection( + self.in_channels, + self.out_channels, + skip_type=channel_mlp_skip, + n_dim=self.n_dim, + ) + for _ in range(n_layers) + ] + ) + + # Each block will have 2 norms if we also use a ChannelMLP + self.n_norms = 2 + if norm is None: + self.norm = None + # elif norm == "instance_norm": + # self.norm = nn.ModuleList( + # [ + # InstanceNorm() + # for _ in range(n_layers * self.n_norms) + # ] + # ) + # elif norm == "group_norm": + # self.norm = nn.ModuleList( + # [ + # nn.GroupNorm(num_groups=1, num_channels=self.out_channels) + # for _ in range(n_layers * self.n_norms) + # ] + # ) + + elif norm == "batch_norm": + self.norm = nn.ModuleList( + [ + BatchNorm(n_dim=self.n_dim, num_features=self.out_channels) + for _ in range(n_layers * self.n_norms) + ] + ) + + # elif norm == "ada_in": + # self.norm = nn.ModuleList( + # [ + # AdaIN(ada_in_features, out_channels) + # for _ in range(n_layers * self.n_norms) + # ] + # ) + else: + raise ValueError( + f"Got norm={norm} but expected None or one of " + "[instance_norm, group_norm, batch_norm, ada_in]" + ) + + # def set_ada_in_embeddings(self, *embeddings): + # """Sets the embeddings of each Ada-IN norm layers + + # Parameters + # ---------- + # embeddings : tensor or list of tensor + # if a single embedding is given, it will be used for each norm layer + # otherwise, each embedding will be used for the corresponding norm layer + # """ + # if self.norm is not None: + # if len(embeddings) == 1: + # for norm in self.norm: + # norm.set_embedding(embeddings[0]) + # else: + # for norm, embedding in zip(self.norm, embeddings): + # norm.set_embedding(embedding) + + def forward(self, x, index=0, output_shape=None): + if self.preactivation: + return self.forward_with_preactivation(x, index, output_shape) + else: + return self.forward_with_postactivation(x, index, output_shape) + + def forward_with_postactivation(self, x, index=0, output_shape=None): + x_skip_fno = self.fno_skips[index](x) + x_skip_fno = self.convs[index].transform(x_skip_fno, output_shape=output_shape) + + if self.use_channel_mlp: + x_skip_channel_mlp = self.channel_mlp_skips[index](x) + x_skip_channel_mlp = self.convs[index].transform(x_skip_channel_mlp, output_shape=output_shape) + + # if self.stabilizer == "tanh": + # if self.complex_data: + # x = ctanh(x) + # else: + # x = torch.tanh(x) + + # This can be adjusted regarding where u want the sindy operation to take place! For 3 layers, index = 1 means the central layer. WORKED THE BEST OUT OF ALL CASES! + if index == 1: + x_fno = self.convs[index](x, output_shape=output_shape, sindy_flag = True) + + # Update sindy loss terms! + self.coefs = self.convs[index].coefs + self.sindy_loss = self.convs[index].sindy_loss + self.loss_coef = self.convs[index].loss_coefs + else: + x_fno = self.convs[index](x, output_shape=output_shape, sindy_flag = False) + + #self.convs(x, index, output_shape=output_shape) + + if self.norm is not None: + x_fno = self.norm[self.n_norms * index](x_fno) + + x = x_fno + x_skip_fno + + # Apply regularization layers for preventing overfitting in the model! + # x = nn.Dropout(0.2)(x) + + # if (index < (self.n_layers - 1)): + x = self.non_linearity(x) + + if self.use_channel_mlp: + x = self.channel_mlp[index](x) + x_skip_channel_mlp + + if self.norm is not None: + x = self.norm[self.n_norms * index + 1](x) + + if index < (self.n_layers - 1): + x = self.non_linearity(x) + + return x + + def forward_with_preactivation(self, x, index=0, output_shape=None): + # Apply non-linear activation (and norm) + # before this block's convolution/forward pass: + x = self.non_linearity(x) + + if self.norm is not None: + x = self.norm[self.n_norms * index](x) + + x_skip_fno = self.fno_skips[index](x) + x_skip_fno = self.convs[index].transform(x_skip_fno, output_shape=output_shape) + + if self.use_channel_mlp: + x_skip_channel_mlp = self.channel_mlp_skips[index](x) + x_skip_channel_mlp = self.convs[index].transform(x_skip_channel_mlp, output_shape=output_shape) + + # if self.stabilizer == "tanh": + # if self.complex_data: + # x = ctanh(x) + # else: + # x = torch.tanh(x) + + x_fno = self.convs[index](x, output_shape=output_shape) + + x = x_fno + x_skip_fno + + if index < (self.n_layers - 1): + x = self.non_linearity(x) + + if self.norm is not None: + x = self.norm[self.n_norms * index + 1](x) + + if self.use_channel_mlp: + x = self.channel_mlp[index](x) + x_skip_channel_mlp + + return x + + @property + def n_modes(self): + return self._n_modes + + @n_modes.setter + def n_modes(self, n_modes): + for i in range(self.n_layers): + self.convs[i].n_modes = n_modes + self._n_modes = n_modes + + def get_block(self, indices): + """Returns a sub-FNO Block layer from the jointly parametrized main block + + The parametrization of an FNOBlock layer is shared with the main one. + """ + if self.n_layers == 1: + raise ValueError( + "A single layer is parametrized, directly use the main class." + ) + + return SubModule(self, indices) + + def __getitem__(self, indices): + return self.get_block(indices) + + +class SubModule(nn.Module): + """Class representing one of the sub_module from the mother joint module + + Notes + ----- + This relies on the fact that nn.Parameters are not duplicated: + if the same nn.Parameter is assigned to multiple modules, + they all point to the same data, which is shared. + """ + + def __init__(self, main_module, indices): + super().__init__() + self.main_module = main_module + self.indices = indices + + def forward(self, x): + return self.main_module.forward(x, self.indices) diff --git a/src/lasdi/FNO/spectral_convolution.py b/src/lasdi/FNO/spectral_convolution.py new file mode 100644 index 0000000..c1f248f --- /dev/null +++ b/src/lasdi/FNO/spectral_convolution.py @@ -0,0 +1,659 @@ +from typing import List, Optional, Tuple, Union +import sys +sys.path.append('/p/gpfs1/khurana1/test/GPLaSDI') + +import torch +from torch import nn +import numpy as np +import tensorly as tl +from tensorly.plugins import use_opt_einsum +from tltorch.factorized_tensors.core import FactorizedTensor +from src.lasdi.latent_dynamics.sindy_high_order import SINDy +from .einsum_utils import einsum_complexhalf +from .base_spectral_conv import BaseSpectralConv +import matplotlib.pyplot as plt + +tl.set_backend("pytorch") +use_opt_einsum("optimal") +einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + +def _contract_dense(x, weight, separable=False): + order = tl.ndim(x) + # batch-size, in_channels, x, y... + x_syms = list(einsum_symbols[:order]) + + # in_channels, out_channels, x, y... + weight_syms = list(x_syms[1:]) # no batch-size + + # batch-size, out_channels, x, y... + if separable: + out_syms = [x_syms[0]] + list(weight_syms) + else: + weight_syms.insert(1, einsum_symbols[order]) # outputs + out_syms = list(weight_syms) + out_syms[0] = x_syms[0] + + eq = f'{"".join(x_syms)},{"".join(weight_syms)}->{"".join(out_syms)}' + + if not torch.is_tensor(weight): + weight = weight.to_tensor() + + if x.dtype == torch.complex32: + # if x is half precision, run a specialized einsum + return einsum_complexhalf(eq, x, weight) + else: + return tl.einsum(eq, x, weight) + +def _contract_dense_separable(x, weight, separable): + if not torch.is_tensor(weight): + weight = weight.to_tensor() + return x * weight + +def _contract_cp(x, cp_weight, separable=False): + order = tl.ndim(x) + + x_syms = str(einsum_symbols[:order]) + rank_sym = einsum_symbols[order] + out_sym = einsum_symbols[order + 1] + out_syms = list(x_syms) + if separable: + factor_syms = [einsum_symbols[1] + rank_sym] # in only + else: + out_syms[1] = out_sym + factor_syms = [einsum_symbols[1] + rank_sym, out_sym + rank_sym] # in, out + factor_syms += [xs + rank_sym for xs in x_syms[2:]] # x, y, ... + eq = f'{x_syms},{rank_sym},{",".join(factor_syms)}->{"".join(out_syms)}' + + if x.dtype == torch.complex32: + return einsum_complexhalf(eq, x, cp_weight.weights, *cp_weight.factors) + else: + return tl.einsum(eq, x, cp_weight.weights, *cp_weight.factors) + + +def _contract_tucker(x, tucker_weight, separable=False): + order = tl.ndim(x) + + x_syms = str(einsum_symbols[:order]) + out_sym = einsum_symbols[order] + out_syms = list(x_syms) + if separable: + core_syms = einsum_symbols[order + 1 : 2 * order] + # factor_syms = [einsum_symbols[1]+core_syms[0]] #in only + # x, y, ... + factor_syms = [xs + rs for (xs, rs) in zip(x_syms[1:], core_syms)] + + else: + core_syms = einsum_symbols[order + 1 : 2 * order + 1] + out_syms[1] = out_sym + factor_syms = [ + einsum_symbols[1] + core_syms[0], + out_sym + core_syms[1], + ] # out, in + # x, y, ... + factor_syms += [xs + rs for (xs, rs) in zip(x_syms[2:], core_syms[2:])] + + eq = f'{x_syms},{core_syms},{",".join(factor_syms)}->{"".join(out_syms)}' + + if x.dtype == torch.complex32: + return einsum_complexhalf(eq, x, tucker_weight.core, *tucker_weight.factors) + else: + return tl.einsum(eq, x, tucker_weight.core, *tucker_weight.factors) + + +def _contract_tt(x, tt_weight, separable=False): + order = tl.ndim(x) + + x_syms = list(einsum_symbols[:order]) + weight_syms = list(x_syms[1:]) # no batch-size + if not separable: + weight_syms.insert(1, einsum_symbols[order]) # outputs + out_syms = list(weight_syms) + out_syms[0] = x_syms[0] + else: + out_syms = list(x_syms) + rank_syms = list(einsum_symbols[order + 1 :]) + tt_syms = [] + for i, s in enumerate(weight_syms): + tt_syms.append([rank_syms[i], s, rank_syms[i + 1]]) + eq = ( + "".join(x_syms) + + "," + + ",".join("".join(f) for f in tt_syms) + + "->" + + "".join(out_syms) + ) + + if x.dtype == torch.complex32: + return einsum_complexhalf(eq, x, *tt_weight.factors) + else: + return tl.einsum(eq, x, *tt_weight.factors) + + +def get_contract_fun(weight, implementation="reconstructed", separable=False): + """Generic ND implementation of Fourier Spectral Conv contraction + + Parameters + ---------- + weight : tensorly-torch's FactorizedTensor + implementation : {'reconstructed', 'factorized'}, default is 'reconstructed' + whether to reconstruct the weight and do a forward pass (reconstructed) + or contract directly the factors of the factorized weight with the input (factorized) + separable: bool + if True, performs contraction with individual tensor factors. + if False, + Returns + ------- + function : (x, weight) -> x * weight in Fourier space + """ + if implementation == "reconstructed": + if separable: + return _contract_dense_separable + else: + return _contract_dense + elif implementation == "factorized": + if torch.is_tensor(weight): + return _contract_dense + elif isinstance(weight, FactorizedTensor): + if weight.name.lower().endswith("dense"): + return _contract_dense + elif weight.name.lower().endswith("tucker"): + return _contract_tucker + elif weight.name.lower().endswith("tt"): + return _contract_tt + elif weight.name.lower().endswith("cp"): + return _contract_cp + else: + raise ValueError(f"Got unexpected factorized weight type {weight.name}") + else: + raise ValueError( + f"Got unexpected weight type of class {weight.__class__.__name__}" + ) + else: + raise ValueError( + f'Got implementation={implementation}, expected "reconstructed" or "factorized"' + ) + + +Number = Union[int, float] + + +class SpectralConv(BaseSpectralConv): + """SpectralConv implements the Spectral Convolution component of a Fourier layer + described in [1]_ and [2]_. + + Parameters + ---------- + in_channels : int + Number of input channels + out_channels : int + Number of output channels + n_modes : int or int tuple + Number of modes to use for contraction in Fourier domain during training. + + .. warning:: + + We take care of the redundancy in the Fourier modes, therefore, for an input + of size I_1, ..., I_N, please provide modes M_K that are I_1 < M_K <= I_N + We will automatically keep the right amount of modes: specifically, for the + last mode only, if you specify M_N modes we will use M_N // 2 + 1 modes + as the real FFT is redundant along that last dimension. For more information on + mode truncation, refer to :ref:`fourier_layer_impl` + + + .. note:: + + Provided modes should be even integers. odd numbers will be rounded to the closest even number. + + This can be updated dynamically during training. + + max_n_modes : int tuple or None, default is None + * If not None, **maximum** number of modes to keep in Fourier Layer, along each dim + The number of modes (`n_modes`) cannot be increased beyond that. + * If None, all the n_modes are used. + + separable : bool, default is True + whether to use separable implementation of contraction + if True, contracts factors of factorized + tensor weight individually + init_std : float or 'auto', default is 'auto' + std to use for the init + factorization : str or None, {'tucker', 'cp', 'tt'}, default is None + If None, a single dense weight is learned for the FNO. + Otherwise, that weight, used for the contraction in the Fourier domain + is learned in factorized form. In that case, `factorization` is the + tensor factorization of the parameters weight used. + rank : float or rank, optional + Rank of the tensor factorization of the Fourier weights, by default 1.0 + Ignored if ``factorization is None`` + fixed_rank_modes : bool, optional + Modes to not factorize, by default False + Ignored if ``factorization is None`` + fft_norm : str, optional + fft normalization parameter, by default 'forward' + implementation : {'factorized', 'reconstructed'}, optional, default is 'factorized' + If factorization is not None, forward mode to use:: + * `reconstructed` : the full weight tensor is reconstructed from the + factorization and used for the forward pass + * `factorized` : the input is directly contracted with the factors of + the decomposition + Ignored if ``factorization is None`` + decomposition_kwargs : dict, optional, default is {} + Optionaly additional parameters to pass to the tensor decomposition + Ignored if ``factorization is None`` + complex_data: bool, optional + whether data takes on complex values in the spatial domain, by default False + if True, uses different logic for FFT contraction and uses full FFT instead of real-valued + + References + ----------- + .. [1] : + + Li, Z. et al. "Fourier Neural Operator for Parametric Partial Differential + Equations" (2021). ICLR 2021, https://arxiv.org/pdf/2010.08895. + + .. [2] : + + Kossaifi, J., Kovachki, N., Azizzadenesheli, K., Anandkumar, A. "Multi-Grid + Tensorized Fourier Neural Operator for High-Resolution PDEs" (2024). + TMLR 2024, https://openreview.net/pdf?id=AWiDlO63bH. + """ + + def __init__( + self, + in_channels, + out_channels, + n_modes, + physics, + complex_data=False, + max_n_modes=None, + bias=True, + separable=False, + resolution_scaling_factor: Optional[Union[Number, List[Number]]] = None, + fno_block_precision="full", + rank=0.5, + factorization=None, + implementation="reconstructed", + fixed_rank_modes=False, + decomposition_kwargs: Optional[dict] = None, + init_std="auto", + fft_norm="forward", + device=None, + high_order_terms = 2, + rand_functions = [], + test_flag = False, + best_coefs = None, + plot_latent = False + ): + super().__init__(device=device) + + self.in_channels = in_channels + self.out_channels = out_channels + + self.complex_data = complex_data + + # Adding sindy stuff + self.high_order_terms = high_order_terms + self.rand_functions = rand_functions + self.physics = physics + + # Defining the sindy object. UPDATE THE SELF.LD!!! + sindy_options = {'sindy': {'fd_type': 'sbp12', 'coef_norm_order': 2} } # finite-difference operator for computing time derivative of latent trajectory. + self.ld = SINDy(n_modes[1]//2+1, self.high_order_terms, self.rand_functions, (n_modes[0]), sindy_options) + + # Defining terms for the coefficients and loss from latent dynamics which will be updated during the forward function! + self.coefs = [0,0] + self.sindy_loss = [0,0] + self.loss_coefs = [0,0] + + # Defining a testing flag deciding when to find coeffcients for sindy or simulate using the best trained coefficients! + self.test_flag = test_flag + + # Storing the best_coeffs if the model is in testing mode! + self.best_coefs = best_coefs + + # Adding a flag whether to plot the latent space! + self.plot_latent = plot_latent + + # n_modes is the total number of modes kept along each dimension + self.n_modes = n_modes + self.order = len(self.n_modes) + + if max_n_modes is None: + max_n_modes = self.n_modes + elif isinstance(max_n_modes, int): + max_n_modes = [max_n_modes] + self.max_n_modes = max_n_modes + + self.fno_block_precision = fno_block_precision + self.rank = rank + self.factorization = factorization + self.implementation = implementation + + self.resolution_scaling_factor = None + if init_std == "auto": + init_std = (2 / (in_channels + out_channels))**0.5 + else: + init_std = init_std + + if isinstance(fixed_rank_modes, bool): + if fixed_rank_modes: + # If bool, keep the number of layers fixed + fixed_rank_modes = [0] + else: + fixed_rank_modes = None + self.fft_norm = fft_norm + + if factorization is None: + factorization = "Dense" # No factorization + + # if separable: + # if in_channels != out_channels: + # raise ValueError( + # "To use separable Fourier Conv, in_channels must be equal " + # f"to out_channels, but got in_channels={in_channels} and " + # f"out_channels={out_channels}", + # ) + # weight_shape = (in_channels, *max_n_modes) + # else: + weight_shape = (in_channels, out_channels, *max_n_modes) + self.separable = separable + + tensor_kwargs = decomposition_kwargs if decomposition_kwargs is not None else {} + + # Create/init spectral weight tensor + + if factorization is None: + self.weight = torch.tensor(weight_shape, dtype=torch.cfloat) + else: + self.weight = FactorizedTensor.new(weight_shape, rank=self.rank, + factorization=factorization, fixed_rank_modes=fixed_rank_modes, + **tensor_kwargs, dtype=torch.cfloat) + # self.weight = torch.empty(weight_shape, dtype=torch.cfloat) + self.weight.normal_(0, init_std) + # nn.init.xavier_normal_(self.weight) + self._contract = get_contract_fun( + self.weight, implementation=implementation, separable=separable + ) + + if bias: + self.bias = nn.Parameter( + init_std * torch.randn(*(tuple([self.out_channels]) + (1,) * self.order)) + ) + else: + self.bias = None + + def transform(self, x, output_shape=None): + in_shape = list(x.shape[2:]) + + if self.resolution_scaling_factor is not None and output_shape is None: + out_shape = tuple( + [round(s * r) for (s, r) in zip(in_shape, self.resolution_scaling_factor)] + ) + elif output_shape is not None: + out_shape = output_shape + else: + out_shape = in_shape + + if in_shape == out_shape: + return x + else: + return resample(x, 1.0, list(range(2, x.ndim)), output_shape=out_shape) + + @property + def n_modes(self): + return self._n_modes + + @n_modes.setter + def n_modes(self, n_modes): + if isinstance(n_modes, int): # Should happen for 1D FNO only + n_modes = [n_modes] + else: + n_modes = list(n_modes) + # the real FFT is skew-symmetric, so the last mode has a redundacy if our data is real in space + # As a design choice we do the operation here to avoid users dealing with the +1 + # if we use the full FFT we cannot cut off informtion from the last mode + if not self.complex_data: + n_modes[-1] = n_modes[-1] // 2 + 1 + self._n_modes = n_modes + + def forward( + self, x: torch.Tensor, output_shape: Optional[Tuple[int]] = None, sindy_flag = False + ): + """Generic forward pass for the Factorized Spectral Conv + + Parameters + ---------- + x : torch.Tensor + input activation of size (batch_size, channels, d1, ..., dN) + + Returns + ------- + tensorized_spectral_conv(x) + """ + batchsize, channels, *mode_sizes = x.shape + + fft_size = list(mode_sizes) + if not self.complex_data: + fft_size[-1] = fft_size[-1] // 2 + 1 # Redundant last coefficient in real spatial data + + # FFT ON BOTH DIMENSIONS FOR TRAINING BUT ONLY SPACE IF TESTING! + fft_dims = [-1]#list(range(-self.order, 0)) + + + # if self.fno_block_precision == "half": + # x = x.half() + + # if self.complex_data: + # x = torch.fft.fftn(x, norm=self.fft_norm, dim=fft_dims) + # dims_to_fft_shift = fft_dims + # else: + x = torch.fft.rfftn(x, norm=self.fft_norm, dim=fft_dims) + # When x is real in spatial domain, the last half of the last dim is redundant. + # See :ref:`fft_shift_explanation` for discussion of the FFT shift. + out_fft = torch.zeros(x.shape, dtype=torch.cfloat, device = x.device) + # out_fft = x[:,:,:,:self.n_modes[1]] + + ######## The following can be adjusted to choose the first n modes or n dominant modes. Burgers1d results were derived using first n modes. Tried the following for vlasov. + # # Find the dominant modes! + mag = torch.abs(x) + average = torch.mean(mag, dim = -2, keepdim = True) + dominant_indices = torch.argsort(average,descending=True)[0,:,:,:self.n_modes[1]].squeeze() + + out_fft[:,:,:,dominant_indices] = x[:,:,:,dominant_indices] + # IF test flag == false! + # if self.order > 1 and self.test_flag == False: + # x = torch.fft.fftshift(x, dim=dims_to_fft_shift) + + # if self.fno_block_precision == "mixed": + # # if 'mixed', the above fft runs in full precision, but the + # # following operations run at half precision + # x = x.chalf() + + # if self.fno_block_precision in ["half", "mixed"]: + # out_dtype = torch.chalf + # else: + # out_dtype = torch.cfloat + + # # UPDATING THE SIZE OF THE OUT_FFT HERE! + # out_fft = torch.zeros([batchsize, self.out_channels, x.shape[2], self.n_modes[1]], + # device=x.device, dtype=out_dtype) + + # # if current modes are less than max, start indexing modes closer to the center of the weight tensor + # starts = [(max_modes - min(size, n_mode)) for (size, n_mode, max_modes) in zip(fft_size, self.n_modes, self.max_n_modes)] + # # if contraction is separable, weights have shape (channels, modes_x, ...) + # # otherwise they have shape (in_channels, out_channels, modes_x, ...) + # if self.separable: + # slices_w = [slice(None)] # channels + # else: + # slices_w = [slice(None), slice(None)] # in_channels, out_channels + # if self.complex_data: + # slices_w += [slice(start//2, -start//2) if start else slice(start, None) for start in starts] + # else: + # # The last mode already has redundant half removed in real FFT + # slices_w += [slice(start//2, -start//2) if start else slice(start, None) for start in starts[:-1]] + # slices_w += [slice(None, -starts[-1]) if starts[-1] else slice(None)] + + # weight = self.weight[slices_w] + + # ### Pick the first n_modes modes of FFT signal along each dim + + # # if separable conv, weight tensor only has one channel dim + # if self.separable: + # weight_start_idx = 1 + # # otherwise drop first two dims (in_channels, out_channels) + # else: + # weight_start_idx = 2 + + # slices_x = [slice(None), slice(None)] # Batch_size, channels + + # for all_modes, kept_modes in zip(fft_size, list(weight.shape[weight_start_idx:])): + # # After fft-shift, the 0th frequency is located at n // 2 in each direction + # # We select n_modes modes around the 0th frequency (kept at index n//2) by grabbing indices + # # n//2 - n_modes//2 to n//2 + n_modes//2 if n_modes is even + # # n//2 - n_modes//2 to n//2 + n_modes//2 + 1 if n_modes is odd + # center = all_modes // 2 + # negative_freqs = kept_modes // 2 + # positive_freqs = kept_modes // 2 + kept_modes % 2 + + # # this slice represents the desired indices along each dim + # slices_x += [slice(center - negative_freqs, center + positive_freqs)] + + # if weight.shape[-1] < fft_size[-1]: + # slices_x[-1] = slice(None, weight.shape[-1]) + # else: + # slices_x[-1] = slice(None) + + # out_fft[slices_x] = self._contract(x[slices_x], weight, separable=self.separable) + + # if self.resolution_scaling_factor is not None and output_shape is None: + # mode_sizes = tuple([round(s * r) for (s, r) in zip(mode_sizes, self.resolution_scaling_factor)]) + + # if output_shape is not None: + # mode_sizes = output_shape + + # # if self.order > 1: + # # out_fft = torch.fft.fftshift(out_fft, dim=fft_dims[:-1]) + + if sindy_flag: + """ + # The following routine determines the sindy coefficients for the real and imaginary parts together. + if self.test_flag == False and self.plot_latent == False: + + # Training Mode! + coefs, loss_sindy, loss_coef = self.ld.calibrate(out_fft.squeeze(), self.physics.dt, compute_loss=True, numpy=False) + + # Update the terms in the class itself! + self.coefs[0] = coefs + self.sindy_loss[0] = loss_sindy + self.loss_coefs[0] = loss_coef + + # Adding functionality to plot the latent space! + elif self.test_flag == False and self.plot_latent == True: + + # Get the simulated response by extracting the initial condition! + ic = out_fft.squeeze()[0,:].detach().numpy() + out_fft_combined = self.ld.simulate(self.best_coefs[0], ic, self.physics.t_grid) + plt.figure() + plt.title('True v/s Estimated Latent Trajectory (Real)') + plt.grid() + plt.xlabel('Time') + plt.ylabel('Magnitude') + plt.plot(self.physics.t_grid, out_fft_combined[:,:self.n_modes[1]], label = 'Estimated') + plt.plot(self.physics.t_grid, out_fft.squeeze().real.detach().numpy(), linestyle = '--',label = 'True') + + plt.figure() + plt.grid() + plt.xlabel('Time') + plt.ylabel('Magnitude') + plt.title('True v/s Estimated Latent Trajectory (Imaginary)') + plt.plot(self.physics.t_grid, out_fft_combined[:,self.n_modes[1]:], label = 'Estimated') + plt.plot(self.physics.t_grid, out_fft.squeeze().imag.detach().numpy(), label = 'True', linestyle = '--') + + else: + + # Testing Mode! + out_fft_combined = self.ld.simulate(self.best_coefs[0], out_fft.squeeze().detach().numpy(), self.physics.t_grid) + + # Combine the real and imag parts! + out_fft = out_fft_combined[:,:self.n_modes[1]] + 1j*out_fft_combined[:,self.n_modes[1]:] + + # Convert numpy back to torch! + out_fft = torch.from_numpy(out_fft) + + # Get the size consistency! + out_fft = (out_fft.unsqueeze(0)).unsqueeze(0) + """ + # The following splits up the complex numbers into real and imgainry and then does the sindy fitting separately + if self.test_flag == False and self.plot_latent == False: + coefs, loss_sindy, loss_coef = self.ld.calibrate(out_fft[:,:,:,dominant_indices].squeeze().real, self.physics.dt, compute_loss=True, numpy=False) + + # Update the terms in the class itself! + self.coefs[0] = coefs + self.sindy_loss[0] = loss_sindy + self.loss_coefs[0] = loss_coef + #### Getting the coefficients for the imaginary part! + coefs, loss_sindy, loss_coef = self.ld.calibrate(out_fft[:,:,:,dominant_indices].squeeze().imag, self.physics.dt, compute_loss=True, numpy=False) + + # Update the terms in the class itself! + self.coefs[1] = coefs + self.sindy_loss[1] = loss_sindy + self.loss_coefs[1] = loss_coef + elif self.test_flag == False and self.plot_latent == True: + + # Get the simulated response by extracting the initial condition! + ic = out_fft[:,:,:,dominant_indices].squeeze()[0,:].detach().numpy() + out_fft_real = self.ld.simulate(self.best_coefs[0], ic.real, self.physics.t_grid) + out_fft_imag = self.ld.simulate(self.best_coefs[1], ic.imag, self.physics.t_grid) + plt.figure() + plt.title('True v/s Estimated Latent Trajectory (Real)') + plt.grid() + plt.xlabel('Time') + plt.ylabel('Magnitude') + plt.plot(self.physics.t_grid, out_fft_real, label = 'Estimated') + plt.plot(self.physics.t_grid, out_fft.squeeze().real.detach().numpy(), linestyle = '--',label = 'True') + plt.legend(['-- True','- Estimate']) + plt.savefig('plots/real_traj.png') + + plt.figure() + plt.grid() + plt.xlabel('Time') + plt.ylabel('Magnitude') + plt.title('True v/s Estimated Latent Trajectory (Imaginary)') + plt.plot(self.physics.t_grid, out_fft_imag, label = 'Estimated') + plt.plot(self.physics.t_grid, out_fft.squeeze().imag.detach().numpy(), label = 'True', linestyle = '--') + plt.legend(['-- True','- Estimate']) + plt.savefig('plots/imag_traj.png') + else: + + # Simulate the sindy based on the latent space and the coefficients! + # Simulate the real part! + out_fft_real = self.ld.simulate(self.best_coefs[0], (out_fft[:,:,:,dominant_indices].squeeze().real).detach().numpy(), self.physics.t_grid) + + # Simulate the imaginary part! + out_fft_imag = self.ld.simulate(self.best_coefs[1], (out_fft[:,:,:,dominant_indices].squeeze().imag).detach().numpy(), self.physics.t_grid) + + # Combine the real and imaginary parts! + out_fft = out_fft_real + 1j * out_fft_imag + out_fft = torch.from_numpy(out_fft).to(torch.cfloat) + + # Convert the out_fft back to the same + out_fft_new = torch.zeros((1,1,len(self.physics.t_grid),mode_sizes[1]),device=x.device, dtype=torch.cfloat) + out_fft_new[:,:,:,dominant_indices] == out_fft + # Get the size consistency! + out_fft = out_fft_new#(out_fft.unsqueeze(0)).unsqueeze(0) + # out_fft_filtered = torch.zeros([batchsize, self.out_channels, x.shape[2], mode_sizes[1]], + # device=x.device, dtype=torch.cfloat) + + # out_fft_filtered[...,dominant_indices] = out_fft + + # x = torch.fft.irfftn(out_fft, s=mode_sizes, dim=fft_dims, norm=self.fft_norm) + if self.test_flag == True: + x = torch.fft.irfftn(out_fft, s=[mode_sizes[1]], dim=fft_dims, norm=self.fft_norm) + else: + x = torch.fft.irfftn(out_fft, s=[mode_sizes[1]], dim=fft_dims, norm=self.fft_norm) + + x = x.to(torch.float32) + if self.bias is not None: + x = x + self.bias.to(x.device) + return x diff --git a/src/lasdi/latent_dynamics/edmd.py b/src/lasdi/latent_dynamics/edmd.py new file mode 100644 index 0000000..d069fb6 --- /dev/null +++ b/src/lasdi/latent_dynamics/edmd.py @@ -0,0 +1,172 @@ +import numpy as np +import torch +from scipy.integrate import odeint +from . import LatentDynamics +from ..inputs import InputParser +from ..fd import FDdict +from scipy.linalg import logm +import importlib + +def get_function_from_string(func_str): + # Split the string into module and function + module_name, func_name = func_str.rsplit('.', 1) + module = importlib.import_module(module_name) + return getattr(module, func_name) + +class edmd(LatentDynamics): + fd_type = '' + fd = None + fd_oper = None + + def __init__(self, dim, high_order_terms, rand_functions, nt, config): + super().__init__(dim, nt) + + # Defining the higher order terms + self.high_order_terms = high_order_terms + self.rand_functions = rand_functions + + # Number of coefficients depend upon the basis functions used + self.ncoefs = ((len(self.rand_functions) + self.high_order_terms)*self.dim + self.dim) ** 2 + + assert('edmd' in config) + parser = InputParser(config['edmd'], name='edmd_input') + + ''' + fd_type is the string that specifies finite-difference scheme for time derivative: + - 'sbp12': summation-by-parts 1st/2nd (boundary/interior) order operator + - 'sbp24': summation-by-parts 2nd/4th order operator + - 'sbp36': summation-by-parts 3rd/6th order operator + - 'sbp48': summation-by-parts 4th/8th order operator + ''' + self.fd_type = parser.getInput(['fd_type'], fallback='sbp12') + self.fd = FDdict[self.fd_type] + self.fd_oper, _, _ = self.fd.getOperators(self.nt) + + # NOTE(kevin): by default, this will be L1 norm. + self.coef_norm_order = parser.getInput(['coef_norm_order'], fallback=1) + + # TODO(kevin): other loss functions + self.MSE = torch.nn.MSELoss() + + return + + def calibrate(self, Z, dt, compute_loss=True, numpy=False): + ''' loop over all train cases, if Z dimension is 3 ''' + if (Z.dim() == 3): + n_train = Z.size(0) + + if (numpy): + coefs = np.zeros([n_train, self.ncoefs]) + else: + coefs = torch.zeros([n_train, self.ncoefs]) + loss_edmd, loss_coef = 0.0, 0.0 + + for i in range(n_train): + result = self.calibrate(Z[i], dt, compute_loss, numpy) + if (compute_loss): + coefs[i] = result[0] + loss_edmd += result[1] + loss_coef += result[2] + else: + coefs[i] = result + + if (compute_loss): + return coefs, loss_edmd, loss_coef + else: + return coefs + + ''' evaluate for one train case ''' + assert(Z.dim() == 2) + + # Creating a copy! + Z_i = Z + + # Adding higher order terms! Running a for loop based on how many higher order terms you want to add! + for i in range(self.high_order_terms): + + # Append to the candidtate library the higher order expressions + Z_i = torch.cat([Z_i, Z**(i+2)], dim = 1) + + # Adding trignometric functions + for i in self.rand_functions: + Z_i = torch.cat([Z_i, get_function_from_string(i)(Z)], dim = 1) + + # reshaping the Z to have columns as snapshots! + Z_i = torch.transpose(Z_i,0,1) + + # Get the Z' matrix! + Z_plus = Z_i[:,1:] + Z_minus = Z_i[:,0:-1] + + # Get the A operator: Using lstsq since that is more stable then pseudo inverse! + A = (torch.linalg.lstsq(Z_minus.T,Z_plus.T).solution).T + #A = Z_plus @ torch.linalg.pinv(Z_minus) + + # Compute the losses! + if (compute_loss): + + # NOTE(khushant): This loss is different from what is used in SINDy. + loss_edmd = self.MSE(Z_plus, A @ Z_minus) + + # NOTE(kevin): by default, this will be L1 norm. + loss_coef = torch.norm(A, self.coef_norm_order) + + # output of lstsq is not contiguous in memory. + coefs = A.detach().flatten() + if (numpy): + coefs = coefs.numpy() + + if (compute_loss): + return coefs, loss_edmd, loss_coef + else: + return coefs + + def simulate(self, coefs, z0, t_grid): + + ''' + + Integrates each system of ODEs corresponding to each training points, given the initial condition Z0 = encoder(U0) + + ''' + # copy is inevitable for numpy==1.26. removed copy=False temporarily. + A = coefs.reshape([(self.high_order_terms + len(self.rand_functions)) * self.dim + self.dim, (self.high_order_terms + len(self.rand_functions)) * self.dim + self.dim]) + + Z_i = np.zeros((len(t_grid), self.dim)) + Z_i[0,:] = z0 + + # Performing the integration + + for i in range(1,len(t_grid)): + + # Making a copy of the z + z_new = Z_i[i-1,:] + + # Add the higher order terms! + for j in range(self.high_order_terms): + + # Get the higher order terms if any + new_terms = np.power(Z_i[i-1,:],j+2) + + # Stack the initial conditions! + z_new = np.hstack((z_new,new_terms)) + + # Add the trignometric funtions to the candidate library! + for k in self.rand_functions: + + # Get the trig terms! + new_terms = get_function_from_string(k)(torch.from_numpy(Z_i[i-1,:])) + + # Stack the initial conditions! + z_new = np.hstack((z_new, new_terms.detach().cpu().numpy())) + + # Integrate and store! + Z_i[i,:] = (A @ z_new)[:self.dim] + + return Z_i + + def export(self): + param_dict = super().export() + param_dict['fd_type'] = self.fd_type + param_dict['coef_norm_order'] = self.coef_norm_order + return param_dict + diff --git a/src/lasdi/workflow.py b/src/lasdi/workflow.py index 3d20fc1..a7afb7a 100644 --- a/src/lasdi/workflow.py +++ b/src/lasdi/workflow.py @@ -8,6 +8,8 @@ from .gplasdi import BayesianGLaSDI from .latent_space import Autoencoder from .latent_dynamics.sindy import SINDy +from .latent_dynamics.sindy_high_order import SINDy as sindy_high +from .latent_dynamics.edmd import edmd from .physics.burgers1d import Burgers1D from .param import ParameterSpace from .inputs import InputParser @@ -16,7 +18,7 @@ latent_dict = {'ae': Autoencoder} -ld_dict = {'sindy': SINDy} +ld_dict = {'edmd': edmd} physics_dict = {'burgers1d': Burgers1D} @@ -151,6 +153,7 @@ def initialize_trainer(config, restart_file=None): physics = initialize_physics(config, param_space.param_name) latent_space = initialize_latent_space(physics, config) + if (restart_file is not None): latent_space.load(restart_file['latent_space']) @@ -158,7 +161,10 @@ def initialize_trainer(config, restart_file=None): ld_type = config['latent_dynamics']['type'] assert(ld_type in config['latent_dynamics']) assert(ld_type in ld_dict) - latent_dynamics = ld_dict[ld_type](latent_space.n_z, physics.nt, config['latent_dynamics']) + + # Updating the dynamics callback to account for the higher order terms and other non linear functions from the .yml file. + + latent_dynamics = ld_dict[ld_type](latent_space.n_z, config['latent_dynamics'][ld_type]['higher_order_terms'], config['latent_dynamics'][ld_type]['extra_functions'], physics.nt, config['latent_dynamics']) if (restart_file is not None): latent_dynamics.load(restart_file['latent_dynamics']) @@ -329,4 +335,4 @@ def collect_samples(trainer, config): return result, next_step if __name__ == "__main__": - main() \ No newline at end of file + main()