Skip to content

Commit 7ff680d

Browse files
Merge pull request #3812 from AayushSabharwal/as/fix-callbacks
fix: compile symbolic affects after `mtkcompile` in `complete`
2 parents e1509c8 + 1729986 commit 7ff680d

File tree

6 files changed

+209
-124
lines changed

6 files changed

+209
-124
lines changed

src/structural_transformation/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,8 @@ function find_eq_solvables!(state::TearingState, ieq, to_rm = Int[], coeffs = no
259259
symbolic_type(v) == ArraySymbolic() &&
260260
Symbolics.shape(v) != Symbolics.Unknown() &&
261261
any(x -> any(isequal(x), fullvars), collect(v)),
262-
vars(a))
262+
vars(
263+
a; op = Union{Differential, Shift, Pre, Sample, Hold, Initial}))
263264
continue
264265
end
265266
else

src/systems/abstractsystem.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,16 @@ function complete(
646646
if add_initial_parameters
647647
sys = add_initialization_parameters(sys; split)
648648
end
649+
if has_continuous_events(sys) && is_time_dependent(sys)
650+
@set! sys.continuous_events = complete.(
651+
get_continuous_events(sys); iv = get_iv(sys),
652+
alg_eqs = [alg_equations(sys); observed(sys)])
653+
end
654+
if has_discrete_events(sys) && is_time_dependent(sys)
655+
@set! sys.discrete_events = complete.(
656+
get_discrete_events(sys); iv = get_iv(sys),
657+
alg_eqs = [alg_equations(sys); observed(sys)])
658+
end
649659
end
650660
if split && has_index_cache(sys)
651661
@set! sys.index_cache = IndexCache(sys)

src/systems/callbacks.jl

Lines changed: 160 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,27 @@ function has_functional_affect(cb)
44
affects(cb) isa ImperativeAffect
55
end
66

7+
struct SymbolicAffect
8+
affect::Vector{Equation}
9+
alg_eqs::Vector{Equation}
10+
discrete_parameters::Vector{Any}
11+
end
12+
13+
function SymbolicAffect(affect::Vector{Equation}; alg_eqs = Equation[],
14+
discrete_parameters = Any[], kwargs...)
15+
if !(discrete_parameters isa AbstractVector)
16+
discrete_parameters = Any[discrete_parameters]
17+
elseif !(discrete_parameters isa Vector{Any})
18+
discrete_parameters = Vector{Any}(discrete_parameters)
19+
end
20+
SymbolicAffect(affect, alg_eqs, discrete_parameters)
21+
end
22+
function SymbolicAffect(affect::SymbolicAffect; kwargs...)
23+
SymbolicAffect(affect.affect; alg_eqs = affect.alg_eqs,
24+
discrete_parameters = affect.discrete_parameters, kwargs...)
25+
end
26+
SymbolicAffect(affect; kwargs...) = make_affect(affect; kwargs...)
27+
728
struct AffectSystem
829
"""The internal implicit discrete system whose equations are solved to obtain values after the affect."""
930
system::AbstractSystem
@@ -15,6 +36,72 @@ struct AffectSystem
1536
discretes::Vector
1637
end
1738

39+
function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[], kwargs...)
40+
AffectSystem(spec.affect; alg_eqs = vcat(spec.alg_eqs, alg_eqs), iv,
41+
discrete_parameters = spec.discrete_parameters, kwargs...)
42+
end
43+
44+
function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[],
45+
iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...)
46+
isempty(affect) && return nothing
47+
if isnothing(iv)
48+
iv = t_nounits
49+
@warn "No independent variable specified. Defaulting to t_nounits."
50+
end
51+
52+
discrete_parameters isa AbstractVector || (discrete_parameters = [discrete_parameters])
53+
discrete_parameters = unwrap.(discrete_parameters)
54+
55+
for p in discrete_parameters
56+
occursin(unwrap(iv), unwrap(p)) ||
57+
error("Non-time dependent parameter $p passed in as a discrete. Must be declared as @parameters $p(t).")
58+
end
59+
60+
dvs = OrderedSet()
61+
params = OrderedSet()
62+
_varsbuf = Set()
63+
for eq in affect
64+
if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic() ||
65+
symbolic_type(eq.lhs) === NotSymbolic())
66+
@warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1."
67+
end
68+
collect_vars!(dvs, params, eq, iv; op = Pre)
69+
empty!(_varsbuf)
70+
vars!(_varsbuf, eq; op = Pre)
71+
filter!(x -> iscall(x) && operation(x) isa Pre, _varsbuf)
72+
union!(params, _varsbuf)
73+
diffvs = collect_applied_operators(eq, Differential)
74+
union!(dvs, diffvs)
75+
end
76+
for eq in alg_eqs
77+
collect_vars!(dvs, params, eq, iv)
78+
end
79+
pre_params = filter(haspre value, params)
80+
sys_params = collect(setdiff(params, union(discrete_parameters, pre_params)))
81+
discretes = map(tovar, discrete_parameters)
82+
dvs = collect(dvs)
83+
_dvs = map(default_toterm, dvs)
84+
85+
rev_map = Dict(zip(discrete_parameters, discretes))
86+
subs = merge(rev_map, Dict(zip(dvs, _dvs)))
87+
affect = Symbolics.fast_substitute(affect, subs)
88+
alg_eqs = Symbolics.fast_substitute(alg_eqs, subs)
89+
90+
@named affectsys = System(
91+
vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)),
92+
collect(union(pre_params, sys_params)); is_discrete = true)
93+
affectsys = mtkcompile(affectsys; fully_determined = nothing)
94+
# get accessed parameters p from Pre(p) in the callback parameters
95+
accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params))))
96+
union!(accessed_params, sys_params)
97+
98+
# add scalarized unknowns to the map.
99+
_dvs = reduce(vcat, map(scalarize, _dvs), init = Any[])
100+
101+
AffectSystem(affectsys, collect(_dvs), collect(accessed_params),
102+
collect(discrete_parameters))
103+
end
104+
18105
system(a::AffectSystem) = a.system
19106
discretes(a::AffectSystem) = a.discretes
20107
unknowns(a::AffectSystem) = a.unknowns
@@ -159,40 +246,40 @@ will run as soon as the solver starts, while finalization affects will be execut
159246
"""
160247
struct SymbolicContinuousCallback <: AbstractCallback
161248
conditions::Vector{Equation}
162-
affect::Union{Affect, Nothing}
163-
affect_neg::Union{Affect, Nothing}
164-
initialize::Union{Affect, Nothing}
165-
finalize::Union{Affect, Nothing}
249+
affect::Union{Affect, SymbolicAffect, Nothing}
250+
affect_neg::Union{Affect, SymbolicAffect, Nothing}
251+
initialize::Union{Affect, SymbolicAffect, Nothing}
252+
finalize::Union{Affect, SymbolicAffect, Nothing}
166253
rootfind::Union{Nothing, SciMLBase.RootfindOpt}
167254
reinitializealg::SciMLBase.DAEInitializationAlgorithm
255+
end
168256

169-
function SymbolicContinuousCallback(
170-
conditions::Union{Equation, Vector{Equation}},
171-
affect = nothing;
172-
affect_neg = affect,
173-
initialize = nothing,
174-
finalize = nothing,
175-
rootfind = SciMLBase.LeftRootFind,
176-
reinitializealg = nothing,
177-
kwargs...)
178-
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
179-
180-
if isnothing(reinitializealg)
181-
if any(a -> a isa ImperativeAffect,
182-
[affect, affect_neg, initialize, finalize])
183-
reinitializealg = SciMLBase.CheckInit()
184-
else
185-
reinitializealg = SciMLBase.NoInit()
186-
end
257+
function SymbolicContinuousCallback(
258+
conditions::Union{Equation, Vector{Equation}},
259+
affect = nothing;
260+
affect_neg = affect,
261+
initialize = nothing,
262+
finalize = nothing,
263+
rootfind = SciMLBase.LeftRootFind,
264+
reinitializealg = nothing,
265+
kwargs...)
266+
conditions = (conditions isa AbstractVector) ? conditions : [conditions]
267+
268+
if isnothing(reinitializealg)
269+
if any(a -> a isa ImperativeAffect,
270+
[affect, affect_neg, initialize, finalize])
271+
reinitializealg = SciMLBase.CheckInit()
272+
else
273+
reinitializealg = SciMLBase.NoInit()
187274
end
275+
end
188276

189-
new(conditions, make_affect(affect; kwargs...),
190-
make_affect(affect_neg; kwargs...),
191-
make_affect(initialize; kwargs...), make_affect(
192-
finalize; kwargs...),
193-
rootfind, reinitializealg)
194-
end # Default affect to nothing
195-
end
277+
SymbolicContinuousCallback(conditions, SymbolicAffect(affect; kwargs...),
278+
SymbolicAffect(affect_neg; kwargs...),
279+
SymbolicAffect(initialize; kwargs...), SymbolicAffect(
280+
finalize; kwargs...),
281+
rootfind, reinitializealg)
282+
end # Default affect to nothing
196283

197284
function SymbolicContinuousCallback(p::Pair, args...; kwargs...)
198285
SymbolicContinuousCallback(p[1], p[2], args...; kwargs...)
@@ -207,71 +294,18 @@ function SymbolicContinuousCallback(cb::Tuple, args...; kwargs...)
207294
end
208295
end
209296

297+
function complete(cb::SymbolicContinuousCallback; kwargs...)
298+
SymbolicContinuousCallback(cb.conditions, make_affect(cb.affect; kwargs...),
299+
make_affect(cb.affect_neg; kwargs...), make_affect(cb.initialize; kwargs...),
300+
make_affect(cb.finalize; kwargs...), cb.rootfind, cb.reinitializealg)
301+
end
302+
303+
make_affect(affect::SymbolicAffect; kwargs...) = AffectSystem(affect; kwargs...)
210304
make_affect(affect::Nothing; kwargs...) = nothing
211305
make_affect(affect::Tuple; kwargs...) = ImperativeAffect(affect...)
212306
make_affect(affect::NamedTuple; kwargs...) = ImperativeAffect(; affect...)
213307
make_affect(affect::Affect; kwargs...) = affect
214-
215-
function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],
216-
iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...)
217-
isempty(affect) && return nothing
218-
if isnothing(iv)
219-
iv = t_nounits
220-
@warn "No independent variable specified. Defaulting to t_nounits."
221-
end
222-
223-
discrete_parameters isa AbstractVector || (discrete_parameters = [discrete_parameters])
224-
discrete_parameters = unwrap.(discrete_parameters)
225-
226-
for p in discrete_parameters
227-
occursin(unwrap(iv), unwrap(p)) ||
228-
error("Non-time dependent parameter $p passed in as a discrete. Must be declared as @parameters $p(t).")
229-
end
230-
231-
dvs = OrderedSet()
232-
params = OrderedSet()
233-
_varsbuf = Set()
234-
for eq in affect
235-
if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic() ||
236-
symbolic_type(eq.lhs) === NotSymbolic())
237-
@warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1."
238-
end
239-
collect_vars!(dvs, params, eq, iv; op = Pre)
240-
empty!(_varsbuf)
241-
vars!(_varsbuf, eq; op = Pre)
242-
filter!(x -> iscall(x) && operation(x) isa Pre, _varsbuf)
243-
union!(params, _varsbuf)
244-
diffvs = collect_applied_operators(eq, Differential)
245-
union!(dvs, diffvs)
246-
end
247-
for eq in alg_eqs
248-
collect_vars!(dvs, params, eq, iv)
249-
end
250-
pre_params = filter(haspre value, params)
251-
sys_params = collect(setdiff(params, union(discrete_parameters, pre_params)))
252-
discretes = map(tovar, discrete_parameters)
253-
dvs = collect(dvs)
254-
_dvs = map(default_toterm, dvs)
255-
256-
rev_map = Dict(zip(discrete_parameters, discretes))
257-
subs = merge(rev_map, Dict(zip(dvs, _dvs)))
258-
affect = Symbolics.fast_substitute(affect, subs)
259-
alg_eqs = Symbolics.fast_substitute(alg_eqs, subs)
260-
261-
@named affectsys = System(
262-
vcat(affect, alg_eqs), iv, collect(union(_dvs, discretes)),
263-
collect(union(pre_params, sys_params)); is_discrete = true)
264-
affectsys = mtkcompile(affectsys; fully_determined = nothing)
265-
# get accessed parameters p from Pre(p) in the callback parameters
266-
accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params))))
267-
union!(accessed_params, sys_params)
268-
269-
# add scalarized unknowns to the map.
270-
_dvs = reduce(vcat, map(scalarize, _dvs), init = Any[])
271-
272-
AffectSystem(affectsys, collect(_dvs), collect(accessed_params),
273-
collect(discrete_parameters))
274-
end
308+
make_affect(affect::Vector{Equation}; kwargs...) = AffectSystem(affect; kwargs...)
275309

276310
function make_affect(affect; kwargs...)
277311
error("Malformed affect $(affect). This should be a vector of equations or a tuple specifying a functional affect.")
@@ -374,30 +408,30 @@ Arguments:
374408
"""
375409
struct SymbolicDiscreteCallback <: AbstractCallback
376410
conditions::Union{Number, Vector{<:Number}, Symbolic{Bool}}
377-
affect::Union{Affect, Nothing}
378-
initialize::Union{Affect, Nothing}
379-
finalize::Union{Affect, Nothing}
411+
affect::Union{Affect, SymbolicAffect, Nothing}
412+
initialize::Union{Affect, SymbolicAffect, Nothing}
413+
finalize::Union{Affect, SymbolicAffect, Nothing}
380414
reinitializealg::SciMLBase.DAEInitializationAlgorithm
415+
end
381416

