Skip to content

Commit d8f4a9e

Browse files
authored
Merge pull request #137 from tpapp/tp/inverse_eltype_cleanup
Refactor `inverse_eltype` calculations to use types.
2 parents 07c1125 + 0e248f1 commit d8f4a9e

File tree

8 files changed

+145
-64
lines changed

8 files changed

+145
-64
lines changed

src/aggregation.jl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,10 @@ function _domain_label(transformation::ViewTransformation, index::Int)
141141
_array_domain_label(asℝ, dims, index)
142142
end
143143

144-
inverse_eltype(transformation::ViewTransformation, y) = eltype(y)
144+
function inverse_eltype(transformation::ViewTransformation,
145+
::Type{T}) where T <: AbstractArray
146+
_ensure_float(eltype(T))
147+
end
145148

146149
function inverse_at!(x::AbstractVector, index, transformation::ViewTransformation,
147150
y::AbstractArray)
@@ -210,8 +213,9 @@ function transform_with(flag::LogJacFlag, transformation::StaticArrayTransformat
210213
end
211214

212215
function inverse_eltype(transformation::Union{ArrayTransformation,StaticArrayTransformation},
213-
x::AbstractArray)
214-
inverse_eltype(transformation.inner_transformation, first(x)) # FIXME shortcut
216+
::Type{T}) where T <: AbstractArray
217+
inverse_eltype(transformation.inner_transformation,
218+
_ensure_float(eltype(T)))
215219
end
216220

