Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 9a42b50

Browse files
Merge pull request #274 from avik-pal/ap/rework_jacvec
Fix type stability and allow JacVec to be non-square jacobians
2 parents 9a713c6 + 84c0dc3 commit 9a42b50

File tree

5 files changed

+133
-94
lines changed

5 files changed

+133
-94
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
33
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
4-
version = "2.11.0"
4+
version = "2.12.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -49,7 +49,7 @@ LinearAlgebra = "1.6"
4949
PackageExtensionCompat = "1"
5050
Random = "1.6"
5151
Reexport = "1"
52-
SciMLOperators = "0.2.11, 0.3"
52+
SciMLOperators = "0.3.7"
5353
Setfield = "1"
5454
SparseArrays = "1.6"
5555
StaticArrayInterface = "1.3"

src/SparseDiffTools.jl

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ include("coloring/greedy_star1_coloring.jl")
4040
include("coloring/greedy_star2_coloring.jl")
4141
include("coloring/matrix2graph.jl")
4242

43+
include("differentiation/common.jl")
4344
include("differentiation/compute_jacobian_ad.jl")
4445
include("differentiation/compute_hessian_ad.jl")
4546
include("differentiation/jaches_products.jl")

src/differentiation/common.jl

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
mutable struct JacFunctionWrapper{iip, oop, mode, F, FU, P, T} <: Function
2+
f::F
3+
fu::FU
4+
p::P
5+
t::T
6+
end
7+
8+
function SciMLOperators.update_coefficients!(L::JacFunctionWrapper{iip, oop, mode}, _,
9+
p, t) where {iip, oop, mode}
10+
mode == 1 && (L.t = t)
11+
mode == 2 && (L.p = p)
12+
return L
13+
end
14+
function SciMLOperators.update_coefficients(L::JacFunctionWrapper{iip, oop, mode}, _, p,
15+
t) where {iip, oop, mode}
16+
return JacFunctionWrapper{iip, oop, mode, typeof(L.f), typeof(L.fu), typeof(p),
17+
typeof(t)}(L.f, L.fu, p,
18+
t)
19+
end
20+
21+
__internal_iip(::JacFunctionWrapper{iip}) where {iip} = iip
22+
__internal_oop(::JacFunctionWrapper{iip, oop}) where {iip, oop} = oop
23+
24+
(f::JacFunctionWrapper{true, oop, 1})(fu, u) where {oop} = f.f(fu, u, f.p, f.t)
25+
(f::JacFunctionWrapper{true, oop, 2})(fu, u) where {oop} = f.f(fu, u, f.p)
26+
(f::JacFunctionWrapper{true, oop, 3})(fu, u) where {oop} = f.f(fu, u)
27+
(f::JacFunctionWrapper{true, true, 1})(u) = f.f(u, f.p, f.t)
28+
(f::JacFunctionWrapper{true, true, 2})(u) = f.f(u, f.p)
29+
(f::JacFunctionWrapper{true, true, 3})(u) = f.f(u)
30+
(f::JacFunctionWrapper{true, false, 1})(u) = (f.f(f.fu, u, f.p, f.t); copy(f.fu))
31+
(f::JacFunctionWrapper{true, false, 2})(u) = (f.f(f.fu, u, f.p); copy(f.fu))
32+
(f::JacFunctionWrapper{true, false, 3})(u) = (f.f(f.fu, u); copy(f.fu))
33+
34+
(f::JacFunctionWrapper{false, true, 1})(fu, u) = (vec(fu) .= vec(f.f(u, f.p, f.t)))
35+
(f::JacFunctionWrapper{false, true, 2})(fu, u) = (vec(fu) .= vec(f.f(u, f.p)))
36+
(f::JacFunctionWrapper{false, true, 3})(fu, u) = (vec(fu) .= vec(f.f(u)))
37+
(f::JacFunctionWrapper{false, true, 1})(u) = f.f(u, f.p, f.t)
38+
(f::JacFunctionWrapper{false, true, 2})(u) = f.f(u, f.p)
39+
(f::JacFunctionWrapper{false, true, 3})(u) = f.f(u)
40+
41+
function JacFunctionWrapper(f::F, fu_, u, p, t) where {F}
42+
# The warning instead of error ensures a non-breaking change for users relying on an
43+
# undefined / undocumented feature
44+
fu = fu_ === nothing ? copy(u) : copy(fu_)
45+
if t !== nothing
46+
iip = static_hasmethod(f, typeof((fu, u, p, t)))
47+
oop = static_hasmethod(f, typeof((u, p, t)))
48+
if !iip && !oop
49+
@warn """`p` and `t` provided but `f(u, p, t)` or `f(fu, u, p, t)` not defined
50+
for `f`! Will fallback to `f(u)` or `f(fu, u)`.""" maxlog=1
51+
else
52+
return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f,
53+
fu, p, t)
54+
end
55+
elseif p !== nothing
56+
iip = static_hasmethod(f, typeof((fu, u, p)))
57+
oop = static_hasmethod(f, typeof((u, p)))
58+
if !iip && !oop
59+
@warn """`p` provided but `f(u, p)` or `f(fu, u, p)` not defined for `f`! Will
60+
fallback to `f(u)` or `f(fu, u)`.""" maxlog=1
61+
else
62+
return JacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f,
63+
fu, p, t)
64+
end
65+
end
66+
iip = static_hasmethod(f, typeof((fu, u)))
67+
oop = static_hasmethod(f, typeof((u,)))
68+
!iip && !oop && throw(ArgumentError("`f(u)` or `f(fu, u)` not defined for `f`"))
69+
return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f,
70+
fu, p, t)
71+
end

