48
48
49
49
rbf (x) = exp .(- (x .^ 2 ))
50
50
51
- chain = Lux. Chain (
52
- Lux. Dense (2 , 5 , rbf), Lux. Dense (5 , 5 , rbf), Lux. Dense (5 , 5 , rbf),
53
- Lux. Dense (5 , 2 ))
51
+ chain = multi_layer_feed_forward (2 , 2 , width = 5 , initial_scaling_factor = 1 )
54
52
ude_sys = lotka_ude (chain)
55
53
56
- sys = mtkcompile (ude_sys, allow_symbolic = true )
54
+ sys = mtkcompile (ude_sys)
55
+
56
+ @test length (equations (sys)) == 2
57
57
58
58
prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys, [], (0 , 5.0 ))
59
59
60
60
model_true = mtkcompile (lotka_true ())
61
61
prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 5.0 ))
62
- sol_ref = solve (prob_true, Vern9 (), abstol = 1e-12 , reltol = 1e-12 )
62
+ sol_ref = solve (prob_true, Vern9 (), abstol = 1e-8 , reltol = 1e-8 )
63
63
64
64
ts = range (0 , 5.0 , length = 21 )
65
65
data = reduce (hcat, sol_ref (ts, idxs = [model_true. x, model_true. y]). u)
@@ -70,10 +70,9 @@ get_vars = getu(sys, [sys.x, sys.y])
70
70
set_x = setsym_oop (sys, sys. nn. p)
71
71
72
72
function loss (x, (prob, sol_ref, get_vars, data, ts, set_x))
73
- # new_u0, new_p = set_x(prob, 1, x)
74
73
new_u0, new_p = set_x (prob, x)
75
74
new_prob = remake (prob, p = new_p, u0 = new_u0)
76
- new_sol = solve (new_prob, Vern9 (), abstol = 1e-10 , reltol = 1e-8 , saveat = ts)
75
+ new_sol = solve (new_prob, Vern9 (), abstol = 1e-8 , reltol = 1e-8 , saveat = ts)
77
76
78
77
if SciMLBase. successful_retcode (new_sol)
79
78
mean (abs2 .(reduce (hcat, get_vars (new_sol)) .- data))
@@ -106,30 +105,32 @@ op = OptimizationProblem(of, x0, ps)
106
105
# oh = []
107
106
108
107
# plot_cb = (opt_state, loss) -> begin
108
+ # opt_state.iter % 500 ≠ 0 && return false
109
109
# @info "step $(opt_state.iter), loss: $loss"
110
110
# push!(oh, opt_state)
111
111
# new_p = SciMLStructures.replace(Tunable(), prob.p, opt_state.u)
112
112
# new_prob = remake(prob, p = new_p)
113
- # sol = solve(new_prob, Rodas4() )
113
+ # sol = solve(new_prob, Vern9(), abstol = 1e-8, reltol = 1e-8 )
114
114
# display(plot(sol))
115
115
# false
116
116
# end
117
117
118
- res = solve (op, Adam (), maxiters = 10000 ) # , callback = plot_cb)
118
+ res = solve (op, Adam (1e-3 ), maxiters = 25_000 , callback = plot_cb)
119
119
120
120
display (res. stats)
121
- @test res. objective < 1
121
+ @test res. objective < 1e-4
122
+
123
+ u0, p = set_x (prob, res. u)
124
+ res_prob = remake (prob; u0, p)
125
+ res_sol = solve (res_prob, Vern9 (), abstol = 1e-8 , reltol = 1e-8 , saveat = ts)
122
126
123
- res_p = set_x (prob, res. u)
124
- res_prob = remake (prob, p = res_p)
125
- res_sol = solve (res_prob, Vern9 ())
127
+ @test SciMLBase. successful_retcode (res_sol)
128
+ @test mean (abs2 .(reduce (hcat, get_vars (res_sol)) .- data)) ≈ res. objective
126
129
127
130
# using Plots
128
131
# plot(sol_ref, idxs = [model_true.x, model_true.y])
129
132
# plot!(res_sol, idxs = [sys.x, sys.y])
130
133
131
- @test SciMLBase. successful_retcode (res_sol)
132
-
133
134
function lotka_ude2 ()
134
135
@variables t x (t)= 3.1 y (t)= 1.5 pred (t)[1 : 2 ]
135
136
@parameters α= 1.3 [tunable = false ] δ= 1.8 [tunable = false ]
0 commit comments