Skip to content
Draft
952 changes: 439 additions & 513 deletions Manifest.toml

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,37 +39,38 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Tracy = "e689c965-62c8-4b79-b2c5-8359227902fd"

[weakdeps]
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

[sources]
Compiler = {rev = "master", url = "https://github.com/JuliaLang/BaseCompiler.jl.git"}
Cthulhu = {rev = "master", url = "https://github.com/JuliaDebug/Cthulhu.jl.git"}
Cthulhu = {rev = "low-level-interface", url = "https://github.com/serenity4/Cthulhu.jl"}
DifferentiationInterface = {rev = "main", subdir = "DifferentiationInterface", url = "https://github.com/Keno/DifferentiationInterface.jl"}
Diffractor = {rev = "main", url = "https://github.com/JuliaDiff/Diffractor.jl.git"}
Diffractor = {rev = "cthulhu", url = "https://github.com/JuliaDiff/Diffractor.jl.git"}
SimpleNonlinearSolve = {rev = "master", subdir = "lib/SimpleNonlinearSolve", url = "https://github.com/SciML/NonlinearSolve.jl.git"}
StateSelection = {rev = "main", url = "https://github.com/JuliaComputing/StateSelection.jl.git"}

[extensions]
DAECompilerCthulhuExt = ["Compiler", "Cthulhu"]

[compat]
Accessors = "0.1.36"
AutoHashEquals = "2.2.0"
CentralizedCaches = "1.1.0"
ChainRules = "1.50"
ChainRulesCore = "1.20"
Compiler = "0"
Cthulhu = "3.0.0"
DiffEqBase = "6.149.2"
DifferentiationInterface = "0.6.52"
Diffractor = "0.2.7"
DifferentiationInterface = "0.7.9"
ForwardDiff = "0.10.36"
InteractiveUtils = "1.11.0"
NonlinearSolve = "3.5.0, 4"
OrderedCollections = "1.6.3"
PrecompileTools = "1"
Preferences = "1.4"
SciMLBase = "2.86.2"
SimpleNonlinearSolve = "2.3.0"
StateSelection = "0.2.0"
StaticArraysCore = "1.4.2"
Sundials = "4.19"
Expand Down
144 changes: 144 additions & 0 deletions ext/DAECompilerCthulhuExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
module DAECompilerCthulhuExt

using Core.IR
using DAECompiler: DAECompiler, DAEIPOResult, UncompilableIPOResult, Settings, ADAnalyzer, structural_analysis!, find_matching_ci, matched_system_structure, StructureCache, ir_to_src, get_method_instance, MappingInfo, AnalyzedSource
using Compiler: Compiler, InferenceResult, NativeInterpreter, SOURCE_MODE_GET_SOURCE, get_inference_world, typeinf_ext, Effects, get_ci_mi, NoCallInfo
using Accessors: setproperties
using Diffractor: FRuleCallInfo

import Cthulhu as _Cthulhu
const Cthulhu = Base.get_extension(_Cthulhu, :CthulhuCompilerExt)
using .Cthulhu: CthulhuState, AbstractProvider, Command, InferenceKey, InferenceDict, PC2Remarks, PC2CallMeta, PC2Effects, PC2Excts, LookupResult, generate_code_instance, value_for_default_command

mutable struct DAEProvider <: AbstractProvider
world::UInt
settings::Settings
remarks::InferenceDict{PC2Remarks}
calls::InferenceDict{PC2CallMeta}
effects::InferenceDict{PC2Effects}
exception_types::InferenceDict{PC2Excts}
end
DAEProvider(; world = Base.tls_world_age(), settings = Settings()) = DAEProvider(world, settings, InferenceDict{PC2Remarks}(), InferenceDict{PC2CallMeta}(), InferenceDict{PC2Effects}(), InferenceDict{PC2Excts}())

Cthulhu.get_inference_world(provider::DAEProvider) = provider.world

function Cthulhu.find_method_instance(provider::DAEProvider, @nospecialize(tt::Type{<:Tuple}), world::UInt)
return get_method_instance(tt, world)
end

function check_result(ci::CodeInstance)
isa(ci.inferred, UncompilableIPOResult) && throw(ci.inferred.error)
return true
end

