Skip to content

LKJ bijector numerical stability issues / DomainError with -1.0 #387

@penelopeysm

Description

@penelopeysm

As reported in TuringLang/Turing.jl#2095, sampling with LKJ tends to throw up errors:

using Turing

@model function f()
    a ~ LKJ(3, 3.0)
end

sample(f(), NUTS(), 2000)
ERROR: DomainError with -1.0:
log was called with a negative real argument but will only return a complex result if called with a complex argument. Try log(Complex(x)).
DomainError detected in the user `f` function. This occurs when the domain of a function is violated.
For example, `log(-1.0)` is undefined because `log` of a real number is defined to only output real
numbers, but `log` of a negative number is complex valued and therefore Julia throws a DomainError
by default. Cases to be aware of include:

* `log(x)`, `sqrt(x)`, `cbrt(x)`, etc. where `x<0`
* `x^y` for `x<0` floating point `y` (example: `(-1.0)^(1/2) == im`)

Within the context of SciML, this error can occur within the solver process even if the domain constraint
would not be violated in the solution due to adaptivity. For example, an ODE solver or optimization
routine may check a step at `new_u` which violates the domain constraint, and if violated reject the
step and use a smaller `dt`. However, the throwing of this error will have halted the solving process.

Thus the recommended fix is to replace this function with the equivalent ones from NaNMath.jl
(https://github.com/JuliaMath/NaNMath.jl) which returns a NaN instead of an error. The solver will then
effectively use the NaN within the error control routines to reject the out of bounds step. Additionally,
one could perform a domain transformation on the variables so that such an issue does not occur in the
definition of `f`.

For more information, check out the following FAQ page:
https://docs.sciml.ai/Optimization/stable/API/FAQ/#The-Solver-Seems-to-Violate-Constraints-During-the-Optimization,-Causing-DomainErrors,-What-Can-I-Do-About-That?

Stacktrace:
  [1] throw_complex_domainerror(f::Symbol, x::Float64)
    @ Base.Math ./math.jl:33
  [2] _log
    @ ./special/log.jl:295 [inlined]
  [3] log(x::Float64)
    @ Base.Math ./special/log.jl:261
  [4] log
    @ ~/.julia/packages/ForwardDiff/UBbGT/src/dual.jl:244 [inlined]
  [5] logdet
    @ ~/.julia/juliaup/julia-1.11.5+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/generic.jl:1719 [inlined]
  [6] logkernel(d::LKJ{Float64, Int64}, R::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{…}, Float64, 3}})
    @ Distributions ~/.julia/packages/Distributions/YQSrn/src/matrix/lkj.jl:117
  [7] _logpdf
    @ ~/.julia/packages/Distributions/YQSrn/src/matrixvariates.jl:82 [inlined]
  [8] logpdf
    @ ~/.julia/packages/Distributions/YQSrn/src/common.jl:269 [inlined]
  [9] invlink_with_logpdf(vi::DynamicPPL.VarInfo{…}, vn::AbstractPPL.VarName{…}, dist::LKJ{…}, y::Vector{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/vlXDM/src/abstract_varinfo.jl:723
 [10] invlink_with_logpdf
    @ ~/.julia/packages/DynamicPPL/vlXDM/src/abstract_varinfo.jl:718 [inlined]
[...]

This is pretty much due to this numerical stability issue:

using Bijectors
d = LKJ(3, 3.0)
x = rand(d)
f = bijector(d)
y = f(x)

y_new = randn(length(y)) * 10
f_inv = inverse(f)
x_new, logjac = with_logabsdet_jacobian(f_inv, y_new)
logpdf(d, x_new) # usually errors

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions