diff --git a/docs/src/examples/LTC_layer.md b/docs/src/examples/LTC_layer.md new file mode 100644 index 0000000000..97ae188a79 --- /dev/null +++ b/docs/src/examples/LTC_layer.md @@ -0,0 +1,50 @@ +using DiffEqFlux, Flux, Plots, Statistics +using Random +Random.seed!(1234); # Fix seed + +N = 48 +π_32 = Float32(π) +t = range(0.0f0,stop=3π_32, length = N) +sin_t = sin.(t) +cos_t = cos.(t) +data_x = vcat(reshape(sin_t,(1,N)), reshape(cos_t,(1,N))) +data_y = reshape(sin.(range(0.0f0,stop=6π_32, length = N)), (1, N)) +data_x = [data_x[:,i] for i=1:N] +data_y = [[data_y[i]] for i=1:N] + +println(size(data_x)) +println(size(data_y)) + + +m = Chain(LTC(2,32), Dense(32,1,x->x)) +function loss_(x,y) + diff = (m.(x) .- y) + diff = [diff[i][1] for i=1:N] + mean(abs2.(diff)) +end + +#callback function to observe training +cb = function () + cur_pred = m.(data_x) + pl = plot(t, [data_y[i][1] for i=1:length(data_y)], label="data") + plot!(pl, t, [cur_pred[i][1] for i=1:length(cur_pred)], label="prediction") + display(plot(pl)) + @show loss_(data_x, data_y) +end + +ps = Flux.params(m); + +opt = Flux.ADAM(0.05) +epochs = 400 +for epoch in 1:epochs + x, y = data_x[:,1], data_y[:,1] + gs = Flux.gradient(ps) do + loss_(x, y) + end + Flux.Optimise.update!(opt, ps, gs) + Flux.reset!(m) + if epoch % 10 == 0 + @show epoch + cb() + end +end diff --git a/src/DiffEqFlux.jl b/src/DiffEqFlux.jl index e9c09e7637..d37b8b9f55 100644 --- a/src/DiffEqFlux.jl +++ b/src/DiffEqFlux.jl @@ -83,10 +83,12 @@ include("tensor_product_layer.jl") include("collocation.jl") include("hnn.jl") include("multiple_shooting.jl") +include("ltc.jl") export diffeq_fd, diffeq_rd, diffeq_adjoint export DeterministicCNF, FFJORD, NeuralODE, NeuralDSDE, NeuralSDE, NeuralCDDE, NeuralDAE, NeuralODEMM, TensorLayer, AugmentedNDELayer, SplineLayer, NeuralHamiltonianDE export HamiltonianNN +export LTC export ChebyshevBasis, SinBasis, CosBasis, FourierBasis, LegendreBasis, PolynomialBasis export neural_ode, neural_ode_rd export neural_dmsde @@ -95,6 +97,7 @@ export FastDense, StaticDense, FastChain, initial_params export EpanechnikovKernel, UniformKernel, TriangularKernel, QuarticKernel export TriweightKernel, TricubeKernel, GaussianKernel, CosineKernel export LogisticKernel, SigmoidKernel, SilvermanKernel + export collocate_data export multiple_shoot diff --git a/src/ltc.jl b/src/ltc.jl new file mode 100644 index 0000000000..a1b184c68b --- /dev/null +++ b/src/ltc.jl @@ -0,0 +1,47 @@ +""" +Constructs a Liquid time-constant Networks [1]. + +References: +[1] Hasani, R., Lechner, M., Amini, A., Rus, D. & Grosu, R. Liquid time-constant +networks. 2020. +""" + +struct LTCCell{F,A,V,S,AB,OU,TA} + σ::F + Wi::A + Wh::A + b::V + A::AB + τ::TA + _ode_unfolds::OU + state0::S + elapsed_time +end + +LTCCell(in::Integer, out::Integer, σ=tanh; init=Flux.glorot_uniform, initb=zeros, init_state=zeros, init_tau=rand, ode_unfolds=6, elapsed_time=1.0) = + LTCCell(σ, init(out, in), init(out, out), initb(out), initb(out), init_tau(out), ode_unfolds, init_state(out,1), elapsed_time) + +Flux.trainable(m::LTCCell) = (m.Wi, m.Wh, m.b, m.A, m.τ,) + +function (m::LTCCell)(h, x) + h = _ode_solver(m::LTCCell, h, x) + out = h + return h, out +end + +function _ode_solver(m::LTCCell, h, x) + σ, Wi, Wh, b, τ, A = m.σ, m.Wi, m.Wh, m.b, m.τ, m.A # assert it is > 0 + τ = Flux.softplus.(τ) # to ensure τ>=0 + Δt = m.elapsed_time/m._ode_unfolds + for t = 1:m._ode_unfolds # FuseStep + f = σ.(Wi*x .+ Wh*h .+ b) + numerator = h .+ Δt .* f .* A + denominator = 1 .+ Δt .* (1 ./ τ .+ f) + h = numerator ./ (denominator .+ 1e-8) # insert epsilon + h = clamp.(h, -1, 1) # to ensure stability + end + return h +end + +LTC(a...; ka...) = Flux.Recur(LTCCell(a...; ka...)) +Flux.Recur(m::LTCCell) = Flux.Recur(m, m.state0)