Skip to content

Commit 8e53643

Browse files
Merge pull request #3609 from AayushSabharwal/as/revert-breaking-change
fix: fix breaking change to `generate_control_function`
2 parents 76c7a9b + 948f38c commit 8e53643

File tree

7 files changed

+25
-23
lines changed

7 files changed

+25
-23
lines changed

docs/src/basics/InputOutput.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Now we can test the generated function `f` with random input and state values
7070
p = [1]
7171
x = [rand()]
7272
u = [rand()]
73-
@test f(x, u, p, 1) ≈ -p[] * (x + u) # Test that the function computes what we expect D(x) = -k*(x + u)
73+
@test f[1](x, u, p, 1) ≈ -p[] * (x + u) # Test that the function computes what we expect D(x) = -k*(x + u)
7474
```
7575

7676
## Generating an output function, ``g``

docs/src/tutorials/disturbance_modeling.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ disturbance_inputs = [ssys.d1, ssys.d2]
184184
P = ssys.system_model
185185
outputs = [P.inertia1.phi, P.inertia2.phi, P.inertia1.w, P.inertia2.w]
186186
187-
f, x_sym, p_sym, io_sys = ModelingToolkit.generate_control_function(
187+
(f_oop, f_ip), x_sym, p_sym, io_sys = ModelingToolkit.generate_control_function(
188188
model_with_disturbance, inputs, disturbance_inputs; disturbance_argument = true)
189189
190190
g = ModelingToolkit.build_explicit_observed_function(
@@ -195,12 +195,12 @@ x0, _ = ModelingToolkit.get_u0_p(io_sys, op, op)
195195
p = MTKParameters(io_sys, op)
196196
u = zeros(1) # Control input
197197
w = zeros(length(disturbance_inputs)) # Disturbance input
198-
@test f(x0, u, p, t, w) == zeros(5)
198+
@test f_oop(x0, u, p, t, w) == zeros(5)
199199
@test g(x0, u, p, 0.0) == [0, 0, 0, 0]
200200
201201
# Non-zero disturbance inputs should result in non-zero state derivatives. We call `sort` since we do not generally know the order of the state variables
202202
w = [1.0, 2.0]
203-
@test sort(f(x0, u, p, t, w)) == [0, 0, 0, 1, 2]
203+
@test sort(f_oop(x0, u, p, t, w)) == [0, 0, 0, 1, 2]
204204
```
205205

206206
## Input signal library

src/inputoutput.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,17 +160,17 @@ has_var(ex, x) = x ∈ Set(get_variables(ex))
160160
# Build control function
161161