217221
function inverse_at!(x::AbstractVector, index,
@@ -341,12 +345,17 @@ internally.
341345
342346
*Performs no argument validation, caller should do this.*
343347
"""
344-
_inverse_eltype_tuple(ts::NTransforms, ys::Tuple) =
345-
reduce(promote_type, map(inverse_eltype, ts, ys))
346-
# NOTE: See https://github.com/tpapp/TransformVariables.jl/pull/80
347-
# `map` and `reduce` both have specializations on `Tuple`s that make them type stable
348-
# even when the `Tuple` is heterogenous, but that is not currently the case with
349-
# `mapreduce`, therefore separate `reduce` and `map` are preferred as a workaround.
348+
function _inverse_eltype_tuple(ts::NTransforms{N}, ::Type{T}) where {N,T<:Tuple}
349+
@argcheck T <: NTuple{N,Any} "Incompatible input length."
350+
__inverse_eltype_tuple(ts, T)
351+
end
352+
function __inverse_eltype_tuple(ts::NTransforms, ::Type{Tuple{}})
353+
Union{}
354+
end
355+
function __inverse_eltype_tuple(ts::NTransforms, ::Type{T}) where {T<:Tuple}
356+
promote_type(inverse_eltype(Base.first(ts), fieldtype(T, 1)),
357+
__inverse_eltype_tuple(Base.tail(ts), Tuple{Base.tail(fieldtypes(T))...}))
358+
end
350359

351360
"""
352361
$(SIGNATURES)
@@ -366,10 +375,9 @@ function transform_with(flag::LogJacFlag, tt::TransformTuple{<:Tuple}, x, index)
366375
transform_tuple(flag, tt.transformations, x, index)
367376
end
368377

369-
function inverse_eltype(tt::TransformTuple{<:Tuple}, y::Tuple)
378+
function inverse_eltype(tt::TransformTuple{<:Tuple}, ::Type{T}) where T <: Tuple
370379
(; transformations) = tt
371-
@argcheck length(transformations) == length(y)
372-
_inverse_eltype_tuple(transformations, y)
380+
_inverse_eltype_tuple(transformations, T)
373381
end
374382

375383
function inverse_at!(x::AbstractVector, index, tt::TransformTuple{<:Tuple}, y::Tuple)
@@ -378,19 +386,19 @@ function inverse_at!(x::AbstractVector, index, tt::TransformTuple{<:Tuple}, y::T
378386
_inverse!_tuple(x, index, tt.transformations, y)
379387
end
380388

381-
as(transformations::NamedTuple{N,<:NTransforms}) where N =
389+
function as(transformations::NamedTuple{N,<:NTransforms}) where N
382390
TransformTuple(transformations)
391+
end
383392

384393
function transform_with(flag::LogJacFlag, tt::TransformTuple{<:NamedTuple}, x, index)
385394
(; transformations) = tt
386395
y, ℓ, index′ = transform_tuple(flag, values(transformations), x, index)
387396
NamedTuple{keys(transformations)}(y), ℓ, index′
388397
end
389398

390-
function inverse_eltype(tt::TransformTuple{<:NamedTuple}, y::NamedTuple)
399+
function inverse_eltype(tt::TransformTuple{<:NamedTuple}, ::Type{NamedTuple{N,T}}) where {N,T}
391400
(; transformations) = tt
392-
@argcheck _same_set_of_names(transformations, y)
393-
_inverse_eltype_tuple(values(transformations), values(NamedTuple{keys(transformations)}(y)))
401+
_inverse_eltype_tuple(values(transformations), T)
394402
end
395403

396404
function inverse_at!(x::AbstractVector, index, tt::TransformTuple{<:NamedTuple}, y::NamedTuple)

src/constant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function transform_with(logjac_flag::LogJacFlag, t::Constant, x::AbstractVector,
1919
t.value, logjac_zero(logjac_flag, eltype(x)), index
2020
end
2121

22-
inverse_eltype(t::Constant, _) = Union{}
22+
inverse_eltype(t::Constant, ::Type) = Union{}
2323

2424
function inverse_at!(x::AbstractVector, index, t::Constant, y)
2525
@argcheck t.value == y

src/generic.jl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -181,16 +181,31 @@ end
181181
inverse(f::CallableInverse) = Base.Fix1(transform, f.x)
182182

183183
"""
184-
$(FUNCTIONNAME)(t::AbstractTransform, y)
184+
```
185+
$(FUNCTIONNAME)(t::AbstractTransform, y)
186+
$(FUNCTIONNAME)(t::AbstractTransform, ::Type{T})
187+
```
185188
186-
The element type for vector `x` so that `inverse!(x, t, y)` works.
189+
The element type for vector `x` so that `inverse!(x, t, y::T)` works.
187190
188-
!!! note
189-
It is not guaranteed that the result is the narrowest possible type, and may change
190-
without warning between versions. Some effort is made to come up with a reasonable
191-
concrete type even in corner cases.
191+
# Notes
192+
193+
1. It is not guaranteed that the result is the narrowest possible type, and may change
194+
without warning between versions. Some effort is made to come up with a reasonable
195+
concrete type even in corner cases.
196+
197+
2. Transformations should provide a method for *types*, not values.
198+
199+
3. No dimension or input compatibility checks are guaranteed to be performed, even for
200+
values.
192201
"""
193-
function inverse_eltype end
202+
function inverse_eltype(t::AbstractTransform, y::T) where T
203+
inverse_eltype(t, T)
204+
end
205+
206+
function inverse_eltype(t::AbstractTransform, T::Type)
207+
throw(MethodError(inverse_eltype, (t, T)))
208+
end
194209

195210
"""
196211
$(SIGNATURES)
@@ -283,15 +298,8 @@ end
283298

284299
# We want to avoid vectors with non-numerical element types
285300
# Ref https://github.com/tpapp/TransformVariables.jl/issues/132
286-
function inverse(t::VectorTransform, y)
287-
inverse!(Vector{_float_or_Float64(inverse_eltype(t, y))}(undef, dimension(t)), t, y)
288-
end
289-
function _float_or_Float64(::Type{T}) where T
290-
if T !== Union{} && T <: Number # heuristic: it is assumed that every `Number` type defines `float`
291-
return float(T)
292-
else
293-
return Float64
294-
end
301+
function inverse(t::VectorTransform, y::T) where T
302+
inverse!(Vector{_ensure_float(inverse_eltype(t, T))}(undef, dimension(t)), t, y)
295303
end
296304

297305
"""

src/scalar.jl

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y::Real)
3030
index + 1
3131
end
3232

33-
function inverse_eltype(t::ScalarTransform, y::Real)
33+
function inverse_eltype(t::ScalarTransform, ::Type{T}) where T <: Real
3434
# NOTE this is a shortcut to get sensible types for all subtypes of ScalarTransform, which
3535
# we test for. If it breaks it should be extended accordingly.
36-
return Base.promote_typejoin_union(Base.promote_op(inverse, typeof(t), typeof(y)))
36+
return Base.promote_typejoin_union(Base.promote_op(inverse, typeof(t), T))
3737
end
3838

3939
_domain_label(::ScalarTransform, index::Int) = ()
@@ -66,43 +66,50 @@ $(TYPEDEF)
6666
6767
Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive reals.
6868
"""
69-
struct TVExp <: ScalarTransform
70-
end
69+
struct TVExp <: ScalarTransform end
70+
7171
transform(::TVExp, x::Real) = exp(x)
72+
7273
transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x
7374

7475
function inverse(::TVExp, x::Number)
7576
log(x)
7677
end
78+
7779
inverse_and_logjac(t::TVExp, x::Number) = inverse(t, x), -log(x)
7880

7981
"""
8082
$(TYPEDEF)
8183
8284
Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1).
8385
"""
84-
struct TVLogistic <: ScalarTransform
85-
end
86+
struct TVLogistic <: ScalarTransform end
87+
8688
transform(::TVLogistic, x::Real) = logistic(x)
89+
8790
transform_and_logjac(t::TVLogistic, x::Real) = transform(t, x), logistic_logjac(x)
8891

8992
function inverse(::TVLogistic, x::Number)
9093
logit(x)
9194
end
95+
9296
inverse_and_logjac(t::TVLogistic, x::Number) = inverse(t, x), logit_logjac(x)
9397

9498
"""
9599
$(TYPEDEF)
96100
97-
Shift transformation `x ↦ x + shift`.
101+
Shift transformation `x ↦ x + shift`.
98102
"""
99103
struct TVShift{T <: Real} <: ScalarTransform
100104
shift::T
101105
end
106+
102107
transform(t::TVShift, x::Real) = x + t.shift
108+
103109
transform_and_logjac(t::TVShift, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x))
104110

105111
inverse(t::TVShift, x::Number) = x - t.shift
112+
106113
inverse_and_logjac(t::TVShift, x::Number) = inverse(t, x), logjac_zero(LogJac(), typeof(x))
107114

108115
"""
@@ -117,12 +124,15 @@ struct TVScale{T} <: ScalarTransform
117124
new(scale)
118125
end
119126
end
127+
120128
TVScale(scale::T) where {T} = TVScale{T}(scale)
121129

122130
transform(t::TVScale, x::Real) = t.scale * x
123-
transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale)
131+
132+
transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale)
124133

