Skip to content

Commit 93c1cbb

Browse files
committed
Use float_type_with_fallback for logjacs and logpdfs
1 parent fabdc44 commit 93c1cbb

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

src/distribution_wrappers.jl

+16-8
Original file line numberDiff line numberDiff line change
@@ -54,30 +54,38 @@ function Distributions.rand!(
5454
) where {N}
5555
return Distributions.rand!(rng, d.dist, x)
5656
end
57-
Distributions.logpdf(::NoDist{<:Univariate}, x::Real) = zero(eltype(x))
58-
Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractVector{<:Real}) = zero(eltype(x))
57+
function Distributions.logpdf(::NoDist{<:Univariate}, x::Real)
58+
return zero(float_type_with_fallback(eltype(x)))
59+
end
60+
function Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractVector{<:Real})
61+
return zero(float_type_with_fallback(eltype(x)))
62+
end
5963
function Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
60-
return zeros(eltype(x), size(x, 2))
64+
return zeros(float_type_with_fallback(eltype(x)), size(x, 2))
65+
end
66+
function Distributions.logpdf(::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real})
67+
return zero(float_type_with_fallback(eltype(x)))
6168
end
62-
Distributions.logpdf(::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real}) = zero(eltype(x))
6369
Distributions.minimum(d::NoDist) = minimum(d.dist)
6470
Distributions.maximum(d::NoDist) = maximum(d.dist)
6571

66-
Bijectors.logpdf_with_trans(::NoDist{<:Univariate}, x::Real, ::Bool) = zero(eltype(x))
72+
function Bijectors.logpdf_with_trans(::NoDist{<:Univariate}, x::Real, ::Bool)
73+
return zero(float_type_with_fallback(eltype(x)))
74+
end
6775
function Bijectors.logpdf_with_trans(
6876
::NoDist{<:Multivariate}, x::AbstractVector{<:Real}, ::Bool
6977
)
70-
return zero(eltype(x))
78+
return zero(float_type_with_fallback(eltype(x)))
7179
end
7280
function Bijectors.logpdf_with_trans(
7381
::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool
7482
)
75-
return zeros(eltype(x), size(x, 2))
83+
return zeros(float_type_with_fallback(eltype(x)), size(x, 2))
7684
end
7785
function Bijectors.logpdf_with_trans(
7886
::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real}, ::Bool
7987
)
80-
return zero(eltype(x))
88+
return zero(float_type_with_fallback(eltype(x)))
8189
end
8290

8391
Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist)

src/utils.jl

+8-6
Original file line numberDiff line numberDiff line change
@@ -253,15 +253,15 @@ function (f::UnwrapSingletonTransform)(x)
253253
end
254254

255255
function Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x)
256-
return f(x), zero(eltype(x))
256+
return f(x), zero(float_type_with_fallback(eltype(x)))
257257
end
258258

259259
function Bijectors.with_logabsdet_jacobian(
260260
inv_f::Bijectors.Inverse{<:UnwrapSingletonTransform}, x
261261
)
262262
f = inv_f.orig
263263
result = reshape([x], f.input_size)
264-
return result, zero(eltype(x))
264+
return result, zero(float_type_with_fallback(eltype(x)))
265265
end
266266

267267
"""
@@ -310,22 +310,24 @@ function (inv_f::Bijectors.Inverse{<:ReshapeTransform})(x)
310310
return inverse(x)
311311
end
312312

313-
Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x) = (f(x), zero(eltype(x)))
313+
function Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x)
314+
return f(x), zero(float_type_with_fallback(eltype(x)))
315+
end
314316

315317
function Bijectors.with_logabsdet_jacobian(inv_f::Bijectors.Inverse{<:ReshapeTransform}, x)
316-
return inv_f(x), zero(eltype(x))
318+
return inv_f(x), zero(float_type_with_fallback(eltype(x)))
317319
end
318320

319321
struct ToChol <: Bijectors.Bijector
320322
uplo::Char
321323
end
322324

323325
function Bijectors.with_logabsdet_jacobian(f::ToChol, x)
324-
return Cholesky(Matrix(x), f.uplo, 0), zero(eltype(x))
326+
return Cholesky(Matrix(x), f.uplo, 0), zero(float_type_with_fallback(eltype(x)))
325327
end
326328

327329
function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky)
328-
return y.UL, zero(eltype(y))
330+
return y.UL, zero(float_type_with_fallback(eltype(y)))
329331
end
330332

331333
function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y)

0 commit comments

Comments
 (0)