382-
function SymbolicDiscreteCallback(
383-
condition::Union{Symbolic{Bool}, Number, Vector{<:Number}}, affect = nothing;
384-
initialize = nothing, finalize = nothing,
385-
reinitializealg = nothing, kwargs...)
386-
c = is_timed_condition(condition) ? condition : value(scalarize(condition))
387-
388-
if isnothing(reinitializealg)
389-
if any(a -> a isa ImperativeAffect,
390-
[affect, initialize, finalize])
391-
reinitializealg = SciMLBase.CheckInit()
392-
else
393-
reinitializealg = SciMLBase.NoInit()
394-
end
417+
function SymbolicDiscreteCallback(
418+
condition::Union{Symbolic{Bool}, Number, Vector{<:Number}}, affect = nothing;
419+
initialize = nothing, finalize = nothing,
420+
reinitializealg = nothing, kwargs...)
421+
c = is_timed_condition(condition) ? condition : value(scalarize(condition))
422+
423+
if isnothing(reinitializealg)
424+
if any(a -> a isa ImperativeAffect,
425+
[affect, initialize, finalize])
426+
reinitializealg = SciMLBase.CheckInit()
427+
else
428+
reinitializealg = SciMLBase.NoInit()
395429
end
396-
new(c, make_affect(affect; kwargs...),
397-
make_affect(initialize; kwargs...),
398-
make_affect(finalize; kwargs...), reinitializealg)
399-
end # Default affect to nothing
400-
end
430+
end
431+
SymbolicDiscreteCallback(c, SymbolicAffect(affect; kwargs...),
432+
SymbolicAffect(initialize; kwargs...),
433+
SymbolicAffect(finalize; kwargs...), reinitializealg)
434+
end # Default affect to nothing
401435

