Skip to content

fix: compile symbolic affects after mtkcompile in complete #3812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no
symbolic_type(v) == ArraySymbolic() &&
Symbolics.shape(v) != Symbolics.Unknown() &&
any(x -> any(isequal(x), fullvars), collect(v)),
vars(a))
vars(
a; op = Union{Differential, Shift, Pre, Sample, Hold, Initial}))
continue
end
else
Expand Down
10 changes: 10 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,16 @@ function complete(
if add_initial_parameters
sys = add_initialization_parameters(sys; split)
end
if has_continuous_events(sys) && is_time_dependent(sys)
@set! sys.continuous_events = complete.(
get_continuous_events(sys); iv = get_iv(sys),
alg_eqs = [alg_equations(sys); observed(sys)])
end
if has_discrete_events(sys) && is_time_dependent(sys)
@set! sys.discrete_events = complete.(
get_discrete_events(sys); iv = get_iv(sys),
alg_eqs = [alg_equations(sys); observed(sys)])
end
end
if split && has_index_cache(sys)
@set! sys.index_cache = IndexCache(sys)
Expand Down
278 changes: 160 additions & 118 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,27 @@ function has_functional_affect(cb)
affects(cb) isa ImperativeAffect
end

struct SymbolicAffect
affect::Vector{Equation}
alg_eqs::Vector{Equation}
discrete_parameters::Vector{Any}
end

function SymbolicAffect(affect::Vector{Equation}; alg_eqs = Equation[],
discrete_parameters = Any[], kwargs...)
if !(discrete_parameters isa AbstractVector)
discrete_parameters = Any[discrete_parameters]
elseif !(discrete_parameters isa Vector{Any})
discrete_parameters = Vector{Any}(discrete_parameters)
end
SymbolicAffect(affect, alg_eqs, discrete_parameters)
end
function SymbolicAffect(affect::SymbolicAffect; kwargs...)
SymbolicAffect(affect.affect; alg_eqs = affect.alg_eqs,
discrete_parameters = affect.discrete_parameters, kwargs...)
end
SymbolicAffect(affect; kwargs...) = make_affect(affect; kwargs...)

struct AffectSystem
"""The internal implicit discrete system whose equations are solved to obtain values after the affect."""
system::AbstractSystem
Expand All @@ -15,6 +36,72 @@ struct AffectSystem
discretes::Vector
end

function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], kwargs...)
AffectSystem(spec.affect; alg_eqs = vcat(spec.alg_eqs, alg_eqs), iv,
discrete_parameters = spec.discrete_parameters, kwargs...)
end

function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[],
iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...)
isempty(affect) && return nothing
if isnothing(iv)
iv = t_nounits
@warn "No independent variable specified. Defaulting to t_nounits."
end

discrete_parameters isa AbstractVector || (discrete_parameters = [discrete_parameters])
discrete_parameters = unwrap.(discrete_parameters)

for p in discrete_parameters
occursin(unwrap(iv), unwrap(p)) ||
error("Non-time dependent parameter $p passed in as a discrete. Must be declared as @parameters $p(t).")
end

dvs = OrderedSet()
params = OrderedSet()
_varsbuf = Set()
for eq in affect
if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic() ||
symbolic_type(eq.lhs) === NotSymbolic())
@warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1."
end
collect_vars!(dvs, params, eq, iv; op = Pre)
empty!(_varsbuf)
vars!(_varsbuf, eq; op = Pre)
filter!(x -> iscall(x) && operation(x) isa Pre, _varsbuf)
union!(params, _varsbuf)
diffvs = collect_applied_operators(eq, Differential)
union!(dvs, diffvs)
end
for eq in alg_eqs
collect_vars!(dvs, params, eq, iv)
end
pre_params = filter(haspre ∘ value, params)
sys_params = collect(setdiff(params, union(discrete_parameters, pre_params)))
discretes = map(tovar, discrete_parameters)
dvs = collect(dvs)
_dvs = map(default_toterm, dvs)

