diff --git a/examples/burgers1d.yml b/examples/burgers1d.yml index dc43164..7eb627e 100644 --- a/examples/burgers1d.yml +++ b/examples/burgers1d.yml @@ -1,12 +1,12 @@ lasdi: type: gplasdi gplasdi: - # device: mps + device: cuda n_samples: 20 lr: 0.001 max_iter: 28000 n_iter: 2000 - max_greedy_iter: 28000 + max_greedy_iter: 28000 ld_weight: 0.1 coef_weight: 1.e-6 path_checkpoint: checkpoint @@ -67,12 +67,21 @@ latent_space: ae: hidden_units: [100] latent_dimension: 5 + activation: softplus latent_dynamics: type: sindy sindy: fd_type: sbp12 coef_norm_order: fro + higher_order_terms: 1 + extra_functions: [] + #type: edmd + #edmd: + #fd_type: sbp12 + # coef_norm_order: fro + # higher_order_terms: 0 + # extra_functions: [] physics: type: burgers1d diff --git a/src/lasdi/latent_dynamics/edmd.py b/src/lasdi/latent_dynamics/edmd.py new file mode 100644 index 0000000..23503d8 --- /dev/null +++ b/src/lasdi/latent_dynamics/edmd.py @@ -0,0 +1,172 @@ +import numpy as np +import torch +from scipy.integrate import odeint +from . import LatentDynamics +from ..inputs import InputParser +from ..fd import FDdict +from scipy.linalg import logm +import importlib + +def get_function_from_string(func_str): + # Split the string into module and function + module_name, func_name = func_str.rsplit('.', 1) + module = importlib.import_module(module_name) + return getattr(module, func_name) + +class EDMD(LatentDynamics): + fd_type = '' + fd = None + fd_oper = None + + def __init__(self, dim, high_order_terms, rand_functions, nt, config): + super().__init__(dim, nt) + + # Defining the higher order terms + self.high_order_terms = high_order_terms + self.rand_functions = rand_functions + + # Number of coefficients depend upon the basis functions used + self.ncoefs = ((len(self.rand_functions) + self.high_order_terms)*self.dim + self.dim) ** 2 + + assert('edmd' in config) + parser = InputParser(config['edmd'], name='edmd_input') + + ''' + fd_type is the string that specifies finite-difference scheme for time derivative: + - 'sbp12': summation-by-parts 1st/2nd (boundary/interior) order operator + - 'sbp24': summation-by-parts 2nd/4th order operator + - 'sbp36': summation-by-parts 3rd/6th order operator + - 'sbp48': summation-by-parts 4th/8th order operator + ''' + self.fd_type = parser.getInput(['fd_type'], fallback='sbp12') + self.fd = FDdict[self.fd_type] + self.fd_oper, _, _ = self.fd.getOperators(self.nt) + + # NOTE(kevin): by default, this will be L1 norm. + self.coef_norm_order = parser.getInput(['coef_norm_order'], fallback=1) + + # TODO(kevin): other loss functions + self.MSE = torch.nn.MSELoss() + + return + + def calibrate(self, Z, dt, compute_loss=True, numpy=False): + ''' loop over all train cases, if Z dimension is 3 ''' + if (Z.dim() == 3): + n_train = Z.size(0) + + if (numpy): + coefs = np.zeros([n_train, self.ncoefs]) + else: + coefs = torch.zeros([n_train, self.ncoefs]) + loss_edmd, loss_coef = 0.0, 0.0 + + for i in range(n_train): + result = self.calibrate(Z[i], dt, compute_loss, numpy) + if (compute_loss): + coefs[i] = result[0] + loss_edmd += result[1] + loss_coef += result[2] + else: + coefs[i] = result + + if (compute_loss): + return coefs, loss_edmd, loss_coef + else: + return coefs + + ''' evaluate for one train case ''' + assert(Z.dim() == 2) + + # Creating a copy! + Z_i = Z + + # Adding higher order terms! Running a for loop based on how many higher order terms you want to add! + for i in range(self.high_order_terms): + + # Append to the candidtate library the higher order expressions + Z_i = torch.cat([Z_i, Z**(i+2)], dim = 1) + + # Adding trignometric functions + for i in self.rand_functions: + Z_i = torch.cat([Z_i, get_function_from_string(i)(Z)], dim = 1) + + # reshaping the Z to have columns as snapshots! + Z_i = torch.transpose(Z_i,0,1) + + # Get the Z' matrix! + Z_plus = Z_i[:,1:] + Z_minus = Z_i[:,0:-1] + + # Get the A operator: Using lstsq since that is more stable then pseudo inverse! + A = (torch.linalg.lstsq(Z_minus.T,Z_plus.T).solution).T + #A = Z_plus @ torch.linalg.pinv(Z_minus) + + # Compute the losses! + if (compute_loss): + + # NOTE(khushant): This loss is different from what is used in SINDy. + loss_edmd = self.MSE(Z_plus, A @ Z_minus) + + # NOTE(kevin): by default, this will be L1 norm. + loss_coef = torch.norm(A, self.coef_norm_order) + + # output of lstsq is not contiguous in memory. + coefs = A.detach().flatten() + if (numpy): + coefs = coefs.numpy() + + if (compute_loss): + return coefs, loss_edmd, loss_coef + else: + return coefs + + def simulate(self, coefs, z0, t_grid): + + ''' + + Integrates each system of ODEs corresponding to each training points, given the initial condition Z0 = encoder(U0) + + ''' + # copy is inevitable for numpy==1.26. removed copy=False temporarily. + A = coefs.reshape([(self.high_order_terms + len(self.rand_functions)) * self.dim + self.dim, (self.high_order_terms + len(self.rand_functions)) * self.dim + self.dim]) + + Z_i = np.zeros((len(t_grid), self.dim)) + Z_i[0,:] = z0 + + # Performing the integration + + for i in range(1,len(t_grid)): + + # Making a copy of the z + z_new = Z_i[i-1,:] + + # Add the higher order terms! + for j in range(self.high_order_terms): + + # Get the higher order terms if any + new_terms = np.power(Z_i[i-1,:],j+2) + + # Stack the initial conditions! + z_new = np.hstack((z_new,new_terms)) + + # Add the trignometric funtions to the candidate library! + for k in self.rand_functions: + + # Get the trig terms! + new_terms = get_function_from_string(k)(torch.from_numpy(Z_i[i-1,:])) + + # Stack the initial conditions! + z_new = np.hstack((z_new, new_terms.detach().cpu().numpy())) + + # Integrate and store! + Z_i[i,:] = (A @ z_new)[:self.dim] + + return Z_i + + def export(self): + param_dict = super().export() + param_dict['fd_type'] = self.fd_type + param_dict['coef_norm_order'] = self.coef_norm_order + return param_dict + diff --git a/src/lasdi/workflow.py b/src/lasdi/workflow.py index 3d20fc1..8e17c37 100644 --- a/src/lasdi/workflow.py +++ b/src/lasdi/workflow.py @@ -8,6 +8,7 @@ from .gplasdi import BayesianGLaSDI from .latent_space import Autoencoder from .latent_dynamics.sindy import SINDy +from .latent_dynamics.edmd import EDMD from .physics.burgers1d import Burgers1D from .param import ParameterSpace from .inputs import InputParser @@ -16,7 +17,7 @@ latent_dict = {'ae': Autoencoder} -ld_dict = {'sindy': SINDy} +ld_dict = {'edmd': EDMD} physics_dict = {'burgers1d': Burgers1D} @@ -151,6 +152,7 @@ def initialize_trainer(config, restart_file=None): physics = initialize_physics(config, param_space.param_name) latent_space = initialize_latent_space(physics, config) + if (restart_file is not None): latent_space.load(restart_file['latent_space']) @@ -158,7 +160,10 @@ def initialize_trainer(config, restart_file=None): ld_type = config['latent_dynamics']['type'] assert(ld_type in config['latent_dynamics']) assert(ld_type in ld_dict) - latent_dynamics = ld_dict[ld_type](latent_space.n_z, physics.nt, config['latent_dynamics']) + + # Updating the dynamics callback to account for the higher order terms and other non linear functions from the .yml file. + + latent_dynamics = ld_dict[ld_type](latent_space.n_z, config['latent_dynamics'][ld_type]['higher_order_terms'], config['latent_dynamics'][ld_type]['extra_functions'], physics.nt, config['latent_dynamics']) if (restart_file is not None): latent_dynamics.load(restart_file['latent_dynamics']) @@ -329,4 +334,4 @@ def collect_samples(trainer, config): return result, next_step if __name__ == "__main__": - main() \ No newline at end of file + main()