src/differentiation/jaches_products.jl

+50-14
Original file line numberDiff line numberDiff line change
@@ -223,36 +223,72 @@ function Base.resize!(L::FwdModeAutoDiffVecProd, n::Integer)
223223
end
224224
end
225225

226-
function JacVec(f, u::AbstractArray, p = nothing, t = nothing;
226+
"""
227+
JacVec(f, u, [p, t]; fu = nothing, autodiff = AutoForwardDiff(), tag = DeivVecTag(),
228+
kwargs...)
229+
230+
Returns SciMLOperators.FunctionOperator which computes jacobian-vector product `df/du * v`.
231+
232+
!!! note
233+
234+
For non-square jacobians with inplace `f`, `fu` must be specified, else `JacVec` assumes
235+
a square jacobian.
236+
237+
```julia
238+
L = JacVec(f, u)
239+
240+
L * v # = df/du * v
241+
mul!(w, L, v) # = df/du * v
242+
243+
L(v, p, t) # = df/dw * v
244+
L(x, v, p, t) # = df/dw * v
245+
```
246+
247+
## Allowed Function Signatures for `f`
248+
249+
For Out of Place Functions:
250+
251+
```julia
252+
f(u, p, t) # t !== nothing
253+
f(u, p) # p !== nothing
254+
f(u) # Otherwise
255+
```
256+
257+
For In Place Functions:
258+
259+
```julia
260+
f(du, u, p, t) # t !== nothing
261+
f(du, u, p) # p !== nothing
262+
f(du, u) # Otherwise
263+
```
264+
"""
265+
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
227266
autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...)
267+
ff = JacFunctionWrapper(f, fu, u, p, t)
268+
fu === nothing && (fu = __internal_oop(ff) ? ff(u) : u)
269+
228270
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
229-
cache1 = similar(u)
271+
cache1 = similar(fu)
230272
cache2 = similar(u)
231273

232274
(cache1, cache2), num_jacvec, num_jacvec!
233275
elseif autodiff isa AutoForwardDiff
234276
cache1 = Dual{
235277
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1,
236278
}.(u, ForwardDiff.Partials.(tuple.(u)))
237-
238-
cache2 = copy(cache1)
279+
cache2 = Dual{
280+
typeof(ForwardDiff.Tag(tag, eltype(fu))), eltype(fu), 1,
281+
}.(fu, ForwardDiff.Partials.(tuple.(fu)))
239282

