Skip to content

Introduce unit_vector_norm. #139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.8.17"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -24,6 +25,7 @@ InverseFunctionsExt = "InverseFunctions"
[compat]
ArgCheck = "1, 2"
ChangesOfVariables = "0.1"
Compat = "4.10.0"
CompositionsBase = "0.1.2"
DocStringExtensions = "0.8, 0.9"
ForwardDiff = "0.10, 1"
Expand Down
11 changes: 11 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,32 @@ Further worked examples of using this package can be found in the [DynamicHMCExa

# General interface

## Transformations

```@docs
dimension
transform
transform_and_logjac
```

## Inverses

```@docs
inverse
inverse!
inverse_eltype
```

## Integration into Bayesian inference

```@docs
transform_logdensity
TransformVariables.logprior
TransformVariables.nonzero_logprior
```

## Miscellaneous

```@docs
domain_label
```
Expand Down Expand Up @@ -136,6 +146,7 @@ produces positive quantities with the dimension of length.
## Special arrays

```@docs
unit_vector_norm
UnitVector
UnitSimplex
CorrCholeskyFactor
Expand Down
3 changes: 2 additions & 1 deletion src/TransformVariables.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module TransformVariables

using ArgCheck: @argcheck
import Compat
using DocStringExtensions: FUNCTIONNAME, SIGNATURES, TYPEDEF
import ForwardDiff
using LogExpFunctions
using LinearAlgebra: UpperTriangular, logabsdet
using LinearAlgebra: UpperTriangular, logabsdet, norm, rmul!
using Random: AbstractRNG, GLOBAL_RNG
using StaticArrays: MMatrix, SMatrix, SArray, SVector, pushfirst
using CompositionsBase
Expand Down
28 changes: 28 additions & 0 deletions src/generic.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
export dimension, transform, transform_and_logjac, transform_logdensity, inverse, inverse!,
inverse_eltype, as, domain_label

Compat.@compat public logprior, nonzero_logprior

###
### log absolute Jacobian determinant
###
Expand Down Expand Up @@ -109,6 +111,8 @@ The user interface consists of
- [`transform_and_logjac`](@ref)
- [`inverse`](@ref), [`inverse!`](@ref)
- [`inverse_eltype`](@ref).
- [`nonzero_logprior`](@ref).
- [`logprior`](@ref)
"""
abstract type AbstractTransform end

Expand Down Expand Up @@ -166,6 +170,30 @@ inverse(t)(y) == inverse(t, y) == inverse(transform(t))(y)
"""
inverse(t::AbstractTransform) = Base.Fix1(inverse, t)

"""
$(SIGNATURES)

Return the log prior correction used in [`transform_and_logjac`](@ref). The second
argument is the output of a transformation.

The log jacobian determinant is corrected by this value, usually for the purpose of
making a distribution proper. Can only be nonzero when [`nonzero_logprior`](@ref) is
true.
"""
logprior(t::AbstractTransform, y) = false # =0, to avoid unnecessary promotions

"""
$(SIGNATURES)

Return `true` only if there are potential inputs for which [`logprior`](@ref) is
nonzero.

!!! note
Currently the only transformation that has a log prior correction is
[`unit_vector_norm`](@ref).
"""
nonzero_logprior(t::AbstractTransform) = false

"""
$(TYPEDEF)

Expand Down
93 changes: 91 additions & 2 deletions src/special_arrays.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export UnitVector, UnitSimplex, CorrCholeskyFactor, corr_cholesky_factor
export UnitVector, unit_vector_norm, UnitSimplex, CorrCholeskyFactor, corr_cholesky_factor

####
#### building blocks
Expand Down Expand Up @@ -68,6 +68,7 @@ Euclidean norm.
struct UnitVector <: VectorTransform
n::Int
function UnitVector(n::Int)
Base.depwarn("UnitVector is deprecated. See `unit_vector_norm`.", :UnitVector)
@argcheck n ≥ 1 "Dimension should be positive."
new(n)
end
Expand Down Expand Up @@ -111,6 +112,95 @@ function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector)
index
end

####
#### unit_vector_norm
####

struct UnitVectorNorm <: VectorTransform
n::Int
chi_prior::Bool
function UnitVectorNorm(n::Int; chi_prior::Bool = true)
@argcheck n ≥ 2 "Dimension should be at least 2."
new(n, chi_prior)
end
end

"""
$(SIGNATURES)

Transform `n ≥ 2` real numbers to a unit vector of length `n` and a radius, under the
Euclidean norm. Returns the tuple `(normalized_vector, radius)`.

When `chi_prior = true`, a prior correction is applied to the radius, which only
affects the log Jacobian determinant. The purpose of this is to make the
distribution proper. If you wish to use another prior, set this to `false` and use
a manual correction, see also [`logprior`](@ref).

!!! note
At the origin, this transform is non-bijective and non-differentiable. If
maximizing a target distribution whose density is constant for the unit vector,
then the maximizer using the Chi prior is at the origin, and behavior is undefined.

!!! note
While ``n = 1`` would be technically possible, for practical purposes it would
likely suffer from numerical issues, since the transform is undefined at ``x = 0``,
and for a Markov chain to travel from ``y=[-1]`` to ``y=[1]``, it would have to leap
over the origin, which is only even possible due to discretization and likely will
often not work. Because of this, it is disallowed.
"""
unit_vector_norm(n::Int; chi_prior::Bool = true) = UnitVectorNorm(n; chi_prior)

nonzero_logprior(t::UnitVectorNorm) = t.chi_prior

function logprior(t::UnitVectorNorm, (y, r)::Tuple{AbstractVector,Real})
(; n, chi_prior) = t
if chi_prior
(t.n - 1) * log(r) - r^2 / 2
else
float(zero(r))
end
end

dimension(t::UnitVectorNorm) = t.n

function _summary_rows(t::UnitVectorNorm, mime)
_summary_row(t, "$(t.n) element (unit vector, norm) transformation")
end

function transform_with(flag::LogJacFlag, t::UnitVectorNorm, x::AbstractVector, index)
(; n, chi_prior) = t
T = robust_eltype(x)
log_r = zero(T)
y = Vector{T}(undef, n)
copyto!(y, 1, x, index, n)
r = norm(y, 2)
__normalize!(y, r)
ℓ = flag isa NoLogJac ? flag : (chi_prior ? -r^2 / 2 : -(t.n - 1) * log(r))
index += n
(y, r), ℓ, index
end

function inverse_eltype(t::UnitVectorNorm,
::Type{Tuple{V,T}}) where {V <: AbstractVector,T}
_ensure_float(eltype(T))
end

function inverse_at!(x::AbstractVector, index, t::UnitVectorNorm,
(y, r)::Tuple{AbstractVector,Real})
(; n) = t
@argcheck length(y) == n
@argcheck r ≥ 0
_x = @view x[index:(index + n - 1)]
if r == 0
fill!(_x, zero(eltype(x)))
else
copyto!(_x, y)
yN = norm(y, 2)
@argcheck isapprox(yN, 1; atol = √eps(r) * n) # somewhat generous tolerance
__normalize!(_x, yN / r)
end
index + n
end

####
#### UnitSimplex
Expand Down Expand Up @@ -220,7 +310,6 @@ function _summary_rows(transformation::CorrCholeskyFactor, mime)
_summary_row(transformation, "$(n)×$(n) correlation cholesky factor")
end


"""
$(SIGNATURES)

Expand Down
18 changes: 18 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,24 @@ Number of elements (strictly) above the diagonal in an ``n×n`` matrix.
"""
unit_triangular_dimension(n::Int) = n * (n-1) ÷ 2


# Adapted from LinearAlgebra.__normalize!
# MIT license
# Copyright (c) 2018-2024 LinearAlgebra.jl contributors: https://github.com/JuliaLang/LinearAlgebra.jl/contributors
@inline function __normalize!(a::AbstractArray, nrm)
# The largest positive floating point number whose inverse is less than infinity
δ = inv(prevfloat(typemax(nrm)))
if nrm ≥ δ # Safe to multiply with inverse
invnrm = inv(nrm)
rmul!(a, invnrm)
else # scale elements to avoid overflow
εδ = eps(one(nrm))/δ
rmul!(a, εδ)
rmul!(a, inv(nrm*εδ))
end
return a
end

###
### type calculations
###
Expand Down
Loading
Loading