Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.15.7"
version = "0.15.8"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand All @@ -14,6 +14,7 @@ IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
Expand All @@ -37,8 +38,8 @@ BijectorsEnzymeCoreExt = "EnzymeCore"
BijectorsForwardDiffExt = "ForwardDiff"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsMooncakeExt = "Mooncake"
BijectorsReverseDiffExt = "ReverseDiff"
BijectorsReverseDiffChainRulesExt = ["ChainRules", "ReverseDiff"]
BijectorsReverseDiffExt = "ReverseDiff"
BijectorsTrackerExt = "Tracker"
BijectorsZygoteExt = "Zygote"

Expand All @@ -59,6 +60,7 @@ LazyArrays = "2"
LogExpFunctions = "0.3.3"
MappedArrays = "0.2.2, 0.3, 0.4"
Mooncake = "0.4.95"
PDMats = "0.11.35"
Reexport = "0.2, 1"
ReverseDiff = "1"
Roots = "1.3.15, 2"
Expand Down
4 changes: 1 addition & 3 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,7 @@ function logabsdetjac(b::VecCorrBijector, x)
end

function with_logabsdet_jacobian(::Inverse{VecCorrBijector}, y)
U_logJ = _inv_link_chol_lkj(y)
# workaround for `Tracker.TrackedTuple` not supporting iteration
U, logJ = U_logJ[1], U_logJ[2]
U, logJ = _inv_link_chol_lkj(y)
K = size(U, 1)
for j in 2:(K - 1)
logJ += (K - j) * log(U[j, j])
Expand Down
14 changes: 6 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using PDMats: PDMat

# `permutedims` seems to work better with AD (cf. KernelFunctions.jl)
aT_b(a::AbstractVector{<:Real}, b::AbstractMatrix{<:Real}) = permutedims(a) * b
# `permutedims` can't be used here since scalar output is desired
Expand All @@ -11,14 +13,8 @@ _vec(x::Real) = x
lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A))
upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A))

function pd_from_lower(X)
L = lower_triangular(X)
return L * L'
end
function pd_from_upper(X)
U = upper_triangular(X)
return U' * U
end
pd_from_lower(X) = PDMat(Cholesky(LowerTriangular(X)))
pd_from_upper(X) = PDMat(Cholesky(UpperTriangular(X)))

# HACK: Allows us to define custom chain rules while we wait for upstream fixes.
transpose_eager(X::AbstractMatrix) = permutedims(X)
Expand All @@ -35,6 +31,7 @@ rather than `LowerTriangular`.
that returns a `Matrix` rather than `LowerTriangular`.
"""
cholesky_lower(X::AbstractMatrix) = lower_triangular(parent(cholesky(Hermitian(X, :L)).L))
cholesky_lower(X::PDMat) = X.chol.L
cholesky_lower(X::Cholesky) = X.L

"""
Expand All @@ -48,6 +45,7 @@ rather than `UpperTriangular`.
that returns a `Matrix` rather than `UpperTriangular`.
"""
cholesky_upper(X::AbstractMatrix) = upper_triangular(parent(cholesky(Hermitian(X)).U))
cholesky_upper(X::PDMat) = X.chol.U
cholesky_upper(X::Cholesky) = X.U

"""
Expand Down
15 changes: 15 additions & 0 deletions test/bijectors/corr.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Bijectors, DistributionsAD, LinearAlgebra, Test
using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector
using Random: Xoshiro

@testset "CorrBijector & VecCorrBijector" begin
for d in [1, 2, 5]
Expand Down Expand Up @@ -43,6 +44,20 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector
@test size(dist_unconstrained) == size(x)
@test dist_unconstrained isa MatrixDistribution
end

@testset "Pathological samples for invlink" begin
# see https://github.com/TuringLang/Bijectors.jl/issues/387
d = LKJ(3, 3.0)
for i in 1:100
rng = Xoshiro(i)
y = randn(rng, 3) * 15
f_inv = inverse(bijector(d))
x = f_inv(y)
@test logpdf(d, x) isa Float64 # used to crash.
x, _ = with_logabsdet_jacobian(f_inv, y)
@test logpdf(d, x) isa Float64
end
end
end

@testset "VecCholeskyBijector" begin
Expand Down
Loading