240283
(cache1, cache2), auto_jacvec, auto_jacvec!
241284
else
242285
error("Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()")
243286
end
244287

245-
outofplace = static_hasmethod(f, typeof((u,)))
246-
isinplace = static_hasmethod(f, typeof((u, u)))
288+
op = FwdModeAutoDiffVecProd(ff, u, cache, vecprod, vecprod!)
247289

248-
if !(isinplace) & !(outofplace)
249-
error("$f must have signature f(u), or f(du, u).")
250-
end
251-
252-
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
253-
254-
return FunctionOperator(L, u, u; isinplace, outofplace, p, t, islinear = true,
255-
kwargs...)
290+
return FunctionOperator(op, u, fu; isinplace = Val(true), outofplace = Val(true), p, t,
291+
islinear = true, kwargs...)
256292
end
257293

258294
function HesVec(f, u::AbstractArray, p = nothing, t = nothing;

src/differentiation/vecjac_products.jl

+9-78
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ end
3535
"""
3636
VecJac(f, u, [p, t]; fu = nothing, autodiff = AutoFiniteDiff())
3737
38-
Returns SciMLOperators.FunctionOperator which computes vector-jacobian product `df/du * v`.
38+
Returns SciMLOperators.FunctionOperator which computes vector-jacobian product
39+
`(df/du)ᵀ * v`.
3940
4041
!!! note
4142
@@ -45,11 +46,11 @@ Returns SciMLOperators.FunctionOperator which computes vector-jacobian product `
4546
```julia
4647
L = VecJac(f, u)
4748
48-
L * v # = df/du * v
49-
mul!(w, L, v) # = df/du * v
49+
L * v # = (df/du)ᵀ * v
50+
mul!(w, L, v) # = (df/du)ᵀ * v
5051
51-
L(v, p, t; VJP_input = w) # = df/dw * v
52-
L(x, v, p, t; VJP_input = w) # = df/dw * v
52+
L(v, p, t; VJP_input = w) # = (df/du)ᵀ * v
53+
L(x, v, p, t; VJP_input = w) # = (df/du)ᵀ * v
5354
```
5455
5556
## Allowed Function Signatures for `f`
@@ -72,7 +73,7 @@ f(du, u) # Otherwise
7273
"""
7374
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
7475
autodiff = AutoFiniteDiff(), kwargs...)
75-
ff = VecJacFunctionWrapper(f, fu, u, p, t)
76+
ff = JacFunctionWrapper(f, fu, u, p, t)
7677

7778
if !__internal_oop(ff) && autodiff isa AutoZygote
7879
msg = "Zygote requires an out of place method with signature f(u)."
@@ -83,82 +84,12 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing,
8384

8485
op = _vecjac(ff, fu, u, autodiff)
8586

