Skip to content

Commit a6c3459

Browse files
author
manyfeatures
committed
adds lagrangian nn and simlple example
1 parent 9e1182e commit a6c3459

File tree

3 files changed

+73
-1
lines changed

3 files changed

+73
-1
lines changed

docs/src/examples/lagrangian_nn.jl

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# One point test
2+
using Flux, ReverseDiff, LagrangianNN
3+
4+
m, k, b = 1, 1, 1
5+
6+
X = rand(2,1)
7+
Y = -k.*X[1]/m
8+
9+
g = Chain(Dense(2, 10, σ), Dense(10,1))
10+
model = LagrangianNN(g)
11+
params = model.params
12+
re = model.re
13+
14+
# some toy loss function
15+
function loss(x, y, p)
16+
nn = x -> model(x,p)
17+
out = sum((y .- (nn(x))).^2)
18+
out
19+
end
20+
opt = ADAM(0.01)
21+
epochs = 100
22+
23+
for epoch in 1:epochs
24+
x, y = X, Y
25+
gs = ReverseDiff.gradient(p -> loss(x, y, p), params)
26+
Flux.Optimise.update!(opt, params, gs)
27+
@show loss(x,y,params)
28+
end

src/DiffEqFlux.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module DiffEqFlux
22

33
using GalacticOptim, DataInterpolations, DiffEqBase, DiffResults, DiffEqSensitivity,
44
Distributions, ForwardDiff, Flux, Requires, Adapt, LinearAlgebra, RecursiveArrayTools,
5-
StaticArrays, Base.Iterators, Printf, Zygote
5+
StaticArrays, Base.Iterators, Printf, Zygote, GenericLinearAlgebra
66

77
using DistributionsAD
88
import ProgressLogging, ZygoteRules
@@ -82,11 +82,13 @@ include("tensor_product_basis.jl")
8282
include("tensor_product_layer.jl")
8383
include("collocation.jl")
8484
include("hnn.jl")
85+
include("lnn.jl")
8586
include("multiple_shooting.jl")
8687

8788
export diffeq_fd, diffeq_rd, diffeq_adjoint
8889
export DeterministicCNF, FFJORD, NeuralODE, NeuralDSDE, NeuralSDE, NeuralCDDE, NeuralDAE, NeuralODEMM, TensorLayer, AugmentedNDELayer, SplineLayer, NeuralHamiltonianDE
8990
export HamiltonianNN
91+
export LagrangianNN
9092
export ChebyshevBasis, SinBasis, CosBasis, FourierBasis, LegendreBasis, PolynomialBasis
9193
export neural_ode, neural_ode_rd
9294
export neural_dmsde

src/lnn.jl

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""
2+
Constructs a Lagrangian Neural Network [1].
3+
4+
References:
5+
[1] Miles Cranmer, Sam Greydanus, Stephan Hoyer, Peter Battaglia, David Spergel, and Shirley Ho.Lagrangian Neural Networks.
6+
InICLR 2020 Workshop on Integration of Deep Neural Modelsand Differential Equations, 2020.
7+
"""
8+
9+
struct LagrangianNN
10+
model
11+
re
12+
params
13+
14+
# Define inner constructor method
15+
function LagrangianNN(model; p = nothing)
16+
_p, re = Flux.destructure(model)
17+
if p === nothing
18+
p = _p
19+
end
20+
return new(model, re, p)
21+
end
22+
end
23+
24+
function (nn::LagrangianNN)(x, p = nn.params)
25+
@assert size(x,1) % 2 === 0 # velocity df should be equal to coords degree of freedom
26+
M = div(size(x,1), 2) # number of velocities degrees of freedom
27+
re = nn.re
28+
hess = x -> Zygote.hessian_reverse(x->sum(re(p)(x)), x) # we have to compute the whole hessian
29+
hess = hess(x)[M+1:end, M+1:end] # takes only velocities
30+
inv_hess = GenericLinearAlgebra.pinv(hess)
31+
32+
_grad_q = x -> Zygote.gradient(x->sum(re(p)(x)), x)[end]
33+
_grad_q = _grad_q(x)[1:M,:] # take only coord derivatives
34+
out1 =_grad_q
35+
36+
# Second term
37+
_grad_qv = x -> Zygote.gradient(x->sum(re(p)(x)), x)[end]
38+
_jac_qv = x -> Zygote.jacobian(x->_grad_qv(x), x)[end]
39+
out2 = _jac_qv(x)[1:M,M+1:end] * x[M+1:end] # take only dqdq_dot derivatives
40+
41+
return inv_hess * (out1 .+ out2)
42+
end

0 commit comments

Comments
 (0)