Skip to content

Commit 1f58ae9

Browse files
committed
fix tests
1 parent d679213 commit 1f58ae9

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

test/lotka_volterra.jl

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,18 @@ end
4848

4949
rbf(x) = exp.(-(x .^ 2))
5050

51-
chain = Lux.Chain(
52-
Lux.Dense(2, 5, rbf), Lux.Dense(5, 5, rbf), Lux.Dense(5, 5, rbf),
53-
Lux.Dense(5, 2))
51+
chain = multi_layer_feed_forward(2, 2, width = 5, initial_scaling_factor = 1)
5452
ude_sys = lotka_ude(chain)
5553

56-
sys = mtkcompile(ude_sys, allow_symbolic = true)
54+
sys = mtkcompile(ude_sys)
55+
56+
@test length(equations(sys)) == 2
5757

5858
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 5.0))
5959

6060
model_true = mtkcompile(lotka_true())
6161
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 5.0))
62-
sol_ref = solve(prob_true, Vern9(), abstol = 1e-12, reltol = 1e-12)
62+
sol_ref = solve(prob_true, Vern9(), abstol = 1e-8, reltol = 1e-8)
6363

6464
ts = range(0, 5.0, length = 21)
6565
data = reduce(hcat, sol_ref(ts, idxs = [model_true.x, model_true.y]).u)
@@ -70,10 +70,9 @@ get_vars = getu(sys, [sys.x, sys.y])
7070
set_x = setsym_oop(sys, sys.nn.p)
7171

7272
function loss(x, (prob, sol_ref, get_vars, data, ts, set_x))
73-
# new_u0, new_p = set_x(prob, 1, x)
7473
new_u0, new_p = set_x(prob, x)
7574
new_prob = remake(prob, p = new_p, u0 = new_u0)
76-
new_sol = solve(new_prob, Vern9(), abstol = 1e-10, reltol = 1e-8, saveat = ts)
75+
new_sol = solve(new_prob, Vern9(), abstol = 1e-8, reltol = 1e-8, saveat = ts)
7776

7877
if SciMLBase.successful_retcode(new_sol)
7978
mean(abs2.(reduce(hcat, get_vars(new_sol)) .- data))
@@ -106,30 +105,32 @@ op = OptimizationProblem(of, x0, ps)
106105
# oh = []
107106

108107
# plot_cb = (opt_state, loss) -> begin
108+
# opt_state.iter % 500 ≠ 0 && return false
109109
# @info "step $(opt_state.iter), loss: $loss"
110110
# push!(oh, opt_state)
111111
# new_p = SciMLStructures.replace(Tunable(), prob.p, opt_state.u)
112112
# new_prob = remake(prob, p = new_p)
113-
# sol = solve(new_prob, Rodas4())
113+
# sol = solve(new_prob, Vern9(), abstol = 1e-8, reltol = 1e-8)
114114
# display(plot(sol))
115115
# false
116116
# end
117117

118-
res = solve(op, Adam(), maxiters = 10000)#, callback = plot_cb)
118+
res = solve(op, Adam(1e-3), maxiters = 25_000, callback = plot_cb)
119119

120120
display(res.stats)
121-
@test res.objective < 1
121+
@test res.objective < 1e-4
122+
123+
u0, p = set_x(prob, res.u)
124+
res_prob = remake(prob; u0, p)
125+
res_sol = solve(res_prob, Vern9(), abstol = 1e-8, reltol = 1e-8, saveat = ts)
122126

123-
res_p = set_x(prob, res.u)
124-
res_prob = remake(prob, p = res_p)
125-
res_sol = solve(res_prob, Vern9())
127+
@test SciMLBase.successful_retcode(res_sol)
128+
@test mean(abs2.(reduce(hcat, get_vars(res_sol)) .- data)) res.objective
126129

127130
# using Plots
128131
# plot(sol_ref, idxs = [model_true.x, model_true.y])
129132
# plot!(res_sol, idxs = [sys.x, sys.y])
130133

131-
@test SciMLBase.successful_retcode(res_sol)
132-
133134
function lotka_ude2()
134135
@variables t x(t)=3.1 y(t)=1.5 pred(t)[1:2]
135136
@parameters α=1.3 [tunable = false] δ=1.8 [tunable = false]

0 commit comments

Comments
 (0)