86-
# FIXME: FunctionOperator is terribly type unstable. It makes it `::Any`
8787
# NOTE: We pass `p`, `t` to Function Operator but we always use the cached version from
88-
# VecJacFunctionWrapper
89-
return FunctionOperator(op, fu, u; p, t, isinplace = true, outofplace = true,
88+
# JacFunctionWrapper
89+
return FunctionOperator(op, fu, u; p, t, isinplace = Val(true), outofplace = Val(true),
9090
islinear = true, accepted_kwargs = (:VJP_input,), kwargs...)
9191
end
9292

93-
mutable struct VecJacFunctionWrapper{iip, oop, mode, F, FU, P, T} <: Function
94-
f::F
95-
fu::FU
96-
p::P
97-
t::T
98-
end
99-
100-
function SciMLOperators.update_coefficients!(L::VecJacFunctionWrapper{iip, oop, mode}, _,
101-
p, t) where {iip, oop, mode}
102-
mode == 1 && (L.t = t)
103-
mode == 2 && (L.p = p)
104-
return L
105-
end
106-
function SciMLOperators.update_coefficients(L::VecJacFunctionWrapper{iip, oop, mode}, _, p,
107-
t) where {iip, oop, mode}
108-
return VecJacFunctionWrapper{iip, oop, mode, typeof(L.f), typeof(L.fu), typeof(p),
109-
typeof(t)}(L.f, L.fu, p,
110-
t)
111-
end
112-
113-
__internal_iip(::VecJacFunctionWrapper{iip}) where {iip} = iip
114-
__internal_oop(::VecJacFunctionWrapper{iip, oop}) where {iip, oop} = oop
115-
116-
(f::VecJacFunctionWrapper{true, oop, 1})(fu, u) where {oop} = f.f(fu, u, f.p, f.t)
117-
(f::VecJacFunctionWrapper{true, oop, 2})(fu, u) where {oop} = f.f(fu, u, f.p)
118-
(f::VecJacFunctionWrapper{true, oop, 3})(fu, u) where {oop} = f.f(fu, u)
119-
(f::VecJacFunctionWrapper{true, true, 1})(u) = f.f(u, f.p, f.t)
120-
(f::VecJacFunctionWrapper{true, true, 2})(u) = f.f(u, f.p)
121-
(f::VecJacFunctionWrapper{true, true, 3})(u) = f.f(u)
122-
(f::VecJacFunctionWrapper{true, false, 1})(u) = (f.f(f.fu, u, f.p, f.t); copy(f.fu))
123-
(f::VecJacFunctionWrapper{true, false, 2})(u) = (f.f(f.fu, u, f.p); copy(f.fu))
124-
(f::VecJacFunctionWrapper{true, false, 3})(u) = (f.f(f.fu, u); copy(f.fu))
125-
126-
(f::VecJacFunctionWrapper{false, true, 1})(fu, u) = (vec(fu) .= vec(f.f(u, f.p, f.t)))
127-
(f::VecJacFunctionWrapper{false, true, 2})(fu, u) = (vec(fu) .= vec(f.f(u, f.p)))
128-
(f::VecJacFunctionWrapper{false, true, 3})(fu, u) = (vec(fu) .= vec(f.f(u)))
129-
(f::VecJacFunctionWrapper{false, true, 1})(u) = f.f(u, f.p, f.t)
130-
(f::VecJacFunctionWrapper{false, true, 2})(u) = f.f(u, f.p)
131-
(f::VecJacFunctionWrapper{false, true, 3})(u) = f.f(u)
132-
133-
function VecJacFunctionWrapper(f::F, fu_, u, p, t) where {F}
134-
fu = fu_ === nothing ? copy(u) : copy(fu_)
135-
if t !== nothing
136-
iip = static_hasmethod(f, typeof((fu, u, p, t)))
137-
oop = static_hasmethod(f, typeof((u, p, t)))
138-
if !iip && !oop
139-
throw(ArgumentError("`f(u, p, t)` or `f(fu, u, p, t)` not defined for `f`"))
140-
end
141-
return VecJacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f,
142-
fu, p, t)
143-
elseif p !== nothing
144-
iip = static_hasmethod(f, typeof((fu, u, p)))
145-
oop = static_hasmethod(f, typeof((u, p)))
146-
if !iip && !oop
147-
throw(ArgumentError("`f(u, p)` or `f(fu, u, p)` not defined for `f`"))
148-
end
149-
return VecJacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f,
150-
fu, p, t)
151-
else
152-
iip = static_hasmethod(f, typeof((fu, u)))
153-
oop = static_hasmethod(f, typeof((u,)))
154-
if !iip && !oop
155-
throw(ArgumentError("`f(u)` or `f(fu, u)` not defined for `f`"))
156-
end
157-
return VecJacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f,
158-
fu, p, t)
159-
end
160-
end
161-
16293
function _vecjac(f::F, fu, u, autodiff::AutoFiniteDiff) where {F}
16394
cache = (similar(fu), similar(fu))
16495
pullback = nothing

0 commit comments

Comments
 (0)