Skip to content

Commit b162275

Browse files
feat: add respecialize
1 parent b1d4592 commit b162275

File tree

3 files changed

+103
-1
lines changed

3 files changed

+103
-1
lines changed

docs/src/API/model_building.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ add_accumulations
227227
noise_to_brownians
228228
convert_system_indepvar
229229
subset_tunables
230+
respecialize
230231
```
231232

232233
## Hybrid systems

src/ModelingToolkit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbanc
268268
hasmisc, getmisc, state_priority,
269269
subset_tunables
270270
export liouville_transform, change_independent_variable, substitute_component,
271-
add_accumulations, noise_to_brownians, Girsanov_transform, change_of_variables
271+
add_accumulations, noise_to_brownians, Girsanov_transform, change_of_variables,
272+
respecialize
272273
export PDESystem
273274
export Differential, expand_derivatives, @derivatives
274275
export Equation, ConstrainedEquation

src/systems/diffeqs/basic_transformations.jl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,3 +706,103 @@ function convert_system_indepvar(sys::System, t; name = nameof(sys))
706706
@set! sys.var_to_name = var_to_name
707707
return sys
708708
end
709+
710+
"""
711+
$(TYPEDSIGNATURES)
712+
713+
Shorthand for `respecialize(sys, []; all = true)`
714+
"""
715+
respecialize(sys::AbstractSystem) = respecialize(sys, []; all = true)
716+
717+
"""
718+
$(TYPEDSIGNATURES)
719+
720+
Specialize nonnumeric parameters in `sys` by changing their symtype to a concrete type.
721+
`mapping` is an iterable, where each element can be a parameter or a pair mapping a parameter
722+
to a value. If the element is a parameter, it must have a default. Each specified parameter
723+
is updated to have the symtype of the value associated with it (either in `mapping` or in
724+
the defaults). This operation can only be performed on nonnumeric, non-array parameters. The
725+
defaults of respecialized parameters are set to the associated values.
726+
727+
This operation can only be performed on `complete`d systems.
728+
729+
# Keyword arguments
730+
731+
- `all`: Specialize all nonnumeric parameters in the system. This will error if any such
732+
parameter does not have a default.
733+
"""
734+
function respecialize(sys::AbstractSystem, mapping; all = false)
735+
if !iscomplete(sys)
736+
error("""
737+
This operation can only be performed on completed systems. Use `complete(sys)` or
738+
`mtkcompile(sys)`.
739+
""")
740+
end
741+
if !is_split(sys)
742+
error("""
743+
This operation can only be performed on split systems. Use `complete(sys)` or
744+
`mtkcompile(sys)` with the `split = true` keyword argument.
745+
""")
746+
end
747+
748+
new_ps = copy(get_ps(sys))
749+
@set! sys.ps = new_ps
750+
751+
extras = []
752+
if all
753+
for x in filter(!is_variable_numeric, get_ps(sys))
754+
if any(y -> isequal(x, y) || y isa Pair && isequal(x, y[1]), mapping) || symbolic_type(x) === ArraySymbolic() || iscall(x) && operation(x) === getindex
755+
continue
756+
end
757+
push!(extras, x)
758+
end
759+
end
760+
761+
defs = copy(defaults(sys))
762+
@set! sys.defaults = defs
763+
764+
for element in Iterators.flatten((extras, mapping))
765+
if element isa Pair
766+
k, v = element
767+
else
768+
k = element
769+
v = get(defs, k, nothing)
770+
@assert v !== nothing """
771+
Parameter $k needs an associated value to be respecialized.
772+
"""
773+
end
774+
775+
k = unwrap(k)
776+
T = typeof(v)
777+
778+
@assert !is_variable_numeric(k) """
779+
Numeric types cannot be respecialized - tried to respecialize $k.
780+
"""
781+
@assert symbolic_type(k) !== ArraySymbolic() """
782+
Cannot respecialize array symbolics - tried to respecialize $k.
783+
"""
784+
@assert !iscall(k) || operation(k) !== getindex """
785+
Cannot respecialized scalarized array variables - tried to respecialize $k.
786+
"""
787+
idx = findfirst(isequal(k), get_ps(sys))
788+
@assert idx !== nothing """
789+
Parameter $k does not exist in the system.
790+
"""
791+
792+
793+
if iscall(k)
794+
op = operation(k)
795+
args = arguments(k)
796+
new_p = SymbolicUtils.term(op, args...; type = T)
797+
else
798+
new_p = SymbolicUtils.Sym{T}(getname(k))
799+
end
800+
801+
get_ps(sys)[idx] = new_p
802+
defaults(sys)[new_p] = v
803+
end
804+
805+
sys = complete(sys; split = is_split(sys))
806+
return sys
807+
end
808+

0 commit comments

Comments
 (0)