162162
"""
163-
f, x_sym, p_sym, io_sys = generate_control_function(
163+
(f_oop, f_ip), x_sym, p_sym, io_sys = generate_control_function(
164164
sys::AbstractODESystem,
165165
inputs = unbound_inputs(sys),
166166
disturbance_inputs = nothing;
167167
implicit_dae = false,
168168
simplify = false,
169169
)
170170
171-
For a system `sys` with inputs (as determined by [`unbound_inputs`](@ref) or user specified), generate a function with additional input argument `u`
171+
For a system `sys` with inputs (as determined by [`unbound_inputs`](@ref) or user specified), generate functions with additional input argument `u`
172172
173-
The returned function `f` can be called in the out-of-place or in-place form:
173+
The returned functions are the out-of-place (`f_oop`) and in-place (`f_ip`) forms:
174174
```
175175
f_oop : (x,u,p,t) -> rhs
176176
f_ip : (xout,x,u,p,t) -> nothing
@@ -191,7 +191,7 @@ f, x_sym, ps = generate_control_function(sys, expression=Val{false}, simplify=fa
191191
p = varmap_to_vars(defaults(sys), ps)
192192
x = varmap_to_vars(defaults(sys), x_sym)
193193
t = 0
194-
f(x, inputs, p, t)
194+
f[1](x, inputs, p, t)
195195
```
196196
"""
197197
function generate_control_function(sys::AbstractODESystem, inputs = unbound_inputs(sys),
@@ -253,9 +253,10 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu
253253
f = build_function_wrapper(sys, rhss, args...; p_start = 3 + implicit_dae,
254254
p_end = length(p) + 2 + implicit_dae, kwargs...)
255255
f = eval_or_rgf.(f; eval_expression, eval_module)
256-
f = GeneratedFunctionWrapper{(3, length(args) - length(p) + 1, is_split(sys))}(f...)
256+
f = GeneratedFunctionWrapper{(
257+
3 + implicit_dae, length(args) - length(p) + 1, is_split(sys))}(f...)
257258
ps = setdiff(parameters(sys), inputs, disturbance_inputs)
258-
(; f, dvs, ps, io_sys = sys)
259+
(; f = (f, f), dvs, ps, io_sys = sys)
259260
end
260261

261262
function inputs_to_parameters!(state::TransformationState, io)

src/systems/optimal_control_interface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ function SciMLBase.ODEInputFunction{iip, specialize}(sys::ODESystem,
5252
kwargs...) where {iip, specialize}
5353
f, _, _ = generate_control_function(
5454
sys, inputs, disturbance_inputs; eval_module, cse, kwargs...)
55+
f = f[1]
5556

5657
if tgrad
5758
tgrad_gen = generate_tgrad(sys, dvs, ps;

test/downstream/test_disturbance_model.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,22 +168,22 @@ x0, p = ModelingToolkit.get_u0_p(io_sys, op, op)
168168
x = zeros(5)
169169
u = zeros(1)
170170
d = zeros(3)
171-
@test f(x, u, p, t, d) == zeros(5)
171+
@test f[1](x, u, p, t, d) == zeros(5)
172172
@test measurement(x, u, p, 0.0) == [0, 0, 0, 0]
173173
@test measurement2(x, u, p, 0.0, d) == [0]
174174

175175
# Add to the integrating disturbance input
176176
d = [1, 0, 0]
177-
@test sort(f(x, u, p, 0.0, d)) == [0, 0, 0, 1, 1] # Affects disturbance state and one velocity
177+
@test sort(f[1](x, u, p, 0.0, d)) == [0, 0, 0, 1, 1] # Affects disturbance state and one velocity
178178
@test measurement2(x, u, p, 0.0, d) == [0]
179179

180180
d = [0, 1, 0]
181-
@test sort(f(x, u, p, 0.0, d)) == [0, 0, 0, 0, 1] # Affects one velocity
181+
@test sort(f[1](x, u, p, 0.0, d)) == [0, 0, 0, 0, 1] # Affects one velocity
182182
@test measurement(x, u, p, 0.0) == [0, 0, 0, 0]
183183
@test measurement2(x, u, p, 0.0, d) == [0]
184184

185185
d = [0, 0, 1]
186-
@test sort(f(x, u, p, 0.0, d)) == [0, 0, 0, 0, 0] # Affects nothing
186+
@test sort(f[1](x, u, p, 0.0, d)) == [0, 0, 0, 0, 0] # Affects nothing
187187
@test measurement(x, u, p, 0.0) == [0, 0, 0, 0]
188188
@test measurement2(x, u, p, 0.0, d) == [1] # We have now disturbed the output
189189

test/extensions/test_infiniteopt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ InfiniteOpt.@variables(m,
6565
# Trace the dynamics
6666
x0, p = ModelingToolkit.get_u0_p(io_sys, [model.θ => 0, model.ω => 0], [model.L => L])
6767

68-
xp = f(x, u, p, τ)
68+
xp = f[1](x, u, p, τ)
6969
cp = f_obs(x, u, p, τ) # Test that it's possible to trace through an observed function
7070

7171
@objective(m, Min, tf)

test/input_output_handling.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ end
173173
p = [rand()]
174174
x = [rand()]
175175
u = [rand()]
176-
@test f(x, u, p, 1) -x + u
176+
@test f[1](x, u, p, 1) -x + u
177177

178178
# With disturbance inputs
179179
@variables x(t)=0 u(t)=0 [input = true] d(t)=0
@@ -191,7 +191,7 @@ end
191191
p = [rand()]
192192
x = [rand()]
193193
u = [rand()]
194-
@test f(x, u, p, 1) -x + u
194+
@test f[1](x, u, p, 1) -x + u
195195

196196
## With added d argument
197197
@variables x(t)=0 u(t)=0 [input = true] d(t)=0
@@ -210,7 +210,7 @@ end
210210
x = [rand()]
211211
u = [rand()]
212212
d = [rand()]
213-
@test f(x, u, p, t, d) -x + u + [d[]^2]
213+
@test f[1](x, u, p, t, d) -x + u + [d[]^2]
214214
end
215215
end
216216

@@ -273,7 +273,7 @@ x = ModelingToolkit.varmap_to_vars(
273273
merge(ModelingToolkit.defaults(model),
274274
Dict(D.(unknowns(model)) .=> 0.0)), dvs)
275275
u = [rand()]
276-
out = f(x, u, p, 1)
276+
out = f[1](x, u, p, 1)
277277
i = findfirst(isequal(u[1]), out)
278278
@test i isa Int
279279
@test iszero(out[[1:(i - 1); (i + 1):end]])
@@ -348,8 +348,8 @@ x0 = randn(5)
348348
x1 = copy(x0) + x_add # add disturbance state perturbation
349349
u = randn(1)
350350
pn = MTKParameters(io_sys, [])
351-
xp0 = f(x0, u, pn, 0)
352-
xp1 = f(x1, u, pn, 0)
351+
xp0 = f[1](x0, u, pn, 0)
352+
xp1 = f[1](x1, u, pn, 0)
353353

354354
@test xp0 matrices.A * x0 + matrices.B * [u; 0]
355355
@test xp1 matrices.A * x1 + matrices.B * [u; 0]
@@ -447,7 +447,7 @@ end
447447
@named sys = ODESystem(eqs, t, [x], [])
448448

449449
f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys, simplify = true)
450-
@test f([0.5], nothing, MTKParameters(io_sys, []), 0.0) [1.0]
450+
@test f[1]([0.5], nothing, MTKParameters(io_sys, []), 0.0) [1.0]
451451
end
452452

453453
@testset "With callable symbolic" begin
@@ -459,5 +459,5 @@ end
459459
p = MTKParameters(io_sys, [])
460460
u = [1.0]
461461
x = [1.0]
462-
@test_nowarn f(x, u, p, 0.0)
462+
@test_nowarn f[1](x, u, p, 0.0)
463463
end

0 commit comments

Comments
 (0)