diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 2f42dbb761..7a9bc2617e 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -711,7 +711,8 @@ end $(TYPEDEF) A callable struct which applies `p_constructor` to possibly nested arrays. It also -ensures that views (including nested ones) are concretized. +ensures that views (including nested ones) are concretized. This is implemented manually +of using `narrow_buffer_type` to preserve type-stability. """ struct PConstructorApplicator{F} p_constructor::F @@ -721,10 +722,18 @@ function (pca::PConstructorApplicator)(x::AbstractArray) pca.p_constructor(x) end +function (pca::PConstructorApplicator)(x::AbstractArray{Bool}) + pca.p_constructor(BitArray(x)) +end + function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray) collect(x) end +function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{Bool}) + BitArray(x) +end + function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{<:AbstractArray}) collect(pca.(x)) end @@ -749,6 +758,7 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns """ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem; initials = false, unwrap_initials = false, p_constructor = identity) + _p_constructor = p_constructor p_constructor = PConstructorApplicator(p_constructor) # if we call `getu` on this (and it were able to handle empty tuples) we get the # fields of `MTKParameters` except caches. @@ -802,14 +812,24 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, syms[3]) end - rest_getters = map(Base.tail(Base.tail(Base.tail(syms)))) do buf - if buf == () - return Returns(()) - else - return Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, buf) - end + const_getter = if syms[4] == () + Returns(()) + else + Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, syms[4]) end - getters = (tunable_getter, initials_getter, discs_getter, rest_getters...) + nonnumeric_getter = if syms[5] == () + Returns(()) + else + ic = get_index_cache(dstsys) + buftypes = Tuple(map(ic.nonnumeric_buffer_sizes) do bufsize + Vector{bufsize.type} + end) + # nonnumerics retain the assigned buffer type without narrowing + Base.Fix1(broadcast, _p_constructor) ∘ + Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ getu(srcsys, syms[5]) + end + getters = ( + tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter) getter = let getters = getters function _getter(valp, initprob) oldcache = parameter_values(initprob).caches @@ -822,6 +842,10 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac return getter end +function call(f, args...) + f(args...) +end + """ $(TYPEDSIGNATURES) diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 4cd1775467..19209b5e46 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1682,3 +1682,23 @@ end prob = ODEProblem(sys, [x[1] => nothing, x[2] => 1], (0.0, 1.0)) @test SciMLBase.initialization_status(prob) == SciMLBase.FULLY_DETERMINED end + +@testset "Nonnumerics aren't narrowed" begin + @mtkmodel Foo begin + @variables begin + x(t) = 1.0 + end + @parameters begin + p::AbstractString + r = 1.0 + end + @equations begin + D(x) ~ r * x + end + end + @mtkbuild sys = Foo(p = "a") + prob = ODEProblem(sys, [], (0.0, 1.0)) + @test prob.p.nonnumeric[1] isa Vector{AbstractString} + integ = init(prob) + @test integ.p.nonnumeric[1] isa Vector{AbstractString} +end