Skip to content

Truncated logabsdetjac forward / inverse inconsistency #375

@penelopeysm

Description

@penelopeysm
import Bijectors as B
using Distributions: Normal

dist = truncated(Normal(); lower=6.515552440303498)
fwd = B.bijector(dist)
ivs = B.inverse(fwd)
y = -24.79099078521386

Inverse transform (changed in #325):

julia> B.logabsdetjac(ivs, y)
-24.79099078521386

implemented here

function logabsdetjac(ib::Inverse{<:TruncatedBijector}, y)
a, b = ib.orig.lb, ib.orig.ub
return sum(truncated_inv_logabsdetjac.(y, a, b))
end

Forward transform (not changed, now out of sync):

julia> B.logabsdetjac(fwd, ivs(y))
24.791002373799344

implemented here

function logabsdetjac(b::TruncatedBijector, x)
a, b = b.lb, b.ub
return sum(truncated_logabsdetjac.(_clamp.(x, a, b), a, b))
end

This is causing CI failures on TuringLang/DynamicPPL.jl#853 (and minimised from the failing test there, hence the very specific values above)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions