diff --git a/src/TransformVariables.jl b/src/TransformVariables.jl index f43b3f1..4ea38a7 100644 --- a/src/TransformVariables.jl +++ b/src/TransformVariables.jl @@ -4,7 +4,7 @@ using ArgCheck: @argcheck using DocStringExtensions: FUNCTIONNAME, SIGNATURES, TYPEDEF import ForwardDiff using LogExpFunctions -using LinearAlgebra: UpperTriangular, logabsdet +using LinearAlgebra: UpperTriangular, logabsdet, norm, rmul! using Random: AbstractRNG, GLOBAL_RNG using StaticArrays: MMatrix, SMatrix, SArray, SVector, pushfirst using CompositionsBase diff --git a/src/special_arrays.jl b/src/special_arrays.jl index c94e4c7..ac7b224 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -62,18 +62,23 @@ end """ UnitVector(n) -Transform `n-1` real numbers to a unit vector of length `n`, under the +Transform `n` real numbers to a unit vector of length `n`, under the Euclidean norm. + +!!! note + This transform is non-bijective and is undefined at the origin. + If maximizing a target distribution whose density is constant for the unit vector, + then the maximizer is at the origin, and behavior is undefined. """ struct UnitVector <: VectorTransform n::Int function UnitVector(n::Int) - @argcheck n ≥ 1 "Dimension should be positive." + @argcheck n ≥ 2 "Dimension should be at least 2." new(n) end end -dimension(t::UnitVector) = t.n - 1 +dimension(t::UnitVector) = t.n function _summary_rows(transformation::UnitVector, mime) _summary_row(transformation, "$(transformation.n) element unit vector transformation") @@ -82,30 +87,40 @@ end function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, index) (; n) = t T = robust_eltype(x) - log_r = zero(T) y = Vector{T}(undef, n) - ℓ = logjac_zero(flag, T) - @inbounds for i in 1:(n - 1) - xi = x[index] - index += 1 - y[i], log_r, ℓi = l2_remainder_transform(flag, xi, log_r) - ℓ += ℓi - end - y[end] = exp(log_r / 2) + copyto!(y, 1, x, index, n) + r = norm(y) + __normalize!(y, r) + ℓ = flag isa NoLogJac ? flag : -r^2 / 2 + index += n y, ℓ, index end +# Adapted from LinearAlgebra.__normalize! +# MIT license +# Copyright (c) 2018-2024 LinearAlgebra.jl contributors: https://github.com/JuliaLang/LinearAlgebra.jl/contributors +@inline function __normalize!(a::AbstractArray, nrm) + # The largest positive floating point number whose inverse is less than infinity + δ = inv(prevfloat(typemax(nrm))) + if nrm ≥ δ # Safe to multiply with inverse + invnrm = inv(nrm) + rmul!(a, invnrm) + else # scale elements to avoid overflow + εδ = eps(one(nrm))/δ + rmul!(a, εδ) + rmul!(a, inv(nrm*εδ)) + end + return a +end + inverse_eltype(t::UnitVector, y::AbstractVector) = robust_eltype(y) function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector) (; n) = t @argcheck length(y) == n - log_r = zero(eltype(y)) - @inbounds for yi in axes(y, 1)[1:(end-1)] - x[index], log_r = l2_remainder_inverse(y[yi], log_r) - index += 1 - end - index + copyto!(x, index, y) + index += n + return index end diff --git a/test/runtests.jl b/test/runtests.jl index dd20c07..b7d3acf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -194,20 +194,40 @@ end @testset "to unit vector" begin @testset "dimension checks" begin - U = UnitVector(3) + @test_throws ArgumentError UnitVector(0) + @test_throws ArgumentError UnitVector(1) + U = UnitVector(2) x = zeros(3) # incorrect @test_throws ArgumentError transform(U, x) @test_throws ArgumentError transform_and_logjac(U, x) end @testset "consistency checks" begin - for K in 1:10 + for K in 2:11 t = UnitVector(K) - @test dimension(t) == K - 1 - if K > 1 - test_transformation(t, y -> sum(abs2, y) ≈ 1, - vec_y = y -> y[1:(end-1)]) - end + @test dimension(t) == K + test_transformation(t, y -> sum(abs2, y) ≈ 1, test_inverse=false, jac=false) + + # because transform is non-bijective, we need to manually test inverse and jac + x = normalize(randn(K)) # if already normalized, inverse is the identity + y = transform(t, x) + @test inverse(t, y) ≈ y ≈ x + + # test "jacobian", here lj is the sum of the log jacobian of the transform and + # the log-density of the prior on the discarded parameter (the norm of the vector) + x = randn(K) + r = norm(x) + _, lj = transform_and_logjac(t, x) + J = ForwardDiff.jacobian(x -> vcat(normalize(x), norm(x)), x) + # log of generalized Jacobian determinant: + # - https://encyclopediaofmath.org/wiki/Jacobian#Generalizations_of_the_Jacobian_determinant + # - https://en.wikipedia.org/wiki/Area_formula_(geometric_measure_theory) + lj_transform = logdet(J' * J) / 2 + # lp_prior + # un-normalized Chi distribution prior on r + lp_prior = (K - 1) * log(r) - r^2 / 2 + lj_manual = lj_transform + lp_prior + @test lj ≈ lj_manual end end end @@ -358,7 +378,8 @@ end x = randn(dimension(tn)) y = @inferred transform(tn, x) @test y isa NamedTuple{(:a,:b,:c)} - @test inverse(tn, y) ≈ x + x′ = inverse(tn, y) + @test inverse(tn, transform(tn, x′)) ≈ x′ index = 0 ljacc = 0.0 for (i, t) in enumerate((t1, t2, t3)) @@ -396,11 +417,13 @@ end for _ in 1:10 N = rand(3:7) tt = as((a = as(Tuple(as(Vector, asℝ₊, 2) for _ in 1:N)), - b = as(Tuple(UnitVector(n) for n in 1:N)))) + b = as(Tuple(UnitVector(n) for n in 2:N)))) x = randn(dimension(tt)) y = transform(tt, x) x′ = inverse(tt, y) - @test x ≈ x′ + m = sum(2:N) + @test x[1:end-m] ≈ x′[1:end-m] + @test inverse(tt, transform(tt, x′)) ≈ x′ end end @@ -506,7 +529,7 @@ end -(abs2(μ) + abs2(σ) + abs2(β) + α + δ[1] + δ[2]) end P = TransformedLogDensities.TransformedLogDensity(t, f) - x = zeros(dimension(t)) + x = randn(dimension(t)) v = logdensity(P, x) g = ForwardDiff.gradient(x -> logdensity(P, x), x) @@ -619,7 +642,7 @@ end t = UnitVector(3) d = dimension(t) - x = [zeros(d), zeros(d)] + x = [randn(d), randn(d)] @test transform.(t, x) == map(x -> transform(t, x), x) end @@ -733,7 +756,7 @@ end [98:98] 1 → asℝ [108:110] 2 → SMatrix{3,3} correlation cholesky factor [120:121] 3 → 3 element unit simplex transformation - [131:133] 4 → 4 element unit vector transformation""" + [131:134] 4 → 4 element unit vector transformation""" repr(MIME("text/plain"), t) == repr_t end