function Cthulhu.generate_code_instance(provider::DAEProvider, mi::MethodInstance)
world = get_inference_world(provider)
ci = find_matching_ci(ci->ci.owner == StructureCache(), mi, world)
# XXX: We should not cache the CodeInstance this way, or at least invalidate in the provider in `toggle_setting!`.
if ci !== nothing
haskey(provider.remarks, ci) && return ci
else
provider.settings.force_inline_all && @warn "`force_inline_all=true` is not supported yet; this setting will be ignored"
analyzer = ADAnalyzer(; world)
ci_pre = typeinf_ext(analyzer, mi, SOURCE_MODE_GET_SOURCE)
result = structural_analysis!(ci_pre, world, provider.settings)
ci = find_matching_ci(ci->ci.owner == StructureCache(), mi, world)::CodeInstance
end

check_result(ci)
provider.remarks[ci] = PC2Remarks()
provider.calls[ci] = PC2CallMeta()
provider.effects[ci] = PC2Effects()
provider.exception_types[ci] = PC2Excts()

@eval Main global result = $(ci.inferred)

return ci
end

Cthulhu.get_override(provider::DAEProvider, @nospecialize(info)) = nothing

Cthulhu.get_pc_remarks(provider::DAEProvider, key::CodeInstance) = get(provider.remarks, key, nothing)
Cthulhu.get_pc_effects(provider::DAEProvider, key::CodeInstance) = get(provider.effects, key, nothing)
Cthulhu.get_pc_excts(provider::DAEProvider, key::CodeInstance) = get(provider.exception_types, key, nothing)

function Cthulhu.LookupResult(provider::DAEProvider, ci::CodeInstance, optimize::Bool)
if isa(ci.inferred, AnalyzedSource)
mi = get_ci_mi(ci)
new_ci = generate_code_instance(provider, mi)
check_result(new_ci)
@assert isa(new_ci.inferred, DAEIPOResult) "Inferred type of newly generated `CodeInstance` must be `DAEIPOResult`, got `$(typeof(new_ci.inferred))`"
return LookupResult(provider, new_ci, optimize)
end
result = ci.inferred::DAEIPOResult
ir = copy(result.ir)
pushfirst!(ir.argtypes, Tuple)
src = ir_to_src(ir, provider.settings; widen = false)
src.ssavaluetypes = copy(ir.stmts.type)
src.min_world = @atomic ci.min_world
src.max_world = @atomic ci.max_world
optimized = true
rt = Cthulhu.cached_return_type(ci)
exct = Cthulhu.cached_exception_type(ci)
infos = widen_call_infos(ir.stmts.info)
return LookupResult(ir, src, rt, exct, infos, src.slottypes, Cthulhu.get_effects(ci), optimized)
end

function widen_call_infos(infos)
infos = copy(infos)
for (i, info) in enumerate(infos)
while true
isa(info, FRuleCallInfo) && (info = info.info; continue)
isa(info, MappingInfo) && (info = info.info; continue)
break
end
infos[i] = info
end
return infos
end

function toggle_setting(provider::DAEProvider, setting::Symbol, value)
return setproperties(provider.settings, NamedTuple((setting => value,)))
end

function Cthulhu.menu_commands(provider::DAEProvider)
commands = Cthulhu.default_menu_commands(provider)
filter!(x -> !in(x.name, (:optimize, :dump_params, :llvm, :native)), commands)
push!(commands, toggle_setting(provider, 'f', :force_inline_all, "force inline all"))
push!(commands, Cthulhu.perform_action(show_mss, 'm', :show_mss, :actions, "Show system structure"))
return commands
end

function show_mss(state::CthulhuState)
result = state.ci.inferred::DAEIPOResult
terminal = state.terminal
io = terminal.out_stream::IO
mss = matched_system_structure(result, state.provider.settings.mode)
(_, width) = displaysize(terminal)
printstyled(io, '\n', '-'^((width - 26) ÷ 2), " Showing system structure ", '-'^((width - 26) ÷ 2), '\n'; color = :light_black)
show(io, MIME"text/plain"(), mss)
printstyled(io, '\n', '-'^width, "\n\n"; color = :light_black)
end

function toggle_setting(provider::DAEProvider, key::Char, name::Symbol, description::String = string(name))
callback = state -> toggle_setting!(state, name)
Command(callback, key, name, description, :toggles)
end

function Cthulhu.value_for_command(provider::DAEProvider, state::CthulhuState, command::Command)
hasproperty(provider.settings, command.name) &&
return getproperty(provider.settings, command.name)
return value_for_default_command(provider, state, command)
end

function toggle_setting!(state::CthulhuState, name::Symbol)
(; provider) = state
(; settings) = provider
value = !getproperty(settings, name)::Bool
provider.settings = setproperties(settings, NamedTuple((name => value,)))
state.display_code = true
end

DAECompiler.dae_provider(args...; kwargs...) = DAEProvider(args...; kwargs...)

