35
35
"""
36
36
VecJac(f, u, [p, t]; fu = nothing, autodiff = AutoFiniteDiff())
37
37
38
- Returns SciMLOperators.FunctionOperator which computes vector-jacobian product `df/du * v`.
38
+ Returns SciMLOperators.FunctionOperator which computes vector-jacobian product
39
+ `(df/du)ᵀ * v`.
39
40
40
41
!!! note
41
42
@@ -45,11 +46,11 @@ Returns SciMLOperators.FunctionOperator which computes vector-jacobian product `
45
46
```julia
46
47
L = VecJac(f, u)
47
48
48
- L * v # = df/du * v
49
- mul!(w, L, v) # = df/du * v
49
+ L * v # = ( df/du)ᵀ * v
50
+ mul!(w, L, v) # = ( df/du)ᵀ * v
50
51
51
- L(v, p, t; VJP_input = w) # = df/dw * v
52
- L(x, v, p, t; VJP_input = w) # = df/dw * v
52
+ L(v, p, t; VJP_input = w) # = ( df/du)ᵀ * v
53
+ L(x, v, p, t; VJP_input = w) # = ( df/du)ᵀ * v
53
54
```
54
55
55
56
## Allowed Function Signatures for `f`
@@ -72,7 +73,7 @@ f(du, u) # Otherwise
72
73
"""
73
74
function VecJac (f, u:: AbstractArray , p = nothing , t = nothing ; fu = nothing ,
74
75
autodiff = AutoFiniteDiff (), kwargs... )
75
- ff = VecJacFunctionWrapper (f, fu, u, p, t)
76
+ ff = JacFunctionWrapper (f, fu, u, p, t)
76
77
77
78
if ! __internal_oop (ff) && autodiff isa AutoZygote
78
79
msg = " Zygote requires an out of place method with signature f(u)."
@@ -83,82 +84,12 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
83
84
84
85
op = _vecjac (ff, fu, u, autodiff)
85
86
86
- # FIXME : FunctionOperator is terribly type unstable. It makes it `::Any`
87
87
# NOTE: We pass `p`, `t` to Function Operator but we always use the cached version from
88
- # VecJacFunctionWrapper
89
- return FunctionOperator (op, fu, u; p, t, isinplace = true , outofplace = true ,
88
+ # JacFunctionWrapper
89
+ return FunctionOperator (op, fu, u; p, t, isinplace = Val ( true ) , outofplace = Val ( true ) ,
90
90
islinear = true , accepted_kwargs = (:VJP_input ,), kwargs... )
91
91
end
92
92
93
- mutable struct VecJacFunctionWrapper{iip, oop, mode, F, FU, P, T} <: Function
94
- f:: F
95
- fu:: FU
96
- p:: P
97
- t:: T
98
- end
99
-
100
- function SciMLOperators. update_coefficients! (L:: VecJacFunctionWrapper{iip, oop, mode} , _,
101
- p, t) where {iip, oop, mode}
102
- mode == 1 && (L. t = t)
103
- mode == 2 && (L. p = p)
104
- return L
105
- end
106
- function SciMLOperators. update_coefficients (L:: VecJacFunctionWrapper{iip, oop, mode} , _, p,
107
- t) where {iip, oop, mode}
108
- return VecJacFunctionWrapper{iip, oop, mode, typeof (L. f), typeof (L. fu), typeof (p),
109
- typeof (t)}(L. f, L. fu, p,
110
- t)
111
- end
112
-
113
- __internal_iip (:: VecJacFunctionWrapper{iip} ) where {iip} = iip
114
- __internal_oop (:: VecJacFunctionWrapper{iip, oop} ) where {iip, oop} = oop
115
-
116
- (f:: VecJacFunctionWrapper{true, oop, 1} )(fu, u) where {oop} = f. f (fu, u, f. p, f. t)
117
- (f:: VecJacFunctionWrapper{true, oop, 2} )(fu, u) where {oop} = f. f (fu, u, f. p)
118
- (f:: VecJacFunctionWrapper{true, oop, 3} )(fu, u) where {oop} = f. f (fu, u)
119
- (f:: VecJacFunctionWrapper{true, true, 1} )(u) = f. f (u, f. p, f. t)
120
- (f:: VecJacFunctionWrapper{true, true, 2} )(u) = f. f (u, f. p)
121
- (f:: VecJacFunctionWrapper{true, true, 3} )(u) = f. f (u)
122
- (f:: VecJacFunctionWrapper{true, false, 1} )(u) = (f. f (f. fu, u, f. p, f. t); copy (f. fu))
123
- (f:: VecJacFunctionWrapper{true, false, 2} )(u) = (f. f (f. fu, u, f. p); copy (f. fu))
124
- (f:: VecJacFunctionWrapper{true, false, 3} )(u) = (f. f (f. fu, u); copy (f. fu))
125
-
126
- (f:: VecJacFunctionWrapper{false, true, 1} )(fu, u) = (vec (fu) .= vec (f. f (u, f. p, f. t)))
127
- (f:: VecJacFunctionWrapper{false, true, 2} )(fu, u) = (vec (fu) .= vec (f. f (u, f. p)))
128
- (f:: VecJacFunctionWrapper{false, true, 3} )(fu, u) = (vec (fu) .= vec (f. f (u)))
129
- (f:: VecJacFunctionWrapper{false, true, 1} )(u) = f. f (u, f. p, f. t)
130
- (f:: VecJacFunctionWrapper{false, true, 2} )(u) = f. f (u, f. p)
131
- (f:: VecJacFunctionWrapper{false, true, 3} )(u) = f. f (u)
132
-
133
- function VecJacFunctionWrapper (f:: F , fu_, u, p, t) where {F}
134
- fu = fu_ === nothing ? copy (u) : copy (fu_)
135
- if t != = nothing
136
- iip = static_hasmethod (f, typeof ((fu, u, p, t)))
137
- oop = static_hasmethod (f, typeof ((u, p, t)))
138
- if ! iip && ! oop
139
- throw (ArgumentError (" `f(u, p, t)` or `f(fu, u, p, t)` not defined for `f`" ))
140
- end
141
- return VecJacFunctionWrapper {iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)} (f,
142
- fu, p, t)
143
- elseif p != = nothing
144
- iip = static_hasmethod (f, typeof ((fu, u, p)))
145
- oop = static_hasmethod (f, typeof ((u, p)))
146
- if ! iip && ! oop
147
- throw (ArgumentError (" `f(u, p)` or `f(fu, u, p)` not defined for `f`" ))
148
- end
149
- return VecJacFunctionWrapper {iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)} (f,
150
- fu, p, t)
151
- else
152
- iip = static_hasmethod (f, typeof ((fu, u)))
153
- oop = static_hasmethod (f, typeof ((u,)))
154
- if ! iip && ! oop
155
- throw (ArgumentError (" `f(u)` or `f(fu, u)` not defined for `f`" ))
156
- end
157
- return VecJacFunctionWrapper {iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)} (f,
158
- fu, p, t)
159
- end
160
- end
161
-
162
93
function _vecjac (f:: F , fu, u, autodiff:: AutoFiniteDiff ) where {F}
163
94
cache = (similar (fu), similar (fu))
164
95
pullback = nothing
0 commit comments