Skip to content

Commit bf8d242

Browse files
Merge pull request #3997 from AayushSabharwal/as/fix-subset-tunables-ad
fix: fix generated `getindex` and `length` for `MTKParameters`
2 parents e822248 + 818c8c2 commit bf8d242

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

src/systems/parameter_buffer.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -867,10 +867,10 @@ end
867867
@generated function Base.getindex(
868868
ps::MTKParameters{T, I, D, C, N, H}, idx::Int) where {T, I, D, C, N, H}
869869
paths = []
870-
if !(T <: SizedVector{0, Float64})
870+
if !(T <: SizedVector{0})
871871
push!(paths, :(ps.tunable))
872872
end
873-
if !(I <: SizedVector{0, Float64})
873+
if !(I <: SizedVector{0})
874874
push!(paths, :(ps.initials))
875875
end
876876
for i in 1:fieldcount(D)
@@ -897,10 +897,10 @@ end
897897
@generated function Base.length(ps::MTKParameters{
898898
T, I, D, C, N, H}) where {T, I, D, C, N, H}
899899
len = 0
900-
if !(T <: SizedVector{0, Float64})
900+
if !(T <: SizedVector{0})
901901
len += 1
902902
end
903-
if !(I <: SizedVector{0, Float64})
903+
if !(I <: SizedVector{0})
904904
len += 1
905905
end
906906
len += fieldcount(D) + fieldcount(C) + fieldcount(N) + fieldcount(H)

test/mtkparameters.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using ModelingToolkit
22
using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
33
using SymbolicIndexingInterface, StaticArrays
44
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants
5+
using ModelingToolkitStandardLibrary.Electrical, ModelingToolkitStandardLibrary.Blocks
56
using BlockArrays: BlockedArray, BlockedVector, Block
67
using OrdinaryDiffEq
78
using ForwardDiff
@@ -379,3 +380,52 @@ with_updated_parameter_timeseries_values(
379380
ps2 = remake_buffer(sys, ps, [p], [:a])
380381
@test ps2.nonnumeric isa Tuple{Vector{Any}}
381382
end
383+
384+
@testset "Issue#3925: Autodiff after `subset_tunables`" begin
385+
function circuit_model()
386+
@named resistor1 = Resistor(R=5.0)
387+
@named resistor2 = Resistor(R=2.0)
388+
@named capacitor1 = Capacitor(C=2.4)
389+
@named capacitor2 = Capacitor(C=60.0)
390+
@named source = Voltage()
391+
@named input_signal = Sine(frequency=1.0)
392+
@named ground = Ground()
393+
@named ampermeter = CurrentSensor()
394+
395+
eqs = [connect(input_signal.output, source.V)
396+
connect(source.p, capacitor1.n, capacitor2.n)
397+
connect(source.n, resistor1.p, resistor2.p, ground.g)
398+
connect(resistor1.n, capacitor1.p, ampermeter.n)
399+
connect(resistor2.n, capacitor2.p, ampermeter.p)]
400+
401+
@named circuit_model = System(eqs, t,
402+
systems=[
403+
resistor1, resistor2, capacitor1, capacitor2,
404+
source, input_signal, ground, ampermeter
405+
])
406+
end
407+
408+
model = circuit_model()
409+
sys = mtkcompile(model)
410+
411+
tunable_parameters(sys)
412+
413+
sub_sys = subset_tunables(sys, [sys.capacitor2.C])
414+
415+
tunable_parameters(sub_sys)
416+
417+
prob = ODEProblem(sub_sys, [sys.capacitor2.v => 0.0], (0, 3.))
418+
419+
setter = setsym_oop(prob, [sys.capacitor2.C]);
420+
421+
function loss(x, ps)
422+
setter, prob = ps
423+
u0, p = setter(prob, x)
424+
new_prob = remake(prob; u0, p)
425+
sol = solve(new_prob, Rodas5P())
426+
sum(sol)
427+
end
428+
429+
grad = ForwardDiff.gradient(Base.Fix2(loss, (setter, prob)), [3.0])
430+
@test grad [0.14882627068752538] atol=1e-10
431+
end

0 commit comments

Comments
 (0)