Skip to content
Closed
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
Flux = "0.9"
LogDensityProblems = "^0.9.0"
julia = "^1"

Expand Down
8 changes: 4 additions & 4 deletions src/aggregation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,18 @@ Return a transformation that transforms consecutive groups of real numbers to a
julia> t = as((asℝ₊, UnitVector(3)));

julia> dimension(t)
3
4

julia> transform(t, zeros(dimension(t)))
(1.0, [0.0, 0.0, 1.0])
(1.0, [1.0, 0.0, 0.0])

julia> t2 = as((σ = asℝ₊, u = UnitVector(3)));

julia> dimension(t2)
3
4

julia> transform(t2, zeros(dimension(t2)))
(σ = 1.0, u = [0.0, 0.0, 1.0])
(σ = 1.0, u = [1.0, 0.0, 0.0])
```
"""
as(transformations::NTransforms) = TransformTuple(transformations)
Expand Down
29 changes: 12 additions & 17 deletions src/special_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Inverse of [`l2_remainder_transform`](@ref) in `x` and `y`.
"""
UnitVector(n)

Transform `n-1` real numbers to a unit vector of length `n`, under the
Transform `n` real numbers to a unit vector of length `n`, under the
Euclidean norm.
"""
@calltrans struct UnitVector <: VectorTransform
Expand All @@ -46,35 +46,30 @@ Euclidean norm.
end
end

dimension(t::UnitVector) = t.n - 1
dimension(t::UnitVector) = t.n

function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, index)
@unpack n = t
T = extended_eltype(x)
r = one(T)
y = Vector{T}(undef, n)
index′ = index + n
vx = view(x, index:(index′ - 1))
nx² = sum(abs2, vx)
y = nx² > 0 ? vx ./ √nx² : [one(T); zeros(T, n - 1)]
ℓ = logjac_zero(flag, T)
@inbounds for i in 1:(n - 1)
xi = x[index]
index += 1
y[i], r, ℓi = l2_remainder_transform(flag, xi, r)
ℓ += ℓi
if !(flag isa NoLogJac)
ℓ -= nx² / 2
end
y[end] = √r
y, ℓ, index
y, ℓ, index′
end

inverse_eltype(t::UnitVector, y::AbstractVector) = extended_eltype(y)

function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector)
@unpack n = t
@argcheck length(y) == n
r = one(eltype(y))
@inbounds for yi in axes(y, 1)[1:(end-1)]
x[index], r = l2_remainder_inverse(y[yi], r)
index += 1
end
index
index′ = index + n
setindex!(x, y, index:(index′ - 1))
index′
end


Expand Down
27 changes: 18 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ end
@testset "to unit vector" begin
@testset "dimension checks" begin
U = UnitVector(3)
x = zeros(3) # incorrect
x = zeros(2) # incorrect
@test_throws ArgumentError U(x)
@test_throws ArgumentError transform(U, x)
@test_throws ArgumentError transform_and_logjac(U, x)
Expand All @@ -94,10 +94,18 @@ end
@testset "consistency checks" begin
for K in 1:10
t = UnitVector(K)
@test dimension(t) == K - 1
@test dimension(t) == K
if K > 1
test_transformation(t, y -> sum(abs2, y) ≈ 1,
vec_y = y -> y[1:(end-1)])
test_transformation(t, y -> sum(abs2, y) ≈ 1;
test_inverse = false, test_logjac = false)
x = randn(K)
y = transform(t, x)
x2 = @inferred inverse(t, y)
@test x2 ≈ y
ι = inverse(t)
@test y ≈ ι(y)
@test transform_and_logjac(t, x)[2] ≈ -sum(abs2, x) ./ 2
@test transform(t, zeros(K)) ≈ [1; zeros(K-1)]
end
end
end
Expand Down Expand Up @@ -242,7 +250,8 @@ end
x = randn(dimension(tn))
y = @inferred transform(tn, x)
@test y isa NamedTuple{(:a,:b,:c)}
@test inverse(tn, y) ≈ x
x2 = inverse(tn, y)
@test inverse(tn, transform(tn, x2)) ≈ x2
index = 0
ljacc = 0.0
for (i, t) in enumerate((t1, t2, t3))
Expand Down Expand Up @@ -284,7 +293,7 @@ end
x = randn(dimension(tt))
y = tt(x)
x′ = inverse(tt, y)
@test x ≈ x′
@test inverse(tt, transform(tt, x′)) ≈ x′
end
end

Expand Down Expand Up @@ -385,11 +394,11 @@ end
u = UnitVector(3), L = CorrCholeskyFactor(4),
δ = as((asℝ₋, as𝕀))))
function f(θ)
@unpack μ, σ, β, α, δ = θ
-(abs2(μ) + abs2(σ) + abs2(β) + α + δ[1] + δ[2])
@unpack μ, σ, β, α, u, δ = θ
-(abs2(μ) + abs2(σ) + abs2(β) + α + sum(u) + δ[1] + δ[2])
end
P = TransformedLogDensity(t, f)
x = zeros(dimension(t))
x = randn(dimension(t)) .* 1e-5
v = logdensity(P, x)
g = ForwardDiff.gradient(x -> logdensity(P, x), x)

Expand Down
13 changes: 8 additions & 5 deletions test/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ automatic differentiation.
`test_inverse` determines whether the inverse is tested.
"""
function test_transformation(t::AbstractTransform, is_valid_y;
vec_y = identity, N = 1000, test_inverse = true)
vec_y = identity, N = 1000,
test_inverse = true, test_logjac = true)
for _ in 1:N
x = t isa ScalarTransform ? randn() : randn(dimension(t))
if t isa ScalarTransform
Expand All @@ -37,10 +38,12 @@ function test_transformation(t::AbstractTransform, is_valid_y;
@test t(x) == y # callable
y2, lj = @inferred transform_and_logjac(t, x)
@test y2 == y
if t isa ScalarTransform
@test lj ≈ AD_logjac(t, x)
else
@test lj ≈ AD_logjac(t, x, vec_y)
if test_logjac
if t isa ScalarTransform
@test lj ≈ AD_logjac(t, x)
else
@test lj ≈ AD_logjac(t, x, vec_y)
end
end
if test_inverse
x2 = inverse(t, y)
Expand Down