Skip to content

gradient broken for (*)(::Diagonal{Real}, ::Matrix{Complex}, ::Diagonal{Real}) when updating Julia 1.8 -> 1.9 #1483

@kylebeggs

Description

@kylebeggs

gradient breaks when triple multiplying a Diagonal{<:Real}, Matrix{<:Complex}, and Diagonal{Real}. This breaks going from Julia 1.8 -> 1.9.

MWE:

using LinearAlgebra
using Zygote

D = Diagonal(rand(3))
Ac = rand(ComplexF64, 3, 3)
Ar = rand(Float64, 3, 3)

f_real(x) = abs(sum(Diagonal(x) * Ar * D))
f_complex(x) = abs(sum(Diagonal(x) * Ac * D))

g_real = gradient(f_real, rand(3)) # works
g_complex = gradient(f_complex, rand(3)) # breaks, error message below

Error message:

ERROR: MethodError: no method matching _mul_partials(::ForwardDiff.Partials{3, Float64}, ::ForwardDiff.Partials{6, Float64}, ::Float64, ::Float64)

Closest candidates are:
  _mul_partials(::ForwardDiff.Partials{N, A}, ::ForwardDiff.Partials{0, B}, ::Any, ::Any) where {N, A, B}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/partials.jl:142
  _mul_partials(::ForwardDiff.Partials{0, A}, ::ForwardDiff.Partials{N, B}, ::Any, ::Any) where {N, A, B}
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/partials.jl:141
  _mul_partials(::ForwardDiff.Partials{N}, ::ForwardDiff.Partials{N}, ::Any, ::Any) where N
   @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/partials.jl:118
  ...

Stacktrace:
  [1] dual_definition_retval(::Val{…}, val::Float64, deriv1::Float64, partial1::ForwardDiff.Partials{…}, deriv2::Float64, partial2::ForwardDiff.Partials{…})
    @ ForwardDiff ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:203
  [2] *
    @ ~/.julia/packages/ForwardDiff/PcZ48/src/dual.jl:271 [inlined]
  [3] *(x::ForwardDiff.Dual{Nothing, Float64, 3}, z::Complex{ForwardDiff.Dual{Nothing, Float64, 6}})
    @ Base ./complex.jl:339
  [4] *
    @ ./operators.jl:587 [inlined]
  [5] (::Zygote.var"#1388#1389"{typeof(*)})(::Float64, ::ComplexF64, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:276
  [6] _broadcast_getindex_evalf
    @ ./broadcast.jl:709 [inlined]
  [7] _broadcast_getindex
    @ ./broadcast.jl:682 [inlined]
  [8] getindex
    @ ./broadcast.jl:636 [inlined]
  [9] copy
    @ ./broadcast.jl:942 [inlined]
 [10] materialize
    @ ./broadcast.jl:903 [inlined]
 [11] broadcast_forward
    @ ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:282 [inlined]
 [12] _broadcast_generic
    @ ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:212 [inlined]
 [13] adjoint
    @ ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:169 [inlined]
 [14] _pullback
    @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
 [15] adjoint
    @ ~/.julia/packages/Zygote/WOy6z/src/lib/broadcast.jl:245 [inlined]
 [16] _pullback
    @ ~/.julia/packages/ZygoteRules/4nXuu/src/adjoint.jl:66 [inlined]
 [17] *
    @ ~/.julia/juliaup/julia-1.10.0-rc3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/diagonal.jl:409 [inlined]
 [18] _pullback(::Zygote.Context{…}, ::typeof(*), ::Diagonal{…}, ::Matrix{…}, ::Diagonal{…})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
 [19] f_complex
    @ ~/.julia/dev/SMOL/dev/update-julia.jl:9 [inlined]
 [20] _pullback(ctx::Zygote.Context{false}, f::typeof(f_complex), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface2.jl:0
 [21] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:44
 [22] pullback
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:42 [inlined]
 [23] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/WOy6z/src/compiler/interface.jl:96
 [24] top-level scope
    @ ~/.julia/dev/SMOL/dev/update-julia.jl:12
Some type information was truncated. Use `show(err)` to see complete types.

Versions:

julia> versioninfo()
Julia Version 1.10.0-rc3
Commit ed79752b939 (2023-12-18 09:57 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 20 × 12th Gen Intel(R) Core(TM) i7-12700H
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, alderlake)
  Threads: 11 on 20 virtual cores
Environment:
  LD_PRELOAD = /usr/lib/x86_64-linux-gnu/libstdc++.so.6
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = 8

Package versions:

(jl_w76chw) pkg> st
Status `/tmp/jl_w76chw/Project.toml`
  [e88e6eb3] Zygote v0.6.68

Metadata

Metadata

Assignees

No one assigned

    Labels

    ChainRulesadjoint -> rrule, and further integrationbugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions