Skip to content

Commit e1509c8

Browse files
Merge pull request #3796 from AayushSabharwal/as/v10-nonnumeric-init
fix: fix `get_mtkparameters_reconstructor` handling of nonnumerics
2 parents 819556e + 6683cf5 commit e1509c8

File tree

2 files changed

+52
-8
lines changed

2 files changed

+52
-8
lines changed

src/systems/problem_utils.jl

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,8 @@ end
711711
$(TYPEDEF)
712712
713713
A callable struct which applies `p_constructor` to possibly nested arrays. It also
714-
ensures that views (including nested ones) are concretized.
714+
ensures that views (including nested ones) are concretized. This is implemented manually
715+
of using `narrow_buffer_type` to preserve type-stability.
715716
"""
716717
struct PConstructorApplicator{F}
717718
p_constructor::F
@@ -721,10 +722,18 @@ function (pca::PConstructorApplicator)(x::AbstractArray)
721722
pca.p_constructor(x)
722723
end
723724

725+
function (pca::PConstructorApplicator)(x::AbstractArray{Bool})
726+
pca.p_constructor(BitArray(x))
727+
end
728+
724729
function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray)
725730
collect(x)
726731
end
727732

733+
function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{Bool})
734+
BitArray(x)
735+
end
736+
728737
function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{<:AbstractArray})
729738
collect(pca.(x))
730739
end
@@ -749,6 +758,7 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
749758
"""
750759
function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem;
751760
initials = false, unwrap_initials = false, p_constructor = identity)
761+
_p_constructor = p_constructor
752762
p_constructor = PConstructorApplicator(p_constructor)
753763
# if we call `getu` on this (and it were able to handle empty tuples) we get the
754764
# fields of `MTKParameters` except caches.
@@ -802,14 +812,24 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
802812
Base.Fix1(broadcast, p_constructor)
803813
getu(srcsys, syms[3])
804814
end
805-
rest_getters = map(Base.tail(Base.tail(Base.tail(syms)))) do buf
806-
if buf == ()
807-
return Returns(())
808-
else
809-
return Base.Fix1(broadcast, p_constructor) getu(srcsys, buf)
810-
end
815+
const_getter = if syms[4] == ()
816+
Returns(())
817+
else
818+
Base.Fix1(broadcast, p_constructor) getu(srcsys, syms[4])
811819
end
812-
getters = (tunable_getter, initials_getter, discs_getter, rest_getters...)
820+
nonnumeric_getter = if syms[5] == ()
821+
Returns(())
822+
else
823+
ic = get_index_cache(dstsys)
824+
buftypes = Tuple(map(ic.nonnumeric_buffer_sizes) do bufsize
825+
Vector{bufsize.type}
826+
end)
827+
# nonnumerics retain the assigned buffer type without narrowing
828+
Base.Fix1(broadcast, _p_constructor)
829+
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) getu(srcsys, syms[5])
830+
end
831+
getters = (
832+
tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter)
813833
getter = let getters = getters
814834
function _getter(valp, initprob)
815835
oldcache = parameter_values(initprob).caches
@@ -822,6 +842,10 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
822842
return getter
823843
end
824844

845+
function call(f, args...)
846+
f(args...)
847+
end
848+
825849
"""
826850
$(TYPEDSIGNATURES)
827851

test/initializationsystem.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,3 +1682,23 @@ end
16821682
prob = ODEProblem(sys, [x[1] => nothing, x[2] => 1], (0.0, 1.0))
16831683
@test SciMLBase.initialization_status(prob) == SciMLBase.FULLY_DETERMINED
16841684
end
1685+
1686+
@testset "Nonnumerics aren't narrowed" begin
1687+
@mtkmodel Foo begin
1688+
@variables begin
1689+
x(t) = 1.0
1690+
end
1691+
@parameters begin
1692+
p::AbstractString
1693+
r = 1.0
1694+
end
1695+
@equations begin
1696+
D(x) ~ r * x
1697+
end
1698+
end
1699+
@mtkbuild sys = Foo(p = "a")
1700+
prob = ODEProblem(sys, [], (0.0, 1.0))
1701+
@test prob.p.nonnumeric[1] isa Vector{AbstractString}
1702+
integ = init(prob)
1703+
@test integ.p.nonnumeric[1] isa Vector{AbstractString}
1704+
end

0 commit comments

Comments
 (0)