402436
function SymbolicDiscreteCallback(p::Pair, args...; kwargs...)
403437
SymbolicDiscreteCallback(p[1], p[2], args...; kwargs...)
@@ -412,6 +446,12 @@ function SymbolicDiscreteCallback(cb::Tuple, args...; kwargs...)
412446
end
413447
end
414448

449+
function complete(cb::SymbolicDiscreteCallback; kwargs...)
450+
SymbolicDiscreteCallback(cb.conditions, make_affect(cb.affect; kwargs...),
451+
make_affect(cb.initialize; kwargs...),
452+
make_affect(cb.finalize; kwargs...), cb.reinitializealg)
453+
end
454+
415455
function is_timed_condition(condition::T) where {T}
416456
if T === Num
417457
false
@@ -457,6 +497,12 @@ function namespace_affects(affect::AffectSystem, s)
457497
renamespace.((s,), parameters(affect)),
458498
renamespace.((s,), discretes(affect)))
459499
end
500+
function namespace_affects(affect::SymbolicAffect, s)
501+
SymbolicAffect(
502+
namespace_equation.(affect.affect, (s,)), namespace_equation.(affect.alg_eqs, (s,)),
503+
renamespace.((s,), affect.discrete_parameters))
504+
end
505+
460506
namespace_affects(af::Nothing, s) = nothing
461507