rev_map = Dict(zip(discrete_parameters, discretes))
subs = merge(rev_map, Dict(zip(dvs, _dvs)))
affect = Symbolics.fast_substitute(affect, subs)
alg_eqs = Symbolics.fast_substitute(alg_eqs, subs)

@named affectsys = System(
vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)),
collect(union(pre_params, sys_params)); is_discrete = true)
affectsys = mtkcompile(affectsys; fully_determined = nothing)
# get accessed parameters p from Pre(p) in the callback parameters
accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params))))
union!(accessed_params, sys_params)

# add scalarized unknowns to the map.
_dvs = reduce(vcat, map(scalarize, _dvs), init = Any[])

AffectSystem(affectsys, collect(_dvs), collect(accessed_params),
collect(discrete_parameters))
end

system(a::AffectSystem) = a.system
discretes(a::AffectSystem) = a.discretes
unknowns(a::AffectSystem) = a.unknowns
Expand Down Expand Up @@ -159,40 +246,40 @@ will run as soon as the solver starts, while finalization affects will be execut
"""
struct SymbolicContinuousCallback <: AbstractCallback
conditions::Vector{Equation}
affect::Union{Affect, Nothing}
affect_neg::Union{Affect, Nothing}
initialize::Union{Affect, Nothing}
finalize::Union{Affect, Nothing}
affect::Union{Affect, SymbolicAffect, Nothing}
affect_neg::Union{Affect, SymbolicAffect, Nothing}
initialize::Union{Affect, SymbolicAffect, Nothing}
finalize::Union{Affect, SymbolicAffect, Nothing}
rootfind::Union{Nothing, SciMLBase.RootfindOpt}
reinitializealg::SciMLBase.DAEInitializationAlgorithm
end

function SymbolicContinuousCallback(
conditions::Union{Equation, Vector{Equation}},
affect = nothing;
affect_neg = affect,
initialize = nothing,
finalize = nothing,
rootfind = SciMLBase.LeftRootFind,
reinitializealg = nothing,
kwargs...)
conditions = (conditions isa AbstractVector) ? conditions : [conditions]

if isnothing(reinitializealg)
if any(a -> a isa ImperativeAffect,
[affect, affect_neg, initialize, finalize])
reinitializealg = SciMLBase.CheckInit()
else
reinitializealg = SciMLBase.NoInit()
end
function SymbolicContinuousCallback(
conditions::Union{Equation, Vector{Equation}},
affect = nothing;
affect_neg = affect,
initialize = nothing,
finalize = nothing,
rootfind = SciMLBase.LeftRootFind,
reinitializealg = nothing,
kwargs...)
conditions = (conditions isa AbstractVector) ? conditions : [conditions]

if isnothing(reinitializealg)
if any(a -> a isa ImperativeAffect,
[affect, affect_neg, initialize, finalize])
reinitializealg = SciMLBase.CheckInit()
else
reinitializealg = SciMLBase.NoInit()
end
end

new(conditions, make_affect(affect; kwargs...),
make_affect(affect_neg; kwargs...),
make_affect(initialize; kwargs...), make_affect(
finalize; kwargs...),
rootfind, reinitializealg)
end # Default affect to nothing
end
SymbolicContinuousCallback(conditions, SymbolicAffect(affect; kwargs...),
SymbolicAffect(affect_neg; kwargs...),
SymbolicAffect(initialize; kwargs...), SymbolicAffect(
finalize; kwargs...),
rootfind, reinitializealg)
end # Default affect to nothing

function SymbolicContinuousCallback(p::Pair, args...; kwargs...)
SymbolicContinuousCallback(p[1], p[2], args...; kwargs...)
Expand All @@ -207,71 +294,18 @@ function SymbolicContinuousCallback(cb::Tuple, args...; kwargs...)
end
end

function complete(cb::SymbolicContinuousCallback; kwargs...)
SymbolicContinuousCallback(cb.conditions, make_affect(cb.affect; kwargs...),
make_affect(cb.affect_neg; kwargs...), make_affect(cb.initialize; kwargs...),
make_affect(cb.finalize; kwargs...), cb.rootfind, cb.reinitializealg)
end

make_affect(affect::SymbolicAffect; kwargs...) = AffectSystem(affect; kwargs...)
make_affect(affect::Nothing; kwargs...) = nothing
make_affect(affect::Tuple; kwargs...) = ImperativeAffect(affect...)
make_affect(affect::NamedTuple; kwargs...) = ImperativeAffect(; affect...)
make_affect(affect::Affect; kwargs...) = affect

function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],
iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...)
isempty(affect) && return nothing
if isnothing(iv)
iv = t_nounits
@warn "No independent variable specified. Defaulting to t_nounits."
end

discrete_parameters isa AbstractVector || (discrete_parameters = [discrete_parameters])
discrete_parameters = unwrap.(discrete_parameters)

for p in discrete_parameters
occursin(unwrap(iv), unwrap(p)) ||
error("Non-time dependent parameter $p passed in as a discrete. Must be declared as @parameters $p(t).")
end

dvs = OrderedSet()
params = OrderedSet()
_varsbuf = Set()
for eq in affect
if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic() ||
symbolic_type(eq.lhs) === NotSymbolic())
@warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1."
end
collect_vars!(dvs, params, eq, iv; op = Pre)
empty!(_varsbuf)
vars!(_varsbuf, eq; op = Pre)
filter!(x -> iscall(x) && operation(x) isa Pre, _varsbuf)
union!(params, _varsbuf)
diffvs = collect_applied_operators(eq, Differential)
union!(dvs, diffvs)
end
for eq in alg_eqs
collect_vars!(dvs, params, eq, iv)
end
pre_params = filter(haspre ∘ value, params)
sys_params = collect(setdiff(params, union(discrete_parameters, pre_params)))
discretes = map(tovar, discrete_parameters)
dvs = collect(dvs)
_dvs = map(default_toterm, dvs)

rev_map = Dict(zip(discrete_parameters, discretes))
subs = merge(rev_map, Dict(zip(dvs, _dvs)))
affect = Symbolics.fast_substitute(affect, subs)
alg_eqs = Symbolics.fast_substitute(alg_eqs, subs)

@named affectsys = System(
vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)),
collect(union(pre_params, sys_params)); is_discrete = true)
affectsys = mtkcompile(affectsys; fully_determined = nothing)
# get accessed parameters p from Pre(p) in the callback parameters
accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params))))
union!(accessed_params, sys_params)

# add scalarized unknowns to the map.
_dvs = reduce(vcat, map(scalarize, _dvs), init = Any[])

AffectSystem(affectsys, collect(_dvs), collect(accessed_params),
collect(discrete_parameters))
end
make_affect(affect::Vector{Equation}; kwargs...) = AffectSystem(affect; kwargs...)

function make_affect(affect; kwargs...)
error("Malformed affect $(affect). This should be a vector of equations or a tuple specifying a functional affect.")
Expand Down Expand Up @@ -374,30 +408,30 @@ Arguments:
"""
struct SymbolicDiscreteCallback <: AbstractCallback
conditions::Union{Number, Vector{<:Number}, Symbolic{Bool}}
affect::Union{Affect, Nothing}
initialize::Union{Affect, Nothing}
finalize::Union{Affect, Nothing}
affect::Union{Affect, SymbolicAffect, Nothing}
initialize::Union{Affect, SymbolicAffect, Nothing}
finalize::Union{Affect, SymbolicAffect, Nothing}
reinitializealg::SciMLBase.DAEInitializationAlgorithm
end

