|
| 1 | + |
| 2 | +# Fourier Neural Operator for Nonlinear Pendulum |
| 3 | + |
| 4 | +This example demonstrates how to use a Fourier Neural Operator (FNO) to model the dynamics of a nonlinear pendulum with a time\-dependent forcing term. The governing equation is given by the following second\-order ordinary differential equation (ODE): |
| 5 | + |
| 6 | + |
| 7 | + $\ddot{\theta} (t)=-\omega_0^2 \sin (\theta (t))+f(t)$, |
| 8 | + |
| 9 | + |
| 10 | +where $\theta$ is the angular position, $\omega_0$ is the natural frequency, and $f$ is the an external forcing function. |
| 11 | + |
| 12 | + |
| 13 | +Traditionally, solving this system for a new forcing function $f$ requires numerically solving the ODE from scratch. In constrast, an FNO learns a mapping from the input forcing function $f$ to the corresponding solution $\theta$, enabling rapid inference without re\-solving the equation. |
| 14 | + |
| 15 | + |
| 16 | +The FNO is trained on a dataset of input\-output pairs $(f_i (t),\theta_i (t))$, and once trained, it can predict $\theta$ for unseen forcing functions. While this approach may be excessive for simple ODEs like the pendulum, it can be advantageous for complex simulations involving partial differential equations (PDEs), where traditional solvers (e.g. finite element method) are computationally expensive. |
| 17 | + |
| 18 | +# Generate or Import Data |
| 19 | +```matlab |
| 20 | +% Get the path to the main directory |
| 21 | +mainDir = findProjectRoot('generatePendulumDataFNO.m'); |
| 22 | +% If first time generating data, set generateData to true. Else, set to false. |
| 23 | +res = 512; |
| 24 | +generateData = 1; |
| 25 | +if generateData |
| 26 | + g = 9.81; r = 1; |
| 27 | + omega0 = sqrt(g/r); |
| 28 | + x0 = [0.5;0.5*omega0]; |
| 29 | + numSamples = 2000; |
| 30 | + doPlot = 0; |
| 31 | + generatePendulumDataFNO(omega0,x0,numSamples,res,doPlot); |
| 32 | +end |
| 33 | +``` |
| 34 | + |
| 35 | +```matlabTextOutput |
| 36 | +FNO data of resolution 512 written to fno_data_512.mat |
| 37 | +``` |
| 38 | + |
| 39 | +```matlab |
| 40 | +% Construct full path to the data file |
| 41 | +dataFile = fullfile(mainDir, 'pendulumData', sprintf('fno_data_%d.mat',res)); |
| 42 | +% Read the data |
| 43 | +load(dataFile,'data'); |
| 44 | +res = 512; |
| 45 | +
|
| 46 | +f = data.fSamples; |
| 47 | +theta = data.thetaSamples; |
| 48 | +t = data.tGrid; |
| 49 | +
|
| 50 | +% normalize and center the data |
| 51 | +fMean = mean(f, 'all'); |
| 52 | +fStd = std(f, 0, 'all'); |
| 53 | +thetaMean = mean(theta, 'all'); |
| 54 | +thetaStd = std(theta, 0, 'all'); |
| 55 | +f = (f - fMean) / fStd; |
| 56 | +theta = (theta - thetaMean) / thetaStd; |
| 57 | +
|
| 58 | +% visualize some of the training data |
| 59 | +numPlots = 3; |
| 60 | +figure |
| 61 | +tiledlayout(numPlots,2) |
| 62 | +for i = 1:numPlots |
| 63 | + nexttile |
| 64 | + plot(t,f(i,:)); |
| 65 | + title("Observation " + i + newline + "Forcing Function") |
| 66 | + xlabel("$t$",Interpreter='latex'); |
| 67 | + ylabel("$f(t)$",Interpreter='latex'); |
| 68 | +
|
| 69 | + nexttile |
| 70 | + plot(t,theta(i,:)); |
| 71 | + title("ODE Solution") |
| 72 | + xlabel("$t$",Interpreter='latex') |
| 73 | + ylabel("$\theta(t)$",Interpreter='latex') |
| 74 | +end |
| 75 | +``` |
| 76 | + |
| 77 | + |
| 78 | +# Prepare Training Data |
| 79 | +```matlab |
| 80 | +numObservations = size(f,1); |
| 81 | +[idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations,[0.8 0.1 0.1]); |
| 82 | +fTrain = f(idxTrain,:); |
| 83 | +thetaTrain = theta(idxTrain,:); |
| 84 | +fValidation = f(idxValidation,:); |
| 85 | +thetaValidation = theta(idxValidation,:); |
| 86 | +fTest = f(idxTest,:); |
| 87 | +thetaTest = theta(idxTest,:); |
| 88 | +``` |
| 89 | + |
| 90 | +FNO requires input data which contains spatio\-temporal information. Concatenate the grid with the input data. |
| 91 | + |
| 92 | +```matlab |
| 93 | +tGridTrain = repmat(t, [numel(idxTrain) 1]); |
| 94 | +tGridValidation = repmat(t, [numel(idxValidation) 1]); |
| 95 | +fTrain = cat(3,fTrain,tGridTrain); |
| 96 | +fValidation = cat(3, fValidation, tGridValidation); |
| 97 | +
|
| 98 | +size(fTrain) |
| 99 | +``` |
| 100 | + |
| 101 | +```matlabTextOutput |
| 102 | +ans = 1x3 |
| 103 | + 1600 512 2 |
| 104 | +
|
| 105 | +``` |
| 106 | + |
| 107 | +```matlab |
| 108 | +size(fValidation) |
| 109 | +``` |
| 110 | + |
| 111 | +```matlabTextOutput |
| 112 | +ans = 1x3 |
| 113 | + 200 512 2 |
| 114 | +
|
| 115 | +``` |
| 116 | + |
| 117 | +# Define Neural Network Architecture |
| 118 | + |
| 119 | +Network consists of multiple Fourier\-GeLU blocks connected in series. |
| 120 | + |
| 121 | +```matlab |
| 122 | +numModes = 8; |
| 123 | +tWidth = 32; |
| 124 | +numBlocks = 4; |
| 125 | +
|
| 126 | +fourierBlock = [ |
| 127 | + fourierLayer(numModes,tWidth) |
| 128 | + geluLayer]; |
| 129 | +
|
| 130 | +layers = [ |
| 131 | + inputLayer([NaN tWidth 2],"BSC"); |
| 132 | + convolution1dLayer(1,tWidth,Name="fc0") |
| 133 | + repmat(fourierBlock,[numBlocks 1]) |
| 134 | + convolution1dLayer(1,128) |
| 135 | + geluLayer |
| 136 | + convolution1dLayer(1,1)]; |
| 137 | +``` |
| 138 | +# Specify Training Options |
| 139 | +```matlab |
| 140 | +schedule = piecewiseLearnRate(DropFactor=0.5); |
| 141 | +
|
| 142 | +options = trainingOptions("adam", ... |
| 143 | + InitialLearnRate=1e-3, ... |
| 144 | + LearnRateSchedule=schedule, ... |
| 145 | + MaxEpochs=10, ... |
| 146 | + MiniBatchSize=64, ... |
| 147 | + Shuffle="every-epoch", ... |
| 148 | + InputDataFormats="BSC", ... |
| 149 | + Plots="training-progress", ... |
| 150 | + ValidationData={fValidation,thetaValidation}, ... |
| 151 | + Verbose=false); |
| 152 | +``` |
| 153 | +# Train the Network |
| 154 | +```matlab |
| 155 | +net = trainnet(fTrain,thetaTrain,layers,"mse",options); |
| 156 | +``` |
| 157 | + |
| 158 | + |
| 159 | +# Test the Model and Visualize Results |
| 160 | +```matlab |
| 161 | +tGridTest = repmat(t, [numel(idxTest) 1]); |
| 162 | +
|
| 163 | +fTest = cat(3,fTest,tGridTest); |
| 164 | +mseTest = testnet(net,fTest,thetaTest,"mse") |
| 165 | +``` |
| 166 | + |
| 167 | +```matlabTextOutput |
| 168 | +mseTest = 0.0079 |
| 169 | +``` |
| 170 | + |
| 171 | + |
| 172 | +Visualize the predictions on the test set |
| 173 | + |
| 174 | +```matlab |
| 175 | +Y = minibatchpredict(net, fTest); |
| 176 | +numTestPlots = 3; |
| 177 | +for i = 1:numTestPlots |
| 178 | + figure(); |
| 179 | + plot(t,fTest(i,:,1),LineWidth=2.5) |
| 180 | + title("Forcing Function") |
| 181 | + xlabel("$t$",Interpreter="latex") |
| 182 | + ylabel("$f(t)$",Interpreter="latex") |
| 183 | + set(gca,FontSize=14,LineWidth=2.5) |
| 184 | + figure(); |
| 185 | + plot(t,Y(i,:),'b-',LineWidth=2.5,DisplayName='FNO'); hold on |
| 186 | + plot(t,thetaTest(i,:),'k--',LineWidth=2.5,DisplayName='True Solution'); hold off |
| 187 | + title("Angular Position") |
| 188 | + xlabel("$t$",Interpreter="latex") |
| 189 | + ylabel("$\theta(t)$",Interpreter="latex") |
| 190 | + legend(Location='best'); |
| 191 | + set(gca,FontSize=14,LineWidth=2.5) |
| 192 | +end |
| 193 | +``` |
| 194 | + |
| 195 | + |
| 196 | + |
| 197 | + |
| 198 | + |
| 199 | + |
| 200 | + |
| 201 | + |
| 202 | + |
| 203 | + |
| 204 | + |
| 205 | + |
0 commit comments