462508
function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
@@ -1060,12 +1106,8 @@ end
10601106
"""
10611107
Process the symbolic events of a system.
10621108
"""
1063-
function create_symbolic_events(cont_events, disc_events, sys_eqs, iv)
1064-
alg_eqs = filter(eq -> eq.lhs isa Union{Symbolic, Number} && !is_diff_equation(eq),
1065-
sys_eqs)
1066-
cont_callbacks = to_cb_vector(cont_events; CB_TYPE = SymbolicContinuousCallback,
1067-
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
1068-
disc_callbacks = to_cb_vector(disc_events; CB_TYPE = SymbolicDiscreteCallback,
1069-
iv = iv, alg_eqs = alg_eqs, warn_no_algebraic = false)
1109+
function create_symbolic_events(cont_events, disc_events)
1110+
cont_callbacks = to_cb_vector(cont_events; CB_TYPE = SymbolicContinuousCallback)
1111+
disc_callbacks = to_cb_vector(disc_events; CB_TYPE = SymbolicDiscreteCallback)
10701112
cont_callbacks, disc_callbacks
10711113
end

src/systems/system.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
389389
end
390390
continuous_events,
391391
discrete_events = create_symbolic_events(
392-
continuous_events, discrete_events, eqs, iv)
392+
continuous_events, discrete_events)
393393

394394
if iv === nothing && (!isempty(continuous_events) || !isempty(discrete_events))
395395
throw(EventsInTimeIndependentSystemError(continuous_events, discrete_events))

test/implicit_discrete_system.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,6 @@ end
7575
y(k) ~ x(k - 1) + x(k - 2),
7676
z(k) * x(k) ~ 3]
7777
@mtkcompile sys = System(eqs, t)
78-
@test occursin("var\"Shift(t, 1)(x(t))\"",
78+
@test occursin("var\"Shift(t, 1)(z(t))\"",
7979
string(ImplicitDiscreteFunction(sys; expression = Val{true})))
8080
end

0 commit comments

Comments
 (0)