function SymbolicDiscreteCallback(
condition::Union{Symbolic{Bool}, Number, Vector{<:Number}}, affect = nothing;
initialize = nothing, finalize = nothing,
reinitializealg = nothing, kwargs...)
c = is_timed_condition(condition) ? condition : value(scalarize(condition))

if isnothing(reinitializealg)
if any(a -> a isa ImperativeAffect,
[affect, initialize, finalize])
reinitializealg = SciMLBase.CheckInit()
else
reinitializealg = SciMLBase.NoInit()
end
function SymbolicDiscreteCallback(
condition::Union{Symbolic{Bool}, Number, Vector{<:Number}}, affect = nothing;
initialize = nothing, finalize = nothing,
reinitializealg = nothing, kwargs...)
c = is_timed_condition(condition) ? condition : value(scalarize(condition))

if isnothing(reinitializealg)
if any(a -> a isa ImperativeAffect,
[affect, initialize, finalize])
reinitializealg = SciMLBase.CheckInit()
else
reinitializealg = SciMLBase.NoInit()
end
new(c, make_affect(affect; kwargs...),
make_affect(initialize; kwargs...),
make_affect(finalize; kwargs...), reinitializealg)
end # Default affect to nothing
end
end
SymbolicDiscreteCallback(c, SymbolicAffect(affect; kwargs...),
SymbolicAffect(initialize; kwargs...),
SymbolicAffect(finalize; kwargs...), reinitializealg)
end # Default affect to nothing

