Skip to content

Commit 82462cd

Browse files
committed
AbstractPPL 0.11; change prefixing behaviour
1 parent 324e623 commit 82462cd

16 files changed

+191
-146
lines changed

HISTORY.md

+49
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,54 @@
11
# DynamicPPL Changelog
22

3+
## 0.36.0
4+
5+
**Breaking changes**
6+
7+
### VarName prefixing behaviour
8+
9+
The way in which VarNames in submodels are prefixed has been changed.
10+
This is best explained through an example.
11+
Consider this model and submodel:
12+
13+
```julia
14+
using DynamicPPL, Distributions
15+
@model inner() = x ~ Normal()
16+
@model outer() = a ~ to_submodel(inner())
17+
```
18+
19+
In previous versions, the inner variable `x` would be saved as `a.x`.
20+
However, this was represented as a single symbol `Symbol("a.x")`:
21+
22+
```julia
23+
julia> dump(keys(VarInfo(outer()))[1])
24+
VarName{Symbol("a.x"), typeof(identity)}
25+
optic: identity (function of type typeof(identity))
26+
```
27+
28+
Now, the inner variable is stored as a field `x` on the VarName `a`:
29+
30+
```julia
31+
julia> dump(keys(VarInfo(outer()))[1])
32+
VarName{:a, Accessors.PropertyLens{:x}}
33+
optic: Accessors.PropertyLens{:x} (@o _.x)
34+
```
35+
36+
In practice, this means that if you are trying to condition a variable in the submodel, you now need to use
37+
38+
```julia
39+
outer() | (@varname(a.x) => 1.0,)
40+
```
41+
42+
instead of either of these (which would have worked previously)
43+
44+
```julia
45+
outer() | (@varname(var"a.x") => 1.0,)
46+
outer() | (a.x = 1.0,)
47+
```
48+
49+
If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain.
50+
(This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.)
51+
352
## 0.35.5
453

554
Several internal methods have been removed:

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ DynamicPPLMooncakeExt = ["Mooncake"]
4444
[compat]
4545
ADTypes = "1"
4646
AbstractMCMC = "5"
47-
AbstractPPL = "0.10.1"
47+
AbstractPPL = "0.11"
4848
Accessors = "0.1"
4949
BangBang = "0.4.1"
5050
Bijectors = "0.13.18, 0.14, 0.15"

src/DynamicPPL.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ using DocStringExtensions
2121

2222
using Random: Random
2323

24+
# For extending
25+
import AbstractPPL: predict, prefix
26+
2427
# TODO: Remove these when it's possible.
2528
import Bijectors: link, invlink
2629

@@ -39,8 +42,6 @@ import Base:
3942
keys,
4043
haskey
4144

42-
import AbstractPPL: predict
43-
4445
# VarInfo
4546
export AbstractVarInfo,
4647
VarInfo,

src/compiler.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ function contextual_isassumption(context::ConditionContext, vn)
113113
return contextual_isassumption(childcontext(context), vn)
114114
end
115115
function contextual_isassumption(context::PrefixContext, vn)
116-
return contextual_isassumption(childcontext(context), prefix(context, vn))
116+
return contextual_isassumption(childcontext(context), prefix_with_context(context, vn))
117117
end
118118

119119
isfixed(expr, vn) = false
@@ -132,7 +132,7 @@ function contextual_isfixed(context::AbstractContext, vn)
132132
return contextual_isfixed(NodeTrait(context), context, vn)
133133
end
134134
function contextual_isfixed(context::PrefixContext, vn)
135-
return contextual_isfixed(childcontext(context), prefix(context, vn))
135+
return contextual_isfixed(childcontext(context), prefix_with_context(context, vn))
136136
end
137137
function contextual_isfixed(context::FixedContext, vn)
138138
if hasfixed(context, vn)

src/context_implementations.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,14 @@ function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, rig
8585
end
8686

8787
function tilde_assume(context::PrefixContext, right, vn, vi)
88-
return tilde_assume(context.context, right, prefix(context, vn), vi)
88+
return tilde_assume(context.context, right, prefix_with_context(context, vn), vi)
8989
end
9090
function tilde_assume(
9191
rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi
9292
)
93-
return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), vi)
93+
return tilde_assume(
94+
rng, context.context, sampler, right, prefix_with_context(context, vn), vi
95+
)
9496
end
9597

