@@ -15,13 +15,13 @@ using DifferentiationInterface
15
15
using SciMLSensitivity
16
16
using Zygote: Zygote
17
17
using Statistics
18
+ using Lux
18
19
19
- function lotka_ude ()
20
+ function lotka_ude (chain )
20
21
@variables t x (t)= 3.1 y (t)= 1.5
21
22
@parameters α= 1.3 [tunable = false ] δ= 1.8 [tunable = false ]
22
23
Dt = ModelingToolkit. D_nounits
23
24
24
- chain = multi_layer_feed_forward (2 , 2 )
25
25
@named nn = NeuralNetworkBlock (2 , 2 ; chain, rng = StableRNG (42 ))
26
26
27
27
eqs = [
36
36
37
37
function lotka_true ()
38
38
@variables t x (t)= 3.1 y (t)= 1.5
39
- @parameters α= 1.3 β= 0.9 γ= 0.8 δ= 1.8
39
+ @parameters α= 1.3 [tunable = false ] β= 0.9 γ= 0.8 δ= 1.8 [tunable = false ]
40
40
Dt = ModelingToolkit. D_nounits
41
41
42
42
eqs = [
43
43
Dt (x) ~ α * x - β * x * y,
44
- Dt (y) ~ - δ * y + δ * x * y
44
+ Dt (y) ~ - δ * y + γ * x * y
45
45
]
46
46
return System (eqs, ModelingToolkit. t_nounits, name = :lotka_true )
47
47
end
48
48
49
- ude_sys = lotka_ude ()
49
+ rbf (x) = exp .(- (x .^ 2 ))
50
+
51
+ chain = Lux. Chain (
52
+ Lux. Dense (2 , 5 , rbf), Lux. Dense (5 , 5 , rbf), Lux. Dense (5 , 5 , rbf),
53
+ Lux. Dense (5 , 2 ))
54
+ ude_sys = lotka_ude (chain)
50
55
51
56
sys = mtkcompile (ude_sys, allow_symbolic = true )
52
57
53
- prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys, [], (0 , 1 .0 ))
58
+ prob = ODEProblem {true, SciMLBase.FullSpecialize} (sys, [], (0 , 5 .0 ))
54
59
55
60
model_true = mtkcompile (lotka_true ())
56
- prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 1.0 ))
57
- sol_ref = solve (prob_true, Vern9 (), abstol = 1e-10 , reltol = 1e-8 )
61
+ prob_true = ODEProblem {true, SciMLBase.FullSpecialize} (model_true, [], (0 , 5.0 ))
62
+ sol_ref = solve (prob_true, Vern9 (), abstol = 1e-12 , reltol = 1e-12 )
63
+
64
+ ts = range (0 , 5.0 , length = 21 )
65
+ data = reduce (hcat, sol_ref (ts, idxs = [model_true. x, model_true. y]). u)
58
66
59
67
x0 = default_values (sys)[sys. nn. p]
60
68
61
69
get_vars = getu (sys, [sys. x, sys. y])
62
- get_refs = getu (model_true, [model_true. x, model_true. y])
63
- set_x = setp_oop (sys, sys. nn. p)
70
+ set_x = setsym_oop (sys, sys. nn. p)
64
71
65
- function loss (x, (prob, sol_ref, get_vars, get_refs , set_x))
66
- new_p = set_x (prob, x)
67
- new_prob = remake (prob, p = new_p, u0 = eltype (x).( prob. u0) )
68
- ts = sol_ref . t
72
+ function loss (x, (prob, sol_ref, get_vars, data, ts , set_x))
73
+ # new_u0, new_p = set_x(prob, 1 , x)
74
+ new_u0, new_p = set_x ( prob, x )
75
+ new_prob = remake (prob, p = new_p, u0 = new_u0)
69
76
new_sol = solve (new_prob, Vern9 (), abstol = 1e-10 , reltol = 1e-8 , saveat = ts)
70
77
71
78
if SciMLBase. successful_retcode (new_sol)
72
- mean (abs2 .(reduce (hcat, get_vars (new_sol)) .- reduce (hcat, get_refs (sol_ref)) ))
79
+ mean (abs2 .(reduce (hcat, get_vars (new_sol)) .- data ))
73
80
else
74
81
Inf
75
82
end
76
83
end
77
84
78
85
of = OptimizationFunction {true} (loss, AutoZygote ())
79
86
80
- ps = (prob, sol_ref, get_vars, get_refs , set_x);
87
+ ps = (prob, sol_ref, get_vars, data, ts , set_x);
81
88
82
89
@test_call target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
83
90
@test_opt target_modules= (ModelingToolkitNeuralNets,) loss (x0, ps)
@@ -89,7 +96,7 @@ ps = (prob, sol_ref, get_vars, get_refs, set_x);
89
96
@test all (.! isnan .(∇l1))
90
97
@test ! iszero (∇l1)
91
98
92
- @test ∇l1≈ ∇l2 rtol= 1e-5
99
+ @test ∇l1≈ ∇l2 rtol= 1e-4
93
100
@test ∇l1 ≈ ∇l3
94
101
95
102
op = OptimizationProblem (of, x0, ps)
0 commit comments