125134
inverse(t::TVScale, x::Number) = x / t.scale
135+
126136
inverse_and_logjac(t::TVScale{<:Real}, x::Number) = inverse(t, x), -log(t.scale)
127137

128138
"""
@@ -155,15 +165,15 @@ struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform
155165
end
156166

157167
transform(t::CompositeScalarTransform, x) = foldr(transform, t.transforms, init=x)
158-
function transform_and_logjac(ts::CompositeScalarTransform, x)
168+
function transform_and_logjac(ts::CompositeScalarTransform, x)
159169
foldr(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do t, (x, logjac)
160170
nx, nlogjac = transform_and_logjac(t, x)
161171
(nx, logjac + nlogjac)
162172
end
163173
end
164174

165175
inverse(ts::CompositeScalarTransform, x) = foldl((y, t) -> inverse(t, y), ts.transforms, init=x)
166-
function inverse_and_logjac(ts::CompositeScalarTransform, x)
176+
function inverse_and_logjac(ts::CompositeScalarTransform, x)
167177
foldl(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do (x, logjac), t
168178
nx, nlogjac = inverse_and_logjac(t, x)
169179
(nx, logjac + nlogjac)
@@ -283,7 +293,7 @@ function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T}, TVNeg,
283293
print(io, "as(Real, -∞, ", ct.transforms[1].shift, ")")
284294
end
285295
function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T1}, TVScale{T2}, TVLogistic}}) where {T1, T2}
286-
print(io, "as(Real, ", ct.transforms[1].shift, ", ", ct.transforms[1].shift +
296+
print(io, "as(Real, ", ct.transforms[1].shift, ", ", ct.transforms[1].shift +
287297
ct.transforms[2].scale, ")")
288298
end
289299

src/special_arrays.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,10 @@ function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, inde
9595
y, ℓ, index
9696
end
9797

98-
inverse_eltype(t::UnitVector, y::AbstractVector) = robust_eltype(y)
98+
function inverse_eltype(t::UnitVector,
99+
::Type{T}) where T <: AbstractVector
100+
_ensure_float(eltype(T))
101+
end
99102

100103
function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector)
101104
(; n) = t
@@ -157,7 +160,10 @@ function transform_with(flag::LogJacFlag, t::UnitSimplex, x::AbstractVector, ind
157160
y, ℓ, index
158161
end
159162

160-
inverse_eltype(t::UnitSimplex, y::AbstractVector) = robust_eltype(y)
163+
function inverse_eltype(t::UnitSimplex,
164+
::Type{T}) where T <: AbstractVector
165+
_ensure_float(eltype(T))
166+
end
161167

162168
function inverse_at!(x::AbstractVector, index, t::UnitSimplex, y::AbstractVector)
163169
(; n) = t
@@ -297,7 +303,10 @@ function transform_with(flag::LogJacFlag, transformation::StaticCorrCholeskyFact
297303
UpperTriangular(SMatrix{S,S}(U)), ℓ, index′
298304
end
299305

300-
inverse_eltype(t::Union{CorrCholeskyFactor,StaticCorrCholeskyFactor}, U::UpperTriangular) = robust_eltype(U)
306+
function inverse_eltype(t::Union{CorrCholeskyFactor,StaticCorrCholeskyFactor},
307+
::Type{T}) where {T<:UpperTriangular}
308+
_ensure_float(eltype(T))
309+
end
301310

302311
function inverse_at!(x::AbstractVector, index,
303312
t::Union{CorrCholeskyFactor,StaticCorrCholeskyFactor}, U::UpperTriangular)

src/utilities.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,31 @@ function robust_eltype(::Type{S}) where S
3737
end
3838

3939
robust_eltype(x::T) where T = robust_eltype(T)
40+
41+
"""
42+
$(SIGNATURES)
43+
44+
Regularize input type, preferring a floating point, falling back to `Float64`.
45+
46+
Internal, not exported.
47+
48+
# Motivation
49+
50+
Type calculations occasionally give types that are too narrow (eg `Union{}` for empty
51+
vectors) or broad. Since this package is primarily intended for *numerical*
52+
calculations, we fall back to something sensible. This function implements the
53+
heuristics for this, and is currently used in inverse element type calculations.
54+
"""
55+
function _ensure_float(::Type{T}) where T
56+
if T <: Number # heuristic: it is assumed that every `Number` type defines `float`
57+
return float(T)
58+
else
59+
return Float64
60+
end
61+
end
62+
63+
# pass through containers
64+
_ensure_float(::Type{T}) where {T<:AbstractArray} = T
65+
66+
# special case Union{}
67+
_ensure_float(::Type{Union{}}) = Float64

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
44
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
55
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
66
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
7+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
910
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"

0 commit comments

Comments
 (0)