-
-
Notifications
You must be signed in to change notification settings - Fork 216
Open
Labels
ChainRulesadjoint -> rrule, and further integrationadjoint -> rrule, and further integrationbugSomething isn't workingSomething isn't working
Description
gradient
breaks when triple multiplying a Diagonal{<:Real}
, Matrix{<:Complex}
, and Diagonal{Real}
. This breaks going from Julia 1.8 -> 1.9.
MWE:
using LinearAlgebra
using Zygote
D = Diagonal(rand(3))
Ac = rand(ComplexF64, 3, 3)
Ar = rand(Float64, 3, 3)
f_real(x) = abs(sum(Diagonal(x) * Ar * D))
f_complex(x) = abs(sum(Diagonal(x) * Ac * D))
g_real = gradient(f_real, rand(3)) # works
g_complex = gradient(f_complex, rand(3)) # breaks, error message below
Error message:
ERROR: MethodError: no method matching _mul_partials(::ForwardDiff.Partials{3, Float64}, ::ForwardDiff.Partials{6, Float64}, ::Float64, ::Float64)
Closest candidates are:
_mul_partials(::ForwardDiff.Partials{N, A}, ::ForwardDiff.Partials{0, B}, ::Any, ::Any) where {N, A, B}
@ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/partials.jl:142
_mul_partials(::ForwardDiff.Partials{0, A}, ::ForwardDiff.Partials{N, B}, ::Any, ::Any) where {N, A, B}
@ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/partials.jl:141
_mul_partials(::ForwardDiff.Partials{N}, ::ForwardDiff.Partials{N}, ::Any, ::Any) where N
@ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/partials.jl:118
...
Stacktrace:
[1] dual_definition_retval(::Val{…}, val::Float64, deriv1::Float64, partial1::ForwardDiff.Partials{…}, deriv2::Float64, partial2::ForwardDiff.Partials{…})
@ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:203
[2] *
@ ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:271 [inlined]
[3] *(x::ForwardDiff.Dual{Nothing, Float64, 3}, z::Complex{ForwardDiff.Dual{Nothing, Float64, 6}})
@ Base ./complex.jl:339
[4] *
@ ./operators.jl:587 [inlined]
[5] (::Zygote.var"#1388#1389"{typeof(*)})(::Float64, ::ComplexF64, ::Float64)
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:276
[6] _broadcast_getindex_evalf
@ ./broadcast.jl:709 [inlined]
[7] _broadcast_getindex
@ ./broadcast.jl:682 [inlined]
[8] getindex
@ ./broadcast.jl:636 [inlined]
[9] copy
@ ./broadcast.jl:942 [inlined]
[10] materialize
@ ./broadcast.jl:903 [inlined]
[11] broadcast_forward
@ ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:282 [inlined]
[12] _broadcast_generic
@ ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:212 [inlined]
[13] adjoint
@ ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:169 [inlined]
[14] _pullback
@ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
[15] adjoint
@ ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:245 [inlined]
[16] _pullback
@ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
[17] *
@ ~/.julia/juliaup/julia-1.10.0-rc3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:409 [inlined]
[18] _pullback(::Zygote.Context{…}, ::typeof(*), ::Diagonal{…}, ::Matrix{…}, ::Diagonal{…})
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
[19] f_complex
@ ~/.julia/dev/SMOL/dev/update-julia.jl:9 [inlined]
[20] _pullback(ctx::Zygote.Context{false}, f::typeof(f_complex), args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
[21] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:44
[22] pullback
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:42 [inlined]
[23] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:96
[24] top-level scope
@ ~/.julia/dev/SMOL/dev/update-julia.jl:12
Some type information was truncated. Use `show(err)` to see complete types.
Versions:
julia> versioninfo()
Julia Version 1.10.0-rc3
Commit ed79752b939 (2023-12-18 09:57 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 20 × 12th Gen Intel(R) Core(TM) i7-12700H
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, alderlake)
Threads: 11 on 20 virtual cores
Environment:
LD_PRELOAD = /usr/lib/x86_64-linux-gnu/libstdc++.so.6
JULIA_EDITOR = code
JULIA_NUM_THREADS = 8
Package versions:
(jl_w76chw) pkg> st
Status `/tmp/jl_w76chw/Project.toml`
[e88e6eb3] Zygote v0.6.68
Metadata
Metadata
Assignees
Labels
ChainRulesadjoint -> rrule, and further integrationadjoint -> rrule, and further integrationbugSomething isn't workingSomething isn't working