From 4401070cca31f85b2f7931e9eeb2331e7fe9e239 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 29 Apr 2025 13:12:40 +0200 Subject: [PATCH 01/12] Load norm and rmul! from LinearAlgebra --- src/TransformVariables.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TransformVariables.jl b/src/TransformVariables.jl index f43b3f1..4ea38a7 100644 --- a/src/TransformVariables.jl +++ b/src/TransformVariables.jl @@ -4,7 +4,7 @@ using ArgCheck: @argcheck using DocStringExtensions: FUNCTIONNAME, SIGNATURES, TYPEDEF import ForwardDiff using LogExpFunctions -using LinearAlgebra: UpperTriangular, logabsdet +using LinearAlgebra: UpperTriangular, logabsdet, norm, rmul! using Random: AbstractRNG, GLOBAL_RNG using StaticArrays: MMatrix, SMatrix, SArray, SVector, pushfirst using CompositionsBase From 423bc8cd1e507b31ca418e9bd002446ea680379b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 29 Apr 2025 13:13:12 +0200 Subject: [PATCH 02/12] Make dimension of transform equal to output size --- src/special_arrays.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/special_arrays.jl b/src/special_arrays.jl index c94e4c7..ecf1189 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -73,7 +73,7 @@ struct UnitVector <: VectorTransform end end -dimension(t::UnitVector) = t.n - 1 +dimension(t::UnitVector) = t.n function _summary_rows(transformation::UnitVector, mime) _summary_row(transformation, "$(transformation.n) element unit vector transformation") From 629f4a2c2ffd43a2c6eda640e5fdb3cf3fa1a6fc Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 29 Apr 2025 13:13:48 +0200 Subject: [PATCH 03/12] Update unit vector transform to use normalization --- src/special_arrays.jl | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/src/special_arrays.jl b/src/special_arrays.jl index ecf1189..9b1deb9 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -82,19 +82,33 @@ end function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, index) (; n) = t T = robust_eltype(x) - log_r = zero(T) y = Vector{T}(undef, n) - ℓ = logjac_zero(flag, T) - @inbounds for i in 1:(n - 1) - xi = x[index] - index += 1 - y[i], log_r, ℓi = l2_remainder_transform(flag, xi, log_r) - ℓ += ℓi - end - y[end] = exp(log_r / 2) + z = view(x, index:index+n-1) + r = norm(z) + copyto!(y, z) + __normalize!(y, r) + ℓ = flag isa NoLogJac ? flag : -r^2 / 2 + index += n y, ℓ, index end +# Adapted from LinearAlgebra.__normalize! +# MIT license +# Copyright (c) 2018-2024 LinearAlgebra.jl contributors: https://github.com/JuliaLang/LinearAlgebra.jl/contributors +@inline function __normalize!(a::AbstractArray, nrm) + # The largest positive floating point number whose inverse is less than infinity + δ = inv(prevfloat(typemax(nrm))) + if nrm ≥ δ # Safe to multiply with inverse + invnrm = inv(nrm) + rmul!(a, invnrm) + else # scale elements to avoid overflow + εδ = eps(one(nrm))/δ + rmul!(a, εδ) + rmul!(a, inv(nrm*εδ)) + end + return a +end + inverse_eltype(t::UnitVector, y::AbstractVector) = robust_eltype(y) function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector) From 54b3882e0cfd5340c49e2b39a52fbc6e5af2d70b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 29 Apr 2025 13:14:18 +0200 Subject: [PATCH 04/12] Update inverse transform to be identity --- src/special_arrays.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/special_arrays.jl b/src/special_arrays.jl index 9b1deb9..725e55a 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -114,12 +114,10 @@ inverse_eltype(t::UnitVector, y::AbstractVector) = robust_eltype(y) function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector) (; n) = t @argcheck length(y) == n - log_r = zero(eltype(y)) - @inbounds for yi in axes(y, 1)[1:(end-1)] - x[index], log_r = l2_remainder_inverse(y[yi], log_r) - index += 1 - end - index + z = view(x, index:index+n-1) + copyto!(z, y) + index += n + return index end From 0ddf68de40fdf3a7f2cdba0a0faaac2f764595f0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 29 Apr 2025 13:14:43 +0200 Subject: [PATCH 05/12] Require unit vectors be length >2 --- src/special_arrays.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/special_arrays.jl b/src/special_arrays.jl index 725e55a..a774df3 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -68,7 +68,7 @@ Euclidean norm. struct UnitVector <: VectorTransform n::Int function UnitVector(n::Int) - @argcheck n ≥ 1 "Dimension should be positive." + @argcheck n ≥ 2 "Dimension should be at least 2." new(n) end end From 41ac480d62054563a728a876542267f3c9fb8cc6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 29 Apr 2025 13:14:53 +0200 Subject: [PATCH 06/12] Update unit vector docstring --- src/special_arrays.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/special_arrays.jl b/src/special_arrays.jl index a774df3..3610fd7 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -62,8 +62,13 @@ end """ UnitVector(n) -Transform `n-1` real numbers to a unit vector of length `n`, under the +Transform `n` real numbers to a unit vector of length `n`, under the Euclidean norm. + +!!! note + This transform is non-bijective and is undefined at the origin. + If maximizing a target distribution whose density is constant for the unit vector, + then the maximizer is at the origin, and behavior is undefined. """ struct UnitVector <: VectorTransform n::Int From aae6c5128b296ccf20f5092abc5b5eaac92f098c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 29 Apr 2025 13:15:16 +0200 Subject: [PATCH 07/12] Update UnitVector tests --- test/runtests.jl | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index dd20c07..0a2590a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -201,13 +201,31 @@ end end @testset "consistency checks" begin - for K in 1:10 + for K in 2:11 t = UnitVector(K) - @test dimension(t) == K - 1 - if K > 1 - test_transformation(t, y -> sum(abs2, y) ≈ 1, - vec_y = y -> y[1:(end-1)]) - end + @test dimension(t) == K + test_transformation(t, y -> sum(abs2, y) ≈ 1, test_inverse=false, jac=false) + + # because transform is non-bijective, we need to manually test inverse and jac + x = normalize(randn(K)) # if already normalized, inverse is the identity + y = transform(t, x) + @test inverse(t, y) ≈ y ≈ x + + # test "jacobian", here lj is the sum of the log jacobian of the transform and + # the log-density of the prior on the discarded parameter (the norm of the vector) + x = randn(K) + r = norm(x) + _, lj = transform_and_logjac(t, x) + J = ForwardDiff.jacobian(x -> vcat(normalize(x), norm(x)), x) + # log of generalized Jacobian determinant: + # - https://encyclopediaofmath.org/wiki/Jacobian#Generalizations_of_the_Jacobian_determinant + # - https://en.wikipedia.org/wiki/Area_formula_(geometric_measure_theory) + lj_transform = logdet(J' * J) / 2 + # lp_prior + # un-normalized Chi distribution prior on r + lp_prior = (K - 1) * log(r) - r^2 / 2 + lj_manual = lj_transform + lp_prior + @test lj ≈ lj_manual end end end From 6cd99b8f8a6f8e0a14b91f7b60577eafb205ad87 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 29 Apr 2025 13:15:42 +0200 Subject: [PATCH 08/12] Update UnitVector dimension tests --- test/runtests.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 0a2590a..ec0919f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -194,7 +194,9 @@ end @testset "to unit vector" begin @testset "dimension checks" begin - U = UnitVector(3) + @test_throws ArgumentError UnitVector(0) + @test_throws ArgumentError UnitVector(1) + U = UnitVector(2) x = zeros(3) # incorrect @test_throws ArgumentError transform(U, x) @test_throws ArgumentError transform_and_logjac(U, x) From 2ab58bb180f2ebc63027e9339e9613042419fc8f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 29 Apr 2025 13:17:20 +0200 Subject: [PATCH 09/12] Test that inverse is a "right inverse" --- test/runtests.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index ec0919f..b51541f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -378,7 +378,8 @@ end x = randn(dimension(tn)) y = @inferred transform(tn, x) @test y isa NamedTuple{(:a,:b,:c)} - @test inverse(tn, y) ≈ x + x′ = inverse(tn, y) + @test inverse(tn, transform(tn, x′)) ≈ x′ index = 0 ljacc = 0.0 for (i, t) in enumerate((t1, t2, t3)) @@ -416,11 +417,13 @@ end for _ in 1:10 N = rand(3:7) tt = as((a = as(Tuple(as(Vector, asℝ₊, 2) for _ in 1:N)), - b = as(Tuple(UnitVector(n) for n in 1:N)))) + b = as(Tuple(UnitVector(n) for n in 2:N)))) x = randn(dimension(tt)) y = transform(tt, x) x′ = inverse(tt, y) - @test x ≈ x′ + m = sum(2:N) + @test x[1:end-m] ≈ x′[1:end-m] + @test inverse(tt, transform(tt, x′)) ≈ x′ end end From 340deb9aab8a86fb86498c213be8c296ea41fcff Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 29 Apr 2025 13:17:35 +0200 Subject: [PATCH 10/12] Avoid test at singularity of UnitVector --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index b51541f..45626c6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -529,7 +529,7 @@ end -(abs2(μ) + abs2(σ) + abs2(β) + α + δ[1] + δ[2]) end P = TransformedLogDensities.TransformedLogDensity(t, f) - x = zeros(dimension(t)) + x = randn(dimension(t)) v = logdensity(P, x) g = ForwardDiff.gradient(x -> logdensity(P, x), x) @@ -642,7 +642,7 @@ end t = UnitVector(3) d = dimension(t) - x = [zeros(d), zeros(d)] + x = [randn(d), randn(d)] @test transform.(t, x) == map(x -> transform(t, x), x) end From a7c1e73819a0f0cbf47a7dd3073a7d5e8eed8848 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 29 Apr 2025 13:17:44 +0200 Subject: [PATCH 11/12] Update show test --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 45626c6..b7d3acf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -756,7 +756,7 @@ end [98:98] 1 → asℝ [108:110] 2 → SMatrix{3,3} correlation cholesky factor [120:121] 3 → 3 element unit simplex transformation - [131:133] 4 → 4 element unit vector transformation""" + [131:134] 4 → 4 element unit vector transformation""" repr(MIME("text/plain"), t) == repr_t end From b9f36edd3345719ae0c3bc21ac9d48dab4ec6176 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 29 Apr 2025 15:11:12 +0200 Subject: [PATCH 12/12] Apply suggestions from code review Co-authored-by: David Widmann --- src/special_arrays.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/special_arrays.jl b/src/special_arrays.jl index 3610fd7..ac7b224 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -88,9 +88,8 @@ function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, inde (; n) = t T = robust_eltype(x) y = Vector{T}(undef, n) - z = view(x, index:index+n-1) - r = norm(z) - copyto!(y, z) + copyto!(y, 1, x, index, n) + r = norm(y) __normalize!(y, r) ℓ = flag isa NoLogJac ? flag : -r^2 / 2 index += n @@ -119,8 +118,7 @@ inverse_eltype(t::UnitVector, y::AbstractVector) = robust_eltype(y) function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector) (; n) = t @argcheck length(y) == n - z = view(x, index:index+n-1) - copyto!(z, y) + copyto!(x, index, y) index += n return index end