Skip to content

Clean up varinfo get/set functions #853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 21, 2025
Merged

Clean up varinfo get/set functions #853

merged 4 commits into from
Mar 21, 2025

Conversation

penelopeysm
Copy link
Member

Closes #795

As discussed in #795 and also https://github.com/TuringLang/DynamicPPL.jl/pull/793/files#r1936255984 there are many internal DPPL functions which more or less do the same thing. This PR cleans them up a bit.

More comments to be made on the diff.

Copy link
Contributor

github-actions bot commented Mar 20, 2025

Benchmark Report for Commit d74413695ca289829900b5d9535f30131280038a

Computer Information

Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                  9.4 |                 1.6 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                608.6 |                40.2 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                414.3 |                44.9 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1215.6 |                27.0 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               3599.0 |                21.3 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1426.2 |                29.3 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |                926.6 |                 5.4 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5467.1 |                 4.1 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |               1076.8 |                 8.4 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              60794.3 |                 3.7 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               8738.1 |                 9.8 |
|               Dynamic |        10 |    mooncake |             typed |   true |                128.6 |                13.1 |
|              Submodel |         1 |    mooncake |             typed |   true |                 24.9 |                 7.3 |
|                   LDA |        12 | reversediff |             typed |   true |                453.9 |                 4.8 |

Copy link

codecov bot commented Mar 20, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 84.87%. Comparing base (e4fa7f2) to head (d744136).
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #853      +/-   ##
==========================================
+ Coverage   84.43%   84.87%   +0.43%     
==========================================
  Files          34       34              
  Lines        3849     3815      -34     
==========================================
- Hits         3250     3238      -12     
+ Misses        599      577      -22     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines -163 to 178
function set_values!!(
function set_initial_values(
varinfo::AbstractVarInfo, initial_params::AbstractVector{<:Union{Real,Missing}}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function, originally called set_values!!, is very similar to unflatten but does a couple of extra things. Consequently, I renamed the function and added a docstring to make it clear how it differs.

Comment on lines -109 to -115
# TODO(mhauru) Isn't this the same as unflatten and/or replace_values?
function VarInfo(old_vi::VarInfo, x::AbstractVector)
md = replace_values(old_vi.metadata, x)
return VarInfo(
md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi))
)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constructor, as far as I can tell, wasn't being used anywhere. Unfortunately, because VarInfo is exported, this constructor is public and so this PR has to be a breaking change (all other changes are purely internal).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine IMO, looks like we'll touch it again when the num_produce is removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm that's a good point, VarInfo constructors will change when num_produce is changed - but that might be a while since Markus is off for a week and it's also a tricky piece of work. Maybe it'd be easier to leave this method in (maybe with a depwarn) and then release this PR as a patch - what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that works for me, seems this is the only breaking change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I think so.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok - I added this back in with a @deprecate, and we can remove it in 0.36.0.

Comment on lines 238 to 240
function unflatten_metadata(md::Metadata, x::AbstractVector)
return replace_values(md, x)
return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.orders, md.flags)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the only place that replace_values was being called, so I just inlined its definition

Comment on lines -270 to -281
@generated function replace_values(metadata::NamedTuple{names}, x) where {names}
exprs = []
offset = :(0)
for f in names
mdf = :(metadata.$f)
len = :(sum(length, $mdf.ranges))
push!(exprs, :($f = replace_values($mdf, x[($offset + 1):($offset + $len)])))
offset = :($offset + $len)
end
length(exprs) == 0 && return :(NamedTuple())
return :($(exprs...),)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is not needed because it's exactly the same thing as unflatten_metadata

Comment on lines +624 to +634
getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon())
# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference.
# See for example https://github.com/JuliaLang/julia/pull/46381.
function getindex_internal(vi::TypedVarInfo, ::Colon)
return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata))
end
function getindex_internal(md::Metadata, ::Colon)
return mapreduce(
Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0)
)
end
Copy link
Member Author

@penelopeysm penelopeysm Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getall(vi) is replaced with getindex_internal(vi, :). This is for two reasons:

(1) getall(::VarNamedVector) already defers to getindex_internal
(2) IMO, the internal part of the name makes it clearer that one is accessing internal values.

Comment on lines -693 to -723
"""
setall!(vi::VarInfo, val)

Set the values of all the variables in `vi` to `val`.

The values may or may not be transformed to Euclidean space.
"""
setall!(vi::VarInfo, val) = _setall!(vi.metadata, val)