function SymbolicDiscreteCallback(p::Pair, args...; kwargs...)
SymbolicDiscreteCallback(p[1], p[2], args...; kwargs...)
Expand All @@ -412,6 +446,12 @@ function SymbolicDiscreteCallback(cb::Tuple, args...; kwargs...)
end
end

function complete(cb::SymbolicDiscreteCallback; kwargs...)
SymbolicDiscreteCallback(cb.conditions, make_affect(cb.affect; kwargs...),
make_affect(cb.initialize; kwargs...),
make_affect(cb.finalize; kwargs...), cb.reinitializealg)
end

function is_timed_condition(condition::T) where {T}
if T === Num
false
Expand Down Expand Up @@ -457,6 +497,12 @@ function namespace_affects(affect::AffectSystem, s)
renamespace.((s,), parameters(affect)),
renamespace.((s,), discretes(affect)))
end
function namespace_affects(affect::SymbolicAffect, s)
SymbolicAffect(
namespace_equation.(affect.affect, (s,)), namespace_equation.(affect.alg_eqs, (s,)),
renamespace.((s,), affect.discrete_parameters))
end

namespace_affects(af::Nothing, s) = nothing

function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
Expand Down Expand Up @@ -1060,12 +1106,8 @@ end
"""
Process the symbolic events of a system.
"""
function create_symbolic_events(cont_events, disc_events, sys_eqs, iv)
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !is_diff_equation(eq),
sys_eqs)
cont_callbacks = to_cb_vector(cont_events; CB_TYPE = SymbolicContinuousCallback,
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
disc_callbacks = to_cb_vector(disc_events; CB_TYPE = SymbolicDiscreteCallback,
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
function create_symbolic_events(cont_events, disc_events)
cont_callbacks = to_cb_vector(cont_events; CB_TYPE = SymbolicContinuousCallback)
disc_callbacks = to_cb_vector(disc_events; CB_TYPE = SymbolicDiscreteCallback)
cont_callbacks, disc_callbacks
end
2 changes: 1 addition & 1 deletion src/systems/system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
end
continuous_events,
discrete_events = create_symbolic_events(
continuous_events, discrete_events, eqs, iv)
continuous_events, discrete_events)

if iv === nothing && (!isempty(continuous_events) || !isempty(discrete_events))
throw(EventsInTimeIndependentSystemError(continuous_events, discrete_events))
Expand Down
2 changes: 1 addition & 1 deletion test/implicit_discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,6 @@ end
y(k) ~ x(k - 1) + x(k - 2),
z(k) * x(k) ~ 3]
@mtkcompile sys = System(eqs, t)
@test occursin("var\"Shift(t, 1)(x(t))\"",
@test occursin("var\"Shift(t, 1)(z(t))\"",
string(ImplicitDiscreteFunction(sys; expression = Val{true})))
end
Loading
Loading