Skip to content

Commit f2497e9

Browse files
authored
Merged #9 into main
* Reviewer edits and clean up
1 parent 776829e commit f2497e9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2960
-0
lines changed
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
2+
# Fourier Neural Operator for Nonlinear Pendulum
3+
4+
This example demonstrates how to use a Fourier Neural Operator (FNO) to model the dynamics of a nonlinear pendulum with a time\-dependent forcing term. The governing equation is given by the following second\-order ordinary differential equation (ODE):
5+
6+
7+
$\ddot{\theta} (t)=-\omega_0^2 \sin (\theta (t))+f(t)$,
8+
9+
10+
where $\theta$ is the angular position, $\omega_0$ is the natural frequency, and $f$ is the an external forcing function.
11+
12+
13+
Traditionally, solving this system for a new forcing function $f$ requires numerically solving the ODE from scratch. In constrast, an FNO learns a mapping from the input forcing function $f$ to the corresponding solution $\theta$, enabling rapid inference without re\-solving the equation.
14+
15+
16+
The FNO is trained on a dataset of input\-output pairs $(f_i (t),\theta_i (t))$, and once trained, it can predict $\theta$ for unseen forcing functions. While this approach may be excessive for simple ODEs like the pendulum, it can be advantageous for complex simulations involving partial differential equations (PDEs), where traditional solvers (e.g. finite element method) are computationally expensive.
17+
18+
# Generate or Import Data
19+
```matlab
20+
% Get the path to the main directory
21+
mainDir = findProjectRoot('generatePendulumDataFNO.m');
22+
% If first time generating data, set generateData to true. Else, set to false.
23+
res = 512;
24+
generateData = 1;
25+
if generateData
26+
g = 9.81; r = 1;
27+
omega0 = sqrt(g/r);
28+
x0 = [0.5;0.5*omega0];
29+
numSamples = 2000;
30+
doPlot = 0;
31+
generatePendulumDataFNO(omega0,x0,numSamples,res,doPlot);
32+
end
33+
```
34+
35+
```matlabTextOutput
36+
FNO data of resolution 512 written to fno_data_512.mat
37+
```
38+
39+
```matlab
40+
% Construct full path to the data file
41+
dataFile = fullfile(mainDir, 'pendulumData', sprintf('fno_data_%d.mat',res));
42+
% Read the data
43+
load(dataFile,'data');
44+
res = 512;
45+
46+
f = data.fSamples;
47+
theta = data.thetaSamples;
48+
t = data.tGrid;
49+
50+
% normalize and center the data
51+
fMean = mean(f, 'all');
52+
fStd = std(f, 0, 'all');
53+
thetaMean = mean(theta, 'all');
54+
thetaStd = std(theta, 0, 'all');
55+
f = (f - fMean) / fStd;
56+
theta = (theta - thetaMean) / thetaStd;
57+
58+
% visualize some of the training data
59+
numPlots = 3;
60+
figure
61+
tiledlayout(numPlots,2)
62+
for i = 1:numPlots
63+
nexttile
64+
plot(t,f(i,:));
65+
title("Observation " + i + newline + "Forcing Function")
66+
xlabel("$t$",Interpreter='latex');
67+
ylabel("$f(t)$",Interpreter='latex');
68+
69+
nexttile
70+
plot(t,theta(i,:));
71+
title("ODE Solution")
72+
xlabel("$t$",Interpreter='latex')
73+
ylabel("$\theta(t)$",Interpreter='latex')
74+
end
75+
```
76+
77+
![Training data](FNO_README_media/figure_0.png)
78+
# Prepare Training Data
79+
```matlab
80+
numObservations = size(f,1);
81+
[idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations,[0.8 0.1 0.1]);
82+
fTrain = f(idxTrain,:);
83+
thetaTrain = theta(idxTrain,:);
84+
fValidation = f(idxValidation,:);
85+
thetaValidation = theta(idxValidation,:);
86+
fTest = f(idxTest,:);
87+
thetaTest = theta(idxTest,:);
88+
```
89+
90+
FNO requires input data which contains spatio\-temporal information. Concatenate the grid with the input data.
91+
92+
```matlab
93+
tGridTrain = repmat(t, [numel(idxTrain) 1]);
94+
tGridValidation = repmat(t, [numel(idxValidation) 1]);
95+
fTrain = cat(3,fTrain,tGridTrain);
96+
fValidation = cat(3, fValidation, tGridValidation);
97+
98+
size(fTrain)
99+
```
100+
101+
```matlabTextOutput
102+
ans = 1x3
103+
1600 512 2
104+
105+
```
106+
107+
```matlab
108+
size(fValidation)
109+
```
110+
111+
```matlabTextOutput
112+
ans = 1x3
113+
200 512 2
114+
115+
```
116+
117+
# Define Neural Network Architecture
118+
119+
Network consists of multiple Fourier\-GeLU blocks connected in series.
120+
121+
```matlab
122+
numModes = 8;
123+
tWidth = 32;
124+
numBlocks = 4;
125+
126+
fourierBlock = [
127+
fourierLayer(numModes,tWidth)
128+
geluLayer];
129+
130+
layers = [
131+
inputLayer([NaN tWidth 2],"BSC");
132+
convolution1dLayer(1,tWidth,Name="fc0")
133+
repmat(fourierBlock,[numBlocks 1])
134+
convolution1dLayer(1,128)
135+
geluLayer
136+
convolution1dLayer(1,1)];
137+
```
138+
# Specify Training Options
139+
```matlab
140+
schedule = piecewiseLearnRate(DropFactor=0.5);
141+
142+
options = trainingOptions("adam", ...
143+
InitialLearnRate=1e-3, ...
144+
LearnRateSchedule=schedule, ...
145+
MaxEpochs=10, ...
146+
MiniBatchSize=64, ...
147+
Shuffle="every-epoch", ...
148+
InputDataFormats="BSC", ...
149+
Plots="training-progress", ...
150+
ValidationData={fValidation,thetaValidation}, ...
151+
Verbose=false);
152+
```
153+
# Train the Network
154+
```matlab
155+
net = trainnet(fTrain,thetaTrain,layers,"mse",options);
156+
```
157+
158+
![Training progress](FNO_README_media/figure_1.png)
159+
# Test the Model and Visualize Results
160+
```matlab
161+
tGridTest = repmat(t, [numel(idxTest) 1]);
162+
163+
fTest = cat(3,fTest,tGridTest);
164+
mseTest = testnet(net,fTest,thetaTest,"mse")
165+
```
166+
167+
```matlabTextOutput
168+
mseTest = 0.0079
169+
```
170+
171+
172+
Visualize the predictions on the test set
173+
174+
```matlab
175+
Y = minibatchpredict(net, fTest);
176+
numTestPlots = 3;
177+
for i = 1:numTestPlots
178+
figure();
179+
plot(t,fTest(i,:,1),LineWidth=2.5)
180+
title("Forcing Function")
181+
xlabel("$t$",Interpreter="latex")
182+
ylabel("$f(t)$",Interpreter="latex")
183+
set(gca,FontSize=14,LineWidth=2.5)
184+
figure();
185+
plot(t,Y(i,:),'b-',LineWidth=2.5,DisplayName='FNO'); hold on
186+
plot(t,thetaTest(i,:),'k--',LineWidth=2.5,DisplayName='True Solution'); hold off
187+
title("Angular Position")
188+
xlabel("$t$",Interpreter="latex")
189+
ylabel("$\theta(t)$",Interpreter="latex")
190+
legend(Location='best');
191+
set(gca,FontSize=14,LineWidth=2.5)
192+
end
193+
```
194+
195+
![Test f 1](FNO_README_media/figure_2.png)
196+
197+
![Solution plot 1](FNO_README_media/figure_3.png)
198+
199+
![Test f 2](FNO_README_media/figure_4.png)
200+
201+
![Solution plot 2](FNO_README_media/figure_5.png)
202+
203+
![Test f 3](FNO_README_media/figure_6.png)
204+
205+
![Solution plot 3](FNO_README_media/figure_7.png)
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Binary file not shown.

0 commit comments

Comments
 (0)