From 2be1bbb55d1d00357b59465241d387eec9e1470c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 20 Mar 2025 23:31:27 +0000 Subject: [PATCH 1/4] Clean up varinfo get/set functions --- HISTORY.md | 11 +++++ Project.toml | 2 +- src/abstract_varinfo.jl | 10 +++- src/sampler.jl | 27 ++++++++--- src/varinfo.jl | 100 ++++++---------------------------------- test/model.jl | 6 +-- test/test_util.jl | 3 +- 7 files changed, 60 insertions(+), 99 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 67e5801a1..3ea8071f3 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,16 @@ # DynamicPPL Changelog +## 0.35.5 + +Several internal methods have been removed: + + - `DynamicPPL.getall(vi::AbstractVarInfo)` has been removed. You can directly replace this with `getindex_internal(vi, Colon())`. + - `DynamicPPL.setall!(vi::AbstractVarInfo, values)` has been removed. Rewrite the calling function to not assume mutation and use `unflatten(vi, values)` instead. + - `DynamicPPL.replace_values(md::Metadata, values)` and `DynamicPPL.replace_values(nt::NamedTuple, values)` (where the `nt` is a NamedTuple of Metadatas) have been removed. Use `DynamicPPL.unflatten_metadata` as a direct replacement. + - `DynamicPPL.set_values!!(vi::AbstractVarInfo, values)` has been renamed to `DynamicPPL.set_initial_values(vi::AbstractVarInfo, values)`; it also no longer mutates the varinfo argument. + +The **exported** method `VarInfo(vi::VarInfo, values)` has been deprecated, and will be removed in the next minor version. You can replace this directly with `unflatten(vi, values)` instead. + ## 0.35.4 Fixed a type instability in an implementation of `with_logabsdet_jacobian`, which resulted in the log-jacobian returned being an Int in some cases and a Float in others. diff --git a/Project.toml b/Project.toml index 86bcacd7f..05d33ec36 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.35.4" +version = "0.35.5" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index aa4c3f98d..44edaa4e9 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -162,8 +162,16 @@ Base.getindex(vi::AbstractVarInfo, ::Colon) = values_as(vi, Vector) """ getindex_internal(vi::AbstractVarInfo, vn::VarName) getindex_internal(vi::AbstractVarInfo, vns::Vector{<:VarName}) + getindex_internal(vi::AbstractVarInfo, ::Colon) -Return the current value(s) of `vn` (`vns`) in `vi` as represented internally in `vi`. +Return the internal value of the varname `vn`, varnames `vns`, or all varnames +in `vi` respectively. The internal value is the value of the variables that is +stored in the varinfo object; this may be the actual realisation of the random +variable (i.e. the value sampled from the distribution), or it may have been +transformed to Euclidean space, depending on whether the varinfo was linked. + +See https://turinglang.org/docs/developers/transforms/dynamicppl/ for more +information on how transformed variables are stored in DynamicPPL. See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@ref) """ diff --git a/src/sampler.jl b/src/sampler.jl index aa3a637ee..ff008cc93 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -151,7 +151,21 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). """ initialsampler(spl::Sampler) = SampleFromPrior() -function set_values!!(varinfo::AbstractVarInfo, initial_params::AbstractVector) +""" + set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) + set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) + +Take the values inside `initial_params`, replace the corresponding values in +the given VarInfo object, and return a new VarInfo object with the updated values. + +This differs from `DynamicPPL.unflatten` in two ways: + +1. It works with `NamedTuple` arguments. +2. For the `AbstractVector` method, if any of the elements are missing, it will not +overwrite the original value in the VarInfo (it will just use the original +value instead). +""" +function set_initial_values(varinfo::AbstractVarInfo, initial_params::AbstractVector) throw( ArgumentError( "`initial_params` must be a vector of type `Union{Real,Missing}`. " * @@ -160,7 +174,7 @@ function set_values!!(varinfo::AbstractVarInfo, initial_params::AbstractVector) ) end -function set_values!!( +function set_initial_values( varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}} ) flattened_param_vals = varinfo[:] @@ -180,11 +194,12 @@ function set_values!!( end # Update in `varinfo`. - setall!(varinfo, flattened_param_vals) - return varinfo + new_varinfo = unflatten(varinfo, flattened_param_vals) + return new_varinfo end -function set_values!!(varinfo::AbstractVarInfo, initial_params::NamedTuple) +function set_initial_values(varinfo::AbstractVarInfo, initial_params::NamedTuple) + varinfo = deepcopy(varinfo) vars_in_varinfo = keys(varinfo) for v in keys(initial_params) vn = VarName{v}() @@ -219,7 +234,7 @@ function initialize_parameters!!(vi::AbstractVarInfo, initial_params, model::Mod end # Set the values in `vi`. - vi = set_values!!(vi, initial_params) + vi = set_initial_values(vi, initial_params) # `invlink` if needed. if linked diff --git a/src/varinfo.jl b/src/varinfo.jl index 2fd5894aa..94b1f1c07 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -106,14 +106,6 @@ const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ # multiple times. transformation(::VarInfo) = DynamicTransformation() -# TODO(mhauru) Isn't this the same as unflatten and/or replace_values? -function VarInfo(old_vi::VarInfo, x::AbstractVector) - md = replace_values(old_vi.metadata, x) - return VarInfo( - md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi)) - ) -end - # No-op if we're already working with a `VarNamedVector`. metadata_to_varnamedvector(vnv::VarNamedVector) = vnv function metadata_to_varnamedvector(md::Metadata) @@ -243,9 +235,8 @@ end return :($(exprs...),) end -# For Metadata unflatten and replace_values are the same. For VarNamedVector they are not. function unflatten_metadata(md::Metadata, x::AbstractVector) - return replace_values(md, x) + return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.orders, md.flags) end unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) @@ -255,31 +246,6 @@ function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext return VarInfo(rng, model, SampleFromPrior(), context) end -function replace_values(metadata::Metadata, x) - return Metadata( - metadata.idcs, - metadata.vns, - metadata.ranges, - x, - metadata.dists, - metadata.orders, - metadata.flags, - ) -end - -@generated function replace_values(metadata::NamedTuple{names}, x) where {names} - exprs = [] - offset = :(0) - for f in names - mdf = :(metadata.$f) - len = :(sum(length, $mdf.ranges)) - push!(exprs, :($f = replace_values($mdf, x[($offset + 1):($offset + $len)]))) - offset = :($offset + $len) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end - #### #### Internal functions #### @@ -652,10 +618,20 @@ getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, # what a bijector would result in, even if the input is a view (`SubArray`). # TODO(torfjelde): An alternative is to implement `view` directly instead. getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn)) - function getindex_internal(vi::VarInfo, vns::Vector{<:VarName}) return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns) end +getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon()) +# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. +# See for example https://github.com/JuliaLang/julia/pull/46381. +function getindex_internal(vi::TypedVarInfo, ::Colon) + return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata)) +end +function getindex_internal(md::Metadata, ::Colon) + return mapreduce( + Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0) + ) +end """ setval!(vi::VarInfo, val, vn::VarName) @@ -672,56 +648,6 @@ function setval!(md::Metadata, val, vn::VarName) return md.vals[getrange(md, vn)] = tovec(val) end -""" - getall(vi::VarInfo) - -Return the values of all the variables in `vi`. - -The values may or may not be transformed to Euclidean space. -""" -getall(vi::VarInfo) = getall(vi.metadata) -# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. -# See for example https://github.com/JuliaLang/julia/pull/46381. -getall(vi::TypedVarInfo) = reduce(vcat, map(getall, vi.metadata)) -function getall(md::Metadata) - return mapreduce( - Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0) - ) -end -getall(vnv::VarNamedVector) = getindex_internal(vnv, Colon()) - -""" - setall!(vi::VarInfo, val) - -Set the values of all the variables in `vi` to `val`. - -The values may or may not be transformed to Euclidean space. -""" -setall!(vi::VarInfo, val) = _setall!(vi.metadata, val) - -function _setall!(metadata::Metadata, val) - for r in metadata.ranges - metadata.vals[r] .= val[r] - end -end -function _setall!(vnv::VarNamedVector, val) - # TODO(mhauru) Do something more efficient here. - for i in 1:length_internal(vnv) - setindex_internal!(vnv, val[i], i) - end -end -@generated function _setall!(metadata::NamedTuple{names}, val) where {names} - expr = Expr(:block) - start = :(1) - for f in names - length = :(sum(length, metadata.$f.ranges)) - finish = :($start + $length - 1) - push!(expr.args, :(copyto!(metadata.$f.vals, 1, val, $start, $length))) - start = :($start + $length) - end - return expr -end - function settrans!!(vi::VarInfo, trans::Bool, vn::VarName) settrans!!(getmetadata(vi, vn), trans, vn) return vi @@ -2114,7 +2040,7 @@ function _setval_and_resample_kernel!( end values_as(vi::VarInfo) = vi.metadata -values_as(vi::VarInfo, ::Type{Vector}) = copy(getall(vi)) +values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) iter = values_from_metadata(vi.metadata) return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) diff --git a/test/model.jl b/test/model.jl index 3b24b26bf..a863b6596 100644 --- a/test/model.jl +++ b/test/model.jl @@ -163,12 +163,12 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() Random.seed!(100 + i) vi = VarInfo() model(Random.default_rng(), vi, sampler) - vals = DynamicPPL.getall(vi) + vals = vi[:] Random.seed!(100 + i) vi = VarInfo() model(Random.default_rng(), vi, sampler) - @test DynamicPPL.getall(vi) == vals + @test vi[:] == vals end end end @@ -240,7 +240,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() for i in 1:10 # Sample with large variations. r_raw = randn(length(vi[:])) * 10 - DynamicPPL.setall!(vi, r_raw) + vi = DynamicPPL.unflatten(vi, r_raw) @test vi[@varname(m)] == r_raw[1] @test vi[@varname(x)] != r_raw[2] model(vi) diff --git a/test/test_util.jl b/test/test_util.jl index 69f9f7656..87c69b5fe 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -8,9 +8,10 @@ end const gdemo_default = gdemo_d() +# TODO(penelopeysm): Remove this (and also test/compat/ad.jl) function test_model_ad(model, logp_manual) vi = VarInfo(model) - x = DynamicPPL.getall(vi) + x = vi[:] # Log probabilities using the model. ℓ = DynamicPPL.LogDensityFunction(model, vi) From ebe409396c88d54bb0c6359a0109acdac7fa4f68 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 21 Mar 2025 14:49:05 +0000 Subject: [PATCH 2/4] Re-add but deprecate VarInfo(::VarInfo, ::AbstractVector) --- src/varinfo.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index 94b1f1c07..0c033e504 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -100,6 +100,8 @@ const TypedVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } +# TODO: Remove this +@deprecate VarInfo(vi::VarInfo, x::AbstractVector) unflatten(vi, x) # NOTE: This is kind of weird, but it effectively preserves the "old" # behavior where we're allowed to call `link!` on the same `VarInfo` From e1a9ba58811c41bfffb25830218d4c6304872aec Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 21 Mar 2025 20:01:59 +0000 Subject: [PATCH 3/4] Increase atol for truncated bijector test --- test/simple_varinfo.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index e67b5656a..9656c4c38 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -257,7 +257,8 @@ # `getlogp` should be equal to the logjoint with log-absdet-jac correction. lp = getlogp(svi) - @test lp ≈ lp_true + # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 + @test lp ≈ lp_true atol=1.2e-5 end end end From d74413695ca289829900b5d9535f30131280038a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 21 Mar 2025 20:05:32 +0000 Subject: [PATCH 4/4] Update test/simple_varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 9656c4c38..8e48814a4 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -258,7 +258,7 @@ # `getlogp` should be equal to the logjoint with log-absdet-jac correction. lp = getlogp(svi) # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 - @test lp ≈ lp_true atol=1.2e-5 + @test lp ≈ lp_true atol = 1.2e-5 end end end