-
Notifications
You must be signed in to change notification settings - Fork 36
Clean up varinfo get/set functions #853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
15cfc4d
to
55dbe6a
Compare
Benchmark Report for Commit
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #853 +/- ##
==========================================
+ Coverage 84.43% 84.87% +0.43%
==========================================
Files 34 34
Lines 3849 3815 -34
==========================================
- Hits 3250 3238 -12
+ Misses 599 577 -22 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
function set_values!!( | ||
function set_initial_values( | ||
varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function, originally called set_values!!
, is very similar to unflatten
but does a couple of extra things. Consequently, I renamed the function and added a docstring to make it clear how it differs.
# TODO(mhauru) Isn't this the same as unflatten and/or replace_values? | ||
function VarInfo(old_vi::VarInfo, x::AbstractVector) | ||
md = replace_values(old_vi.metadata, x) | ||
return VarInfo( | ||
md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi)) | ||
) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This constructor, as far as I can tell, wasn't being used anywhere. Unfortunately, because VarInfo is exported, this constructor is public and so this PR has to be a breaking change (all other changes are purely internal).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fine IMO, looks like we'll touch it again when the num_produce
is removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm that's a good point, VarInfo constructors will change when num_produce is changed - but that might be a while since Markus is off for a week and it's also a tricky piece of work. Maybe it'd be easier to leave this method in (maybe with a depwarn) and then release this PR as a patch - what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, that works for me, seems this is the only breaking change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, I think so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok - I added this back in with a @deprecate
, and we can remove it in 0.36.0.
function unflatten_metadata(md::Metadata, x::AbstractVector) | ||
return replace_values(md, x) | ||
return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.orders, md.flags) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the only place that replace_values
was being called, so I just inlined its definition
@generated function replace_values(metadata::NamedTuple{names}, x) where {names} | ||
exprs = [] | ||
offset = :(0) | ||
for f in names | ||
mdf = :(metadata.$f) | ||
len = :(sum(length, $mdf.ranges)) | ||
push!(exprs, :($f = replace_values($mdf, x[($offset + 1):($offset + $len)]))) | ||
offset = :($offset + $len) | ||
end | ||
length(exprs) == 0 && return :(NamedTuple()) | ||
return :($(exprs...),) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is not needed because it's exactly the same thing as unflatten_metadata
getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon()) | ||
# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. | ||
# See for example https://github.com/JuliaLang/julia/pull/46381. | ||
function getindex_internal(vi::TypedVarInfo, ::Colon) | ||
return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata)) | ||
end | ||
function getindex_internal(md::Metadata, ::Colon) | ||
return mapreduce( | ||
Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0) | ||
) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getall(vi)
is replaced with getindex_internal(vi, :)
. This is for two reasons:
(1) getall(::VarNamedVector)
already defers to getindex_internal
(2) IMO, the internal
part of the name makes it clearer that one is accessing internal values.
""" | ||
setall!(vi::VarInfo, val) | ||
|
||
Set the values of all the variables in `vi` to `val`. | ||
|
||
The values may or may not be transformed to Euclidean space. | ||
""" | ||
setall!(vi::VarInfo, val) = _setall!(vi.metadata, val) | ||
|
||
function _setall!(metadata::Metadata, val) | ||
for r in metadata.ranges | ||
metadata.vals[r] .= val[r] | ||
end | ||
end | ||
function _setall!(vnv::VarNamedVector, val) | ||
# TODO(mhauru) Do something more efficient here. | ||
for i in 1:length_internal(vnv) | ||
setindex_internal!(vnv, val[i], i) | ||
end | ||
end | ||
@generated function _setall!(metadata::NamedTuple{names}, val) where {names} | ||
expr = Expr(:block) | ||
start = :(1) | ||
for f in names | ||
length = :(sum(length, metadata.$f.ranges)) | ||
finish = :($start + $length - 1) | ||
push!(expr.args, :(copyto!(metadata.$f.vals, 1, val, $start, $length))) | ||
start = :($start + $length) | ||
end | ||
return expr | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setall!
(and _setall!
) was just a mutating version of unflatten
(and unflatten_metadata
). The solution to this code duplication is to remove the mutating version and make people use unflatten
. There weren't any parts of the codebase where this could conceivably cause performance problems.
55dbe6a
to
2e19236
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
2e19236
to
2be1bbb
Compare
Pull Request Test Coverage Report for Build 13999996003Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! (benchmark failed due to a random error, I rerun it, good to merge after it pass)
Test failures seem to be related to the change in TuringLang/Bijectors.jl#325. Repro: using DynamicPPL, Test
model = DynamicPPL.TestUtils.demo_dynamic_constraint()
svi = DynamicPPL.settrans!!(SimpleVarInfo(), true)
svi = last(DynamicPPL.evaluate!!(model, svi, SamplingContext()))
svi = DynamicPPL.unflatten(svi, [6.515552440303498, -24.79099078521386])
retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext())
retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(
model, retval.m, retval.x
)
lp = getlogp(svi)
@show lp
@show lp_true
@show lp - lp_true
@test lp ≈ lp_true
#=
Test Failed at REPL[32]:1
Expression: lp ≈ lp_true
Evaluated: -45.039641779355634 ≈ -45.03965336794111
=# Prior to that Bijectors commit, both Although the example above uses SimpleVarInfo, the same is true of VarInfo. It seems to boil down to the fact that the forward transform and reverse transform give slightly different results. In the tilde pipeline, we use |
I've isolated the issue to Bijectors, TuringLang/Bijectors.jl#375 so on the DynamicPPL side we just need to increase the atol on the test and wait for upstream fix |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Big fan of this PR I am :) |
Closes #795
As discussed in #795 and also https://github.com/TuringLang/DynamicPPL.jl/pull/793/files#r1936255984 there are many internal DPPL functions which more or less do the same thing. This PR cleans them up a bit.
More comments to be made on the diff.