9698
"""

src/contexts.jl

+20-20
Original file line numberDiff line numberDiff line change
@@ -260,25 +260,25 @@ function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix}
260260
return PrefixContext{Prefix}(child)
261261
end
262262

263-
const PREFIX_SEPARATOR = Symbol(".")
264-
265-
@generated function PrefixContext{PrefixOuter}(
266-
context::PrefixContext{PrefixInner}
267-
) where {PrefixOuter,PrefixInner}
268-
return :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}(
269-
context.context
270-
))
271-
end
263+
"""
264+
prefix_with_context(ctx::AbstractContext, vn::VarName)
272265
273-
function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
274-
vn_prefixed_inner = prefix(childcontext(ctx), vn)
275-
return VarName{Symbol(Prefix, PREFIX_SEPARATOR, getsym(vn_prefixed_inner))}(
276-
getoptic(vn_prefixed_inner)
266+
Apply the prefixes in the context `ctx` to the variable name `vn`.
267+
"""
268+
function prefix_with_context(
269+
ctx::PrefixContext{Prefix}, vn::VarName{Sym}
270+
) where {Prefix,Sym}
271+
return AbstractPPL.prefix(
272+
prefix_with_context(childcontext(ctx), vn), VarName{Symbol(Prefix)}()
277273
)
278274
end
279-
prefix(ctx::AbstractContext, vn::VarName) = prefix(NodeTrait(ctx), ctx, vn)
280-
prefix(::IsLeaf, ::AbstractContext, vn::VarName) = vn
281-
prefix(::IsParent, ctx::AbstractContext, vn::VarName) = prefix(childcontext(ctx), vn)
275+
function prefix_with_context(ctx::AbstractContext, vn::VarName)
276+
return prefix_with_context(NodeTrait(ctx), ctx, vn)
277+
end
278+
prefix_with_context(::IsLeaf, ::AbstractContext, vn::VarName) = vn
279+
function prefix_with_context(::IsParent, ctx::AbstractContext, vn::VarName)
280+
return prefix_with_context(childcontext(ctx), vn)
281+
end
282282

283283
"""
284284
prefix(model::Model, x)
@@ -392,7 +392,7 @@ function hasconditioned_nested(::IsParent, context, vn)
392392
return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn)
393393
end
394394
function hasconditioned_nested(context::PrefixContext, vn)
395-
return hasconditioned_nested(childcontext(context), prefix(context, vn))
395+
return hasconditioned_nested(childcontext(context), prefix_with_context(context, vn))
396396
end
397397

398398
"""
@@ -410,7 +410,7 @@ function getconditioned_nested(::IsLeaf, context, vn)
410410
return error("context $(context) does not contain value for $vn")
411411
end
412412
function getconditioned_nested(context::PrefixContext, vn)
413-
return getconditioned_nested(childcontext(context), prefix(context, vn))
413+
return getconditioned_nested(childcontext(context), prefix_with_context(context, vn))
414414
end
415415
function getconditioned_nested(::IsParent, context, vn)
416416
return if hasconditioned(context, vn)
@@ -543,7 +543,7 @@ function hasfixed_nested(::IsParent, context, vn)
543543
return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn)
544544
end
545545
function hasfixed_nested(context::PrefixContext, vn)
546-
return hasfixed_nested(childcontext(context), prefix(context, vn))
546+
return hasfixed_nested(childcontext(context), prefix_with_context(context, vn))
547547
end
548548

549549
"""
@@ -561,7 +561,7 @@ function getfixed_nested(::IsLeaf, context, vn)
561561
return error("context $(context) does not contain value for $vn")
562562
end
563563
function getfixed_nested(context::PrefixContext, vn)
564-
return getfixed_nested(childcontext(context), prefix(context, vn))
564+
return getfixed_nested(childcontext(context), prefix_with_context(context, vn))
565565
end
566566
function getfixed_nested(::IsParent, context, vn)
567567
return if hasfixed(context, vn)

src/debug_utils.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ function DynamicPPL.setchildcontext(context::DebugContext, child)
183183
end
184184

185185
function record_varname!(context::DebugContext, varname::VarName, dist)
186-
prefixed_varname = prefix(context, varname)
186+
prefixed_varname = DynamicPPL.prefix_with_context(context, varname)
187187
if haskey(context.varnames_seen, prefixed_varname)
188188
if context.error_on_failure
189189
error("varname $prefixed_varname used multiple times in model")

src/model.jl

