Skip to content

Commit 223b908

Browse files
committed
fix typos in test
1 parent 0e6121e commit 223b908

File tree

1 file changed

+24
-17
lines changed

1 file changed

+24
-17
lines changed

test/lotka_volterra.jl

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ using DifferentiationInterface
1515
using SciMLSensitivity
1616
using Zygote: Zygote
1717
using Statistics
18+
using Lux
1819

19-
function lotka_ude()
20+
function lotka_ude(chain)
2021
@variables t x(t)=3.1 y(t)=1.5
2122
@parameters α=1.3 [tunable = false] δ=1.8 [tunable = false]
2223
Dt = ModelingToolkit.D_nounits
2324

24-
chain = multi_layer_feed_forward(2, 2)
2525
@named nn = NeuralNetworkBlock(2, 2; chain, rng = StableRNG(42))
2626

2727
eqs = [
@@ -36,48 +36,55 @@ end
3636

3737
function lotka_true()
3838
@variables t x(t)=3.1 y(t)=1.5
39-
@parameters α=1.3 β=0.9 γ=0.8 δ=1.8
39+
@parameters α=1.3 [tunable = false] β=0.9 γ=0.8 δ=1.8 [tunable = false]
4040
Dt = ModelingToolkit.D_nounits
4141

4242
eqs = [
4343
Dt(x) ~ α * x - β * x * y,
44-
Dt(y) ~ -δ * y + δ * x * y
44+
Dt(y) ~ -δ * y + γ * x * y
4545
]
4646
return System(eqs, ModelingToolkit.t_nounits, name = :lotka_true)
4747
end
4848

49-
ude_sys = lotka_ude()
49+
rbf(x) = exp.(-(x .^ 2))
50+
51+
chain = Lux.Chain(
52+
Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
53+
Lux.Dense(5, 2))
54+
ude_sys = lotka_ude(chain)
5055

5156
sys = mtkcompile(ude_sys, allow_symbolic = true)
5257

53-
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0))
58+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 5.0))
5459

5560
model_true = mtkcompile(lotka_true())
56-
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 1.0))
57-
sol_ref = solve(prob_true, Vern9(), abstol = 1e-10, reltol = 1e-8)
61+
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 5.0))
62+
sol_ref = solve(prob_true, Vern9(), abstol = 1e-12, reltol = 1e-12)
63+
64+
ts = range(0, 5.0, length = 21)
65+
data = reduce(hcat, sol_ref(ts, idxs = [model_true.x, model_true.y]).u)
5866

5967
x0 = default_values(sys)[sys.nn.p]
6068

6169
get_vars = getu(sys, [sys.x, sys.y])
62-
get_refs = getu(model_true, [model_true.x, model_true.y])
63-
set_x = setp_oop(sys, sys.nn.p)
70+
set_x = setsym_oop(sys, sys.nn.p)
6471

65-
function loss(x, (prob, sol_ref, get_vars, get_refs, set_x))
66-
new_p = set_x(prob, x)
67-
new_prob = remake(prob, p = new_p, u0 = eltype(x).(prob.u0))
68-
ts = sol_ref.t
72+
function loss(x, (prob, sol_ref, get_vars, data, ts, set_x))
73+
# new_u0, new_p = set_x(prob, 1, x)
74+
new_u0, new_p = set_x(prob, x)
75+
new_prob = remake(prob, p = new_p, u0 = new_u0)
6976
new_sol = solve(new_prob, Vern9(), abstol = 1e-10, reltol = 1e-8, saveat = ts)
7077

7178
if SciMLBase.successful_retcode(new_sol)
72-
mean(abs2.(reduce(hcat, get_vars(new_sol)) .- reduce(hcat, get_refs(sol_ref))))
79+
mean(abs2.(reduce(hcat, get_vars(new_sol)) .- data))
7380
else
7481
Inf
7582
end
7683
end
7784

7885
of = OptimizationFunction{true}(loss, AutoZygote())
7986

80-
ps = (prob, sol_ref, get_vars, get_refs, set_x);
87+
ps = (prob, sol_ref, get_vars, data, ts, set_x);
8188

8289
@test_call target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
8390
@test_opt target_modules=(ModelingToolkitNeuralNets,) loss(x0, ps)
@@ -89,7 +96,7 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x);
8996
@test all(.!isnan.(∇l1))
9097
@test !iszero(∇l1)
9198

92-
@test ∇l1∇l2 rtol=1e-5
99+
@test ∇l1∇l2 rtol=1e-4
93100
@test ∇l1 ∇l3
94101

95102
op = OptimizationProblem(of, x0, ps)

0 commit comments

Comments
 (0)