function _setall!(metadata::Metadata, val)
for r in metadata.ranges
metadata.vals[r] .= val[r]
end
end
function _setall!(vnv::VarNamedVector, val)
# TODO(mhauru) Do something more efficient here.
for i in 1:length_internal(vnv)
setindex_internal!(vnv, val[i], i)
end
end
@generated function _setall!(metadata::NamedTuple{names}, val) where {names}
expr = Expr(:block)
start = :(1)
for f in names
length = :(sum(length, metadata.$f.ranges))
finish = :($start + $length - 1)
push!(expr.args, :(copyto!(metadata.$f.vals, 1, val, $start, $length)))
start = :($start + $length)
end
return expr
end
Copy link
Member Author

@penelopeysm penelopeysm Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

setall! (and _setall!) was just a mutating version of unflatten (and unflatten_metadata). The solution to this code duplication is to remove the mutating version and make people use unflatten. There weren't any parts of the codebase where this could conceivably cause performance problems.

Copy link
Member

@sunxd3 sunxd3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@penelopeysm penelopeysm changed the base branch from breaking to main March 21, 2025 14:42
@penelopeysm penelopeysm requested a review from sunxd3 March 21, 2025 14:49
@coveralls
Copy link

coveralls commented Mar 21, 2025

Pull Request Test Coverage Report for Build 13999996003

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 14 of 14 (100.0%) changed or added relevant lines in 2 files are covered.
  • 1 unchanged line in 1 file lost coverage.
  • Overall coverage increased (+0.4%) to 84.965%

Files with Coverage Reduction New Missed Lines %
src/varinfo.jl 1 84.97%
Totals Coverage Status
Change from base Build 13957720054: 0.4%
Covered Lines: 3238
Relevant Lines: 3811

💛 - Coveralls

Copy link
Member

@sunxd3 sunxd3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! (benchmark failed due to a random error, I rerun it, good to merge after it pass)

@penelopeysm penelopeysm added this pull request to the merge queue Mar 21, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Mar 21, 2025
@penelopeysm penelopeysm added this pull request to the merge queue Mar 21, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Mar 21, 2025
@penelopeysm
Copy link
Member Author

penelopeysm commented Mar 21, 2025

Test failures seem to be related to the change in TuringLang/Bijectors.jl#325.

Repro:

using DynamicPPL, Test

model = DynamicPPL.TestUtils.demo_dynamic_constraint()
svi = DynamicPPL.settrans!!(SimpleVarInfo(), true)
svi = last(DynamicPPL.evaluate!!(model, svi, SamplingContext()))

svi = DynamicPPL.unflatten(svi, [6.515552440303498, -24.79099078521386])
retval, svi = DynamicPPL.evaluate!!(model, svi, DefaultContext())
retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(
    model, retval.m, retval.x
)
lp = getlogp(svi)
@show lp
@show lp_true
@show lp - lp_true
@test lp  lp_true
#=
Test Failed at REPL[32]:1
  Expression: lp ≈ lp_true
   Evaluated: -45.039641779355634 ≈ -45.03965336794111
=#

Prior to that Bijectors commit, both lp and lp_true would be -45.03965336794111 which matches the value of lp_true above - this means that it's the calculation of lp that's changed, i.e. something inside the tilde pipeline.

Although the example above uses SimpleVarInfo, the same is true of VarInfo.

It seems to boil down to the fact that the forward transform and reverse transform give slightly different results. In the tilde pipeline, we use invlink_with_logpdf i.e. the new formula. In logprior_true_with_logabsdet_jacobian we use link i.e. the old formula.

@penelopeysm
Copy link
Member Author

I've isolated the issue to Bijectors, TuringLang/Bijectors.jl#375 so on the DynamicPPL side we just need to increase the atol on the test and wait for upstream fix

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@penelopeysm penelopeysm added this pull request to the merge queue Mar 21, 2025
Merged via the queue into main with commit 0810e14 Mar 21, 2025
18 checks passed
@penelopeysm penelopeysm deleted the py/varinfos-1 branch March 21, 2025 21:00
@mhauru
Copy link
Member

mhauru commented Apr 1, 2025

Big fan of this PR I am :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unify unflatten, replace_value, and others
4 participants