+17-41
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ julia> model() ≠ 1.0
243243
true
244244
245245
julia> # To condition the variable inside `demo_inner` we need to refer to it as `inner.m`.
246-
conditioned_model = model | (var"inner.m" = 1.0, );
246+
conditioned_model = model | (@varname(inner.m) => 1.0, );
247247
248248
julia> conditioned_model()
249249
1.0
@@ -255,15 +255,6 @@ julia> conditioned_model_fail()
255255
ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported
256256
[...]
257257
```
258-
259-
And similarly when using `Dict`:
260-
261-
```jldoctest condition
262-
julia> conditioned_model_dict = model | (@varname(var"inner.m") => 1.0);
263-
264-
julia> conditioned_model_dict()
265-
1.0
266-
```
267258
"""
268259
function AbstractPPL.condition(model::Model, values...)
269260
# Positional arguments - need to handle cases carefully
@@ -443,16 +434,16 @@ julia> conditioned(cm)
443434
julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed,
444435
# `a.m` is treated as a random variable.
445436
keys(VarInfo(cm))
446-
1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}:
437+
1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}:
447438
a.m
448439
449440
julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation.
450-
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((var"a.m"=1.0,)))), x=100.0);
441+
cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext(Dict(@varname(a.m) => 1.0)))), x=100.0);
451442
452-
julia> conditioned(cm).x
443+
julia> conditioned(cm)[@varname(x)]
453444
100.0
454445
455-
julia> conditioned(cm).var"a.m"
446+
julia> conditioned(cm)[@varname(a.m)]
456447
1.0
457448
458449
julia> keys(VarInfo(cm)) # No variables are sampled
@@ -583,7 +574,7 @@ julia> model = demo_outer();
583574
julia> model() ≠ 1.0
584575
true
585576
586-
julia> fixed_model = fix(model, var"inner.m" = 1.0, );
577+
julia> fixed_model = fix(model, (@varname(inner.m) => 1.0, ));
587578
588579
julia> fixed_model()
589580
1.0
@@ -599,24 +590,9 @@ julia> fixed_model()
599590
2.0
600591
```
601592
602-
And similarly when using `Dict`:
603-
604-
```jldoctest fix
605-
julia> fixed_model_dict = fix(model, @varname(var"inner.m") => 1.0);
606-
607-
julia> fixed_model_dict()
608-
1.0
609-
610-
julia> fixed_model_dict = fix(model, @varname(inner) => 2.0);
611-
612-
julia> fixed_model_dict()
613-
2.0
614-
```
615-
616593
## Difference from `condition`
617594
618-
A very similar functionality is also provided by [`condition`](@ref) which,
619-
not surprisingly, _conditions_ variables instead of fixing them. The only
595+
A very similar functionality is also provided by [`condition`](@ref). The only
620596
difference between fixing and conditioning is as follows:
621597
- `condition`ed variables are considered to be observations, and are thus
622598
included in the computation [`logjoint`](@ref) and [`loglikelihood`](@ref),
@@ -798,16 +774,16 @@ julia> fixed(cm)
798774
julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed,
799775
# `a.m` is treated as a random variable.
800776
keys(VarInfo(cm))
801-
1-element Vector{VarName{Symbol("a.m"), typeof(identity)}}:
777+
1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}:
802778
a.m
803779
804780
julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation.
805-
cm = fix(contextualize(m, PrefixContext{:a}(fix(var"a.m"=1.0))), x=100.0);
781+
cm = fix(contextualize(m, PrefixContext{:a}(fix(@varname(a.m) => 1.0,))), x=100.0);
806782
807-
julia> fixed(cm).x
783+
julia> fixed(cm)[@varname(x)]
808784
100.0
809785
810-
julia> fixed(cm).var"a.m"
786+
julia> fixed(cm)[@varname(a.m)]
811787
1.0
812788
813789
julia> keys(VarInfo(cm)) # <= no variables are sampled
@@ -1365,7 +1341,7 @@ When we sample from the model `demo2(missing, 0.4)` random variable `x` will be
13651341
```jldoctest submodel-to_submodel
13661342
julia> vi = VarInfo(demo2(missing, 0.4));
13671343
1368-
julia> @varname(var\"a.x\") in keys(vi)
1344+
julia> @varname(a.x) in keys(vi)
13691345
true
13701346
```
13711347
@@ -1379,7 +1355,7 @@ false
13791355
We can check that the log joint probability of the model accumulated in `vi` is correct:
13801356
13811357
```jldoctest submodel-to_submodel
1382-
julia> x = vi[@varname(var\"a.x\")];
1358+
julia> x = vi[@varname(a.x)];
13831359
13841360
julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4)
13851361
true
@@ -1417,10 +1393,10 @@ julia> @model function demo2(x, y, z)
14171393
14181394
julia> vi = VarInfo(demo2(missing, missing, 0.4));
14191395
1420-
julia> @varname(var"sub1.x") in keys(vi)
1396+
julia> @varname(sub1.x) in keys(vi)
14211397
true
14221398
1423-
julia> @varname(var"sub2.x") in keys(vi)
1399+
julia> @varname(sub2.x) in keys(vi)
14241400
true
14251401
```
14261402
@@ -1437,9 +1413,9 @@ false
14371413
We can check that the log joint probability of the model accumulated in `vi` is correct:
14381414
14391415
```jldoctest submodel-to_submodel-prefix
1440-
julia> sub1_x = vi[@varname(var"sub1.x")];
1416+
julia> sub1_x = vi[@varname(sub1.x)];
14411417
1442-
julia> sub2_x = vi[@varname(var"sub2.x")];
1418+
julia> sub2_x = vi[@varname(sub2.x)];
14431419
14441420
julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x);
14451421

0 commit comments

Comments
 (0)