Skip to content

Commit d3b536b

Browse files
Merge pull request #3615 from AayushSabharwal/as/diagonal-mm
fix: use `Diagonal` for diagonal mass matrices
2 parents 4b52b02 + 32a7781 commit d3b536b

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ function calculate_massmatrix(sys::AbstractODESystem; simplify = false)
284284
end
285285
M = simplify ? ModelingToolkit.simplify.(M) : M
286286
# M should only contain concrete numbers
287+
if isdiag(M)
288+
M = Diagonal(M)
289+
end
287290
M == I ? I : M
288291
end
289292

@@ -410,6 +413,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
410413
SparseArrays.sparse(M)
411414
elseif u0 === nothing || M === I
412415
M
416+
elseif M isa Diagonal
417+
Diagonal(ArrayInterface.restructure(u0, diag(M)))
413418
else
414419
ArrayInterface.restructure(u0 .* u0', M)
415420
end

src/systems/diffeqs/sdesystem.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,15 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
652652
W_prototype = nothing
653653
end
654654

655-
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)
655+
_M = if sparse && !(u0 === nothing || M === I)
656+
SparseArrays.sparse(M)
657+
elseif u0 === nothing || M === I
658+
M
659+
elseif M isa Diagonal
660+
Diagonal(ArrayInterface.restructure(u0, diag(M)))
661+
else
662+
ArrayInterface.restructure(u0 .* u0', M)
663+
end
656664

657665
observedfun = ObservedFunctionCache(
658666
sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse)

test/mass_matrix.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using OrdinaryDiffEq, ModelingToolkit, Test, LinearAlgebra
1+
using OrdinaryDiffEq, ModelingToolkit, Test, LinearAlgebra, StaticArrays
22
using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
33

44
@variables y(t)[1:3]
@@ -12,13 +12,18 @@ eqs = [D(y[1]) ~ -k[1] * y[1] + k[3] * y[2] * y[3],
1212
sys = complete(sys)
1313
@test_throws ArgumentError ODESystem(eqs, y[1])
1414
M = calculate_massmatrix(sys)
15+
@test M isa Diagonal
1516
@test M == [1 0 0
1617
0 1 0
1718
0 0 0]
1819

1920
prob_mm = ODEProblem(sys, [y => [1.0, 0.0, 0.0]], (0.0, 1e5),
2021
[k => [0.04, 3e7, 1e4]])
22+
@test prob_mm.f.mass_matrix isa Diagonal{Float64, Vector{Float64}}
2123
sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8)
24+
prob_mm = ODEProblem(sys, SA[y => [1.0, 0.0, 0.0]], (0.0, 1e5),
25+
[k => [0.04, 3e7, 1e4]])
26+
@test prob_mm.f.mass_matrix isa Diagonal{Float64, SVector{3, Float64}}
2227

2328
function rober(du, u, p, t)
2429
y₁, y₂, y₃ = u
@@ -43,3 +48,17 @@ eqs = [D(y[1]) ~ y[1], D(y[2]) ~ y[2], D(y[3]) ~ y[3]]
4348
@named sys = ODESystem(eqs, t, collect(y), [k])
4449

4550
@test calculate_massmatrix(sys) === I
51+
52+
@testset "Mass matrix `isa Diagonal` for `SDEProblem`" begin
53+
eqs = [D(y[1]) ~ -k[1] * y[1] + k[3] * y[2] * y[3],
54+
D(y[2]) ~ k[1] * y[1] - k[3] * y[2] * y[3] - k[2] * y[2]^2,
55+
0 ~ y[1] + y[2] + y[3] - 1]
56+
57+
@named sys = ODESystem(eqs, t, collect(y), [k])
58+
@named sys = SDESystem(sys, [1, 1, 0])
59+
sys = complete(sys)
60+
prob = SDEProblem(sys, [y => [1.0, 0.0, 0.0]], (0.0, 1e5), [k => [0.04, 3e7, 1e4]])
61+
@test prob.f.mass_matrix isa Diagonal{Float64, Vector{Float64}}
62+
prob = SDEProblem(sys, SA[y => [1.0, 0.0, 0.0]], (0.0, 1e5), [k => [0.04, 3e7, 1e4]])
63+
@test prob.f.mass_matrix isa Diagonal{Float64, SVector{3, Float64}}
64+
end

0 commit comments

Comments
 (0)