Skip to content
2 changes: 1 addition & 1 deletion src/TransformVariables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ArgCheck: @argcheck
using DocStringExtensions: FUNCTIONNAME, SIGNATURES, TYPEDEF
import ForwardDiff
using LogExpFunctions
using LinearAlgebra: UpperTriangular, logabsdet
using LinearAlgebra: UpperTriangular, logabsdet, norm, rmul!
using Random: AbstractRNG, GLOBAL_RNG
using StaticArrays: MMatrix, SMatrix, SArray, SVector, pushfirst
using CompositionsBase
Expand Down
53 changes: 35 additions & 18 deletions src/special_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,23 @@ end
"""
UnitVector(n)

Transform `n-1` real numbers to a unit vector of length `n`, under the
Transform `n` real numbers to a unit vector of length `n`, under the
Euclidean norm.

!!! note
This transform is non-bijective and is undefined at the origin.
If maximizing a target distribution whose density is constant for the unit vector,
then the maximizer is at the origin, and behavior is undefined.
"""
struct UnitVector <: VectorTransform
n::Int
function UnitVector(n::Int)
@argcheck n ≥ 1 "Dimension should be positive."
@argcheck n ≥ 2 "Dimension should be at least 2."
new(n)
end
end

dimension(t::UnitVector) = t.n - 1
dimension(t::UnitVector) = t.n

function _summary_rows(transformation::UnitVector, mime)
_summary_row(transformation, "$(transformation.n) element unit vector transformation")
Expand All @@ -82,30 +87,42 @@ end
function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, index)
(; n) = t
T = robust_eltype(x)
log_r = zero(T)
y = Vector{T}(undef, n)
ℓ = logjac_zero(flag, T)
@inbounds for i in 1:(n - 1)
xi = x[index]
index += 1
y[i], log_r, ℓi = l2_remainder_transform(flag, xi, log_r)
ℓ += ℓi
end
y[end] = exp(log_r / 2)
z = view(x, index:index+n-1)
r = norm(z)
copyto!(y, z)
__normalize!(y, r)
ℓ = flag isa NoLogJac ? flag : -r^2 / 2
index += n
y, ℓ, index
end

# Adapted from LinearAlgebra.__normalize!
# MIT license
# Copyright (c) 2018-2024 LinearAlgebra.jl contributors: https://github.com/JuliaLang/LinearAlgebra.jl/contributors
@inline function __normalize!(a::AbstractArray, nrm)
# The largest positive floating point number whose inverse is less than infinity
δ = inv(prevfloat(typemax(nrm)))
if nrm ≥ δ # Safe to multiply with inverse
invnrm = inv(nrm)
rmul!(a, invnrm)
else # scale elements to avoid overflow
εδ = eps(one(nrm))/δ
rmul!(a, εδ)
rmul!(a, inv(nrm*εδ))
end
return a
end

inverse_eltype(t::UnitVector, y::AbstractVector) = robust_eltype(y)

function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector)
(; n) = t
@argcheck length(y) == n
log_r = zero(eltype(y))
@inbounds for yi in axes(y, 1)[1:(end-1)]
x[index], log_r = l2_remainder_inverse(y[yi], log_r)
index += 1
end
index
z = view(x, index:index+n-1)
copyto!(z, y)
index += n
return index
end


Expand Down
49 changes: 36 additions & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,20 +194,40 @@ end

@testset "to unit vector" begin
@testset "dimension checks" begin
U = UnitVector(3)
@test_throws ArgumentError UnitVector(0)
@test_throws ArgumentError UnitVector(1)
U = UnitVector(2)
x = zeros(3) # incorrect
@test_throws ArgumentError transform(U, x)
@test_throws ArgumentError transform_and_logjac(U, x)
end

@testset "consistency checks" begin
for K in 1:10
for K in 2:11
t = UnitVector(K)
@test dimension(t) == K - 1
if K > 1
test_transformation(t, y -> sum(abs2, y) ≈ 1,
vec_y = y -> y[1:(end-1)])
end
@test dimension(t) == K
test_transformation(t, y -> sum(abs2, y) ≈ 1, test_inverse=false, jac=false)

# because transform is non-bijective, we need to manually test inverse and jac
x = normalize(randn(K)) # if already normalized, inverse is the identity
y = transform(t, x)
@test inverse(t, y) ≈ y ≈ x

# test "jacobian", here lj is the sum of the log jacobian of the transform and
# the log-density of the prior on the discarded parameter (the norm of the vector)
x = randn(K)
r = norm(x)
_, lj = transform_and_logjac(t, x)
J = ForwardDiff.jacobian(x -> vcat(normalize(x), norm(x)), x)
# log of generalized Jacobian determinant:
# - https://encyclopediaofmath.org/wiki/Jacobian#Generalizations_of_the_Jacobian_determinant
# - https://en.wikipedia.org/wiki/Area_formula_(geometric_measure_theory)
lj_transform = logdet(J' * J) / 2
# lp_prior
# un-normalized Chi distribution prior on r
lp_prior = (K - 1) * log(r) - r^2 / 2
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please explain why this correction is needed?

lj_manual = lj_transform + lp_prior
@test lj ≈ lj_manual
end
end
end
Expand Down Expand Up @@ -358,7 +378,8 @@ end
x = randn(dimension(tn))
y = @inferred transform(tn, x)
@test y isa NamedTuple{(:a,:b,:c)}
@test inverse(tn, y) ≈ x
x′ = inverse(tn, y)
@test inverse(tn, transform(tn, x′)) ≈ x′
index = 0
ljacc = 0.0
for (i, t) in enumerate((t1, t2, t3))
Expand Down Expand Up @@ -396,11 +417,13 @@ end
for _ in 1:10
N = rand(3:7)
tt = as((a = as(Tuple(as(Vector, asℝ₊, 2) for _ in 1:N)),
b = as(Tuple(UnitVector(n) for n in 1:N))))
b = as(Tuple(UnitVector(n) for n in 2:N))))
x = randn(dimension(tt))
y = transform(tt, x)
x′ = inverse(tt, y)
@test x ≈ x′
m = sum(2:N)
@test x[1:end-m] ≈ x′[1:end-m]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might as well remove these two lines, the one below is sufficient.

@test inverse(tt, transform(tt, x′)) ≈ x′
end
end

Expand Down Expand Up @@ -506,7 +529,7 @@ end
-(abs2(μ) + abs2(σ) + abs2(β) + α + δ[1] + δ[2])
end
P = TransformedLogDensities.TransformedLogDensity(t, f)
x = zeros(dimension(t))
x = randn(dimension(t))
v = logdensity(P, x)
g = ForwardDiff.gradient(x -> logdensity(P, x), x)

Expand Down Expand Up @@ -619,7 +642,7 @@ end

t = UnitVector(3)
d = dimension(t)
x = [zeros(d), zeros(d)]
x = [randn(d), randn(d)]
@test transform.(t, x) == map(x -> transform(t, x), x)
end

Expand Down Expand Up @@ -733,7 +756,7 @@ end
[98:98] 1 → asℝ
[108:110] 2 → SMatrix{3,3} correlation cholesky factor
[120:121] 3 → 3 element unit simplex transformation
[131:133] 4 → 4 element unit vector transformation"""
[131:134] 4 → 4 element unit vector transformation"""
repr(MIME("text/plain"), t) == repr_t
end

Expand Down