end # module
2 changes: 2 additions & 0 deletions src/DAECompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,6 @@ module DAECompiler
include("analysis/consistency.jl")
include("interface.jl")
include("problem_interface.jl")

export dae_provider # use with Cthulhu, `@descend provider=dae_provider() pingpong()`
end
2 changes: 1 addition & 1 deletion src/analysis/ADAnalyzer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ end
return AnalyzedSource(ir, slotnames, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva)
end

@override function Compiler.transform_result_for_local_cache(interp::ADAnalyzer, result::InferenceResult)
@override function Compiler.transform_result_for_local_cache(interp::ADAnalyzer, result::InferenceResult, edges::SimpleVector)
if Compiler.result_is_constabi(interp, result)
return nothing
end
Expand Down
17 changes: 17 additions & 0 deletions src/analysis/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,20 @@ function make_structure_from_ipo(ipo::DAEIPOResult)

structure = DAESystemStructure(StateSelection.complete(var_to_diff), StateSelection.complete(eq_to_diff), graph, solvable_graph)
end

function matched_system_structure(result::DAEIPOResult, mode)
structure = make_structure_from_ipo(result)

tstate = TransformationState(result, structure)
err = StateSelection.check_consistency(tstate, nothing)
err !== nothing && throw(err)

ret = top_level_state_selection!(tstate)
isa(ret, UncompilableIPOResult) && throw(ret.error)

(diff_key, init_key) = ret
key = in(mode, (DAE, DAENoInit, ODE, ODENoInit)) ? diff_key : init_key

var_eq_matching = matching_for_key(tstate, key)
return StateSelection.MatchedSystemStructure(result, structure, var_eq_matching)
end
3 changes: 3 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,6 @@ function refresh()
return nothing
end
refresh()

# methods are to be added via the Cthulhu extension
function dae_provider end
17 changes: 1 addition & 16 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,7 @@ function code_structure_by_type(@nospecialize(tt::Type); world::UInt = Base.tls_
_result = structural_analysis!(ci, world, settings)
isa(_result, UncompilableIPOResult) && throw(_result.error)
!matched && return result ? _result : _result.ir
result = _result

structure = make_structure_from_ipo(result)

tstate = TransformationState(result, structure)
err = StateSelection.check_consistency(tstate, nothing)
err !== nothing && throw(err)

ret = top_level_state_selection!(tstate)
isa(ret, UncompilableIPOResult) && throw(ret.error)

(diff_key, init_key) = ret
key = in(mode, (DAE, DAENoInit, ODE, ODENoInit)) ? diff_key : init_key

var_eq_matching = matching_for_key(tstate, key)
return StateSelection.MatchedSystemStructure(result, structure, var_eq_matching)
return matched_system_structure(_result, mode)
end

"""
Expand Down
13 changes: 5 additions & 8 deletions src/transform/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,23 @@ function widen_extra_info!(ir)
end
end

function ir_to_src(ir::IRCode, settings::Settings; slotnames = nothing)
isva = false
function ir_to_src(ir::IRCode, settings::Settings; slotnames = nothing, widen = true, isva = false)
ir.debuginfo.def === nothing && (ir.debuginfo.def = :var"generated IR for OpaqueClosure")
maybe_rewrite_debuginfo!(ir, settings)
nargtypes = length(ir.argtypes)
nargs = nargtypes-1
sig = Compiler.compute_oc_signature(ir, nargs, isva)
rt = Compiler.compute_ir_rettype(ir)
src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
if slotnames === nothing
src.slotnames = Symbol[Symbol("arg$i") for i = 1:nargtypes]
else
length(slotnames) == nargtypes || error("mismatched `argtypes` and `slotnames`")
src.slotnames = slotnames
end
src.nargs = length(ir.argtypes)
src.isva = false
src.nargs = nargtypes
src.isva = isva
src.slotflags = fill(zero(UInt8), nargtypes)
src.slottypes = copy(ir.argtypes)
src = Compiler.ir_to_codeinf!(src, ir)
Compiler.replace_code_newstyle!(src, ir)
widen && Compiler.widen_all_consts!(src)
return src
end

Expand Down
4 changes: 0 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,10 @@ REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StateSelection = "64909d44-ed92-46a8-bbd9-f047dfbdc84b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
XSteam = "95ff35a0-be81-11e9-2ca3-5b4e338e8476"

[sources]
SciMLSensitivity = {rev = "kf/mindep4", url = "https://github.com/CedarEDA/SciMLSensitivity.jl"}
Loading