From 85c3a9e25892c262b5002b18c8adaff200b63756 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 16 Apr 2025 16:38:16 +0100 Subject: [PATCH 01/17] Fix conditioning in submodels --- src/compiler.jl | 14 ++++++--- src/context_implementations.jl | 27 ++++++++++++---- src/contexts.jl | 57 +++++++++++++++++++++++++++++----- src/utils.jl | 9 ++++-- 4 files changed, 87 insertions(+), 20 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 4771b0171..ff8903045 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -53,7 +53,9 @@ function isassumption( vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))), ) return quote - if $(DynamicPPL.contextual_isassumption)(__context__, $vn) + if $(DynamicPPL.contextual_isassumption)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + ) # Considered an assumption by `__context__` which means either: # 1. We hit the default implementation, e.g. using `DefaultContext`, # which in turn means that we haven't considered if it's one of @@ -112,8 +114,10 @@ function contextual_isassumption(context::ConditionContext, vn) # so we defer to `childcontext` if we haven't concluded that anything yet. return contextual_isassumption(childcontext(context), vn) end -function contextual_isassumption(context::PrefixContext, vn) - return contextual_isassumption(childcontext(context), prefix(context, vn)) +function contextual_isassumption(context::PrefixContext{Prefix}, vn) where {Prefix} + return contextual_isassumption( + prefix_conditioned_variables(childcontext(context), VarName{Prefix}()), vn + ) end isfixed(expr, vn) = false @@ -473,7 +477,9 @@ function generate_tilde(left, right) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) - $left = $(DynamicPPL.getconditioned_nested)(__context__, $vn) + $left = $(DynamicPPL.getconditioned_nested)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + ) end $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e4ba5d252..2e7b85162 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -104,12 +104,27 @@ probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, vi) return if is_rhs_model(right) - # Prefix the variables using the `vn`. - rand_like!!( - right, - should_auto_prefix(right) ? PrefixContext{Symbol(vn)}(context) : context, - vi, - ) + # Here, we apply the PrefixContext _not_ to the parent `context`, but + # to the context of the submodel being evaluated. This means that later= + # on in `make_evaluate_args_and_kwargs`, the context stack will be + # correctly arranged such that it goes like this: + # parent_context[1] -> parent_context[2] -> ... -> PrefixContext -> + # submodel_context[1] -> submodel_context[2] -> ... -> leafcontext + # See the docstring of `make_evaluate_args_and_kwargs`, and the internal + # DynamicPPL documentation on submodel conditioning, for more details. + # + # NOTE: This relies on the existence of `right.model.model`. Right now, + # the only thing that can return true for `is_rhs_model` is something + # (a `Sampleable`) that has a `model` field that itself (a + # `ReturnedModelWrapper`) has a `model` field. This may or may not + # change in the future. + if should_auto_prefix(right) + dppl_model = right.model.model # This isa DynamicPPL.Model + prefixed_submodel_context = PrefixContext{getsym(vn)}(dppl_model.context) + new_dppl_model = contextualize(dppl_model, prefixed_submodel_context) + right = to_submodel(new_dppl_model, true) + end + rand_like!!(right, context, vi) else value, logp, vi = tilde_assume(context, right, vn, vi) value, acclogp_assume!!(context, vi, logp) diff --git a/src/contexts.jl b/src/contexts.jl index 58ac612b8..7ba79bc7f 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -265,8 +265,8 @@ end Apply the prefixes in the context `ctx` to the variable name `vn`. """ -function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} - return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Symbol(Prefix)}()) +function prefix(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix} + return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Prefix}()) end function prefix(ctx::AbstractContext, vn::VarName) return prefix(NodeTrait(ctx), ctx, vn) @@ -351,6 +351,43 @@ NodeTrait(::ConditionContext) = IsParent() childcontext(context::ConditionContext) = context.context setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) +""" + prefix_conditioned_variables(context::AbstractContext, prefix::VarName) + +Prefix all the conditioned variables in a given context with `prefix`. + +```jldoctest +julia> using DynamicPPL: prefix_conditioned_variables, ConditionContext + +julia> c1 = ConditionContext((a=1, )) +ConditionContext((a = 1,), DefaultContext()) + +julia> prefix_conditioned_variables(c1, @varname(y)) +ConditionContext(Dict(y.a => 1), DefaultContext()) +``` +""" +function prefix_conditioned_variables(ctx::ConditionContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_conditioned_variables(childcontext(ctx), prefix) + return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_conditioned_variables(c::AbstractContext, prefix::VarName) + return prefix_conditioned_variables( + NodeTrait(prefix_conditioned_variables, c), c, prefix + ) +end +prefix_conditioned_variables(::IsLeaf, context::AbstractContext, prefix::VarName) = context +function prefix_conditioned_variables(::IsParent, context::AbstractContext, prefix::VarName) + return setchildcontext( + context, prefix_conditioned_variables(childcontext(context), prefix) + ) +end + """ hasconditioned(context::AbstractContext, vn::VarName) @@ -370,7 +407,9 @@ Return value of `vn` in `context`. function getconditioned(context::AbstractContext, vn::VarName) return error("context $(context) does not contain value for $vn") end -getconditioned(context::ConditionContext, vn::VarName) = getvalue(context.values, vn) +function getconditioned(context::ConditionContext, vn::VarName) + return getvalue(context.values, vn) +end """ hasconditioned_nested(context, vn) @@ -387,8 +426,10 @@ hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) function hasconditioned_nested(::IsParent, context, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end -function hasconditioned_nested(context::PrefixContext, vn) - return hasconditioned_nested(childcontext(context), prefix(context, vn)) +function hasconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix} + return hasconditioned_nested( + prefix_conditioned_variables(childcontext(context), VarName{Prefix}()), vn + ) end """ @@ -405,8 +446,10 @@ end function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end -function getconditioned_nested(context::PrefixContext, vn) - return getconditioned_nested(childcontext(context), prefix(context, vn)) +function getconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix} + return getconditioned_nested( + prefix_conditioned_variables(childcontext(context), VarName{Prefix}()), vn + ) end function getconditioned_nested(::IsParent, context, vn) return if hasconditioned(context, vn) diff --git a/src/utils.jl b/src/utils.jl index 56c3d70af..71919480c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1286,7 +1286,10 @@ broadcast_safe(x::Distribution) = (x,) broadcast_safe(x::AbstractContext) = (x,) # Convert (x=1,) to Dict(@varname(x) => 1) -_nt_to_varname_dict(nt) = Dict(VarName{k}() => v for (k, v) in pairs(nt)) +function to_varname_dict(nt::NamedTuple) + return Dict{VarName,Any}(VarName{k}() => v for (k, v) in pairs(nt)) +end +to_varname_dict(d::AbstractDict) = d # Version of `merge` used by `conditioned` and `fixed` to handle # the scenario where we might try to merge a dict with an empty # tuple. @@ -1294,9 +1297,9 @@ _nt_to_varname_dict(nt) = Dict(VarName{k}() => v for (k, v) in pairs(nt)) _merge(left::NamedTuple, right::NamedTuple) = merge(left, right) _merge(left::AbstractDict, right::AbstractDict) = merge(left, right) _merge(left::AbstractDict, ::NamedTuple{()}) = left -_merge(left::AbstractDict, right::NamedTuple) = merge(left, _nt_to_varname_dict(right)) +_merge(left::AbstractDict, right::NamedTuple) = merge(left, to_varname_dict(right)) _merge(::NamedTuple{()}, right::AbstractDict) = right -_merge(left::NamedTuple, right::AbstractDict) = merge(_nt_to_varname_dict(left), right) +_merge(left::NamedTuple, right::AbstractDict) = merge(to_varname_dict(left), right) """ unique_syms(vns::T) where {T<:NTuple{N,VarName}} From ab3ac228e5ddeaa1e16c722afd9b76588ec08292 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 17 Apr 2025 01:39:36 +0100 Subject: [PATCH 02/17] Simplify contextual_isassumption --- src/compiler.jl | 44 ++++++++------------------------------------ 1 file changed, 8 insertions(+), 36 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index ff8903045..eb71404a4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -89,35 +89,19 @@ isassumption(expr) = :(false) contextual_isassumption(context, vn) Return `true` if `vn` is considered an assumption by `context`. - -The default implementation for `AbstractContext` always returns `true`. """ -contextual_isassumption(::IsLeaf, context, vn) = true -function contextual_isassumption(::IsParent, context, vn) - return contextual_isassumption(childcontext(context), vn) -end function contextual_isassumption(context::AbstractContext, vn) - return contextual_isassumption(NodeTrait(context), context, vn) -end -function contextual_isassumption(context::ConditionContext, vn) - if hasconditioned(context, vn) - val = getconditioned(context, vn) + if hasconditioned_nested(context, vn) + val = getconditioned_nested(context, vn) # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? if eltype(val) >: Missing && val === missing return true else return false end + else + return true end - - # We might have nested contexts, e.g. `ConditionContext{.., <:PrefixContext{..., <:ConditionContext}}` - # so we defer to `childcontext` if we haven't concluded that anything yet. - return contextual_isassumption(childcontext(context), vn) -end -function contextual_isassumption(context::PrefixContext{Prefix}, vn) where {Prefix} - return contextual_isassumption( - prefix_conditioned_variables(childcontext(context), VarName{Prefix}()), vn - ) end isfixed(expr, vn) = false @@ -128,30 +112,18 @@ isfixed(::Union{Symbol,Expr}, vn) = :($(DynamicPPL.contextual_isfixed)(__context Return `true` if `vn` is considered fixed by `context`. """ -contextual_isfixed(::IsLeaf, context, vn) = false -function contextual_isfixed(::IsParent, context, vn) - return contextual_isfixed(childcontext(context), vn) -end function contextual_isfixed(context::AbstractContext, vn) - return contextual_isfixed(NodeTrait(context), context, vn) -end -function contextual_isfixed(context::PrefixContext, vn) - return contextual_isfixed(childcontext(context), prefix(context, vn)) -end -function contextual_isfixed(context::FixedContext, vn) - if hasfixed(context, vn) - val = getfixed(context, vn) + if hasfixed_nested(context, vn) + val = getfixed_nested(context, vn) # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? if eltype(val) >: Missing && val === missing return false else return true end + else + return false end - - # We might have nested contexts, e.g. `FixedContext{.., <:PrefixContext{..., <:FixedContext}}` - # so we defer to `childcontext` if we haven't concluded that anything yet. - return contextual_isfixed(childcontext(context), vn) end # If we're working with, say, a `Symbol`, then we're not going to `view`. From f84988fcc8a7ba671c4f5f75ead927008851de83 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 17 Apr 2025 01:15:35 +0100 Subject: [PATCH 03/17] Add documentation --- docs/make.jl | 4 +- docs/src/internals/submodel_condition.md | 233 +++++++++++++++++++++++ 2 files changed, 236 insertions(+), 1 deletion(-) create mode 100644 docs/src/internals/submodel_condition.md diff --git a/docs/make.jl b/docs/make.jl index c69b72fb8..7984fa1d1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -24,7 +24,9 @@ makedocs(; format=Documenter.HTML(; size_threshold=2^10 * 400), modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)], pages=[ - "Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"] + "Home" => "index.md", + "API" => "api.md", + "Internals" => ["internals/varinfo.md", "internals/submodel_condition.md"], ], checkdocs=:exports, doctest=false, diff --git a/docs/src/internals/submodel_condition.md b/docs/src/internals/submodel_condition.md new file mode 100644 index 000000000..fe2f2c1df --- /dev/null +++ b/docs/src/internals/submodel_condition.md @@ -0,0 +1,233 @@ +# How `PrefixContext` and `ConditionContext` interact + +```@meta +ShareDefaultModule = true +``` + +## PrefixContext + +`PrefixContext` is a context that, as the name suggests, prefixes all variables inside a model with a given symbol. +Thus, for example: + +```@example +using DynamicPPL, Distributions + +@model function f() + x ~ Normal() + return y ~ Normal() +end + +@model function g() + return a ~ to_submodel(f()) +end +``` + +inside the submodel `f`, the variables `x` and `y` become `a.x` and `a.y` respectively. +This is easiest to observe by running the model: + +```@example +vi = VarInfo(g()) +keys(vi) +``` + +!!! note + + In this case, where `to_submodel` is called without any other arguments, the prefix to be used is automatically inferred from the name of the variable on the left-hand side of the tilde. + We will return to the 'manual prefixing' case later. + +What does it really mean to 'become' a different variable? +We can see this from [the definition of `tilde_assume`, for example](https://github.com/TuringLang/DynamicPPL.jl/blob/60ee68e2ce28a15c6062c243019e6208d16802a5/src/context_implementations.jl#L87-L89): + +``` +function tilde_assume(context::PrefixContext, right, vn, vi) + return tilde_assume(context.context, right, prefix(context, vn), vi) +end +``` + +Functionally, this means that even though the _initial_ entry to the tilde-pipeline has `vn` as `x` and `y`, once the `PrefixContext` has been applied, the later functions will see `a.x` and `a.y` instead. + +## ConditionContext + +`ConditionContext` is a context which stores values of variables that are to be conditioned on. +These values may be stored as a `Dict` which maps `VarName`s to values, or alternatively as a `NamedTuple`. +The latter only works correctly if all `VarName`s are 'basic', in that they have an identity optic (i.e., something like `a.x` or `a[1]` is forbidden). +Because of this limitation, we will only use `Dict` in this example. + +!!! note + + If a `ConditionContext` with a `NamedTuple` encounters anything to do with a prefix, its internal `NamedTuple` is converted to a `Dict` anyway, so it is quite reasonable to ignore the `NamedTuple` case in this exposition. + +One can inspect the conditioning values with, for example: + +```@example +@model function d() + x ~ Normal() + return y ~ Normal() +end + +cond_model = d() | (@varname(x) => 1.0) +cond_ctx = cond_model.context +``` + +There are several internal functions that are used to determine whether a variable is conditioned, and if so, what its value is. + +```@example +DynamicPPL.hasconditioned_nested(cond_ctx, @varname(x)) +``` + +```@example +DynamicPPL.getconditioned_nested(cond_ctx, @varname(x)) +``` + +These functions are in turn used by the function `DynamicPPL.contextual_isassumption`, which is largely the same as `hasconditioned_nested`, but also checks whether the value is `missing` (in which case it isn't really conditioned). + +```@example +DynamicPPL.contextual_isassumption(cond_ctx, @varname(x)) +``` + +!!! note + + Notice that (neglecting `missing` values) the return value of `contextual_isassumption` is the _opposite_ of `hasconditioned_nested`, i.e. for a variable that _is_ conditioned on, `contextual_isassumption` returns `false`. + +If a variable `x` is conditioned on, then the effect of this is to set the value of `x` to the given value (while still including its contribution to the log probability density). +Since `x` is no longer a random variable, if we were to evaluate the model, we would find only one key in the `VarInfo`: + +```@example +keys(VarInfo(cond_model)) +``` + +## Joint behaviour: desiderata at the model level + +When paired together, these two contexts have the potential to cause substantial confusion: `PrefixContext` modifies the variable names that are seen, which may cause them to be out of sync with the values contained inside the `ConditionContext`. + +We begin by mentioning some high-level desiderata for their joint behaviour. +Take these models, for example: + +```@example +# We define a helper function to unwrap a layer of SamplingContext, to +# avoid cluttering the print statements. +unwrap_sampling_context(ctx::DynamicPPL.SamplingContext) = ctx.context +unwrap_sampling_context(ctx::DynamicPPL.AbstractContext) = ctx +@model function inner() + println("inner context: $(unwrap_sampling_context(__context__))") + x ~ Normal() + return y ~ Normal() +end + +@model function outer() + println("outer context: $(unwrap_sampling_context(__context__))") + return a ~ to_submodel(inner()) +end + +# 'Outer conditioning' +with_outer_cond = outer() | (@varname(a.x) => 1.0) + +# 'Inner conditioning' +inner_cond = inner() | (@varname(x) => 1.0) +@model function outer2() + println("outer context: $(unwrap_sampling_context(__context__))") + return a ~ to_submodel(inner_cond) +end +with_inner_cond = outer2() +``` + +We want that: + + 1. `keys(VarInfo(outer()))` should return `[a.x, a.y]`; + 2. `keys(VarInfo(with_outer_cond))` should return `[a.y]`; + 3. `keys(VarInfo(with_inner_cond))` should return `[a.y]`, + +**In other words, we can condition submodels either from the outside (point (2)) or from the inside (point (3)), and the variable name we use to specify the conditioning should match the level at which we perform the conditioning.** + +This is an incredibly salient point because it means that submodels can be treated as individual, opaque objects, and we can condition them without needing to know what it will be prefixed with, or the context in which that submodel is being used. +For example, this means we can reuse `inner_cond` in another model with a different prefix, and it will _still_ have its inner `x` value be conditioned, despite the prefix differing. + +!!! info + + In the current version of DynamicPPL, these criteria are all fulfilled. However, this was not the case in the past: in particular, point (3) was not fulfilled, and users had to condition the internal submodel with the prefixes that were used outside. (See [this GitHub issue](https://github.com/TuringLang/DynamicPPL.jl/issues/857) for more information; this issue was the direct motivation for this documentation page.) + +## Desiderata at the context level + +The above section describes how we expect conditioning and prefixing to behave from a user's perpective. +We now turn to the question of how we implement this in terms of DynamicPPL contexts. +We do not specify the implementation details here, but we will sketch out something resembling an API that will allow us to achieve the target behaviour. + +**Point (1)** does not involve any conditioning, only prefixing; it is therefore already satisfied by virtue of the `tilde_assume` method shown above. + +**Points (2) and (3)** are more tricky. +As the reader may surmise, the difference between them is the order in which the contexts are stacked. + +For the _outer_ conditioning case (point (2)), the `ConditionContext` will contain a `VarName` that is already prefixed. +When we enter the inner submodel, this `ConditionContext` has to be passed down and somehow combined with the `PrefixContext` that is created when we enter the submodel. +We make the claim here that the best way to do this is to nest the `PrefixContext` _inside_ the `ConditionContext`. +This is indeed what happens, as can be demonstrated by running the model. + +```@example +with_outer_cond(); +nothing; +``` + +!!! info + + The `; nothing` at the end is purely to circumvent a Documenter.jl quirk where stdout is only shown if the return value of the final statement is `nothing`. + If these documentation pages are moved to Quarto, it will be possible to remove this. + +For the _inner_ conditioning case (point (3)), the outer model is not run with any special context. +The inner model will itself contain a `ConditionContext` will contain a `VarName` that is not prefixed. +When we run the model, this `ConditionContext` should be then nested _inside_ a `PrefixContext` to form the final evaluation context. +Again, we can run the model to see this in action: + +```@example +with_inner_cond(); +nothing; +``` + +Putting all of the information so far together, what it means is that if we have these two inner contexts (taken from above): + +```@example +using DynamicPPL: PrefixContext, ConditionContext, DefaultContext + +inner_ctx_with_outer_cond = ConditionContext( + Dict(@varname(a.x) => 1.0), PrefixContext{:a}(DefaultContext()) +) +inner_ctx_with_inner_cond = PrefixContext{:a}( + ConditionContext(Dict(@varname(x) => 1.0), DefaultContext()) +) +``` + +then we want both of these to be `true` (and thankfully, they are!): + +```@example +DynamicPPL.hasconditioned_nested(inner_ctx_with_outer_cond, @varname(a.x)) +``` + +```@example +DynamicPPL.hasconditioned_nested(inner_ctx_with_inner_cond, @varname(a.x)) +``` + +Essentially, our job is threefold: + + - Firstly, given the correct arguments, we need to make sure that `hasconditioned_nested` and `getconditioned_nested` behave correctly. + + - Secondly, we need to make sure that both the correct arguments are supplied. In order to do so: + + + We need to make sure that when evaluating a submodel, the context stack is arranged such that prefixes are applied _inside_ the parent model's context, but _outside_ the submodel's own context. + + We also need to make sure that the `VarName` passed to it is prefixed correctly. This is, in fact, _not_ handled by `tilde_assume`, because `contextual_isassumption` is much higher in the call stack than `tilde_assume` is. So, we need to explicitly prefix it. + +## How do we do it? + +`hasconditioned_nested` accomplishes this by doing the following: + + - If the outermost layer is a `ConditionContext`, it checks whether the variable is contained in its values. + - If the outermost layer is a `PrefixContext`, it goes through the `PrefixContext`'s child context and prefixes any inner conditioned variables, before checking whether the variable is contained. + +We ensure that the context stack is correctly arranged by relying on the behaviour of `make_evaluate_args_and_kwargs`. +This function is called whenever a model (which itself contains a context) is evaluated with a separate ('outer') context, and makes sure to arrange it such that the model's context is nested inside the outer context. +Thus, as long as prefixing is implemented by applying a `PrefixContext` on the outermost layer of the _inner_ model context, this will be correctly combined with an outer context to give the behaviour seen above. + +And finally, we ensure that the `VarName` is correctly prefixed by modifying the `@model` macro (or, technically, its subsidiary `isassumption`) to explicitly prefix the variable before passing it to `contextual_isassumption`. + +## FixedContext + +Finally, note that all of the above also applies to the interaction between `PrefixContext` and `FixedContext`, except that the functions have different names. +(`FixedContext` behaves the same way as `ConditionContext`, except that unlike conditioned variables, fixed variables do not contribute to the log probability density.) From 32437fb4d2f4a53a275b1dbfce4d0a73104cd7c7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 17 Apr 2025 16:11:32 +0100 Subject: [PATCH 04/17] Fix some tests --- src/contexts.jl | 5 +++ src/model.jl | 35 ++++++++------- test/contexts.jl | 113 ++++++++++++++++------------------------------- 3 files changed, 63 insertions(+), 90 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 7ba79bc7f..d6fbc50a7 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -519,6 +519,11 @@ function conditioned(context::ConditionContext) # precedence over decendants of `context`. return _merge(context.values, conditioned(childcontext(context))) end +function conditioned(context::PrefixContext{Prefix}) where {Prefix} + return conditioned( + prefix_conditioned_variables(childcontext(context), VarName{Prefix}()) + ) +end struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext values::Values diff --git a/src/model.jl b/src/model.jl index b4d5f6bb7..4122a1eae 100644 --- a/src/model.jl +++ b/src/model.jl @@ -425,29 +425,34 @@ julia> # Returns all the variables we have conditioned on + their values. conditioned(condition(m, x=100.0, m=1.0)) (x = 100.0, m = 1.0) -julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). +julia> # Nested ones also work. + # (Note that `PrefixContext` also prefixes the variables of any + # ConditionContext that is _inside_ it; because of this, the type of the + # container has to be broadened to a `Dict`.) cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0); julia> conditioned(cm) -(x = 100.0, m = 1.0) +Dict{VarName, Any} with 2 entries: + a.m => 1.0 + x => 100.0 -julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, - # `a.m` is treated as a random variable. +julia> # Since we conditioned on `a.m`, it is not treated as a random variable. + # However, `a.x` will still be a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: - a.m +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x -julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. - cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext(Dict(@varname(a.m) => 1.0)))), x=100.0); +julia> # We can also condition on `a.m` _outside_ of the PrefixContext: + cm = condition(contextualize(m, PrefixContext{:a}(DefaultContext())), (@varname(a.m) => 1.0)); -julia> conditioned(cm)[@varname(x)] -100.0 - -julia> conditioned(cm)[@varname(a.m)] -1.0 +julia> conditioned(cm) +Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: + a.m => 1.0 -julia> keys(VarInfo(cm)) # No variables are sampled -VarName[] +julia> # Now `a.x` will be sampled. + keys(VarInfo(cm)) +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x ``` """ conditioned(model::Model) = conditioned(model.context) diff --git a/test/contexts.jl b/test/contexts.jl index 11e591f8f..ffb720e22 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,4 +1,5 @@ using Test, DynamicPPL, Accessors +using AbstractPPL: getoptic using DynamicPPL: leafcontext, setleafcontext, @@ -57,7 +58,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) ), :condition3 => ConditionContext( - (x=1.0,), PrefixContext{:a}(ConditionContext(Dict(@varname(a.y) => 2.0))) + (x=1.0,), PrefixContext{:a}(ConditionContext(Dict(@varname(y) => 2.0))) ), :condition4 => ConditionContext((x=[1.0, missing],)), ) @@ -70,91 +71,53 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "contextual_isassumption" begin - @testset "$(name)" for (name, context) in contexts - # Any `context` should return `true` by default. - @test contextual_isassumption(context, VarName{gensym(:x)}()) - - if any(Base.Fix2(isa, ConditionContext), context) - # We have a `ConditionContext` among us. - # Let's first extract the conditioned variables. - conditioned_values = DynamicPPL.conditioned(context) - - # The conditioned values might be a NamedTuple, or a Dict. - # We convert to a Dict for consistency - if conditioned_values isa NamedTuple - conditioned_values = Dict( - VarName{sym}() => val for (sym, val) in pairs(conditioned_values) - ) - end - - for (vn, val) in pairs(conditioned_values) - # We need to drop the prefix of `var` since in `contextual_isassumption` - # it will be threaded through the `PrefixContext` before it reaches - # `ConditionContext` with the conditioned variable. - vn_without_prefix = if getoptic(vn) isa PropertyLens - # Hacky: This assumes that there is exactly one level of prefixing - # that we need to undo. This is appropriate for the :condition3 - # test case above, but is not generally correct. - AbstractPPL.unprefix(vn, VarName{getsym(vn)}()) - else - vn - end - - @show DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - # Let's check elementwise. - for vn_child in - DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - if getoptic(vn_child)(val) === missing - @test contextual_isassumption(context, vn_child) - else - @test !contextual_isassumption(context, vn_child) - end - end - end - end - end - end + @testset "extracting conditioned values" begin + # This testset tests `contextual_isassumption`, `getconditioned_nested`, and + # `hasconditioned_nested`. - @testset "getconditioned_nested & hasconditioned_nested" begin - @testset "$name" for (name, context) in contexts + @testset "$(name)" for (name, context) in contexts + # If the varname doesn't exist, it should always be an assumption. fake_vn = VarName{gensym(:x)}() + @test contextual_isassumption(context, fake_vn) @test !hasconditioned_nested(context, fake_vn) @test_throws ErrorException getconditioned_nested(context, fake_vn) if any(Base.Fix2(isa, ConditionContext), context) - # `ConditionContext` specific. - + # We have a `ConditionContext` among us. # Let's first extract the conditioned variables. conditioned_values = DynamicPPL.conditioned(context) + # The conditioned values might be a NamedTuple, or a Dict. # We convert to a Dict for consistency - if conditioned_values isa NamedTuple - conditioned_values = Dict( - VarName{sym}() => val for (sym, val) in pairs(conditioned_values) - ) - end - - for (vn, val) in pairs(conditioned_values) - # We need to drop the prefix of `var` since in `contextual_isassumption` - # it will be threaded through the `PrefixContext` before it reaches - # `ConditionContext` with the conditioned variable. - vn_without_prefix = if getoptic(vn) isa PropertyLens - # Hacky: This assumes that there is exactly one level of prefixing - # that we need to undo. This is appropriate for the :condition3 - # test case above, but is not generally correct. - AbstractPPL.unprefix(vn, VarName{getsym(vn)}()) + conditioned_values = DynamicPPL.to_varname_dict(conditioned_values) + + # Extract all conditioned variables. We also use varname_leaves + # here to split up arrays which could potentially have some, + # but not all, elements being `missing`. + conditioned_vns = mapreduce( + p -> DynamicPPL.TestUtils.varname_leaves(p.first, p.second), + vcat, + pairs(conditioned_values), + ) + @show conditioned_vns + + # We can now loop over them to check which ones are missing. We use + # `getvalue` to handle the awkward case where sometimes + # `conditioned_values` contains the full Varname (e.g. `a.x`) and + # sometimes only the main symbol (e.g. it contains `x` when + # `vn` is `x[1]`) + for vn in conditioned_vns + val = DynamicPPL.getvalue(conditioned_values, vn) + # These VarNames are present in the conditioning values, so + # we should always be able to extract the value. + @test hasconditioned_nested(context, vn) + @test getconditioned_nested(context, vn) === val + # However, the return value of contextual_isassumption depends on + # whether the value is missing or not. + if ismissing(val) + @test contextual_isassumption(context, vn) else - vn - end - - for vn_child in - DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - # `vn_child` should be in `context`. - @test hasconditioned_nested(context, vn_child) - # Value should be the same as extracted above. - @test getconditioned_nested(context, vn_child) === - getoptic(vn_child)(val) + @test !contextual_isassumption(context, vn) end end end From 30368c797c9475c8d45cc5a4b5c5d82dd7720c12 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 17 Apr 2025 17:58:41 +0100 Subject: [PATCH 05/17] Add tests; fix a bunch of nested submodel issues --- src/context_implementations.jl | 18 ++++- src/contexts.jl | 55 ++++++++++++++-- test/runtests.jl | 1 + test/submodels.jl | 117 +++++++++++++++++++++++++++++++++ 4 files changed, 183 insertions(+), 8 deletions(-) create mode 100644 test/submodels.jl diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 2e7b85162..7b17b9e8c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -57,6 +57,7 @@ function tilde_assume(context::AbstractContext, args...) return tilde_assume(NodeTrait(tilde_assume, context), context, args...) end function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vi) + @show "isleaf", vn return assume(right, vn, vi) end function tilde_assume(::IsParent, context::AbstractContext, args...) @@ -85,12 +86,25 @@ function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, rig end function tilde_assume(context::PrefixContext, right, vn, vi) - return tilde_assume(context.context, right, prefix(context, vn), vi) + # The slightly tricky thing about PrefixContext is that they are applied + # from the outside in, so `PrefixContext{:a}(PrefixContext{:b}(ctx))` means + # that variables get prefixed like `a.b.x`. + # This motivates the implementation shown here, where the function + # `prefix_and_strip_contexts` is responsible for not only adding the + # prefixes, but also removing the `PrefixContext`s from the context stack + # so that they don't get applied twice when recursing. + # TODO(penelopeysm): It would be nice to switch this round, but it's a very + # tricky task. Essentially it forces us to use a foldr inside + # `prefix_and_strip_contexts`, rather than a foldl which is what most of + # DynamicPPL uses. + new_vn, new_context = prefix_and_strip_contexts(context, vn) + return tilde_assume(new_context, right, new_vn, vi) end function tilde_assume( rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi ) - return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), vi) + new_vn, new_context = prefix_and_strip_contexts(context, vn) + return tilde_assume(rng, new_context, sampler, right, new_vn, vi) end """ diff --git a/src/contexts.jl b/src/contexts.jl index d6fbc50a7..299a93199 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -276,6 +276,30 @@ function prefix(::IsParent, ctx::AbstractContext, vn::VarName) return prefix(childcontext(ctx), vn) end +""" + prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) + +Same as `prefix`, but additionally returns a new context stack that has all the +PrefixContexts removed. +""" +function prefix_and_strip_contexts(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix} + child_context = childcontext(ctx) + # vn_prefixed contains the prefixes from all lower levels + vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( + child_context, vn + ) + return AbstractPPL.prefix(vn_prefixed, VarName{Prefix}()), + child_context_without_prefixes +end +function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) + return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) +end +prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) +function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) + vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) + return vn, setchildcontext(ctx, new_ctx) +end + """ prefix(model::Model, x) @@ -351,6 +375,29 @@ NodeTrait(::ConditionContext) = IsParent() childcontext(context::ConditionContext) = context.context setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) +""" + collapse_prefix_and_condition(context::AbstractContext) + +Apply `PrefixContext`s to any conditioned values inside them, and remove +the `PrefixContext`s from the context stack. + +```jldoctest +julia> using DynamicPPL: collapse_prefix_and_condition + +julia> c1 = PrefixContext({:a}(ConditionContext((x=1, ))) +``` +""" +function collapse_prefix_and_condition(context::PrefixContext{Prefix}) where {Prefix} + # Collapse the child context (thus applying any inner prefixes first) + collapsed = collapse_prefix_and_condition(childcontext(context)) + # Prefix any conditioned variables with the current prefix + # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. + # So is this function. In the worst case scenario, this is O(N^2) in the + # depth of the context stack. + return prefix_conditioned_variables(collapsed, VarName{Prefix}()) +end +collapse_prefix_and_condition(context::AbstractContext) = context + """ prefix_conditioned_variables(context::AbstractContext, prefix::VarName) @@ -427,9 +474,7 @@ function hasconditioned_nested(::IsParent, context, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end function hasconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix} - return hasconditioned_nested( - prefix_conditioned_variables(childcontext(context), VarName{Prefix}()), vn - ) + return hasconditioned_nested(collapse_prefix_and_condition(context), vn) end """ @@ -447,9 +492,7 @@ function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix} - return getconditioned_nested( - prefix_conditioned_variables(childcontext(context), VarName{Prefix}()), vn - ) + return getconditioned_nested(collapse_prefix_and_condition(context), vn) end function getconditioned_nested(::IsParent, context, vn) return if hasconditioned(context, vn) diff --git a/test/runtests.jl b/test/runtests.jl index 3473d5594..72f33f2d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -67,6 +67,7 @@ include("test_util.jl") include("threadsafe.jl") include("debug_utils.jl") include("deprecated.jl") + include("submodels.jl") end if GROUP == "All" || GROUP == "Group2" diff --git a/test/submodels.jl b/test/submodels.jl new file mode 100644 index 000000000..721284fb8 --- /dev/null +++ b/test/submodels.jl @@ -0,0 +1,117 @@ +module DPPLSubmodelTests + +using DynamicPPL +using Distributions +using Test + +@testset "submodels.jl" begin + @testset "Conditioning variables" begin + @testset "Auto prefix" begin + @model function inner() + x ~ Normal() + return y ~ Normal() + end + @model function outer() + return a ~ to_submodel(inner()) + end + inner_cond = inner() | (@varname(x) => 1.0) + with_outer_cond = outer() | (@varname(a.x) => 1.0) + + # No conditioning + @test Set(keys(VarInfo(outer()))) == Set([@varname(a.x), @varname(a.y)]) + # Conditioning from the outside + @test Set(keys(VarInfo(with_outer_cond))) == Set([@varname(a.y)]) + # Conditioning from the inside + @model function outer2() + return a ~ to_submodel(inner_cond) + end + with_inner_cond = outer2() + @test Set(keys(VarInfo(with_inner_cond))) == Set([@varname(a.y)]) + end + + @testset "No prefix" begin + @model function inner() + x ~ Normal() + return y ~ Normal() + end + @model function outer() + return a ~ to_submodel(inner(), false) + end + inner_cond = inner() | (@varname(x) => 1.0) + with_outer_cond = outer() | (@varname(x) => 1.0) + + # No conditioning + @test Set(keys(VarInfo(outer()))) == Set([@varname(x), @varname(y)]) + # Conditioning from the outside + @test Set(keys(VarInfo(with_outer_cond))) == Set([@varname(y)]) + # Conditioning from the inside + @model function outer2() + return a ~ to_submodel(inner_cond, false) + end + with_inner_cond = outer2() + @test Set(keys(VarInfo(with_inner_cond))) == Set([@varname(y)]) + end + + @testset "Manual prefix" begin + @model function inner() + x ~ Normal() + return y ~ Normal() + end + @model function outer() + return a ~ to_submodel(prefix(inner(), :b), false) + end + inner_cond = inner() | (@varname(x) => 1.0) + with_outer_cond = outer() | (@varname(b.x) => 1.0) + + # No conditioning + @test Set(keys(VarInfo(outer()))) == Set([@varname(b.x), @varname(b.y)]) + # Conditioning from the outside + @test Set(keys(VarInfo(with_outer_cond))) == Set([@varname(b.y)]) + # Conditioning from the inside + @model function outer2() + return a ~ to_submodel(prefix(inner_cond, :b), false) + end + with_inner_cond = outer2() + @test Set(keys(VarInfo(with_inner_cond))) == Set([@varname(b.y)]) + end + + @testset "Nested submodels" begin + @model function f() + x ~ Normal() + return y ~ Normal() + end + @model function g() + return _unused ~ to_submodel(prefix(f(), :b), false) + end + @model function h() + return a ~ to_submodel(g()) + end + + # No conditioning + @test Set(keys(VarInfo(h()))) == Set([@varname(a.b.x), @varname(a.b.y)]) + + # Conditioning at the top level + condition_h = h() | (@varname(a.b.x) => 1.0) + @test Set(keys(VarInfo(condition_h))) == Set([@varname(a.b.y)]) + + # Conditioning at the second level + condition_g = g() | (@varname(b.x) => 1.0) + @model function h2() + return a ~ to_submodel(condition_g) + end + @test Set(keys(VarInfo(h2()))) == Set([@varname(a.b.y)]) + + # Conditioning at the very bottom + condition_f = f() | (@varname(x) => 1.0) + @model function g2() + return _unused ~ to_submodel(prefix(condition_f, :b), false) + end + @model function h3() + return a ~ to_submodel(g2()) + end + @test Set(keys(VarInfo(h3()))) == Set([@varname(a.b.y)]) + end + end +end + +end From 32c6a37f8a19f4bf0adcada2b9702426b142b1e9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 17 Apr 2025 18:35:31 +0100 Subject: [PATCH 06/17] Fix fix as well --- src/compiler.jl | 10 ++- src/context_implementations.jl | 1 - src/contexts.jl | 149 ++++++++++++++++++--------------- test/submodels.jl | 113 +++++++++++++++---------- 4 files changed, 160 insertions(+), 113 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index eb71404a4..6f7489b8e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -105,7 +105,11 @@ function contextual_isassumption(context::AbstractContext, vn) end isfixed(expr, vn) = false -isfixed(::Union{Symbol,Expr}, vn) = :($(DynamicPPL.contextual_isfixed)(__context__, $vn)) +function isfixed(::Union{Symbol,Expr}, vn) + return :($(DynamicPPL.contextual_isfixed)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + )) +end """ contextual_isfixed(context, vn) @@ -443,7 +447,9 @@ function generate_tilde(left, right) ) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) - $left = $(DynamicPPL.getfixed_nested)(__context__, $vn) + $left = $(DynamicPPL.getfixed_nested)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + ) elseif $isassumption $(generate_tilde_assume(left, dist, vn)) else diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 7b17b9e8c..8cebd5f81 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -57,7 +57,6 @@ function tilde_assume(context::AbstractContext, args...) return tilde_assume(NodeTrait(tilde_assume, context), context, args...) end function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vi) - @show "isleaf", vn return assume(right, vn, vi) end function tilde_assume(::IsParent, context::AbstractContext, args...) diff --git a/src/contexts.jl b/src/contexts.jl index 299a93199..36e28ff72 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -375,66 +375,6 @@ NodeTrait(::ConditionContext) = IsParent() childcontext(context::ConditionContext) = context.context setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) -""" - collapse_prefix_and_condition(context::AbstractContext) - -Apply `PrefixContext`s to any conditioned values inside them, and remove -the `PrefixContext`s from the context stack. - -```jldoctest -julia> using DynamicPPL: collapse_prefix_and_condition - -julia> c1 = PrefixContext({:a}(ConditionContext((x=1, ))) -``` -""" -function collapse_prefix_and_condition(context::PrefixContext{Prefix}) where {Prefix} - # Collapse the child context (thus applying any inner prefixes first) - collapsed = collapse_prefix_and_condition(childcontext(context)) - # Prefix any conditioned variables with the current prefix - # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. - # So is this function. In the worst case scenario, this is O(N^2) in the - # depth of the context stack. - return prefix_conditioned_variables(collapsed, VarName{Prefix}()) -end -collapse_prefix_and_condition(context::AbstractContext) = context - -""" - prefix_conditioned_variables(context::AbstractContext, prefix::VarName) - -Prefix all the conditioned variables in a given context with `prefix`. - -```jldoctest -julia> using DynamicPPL: prefix_conditioned_variables, ConditionContext - -julia> c1 = ConditionContext((a=1, )) -ConditionContext((a = 1,), DefaultContext()) - -julia> prefix_conditioned_variables(c1, @varname(y)) -ConditionContext(Dict(y.a => 1), DefaultContext()) -``` -""" -function prefix_conditioned_variables(ctx::ConditionContext, prefix::VarName) - # Replace the prefix of the conditioned variables - vn_dict = to_varname_dict(ctx.values) - prefixed_vn_dict = Dict( - AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict - ) - # Prefix the child context as well - prefixed_child_ctx = prefix_conditioned_variables(childcontext(ctx), prefix) - return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) -end -function prefix_conditioned_variables(c::AbstractContext, prefix::VarName) - return prefix_conditioned_variables( - NodeTrait(prefix_conditioned_variables, c), c, prefix - ) -end -prefix_conditioned_variables(::IsLeaf, context::AbstractContext, prefix::VarName) = context -function prefix_conditioned_variables(::IsParent, context::AbstractContext, prefix::VarName) - return setchildcontext( - context, prefix_conditioned_variables(childcontext(context), prefix) - ) -end - """ hasconditioned(context::AbstractContext, vn::VarName) @@ -474,7 +414,7 @@ function hasconditioned_nested(::IsParent, context, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end function hasconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix} - return hasconditioned_nested(collapse_prefix_and_condition(context), vn) + return hasconditioned_nested(collapse_prefix_stack(context), vn) end """ @@ -492,7 +432,7 @@ function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix} - return getconditioned_nested(collapse_prefix_and_condition(context), vn) + return getconditioned_nested(collapse_prefix_stack(context), vn) end function getconditioned_nested(::IsParent, context, vn) return if hasconditioned(context, vn) @@ -563,9 +503,7 @@ function conditioned(context::ConditionContext) return _merge(context.values, conditioned(childcontext(context))) end function conditioned(context::PrefixContext{Prefix}) where {Prefix} - return conditioned( - prefix_conditioned_variables(childcontext(context), VarName{Prefix}()) - ) + return conditioned(collapse_prefix_stack(context)) end struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext @@ -630,7 +568,7 @@ function hasfixed_nested(::IsParent, context, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end function hasfixed_nested(context::PrefixContext, vn) - return hasfixed_nested(childcontext(context), prefix(context, vn)) + return hasfixed_nested(collapse_prefix_stack(context), vn) end """ @@ -648,7 +586,7 @@ function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getfixed_nested(context::PrefixContext, vn) - return getfixed_nested(childcontext(context), prefix(context, vn)) + return getfixed_nested(collapse_prefix_stack(context), vn) end function getfixed_nested(::IsParent, context, vn) return if hasfixed(context, vn) @@ -743,3 +681,80 @@ function fixed(context::FixedContext) # precedence over decendants of `context`. return _merge(context.values, fixed(childcontext(context))) end + +""" + collapse_prefix_stack(context::AbstractContext) + +Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove +the `PrefixContext`s from the context stack. + +```jldoctest +julia> using DynamicPPL: collapse_prefix_stack + +julia> c1 = PrefixContext({:a}(ConditionContext((x=1, ))) +``` +""" +function collapse_prefix_stack(context::PrefixContext{Prefix}) where {Prefix} + # Collapse the child context (thus applying any inner prefixes first) + collapsed = collapse_prefix_stack(childcontext(context)) + # Prefix any conditioned variables with the current prefix + # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. + # So is this function. In the worst case scenario, this is O(N^2) in the + # depth of the context stack. + return prefix_cond_and_fixed_variables(collapsed, VarName{Prefix}()) +end +collapse_prefix_stack(context::AbstractContext) = context + +""" + prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) + +Prefix all the conditioned and fixed variables in a given context with a single +`prefix`. + +```jldoctest +julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext + +julia> c1 = ConditionContext((a=1, )) +ConditionContext((a = 1,), DefaultContext()) + +julia> prefix_cond_and_fixed_variables(c1, @varname(y)) +ConditionContext(Dict(y.a => 1), DefaultContext()) +``` +""" +function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return FixedContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) + return prefix_cond_and_fixed_variables( + NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix + ) +end +function prefix_cond_and_fixed_variables( + ::IsLeaf, context::AbstractContext, prefix::VarName +) + return context +end +function prefix_cond_and_fixed_variables( + ::IsParent, context::AbstractContext, prefix::VarName +) + return setchildcontext( + context, prefix_cond_and_fixed_variables(childcontext(context), prefix) + ) +end diff --git a/test/submodels.jl b/test/submodels.jl index 721284fb8..da91b8b93 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -5,74 +5,101 @@ using Distributions using Test @testset "submodels.jl" begin - @testset "Conditioning variables" begin + @testset "$op" for op in [condition, fix] + x_val = 1.0 + x_logp = op == condition ? logpdf(Normal(), x_val) : 0.0 + @testset "Auto prefix" begin @model function inner() x ~ Normal() - return y ~ Normal() + y ~ Normal() + return (x, y) end @model function outer() return a ~ to_submodel(inner()) end - inner_cond = inner() | (@varname(x) => 1.0) - with_outer_cond = outer() | (@varname(a.x) => 1.0) + inner_op = op(inner(), (@varname(x) => x_val)) + @model function outer2() + return a ~ to_submodel(inner_op) + end + with_inner_op = outer2() + with_outer_op = op(outer(), (@varname(a.x) => x_val)) - # No conditioning + # No conditioning/fixing @test Set(keys(VarInfo(outer()))) == Set([@varname(a.x), @varname(a.y)]) - # Conditioning from the outside - @test Set(keys(VarInfo(with_outer_cond))) == Set([@varname(a.y)]) - # Conditioning from the inside - @model function outer2() - return a ~ to_submodel(inner_cond) + + # With conditioning/fixing + @testset "$model" for model in [with_inner_op, with_outer_op] + # Test that the value was correctly set + @test model()[1] == x_val + # Test that the logp was correctly set + vi = VarInfo(model) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) + # Check the keys + @test Set(keys(VarInfo(model))) == Set([@varname(a.y)]) end - with_inner_cond = outer2() - @test Set(keys(VarInfo(with_inner_cond))) == Set([@varname(a.y)]) end @testset "No prefix" begin @model function inner() x ~ Normal() - return y ~ Normal() + y ~ Normal() + return x end @model function outer() return a ~ to_submodel(inner(), false) end - inner_cond = inner() | (@varname(x) => 1.0) - with_outer_cond = outer() | (@varname(x) => 1.0) + @model function outer2() + return a ~ to_submodel(inner_op, false) + end + with_inner_op = outer2() + inner_op = op(inner(), (@varname(x) => x_val)) + with_outer_op = op(outer(), (@varname(x) => x_val)) - # No conditioning + # No conditioning/fixing @test Set(keys(VarInfo(outer()))) == Set([@varname(x), @varname(y)]) - # Conditioning from the outside - @test Set(keys(VarInfo(with_outer_cond))) == Set([@varname(y)]) - # Conditioning from the inside - @model function outer2() - return a ~ to_submodel(inner_cond, false) + + # With conditioning/fixing + @testset "$model" for model in [with_inner_op, with_outer_op] + # Test that the value was correctly set + @test model()[1] == x_val + # Test that the logp was correctly set + vi = VarInfo(model) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) + # Check the keys + @test Set(keys(VarInfo(model))) == Set([@varname(y)]) end - with_inner_cond = outer2() - @test Set(keys(VarInfo(with_inner_cond))) == Set([@varname(y)]) end @testset "Manual prefix" begin @model function inner() x ~ Normal() - return y ~ Normal() + y ~ Normal() + return x end @model function outer() return a ~ to_submodel(prefix(inner(), :b), false) end - inner_cond = inner() | (@varname(x) => 1.0) - with_outer_cond = outer() | (@varname(b.x) => 1.0) + inner_op = op(inner(), (@varname(x) => x_val)) + @model function outer2() + return a ~ to_submodel(prefix(inner_op, :b), false) + end + with_inner_op = outer2() + with_outer_op = op(outer(), (@varname(b.x) => x_val)) - # No conditioning + # No conditioning/fixing @test Set(keys(VarInfo(outer()))) == Set([@varname(b.x), @varname(b.y)]) - # Conditioning from the outside - @test Set(keys(VarInfo(with_outer_cond))) == Set([@varname(b.y)]) - # Conditioning from the inside - @model function outer2() - return a ~ to_submodel(prefix(inner_cond, :b), false) + + # With conditioning/fixing + @testset "$model" for model in [with_inner_op, with_outer_op] + # Test that the value was correctly set + @test model()[1] == x_val + # Test that the logp was correctly set + vi = VarInfo(model) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) + # Check the keys + @test Set(keys(VarInfo(model))) == Set([@varname(b.y)]) end - with_inner_cond = outer2() - @test Set(keys(VarInfo(with_inner_cond))) == Set([@varname(b.y)]) end @testset "Nested submodels" begin @@ -90,21 +117,21 @@ using Test # No conditioning @test Set(keys(VarInfo(h()))) == Set([@varname(a.b.x), @varname(a.b.y)]) - # Conditioning at the top level - condition_h = h() | (@varname(a.b.x) => 1.0) - @test Set(keys(VarInfo(condition_h))) == Set([@varname(a.b.y)]) + # Conditioning/fixing at the top level + op_h = op(h(), (@varname(a.b.x) => x_val)) + @test Set(keys(VarInfo(op_h))) == Set([@varname(a.b.y)]) - # Conditioning at the second level - condition_g = g() | (@varname(b.x) => 1.0) + # Conditioning/fixing at the second level + op_g = op(g(), (@varname(b.x) => x_val)) @model function h2() - return a ~ to_submodel(condition_g) + return a ~ to_submodel(op_g) end @test Set(keys(VarInfo(h2()))) == Set([@varname(a.b.y)]) - # Conditioning at the very bottom - condition_f = f() | (@varname(x) => 1.0) + # Conditioning/fixing at the very bottom + op_f = op(f(), (@varname(x) => x_val)) @model function g2() - return _unused ~ to_submodel(prefix(condition_f, :b), false) + return _unused ~ to_submodel(prefix(op_f, :b), false) end @model function h3() return a ~ to_submodel(g2()) From 523f411e28b141abafe594c73e745064b416e0c3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 17 Apr 2025 19:35:41 +0100 Subject: [PATCH 07/17] Fix doctests --- src/contexts.jl | 31 ++++++++++++++++++++++++++++--- src/model.jl | 38 +++++++++++++++++--------------------- 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 36e28ff72..a044e85cb 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -502,7 +502,7 @@ function conditioned(context::ConditionContext) # precedence over decendants of `context`. return _merge(context.values, conditioned(childcontext(context))) end -function conditioned(context::PrefixContext{Prefix}) where {Prefix} +function conditioned(context::PrefixContext) return conditioned(collapse_prefix_stack(context)) end @@ -681,6 +681,9 @@ function fixed(context::FixedContext) # precedence over decendants of `context`. return _merge(context.values, fixed(childcontext(context))) end +function fixed(context::PrefixContext) + return fixed(collapse_prefix_stack(context)) +end """ collapse_prefix_stack(context::AbstractContext) @@ -691,7 +694,22 @@ the `PrefixContext`s from the context stack. ```jldoctest julia> using DynamicPPL: collapse_prefix_stack -julia> c1 = PrefixContext({:a}(ConditionContext((x=1, ))) +julia> c1 = PrefixContext{:a}(ConditionContext((x=1, ))); + +julia> collapse_prefix_stack(c1) +ConditionContext(Dict(a.x => 1), DefaultContext()) + +julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. + c2 = PrefixContext{:a}(ConditionContext((x=1, ), PrefixContext{:b}(ConditionContext((y=2,))))); + +julia> collapsed = collapse_prefix_stack(c2); + +julia> # `collapsed` really looks something like this: + # ConditionContext(Dict{VarName{:a}, Int64}(a.b.y => 2, a.x => 1), DefaultContext()) + # To avoid fragility arising from the order of the keys in the doctest, we test + # this indirectly: + collapsed.values[@varname(a.x)], collapsed.values[@varname(a.b.y)] +(1, 2) ``` """ function collapse_prefix_stack(context::PrefixContext{Prefix}) where {Prefix} @@ -703,7 +721,14 @@ function collapse_prefix_stack(context::PrefixContext{Prefix}) where {Prefix} # depth of the context stack. return prefix_cond_and_fixed_variables(collapsed, VarName{Prefix}()) end -collapse_prefix_stack(context::AbstractContext) = context +function collapse_prefix_stack(context::AbstractContext) + return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) +end +collapse_prefix_stack(::IsLeaf, context) = context +function collapse_prefix_stack(::IsParent, context) + new_child_context = collapse_prefix_stack(childcontext(context)) + return setchildcontext(context, new_child_context) +end """ prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) diff --git a/src/model.jl b/src/model.jl index 4122a1eae..cfe87ad44 100644 --- a/src/model.jl +++ b/src/model.jl @@ -431,10 +431,8 @@ julia> # Nested ones also work. # container has to be broadened to a `Dict`.) cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0); -julia> conditioned(cm) -Dict{VarName, Any} with 2 entries: - a.m => 1.0 - x => 100.0 +julia> Set(keys(conditioned(cm))) == Set([@varname(a.m), @varname(x)]) +true julia> # Since we conditioned on `a.m`, it is not treated as a random variable. # However, `a.x` will still be a random variable. @@ -770,29 +768,27 @@ julia> # Returns all the variables we have fixed on + their values. fixed(fix(m, x=100.0, m=1.0)) (x = 100.0, m = 1.0) -julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). +julia> # The rest of this is the same as the `condition` example above. cm = fix(contextualize(m, PrefixContext{:a}(fix(m=1.0))), x=100.0); -julia> fixed(cm) -(x = 100.0, m = 1.0) - -julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed, - # `a.m` is treated as a random variable. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: - a.m +julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)]) +true -julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation. - cm = fix(contextualize(m, PrefixContext{:a}(fix(@varname(a.m) => 1.0,))), x=100.0); +julia> keys(VarInfo(cm)) +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x -julia> fixed(cm)[@varname(x)] -100.0 +julia> # We can also condition on `a.m` _outside_ of the PrefixContext: + cm = fix(contextualize(m, PrefixContext{:a}(DefaultContext())), (@varname(a.m) => 1.0)); -julia> fixed(cm)[@varname(a.m)] -1.0 +julia> fixed(cm) +Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: + a.m => 1.0 -julia> keys(VarInfo(cm)) # <= no variables are sampled -VarName[] +julia> # Now `a.x` will be sampled. + keys(VarInfo(cm)) +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x ``` """ fixed(model::Model) = fixed(model.context) From 31a65a2d886b4264e1e9518fdbaa1e1658e294a3 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 17 Apr 2025 22:06:54 +0100 Subject: [PATCH 08/17] Add unit tests for new functions --- test/contexts.jl | 126 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 125 insertions(+), 1 deletion(-) diff --git a/test/contexts.jl b/test/contexts.jl index ffb720e22..081e59775 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -11,12 +11,18 @@ using DynamicPPL: IsParent, PointwiseLogdensityContext, contextual_isassumption, + FixedContext, ConditionContext, decondition_context, hasconditioned, getconditioned, + conditioned, + fixed, hasconditioned_nested, - getconditioned_nested + getconditioned_nested, + collapse_prefix_stack, + prefix_cond_and_fixed_variables, + getvalue using EnzymeCore @@ -156,6 +162,29 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end + @testset "prefix_and_strip_contexts" begin + vn = @varname(x[1]) + ctx1 = PrefixContext{:a}(DefaultContext()) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx1, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == DefaultContext() + + ctx2 = SamplingContext(PrefixContext{:a}(DefaultContext())) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == SamplingContext() + + ctx3 = PrefixContext{:a}(ConditionContext((a=1,))) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == ConditionContext((a=1,)) + + ctx4 = SamplingContext(PrefixContext{:a}(ConditionContext((a=1,)))) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == SamplingContext(ConditionContext((a=1,))) + end + @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS prefix = :my_prefix context = DynamicPPL.PrefixContext{prefix}(SamplingContext()) @@ -306,4 +335,99 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) end end + + @testset "PrefixContext + Condition/FixedContext interactions" begin + @testset "prefix_cond_and_fixed_variables" begin + c1 = ConditionContext((c=1, d=2)) + c1_prefixed = prefix_cond_and_fixed_variables(c1, @varname(a)) + @test c1_prefixed isa ConditionContext + @test childcontext(c1_prefixed) isa DefaultContext + @test c1_prefixed.values[@varname(a.c)] == 1 + @test c1_prefixed.values[@varname(a.d)] == 2 + + c2 = FixedContext((f=1, g=2)) + c2_prefixed = prefix_cond_and_fixed_variables(c2, @varname(a)) + @test c2_prefixed isa FixedContext + @test childcontext(c2_prefixed) isa DefaultContext + @test c2_prefixed.values[@varname(a.f)] == 1 + @test c2_prefixed.values[@varname(a.g)] == 2 + + c3 = ConditionContext((c=1, d=2), FixedContext((f=1, g=2))) + c3_prefixed = prefix_cond_and_fixed_variables(c3, @varname(a)) + c3_prefixed_child = childcontext(c3_prefixed) + @test c3_prefixed isa ConditionContext + @test c3_prefixed.values[@varname(a.c)] == 1 + @test c3_prefixed.values[@varname(a.d)] == 2 + @test c3_prefixed_child isa FixedContext + @test c3_prefixed_child.values[@varname(a.f)] == 1 + @test c3_prefixed_child.values[@varname(a.g)] == 2 + @test childcontext(c3_prefixed_child) isa DefaultContext + end + + @testset "collapse_prefix_stack" begin + # Utility function to make sure that there are no PrefixContexts in + # the context stack. + function has_no_prefixcontexts(ctx::AbstractContext) + return !(ctx isa PrefixContext) && ( + NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx)) + ) + end + + # Prefix -> Condition + c1 = PrefixContext{:a}(ConditionContext((c=1, d=2))) + c1 = collapse_prefix_stack(c1) + @test has_no_prefixcontexts(c1) + c1_vals = conditioned(c1) + @test length(c1_vals) == 2 + @test getvalue(c1_vals, @varname(a.c)) == 1 + @test getvalue(c1_vals, @varname(a.d)) == 2 + + # Condition -> Prefix + c2 = (ConditionContext((c=1, d=2), PrefixContext{:a}(DefaultContext()))) + c2 = collapse_prefix_stack(c2) + @test has_no_prefixcontexts(c2) + c2_vals = conditioned(c2) + @test length(c2_vals) == 2 + @test getvalue(c2_vals, @varname(c)) == 1 + @test getvalue(c2_vals, @varname(d)) == 2 + + # Prefix -> Fixed + c3 = PrefixContext{:a}(FixedContext((f=1, g=2))) + c3 = collapse_prefix_stack(c3) + c3_vals = fixed(c3) + @test length(c3_vals) == 2 + @test length(c3_vals) == 2 + @test getvalue(c3_vals, @varname(a.f)) == 1 + @test getvalue(c3_vals, @varname(a.g)) == 2 + + # Fixed -> Prefix + c4 = (FixedContext((f=1, g=2), PrefixContext{:a}(DefaultContext()))) + c4 = collapse_prefix_stack(c4) + @test has_no_prefixcontexts(c4) + c4_vals = fixed(c4) + @test length(c4_vals) == 2 + @test getvalue(c4_vals, @varname(f)) == 1 + @test getvalue(c4_vals, @varname(g)) == 2 + + # Prefix -> Condition -> Prefix -> Condition + c5 = PrefixContext{:a}( + ConditionContext((c=1,), PrefixContext{:b}(ConditionContext((d=2,)))) + ) + c5 = collapse_prefix_stack(c5) + @test has_no_prefixcontexts(c5) + c5_vals = conditioned(c5) + @test length(c5_vals) == 2 + @test getvalue(c5_vals, @varname(a.c)) == 1 + @test getvalue(c5_vals, @varname(a.b.d)) == 2 + + # Prefix -> Condition -> Prefix -> Fixed + c6 = PrefixContext{:a}( + ConditionContext((c=1,), PrefixContext{:b}(FixedContext((d=2,)))) + ) + c6 = collapse_prefix_stack(c6) + @test has_no_prefixcontexts(c6) + @test conditioned(c6) == Dict(@varname(a.c) => 1) + @test fixed(c6) == Dict(@varname(a.b.d) => 2) + end + end end From 709389ce2aaaac62da3178d356a5e8a316fc2a81 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 17 Apr 2025 22:15:38 +0100 Subject: [PATCH 09/17] Add changelog entry --- HISTORY.md | 19 +++++++++++++++++++ src/contexts.jl | 8 ++++++++ 2 files changed, 27 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index a45644a64..70561d684 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,25 @@ **Breaking changes** +### Submodels + +Variables in a submodel can now be conditioned and fixed in a correct way. +See https://github.com/TuringLang/DynamicPPL.jl/issues/857 for a full illustration, but essentially it means you can now do this: + +```julia +@model function inner() + x ~ Normal() + return y ~ Normal() +end +inner_conditioned = inner() | (x=1.0,) +@model function outer() + return a ~ to_submodel(inner_conditioned) +end +``` + +and the `inner.x` variable will be correctly conditioned. +(Previously, you would have to condition `inner()` with the variable `a.x`, meaning that you would need to know what prefix to use before you had actually prefixed it.) + ### AD testing utilities `DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default. diff --git a/src/contexts.jl b/src/contexts.jl index a044e85cb..c71a7467e 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -691,6 +691,14 @@ end Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove the `PrefixContext`s from the context stack. +!!! note + If you are reading this docstring, you might probably be interested in a more +thorough explanation of how PrefixContext and ConditionContext / FixedContext +interact with one another, especially in the context of submodels. + The DynamicPPL documentation contains [a separate page on this +topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_condition/) +which explains this in much more detail. + ```jldoctest julia> using DynamicPPL: collapse_prefix_stack From a2c460e0b0f7f7fab198a062929702d5842f6be7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 17 Apr 2025 23:27:14 +0100 Subject: [PATCH 10/17] Update changelog Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- HISTORY.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 70561d684..17b0b2611 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -14,9 +14,8 @@ See https://github.com/TuringLang/DynamicPPL.jl/issues/857 for a full illustrati x ~ Normal() return y ~ Normal() end -inner_conditioned = inner() | (x=1.0,) @model function outer() - return a ~ to_submodel(inner_conditioned) + return a ~ to_submodel(inner() | (x=1.0,)) end ``` From 5c0c0f1f68788dd32cd7fa4c24148aeca000aeda Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 18 Apr 2025 01:15:56 +0100 Subject: [PATCH 11/17] Finish docs --- docs/Project.toml | 1 + docs/src/api.md | 12 +- docs/src/internals/submodel_condition.md | 170 +++++++++++++++++++---- src/context_implementations.jl | 20 ++- src/contexts.jl | 13 ++ 5 files changed, 175 insertions(+), 41 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 40a719e03..93f449308 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/docs/src/api.md b/docs/src/api.md index ec741c9ad..acca6e3af 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -78,7 +78,7 @@ decondition ## Fixing and unfixing -We can also _fix_ a collection of variables in a [`Model`](@ref) to certain using [`fix`](@ref). +We can also _fix_ a collection of variables in a [`Model`](@ref) to certain using [`DynamicPPL.fix`](@ref). This might seem quite similar to the aforementioned [`condition`](@ref) and its siblings, but they are indeed different operations: @@ -89,19 +89,19 @@ but they are indeed different operations: - `fix`ed variables are considered to be _constant_, and are thus not included in any log-probability computations. -The differences are more clearly spelled out in the docstring of [`fix`](@ref) below. +The differences are more clearly spelled out in the docstring of [`DynamicPPL.fix`](@ref) below. ```@docs -fix +DynamicPPL.fix DynamicPPL.fixed ``` -The difference between [`fix`](@ref) and [`condition`](@ref) is described in the docstring of [`fix`](@ref) above. +The difference between [`DynamicPPL.fix`](@ref) and [`DynamicPPL.condition`](@ref) is described in the docstring of [`DynamicPPL.fix`](@ref) above. -Similarly, we can [`unfix`](@ref) variables, i.e. return them to their original meaning: +Similarly, we can revert this with [`DynamicPPL.unfix`](@ref), i.e. return the variables to their original meaning: ```@docs -unfix +DynamicPPL.unfix ``` ## Predicting diff --git a/docs/src/internals/submodel_condition.md b/docs/src/internals/submodel_condition.md index fe2f2c1df..042a0f77a 100644 --- a/docs/src/internals/submodel_condition.md +++ b/docs/src/internals/submodel_condition.md @@ -35,16 +35,9 @@ keys(vi) In this case, where `to_submodel` is called without any other arguments, the prefix to be used is automatically inferred from the name of the variable on the left-hand side of the tilde. We will return to the 'manual prefixing' case later. -What does it really mean to 'become' a different variable? -We can see this from [the definition of `tilde_assume`, for example](https://github.com/TuringLang/DynamicPPL.jl/blob/60ee68e2ce28a15c6062c243019e6208d16802a5/src/context_implementations.jl#L87-L89): - -``` -function tilde_assume(context::PrefixContext, right, vn, vi) - return tilde_assume(context.context, right, prefix(context, vn), vi) -end -``` - -Functionally, this means that even though the _initial_ entry to the tilde-pipeline has `vn` as `x` and `y`, once the `PrefixContext` has been applied, the later functions will see `a.x` and `a.y` instead. +The phrase 'becoming' a different variable is a little underspecified: it is useful to pinpoint the exact location where the prefixing occurs, which is `tilde_assume`. +The method responsible for it is `tilde_assume(::PrefixContext, right, vn, vi)`: this attaches the prefix in the context to the `VarName` argument, before recursively calling `tilde_assume` with the new prefixed `VarName`. +This means that even though a statement `x ~ dist` still enters the tilde pipeline at the top level as `x`, if the model evaluation context contains a `PrefixContext`, any function from `tilde_assume` onwards will see `a.x` instead. ## ConditionContext @@ -205,29 +198,158 @@ DynamicPPL.hasconditioned_nested(inner_ctx_with_outer_cond, @varname(a.x)) DynamicPPL.hasconditioned_nested(inner_ctx_with_inner_cond, @varname(a.x)) ``` -Essentially, our job is threefold: +This allows us to finally specify our task as follows: - - Firstly, given the correct arguments, we need to make sure that `hasconditioned_nested` and `getconditioned_nested` behave correctly. +(1) Given the correct arguments, we need to make sure that `hasconditioned_nested` and `getconditioned_nested` behave correctly. - - Secondly, we need to make sure that both the correct arguments are supplied. In order to do so: - - + We need to make sure that when evaluating a submodel, the context stack is arranged such that prefixes are applied _inside_ the parent model's context, but _outside_ the submodel's own context. - + We also need to make sure that the `VarName` passed to it is prefixed correctly. This is, in fact, _not_ handled by `tilde_assume`, because `contextual_isassumption` is much higher in the call stack than `tilde_assume` is. So, we need to explicitly prefix it. +(2) We need to make sure that both the correct arguments are supplied. In order to do so: + + - (2a) We need to make sure that when evaluating a submodel, the context stack is arranged such that `PrefixContext` is applied _inside_ the parent model's context, but _outside_ the submodel's own context. + + - (2b) We also need to make sure that the `VarName` passed to it is prefixed correctly. ## How do we do it? -`hasconditioned_nested` accomplishes this by doing the following: +(1) `hasconditioned_nested` and `getconditioned_nested` accomplish this by first 'collapsing' the context stack, i.e. they go through the context stack, remove all `PrefixContext`s, and apply those prefixes to any conditioned variables below it in the stack. +Once the `PrefixContext`s have been removed, one can then iterate through the context stack and check if any of the `ConditionContext`s contain the variable, or get the value itself. +For more details the reader is encouraged to read the source code. + +(2a) We ensure that the context stack is correctly arranged by relying on the behaviour of `make_evaluate_args_and_kwargs`. +This function is called whenever a model (which itself contains a context) is evaluated with a separate ('external') context, and makes sure to arrange both of these contexts such that _the model's context is nested inside the external context_. +Thus, as long as prefixing is implemented by applying a `PrefixContext` on the outermost layer of the _inner_ model context, this will be correctly combined with an external context to give the behaviour seen above. + +(2b) At first glance, it seems like `tilde_assume` can take care of the `VarName` prefixing for us (as described in the first section). +However, this is not actually the case: `contextual_isassumption`, which is the function that calls `hasconditioned_nested`, is much higher in the call stack than `tilde_assume` is. +So, we need to explicitly prefix it before passing it to `contextual_isassumption`. +This is done inside the `@model` macro, or technically, its subsidiary function `isassumption`. + +## Nested submodels + +Just in case the above wasn't complicated enough, we need to also be very careful when dealing with nested submodels, which have multiple layers of `PrefixContext`s which may be interspersed with `ConditionContext`s. +For example, in this series of nested submodels, + +```@example +@model function charlie() + x ~ Normal() + y ~ Normal() + return z ~ Normal() +end +@model function bravo() + return b ~ to_submodel(charlie() | (@varname(x) => 1.0)) +end +@model function alpha() + return a ~ to_submodel(bravo() | (@varname(b.y) => 1.0)) +end +``` + +we expect that the only variable to be sampled should be `z` inside `charlie`, or rather, `a.b.z` once it has been through the prefixes. + +```@example +keys(VarInfo(alpha())) +``` + +The general strategy that we adopt is similar to above. +Following the principle that `PrefixContext` should be nested inside the outer context, but outside the inner submodel's context, we can infer that the correct context inside `charlie` should be: + +```@example +big_ctx = PrefixContext{:a}( + ConditionContext( + Dict(@varname(b.y) => 1.0), + PrefixContext{:b}(ConditionContext(Dict(@varname(x) => 1.0))), + ), +) +``` + +We need several things to work correctly here: we need the `VarName` prefixing to behave correctly, and then we need to implement `hasconditioned_nested` and `getconditioned_nested` on the resulting prefixed `VarName`. +It turns out that the prefixing itself is enough to illustrate the most important point in this section, namely, the need to traverse the context stack in a _different direction_ to what most of DynamicPPL does. + +Let's work with a function called `myprefix(::AbstractContext, ::VarName)` (to avoid confusion with any existing DynamicPPL function). +We should like `myprefix(big_ctx, @varname(x))` to return `@varname(a.b.x)`. +Consider the following naive implementation, which mirrors a lot of code in the tilde-pipeline: + +```@example +using DynamicPPL: NodeTrait, IsLeaf, IsParent, childcontext, AbstractContext +using AbstractPPL: AbstractPPL + +function myprefix(ctx::DynamicPPL.AbstractContext, vn::VarName) + return myprefix(NodeTrait(ctx), ctx, vn) +end +function myprefix(::IsLeaf, ::AbstractContext, vn::VarName) + return vn +end +function myprefix(::IsParent, ctx::AbstractContext, vn::VarName) + return myprefix(childcontext(ctx), vn) +end +function myprefix(ctx::DynamicPPL.PrefixContext{Prefix}, vn::VarName) where {Prefix} + # The functionality to actually manipulate the VarNames is in AbstractPPL + new_vn = AbstractPPL.prefix(vn, VarName{Prefix}()) + # Then pass to the child context + return myprefix(childcontext(ctx), new_vn) +end + +myprefix(big_ctx, @varname(x)) +``` + +This implementation clearly is not correct, because it applies the _inner_ `PrefixContext` before the outer one. + +The right way to implement `myprefix` is to, essentially, reverse the order of two lines above: + +```@example +function myprefix(ctx::DynamicPPL.PrefixContext{Prefix}, vn::VarName) where {Prefix} + # Pass to the child context first + new_vn = myprefix(childcontext(ctx), vn) + # Then apply this context's prefix + return AbstractPPL.prefix(new_vn, VarName{Prefix}()) +end + +myprefix(big_ctx, @varname(x)) +``` + +This is a much better result! +The implementation of related functions such as `hasconditioned_nested` and `getconditioned_nested`, under the hood, use a similar recursion scheme, so you will find that this is a common pattern when reading the source code of various prefixing-related functions, you will find that this is a common pattern +When editing this code, it is worth being mindful of this as a potential source of incorrectness. + +!!! info + + If you have encountered left and right folds, the above discussion illustrates the difference between them: the wrong implementation of `myprefix` uses a left fold (which collects prefixes in the opposite order from which they are encountered), while the correct implementation uses a right fold. - - If the outermost layer is a `ConditionContext`, it checks whether the variable is contained in its values. - - If the outermost layer is a `PrefixContext`, it goes through the `PrefixContext`'s child context and prefixes any inner conditioned variables, before checking whether the variable is contained. +## Loose ends 1: Manual prefixing -We ensure that the context stack is correctly arranged by relying on the behaviour of `make_evaluate_args_and_kwargs`. -This function is called whenever a model (which itself contains a context) is evaluated with a separate ('outer') context, and makes sure to arrange it such that the model's context is nested inside the outer context. -Thus, as long as prefixing is implemented by applying a `PrefixContext` on the outermost layer of the _inner_ model context, this will be correctly combined with an outer context to give the behaviour seen above. +Sometimes users may want to manually prefix a model, for example: -And finally, we ensure that the `VarName` is correctly prefixed by modifying the `@model` macro (or, technically, its subsidiary `isassumption`) to explicitly prefix the variable before passing it to `contextual_isassumption`. +```@example +@model function inner_manual() + x ~ Normal() + return y ~ Normal() +end + +@model function outer_manual() + return _unused ~ to_submodel(prefix(inner_manual(), :a), false) +end +``` + +In this case, the `VarName` on the left-hand side of the tilde is not used, and the prefix is instead specified using the `prefix` function. + +The way to deal with this follows on from the previous discussion. +Specifically, we said that: + +> [...] as long as prefixing is implemented by applying a `PrefixContext` on the outermost layer of the _inner_ model context, this will be correctly combined [...] + +When automatic prefixing is used, this application of `PrefixContext` occurs inside the `tilde_assume!!` method. +In the manual prefixing case, we need to make sure that `prefix(submodel::Model, ::Symbol)` does the same thing, i.e. it inserts a `PrefixContext` at the outermost layer of `submodel`'s context. +We can see that this is precisely what happens: + +```@example +@model f() = x ~ Normal() + +model = f() +prefixed_model = prefix(model, :a) + +(model.context, prefixed_model.context) +``` -## FixedContext +## Loose ends 2: FixedContext Finally, note that all of the above also applies to the interaction between `PrefixContext` and `FixedContext`, except that the functions have different names. (`FixedContext` behaves the same way as `ConditionContext`, except that unlike conditioned variables, fixed variables do not contribute to the log probability density.) +This generally results in a large amount of code duplication, but the concepts that underlie both contexts are exactly the same. diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 8cebd5f81..876f8a616 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -85,17 +85,15 @@ function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, rig end function tilde_assume(context::PrefixContext, right, vn, vi) - # The slightly tricky thing about PrefixContext is that they are applied - # from the outside in, so `PrefixContext{:a}(PrefixContext{:b}(ctx))` means - # that variables get prefixed like `a.b.x`. - # This motivates the implementation shown here, where the function - # `prefix_and_strip_contexts` is responsible for not only adding the - # prefixes, but also removing the `PrefixContext`s from the context stack - # so that they don't get applied twice when recursing. - # TODO(penelopeysm): It would be nice to switch this round, but it's a very - # tricky task. Essentially it forces us to use a foldr inside - # `prefix_and_strip_contexts`, rather than a foldl which is what most of - # DynamicPPL uses. + # Note that we can't use something like this here: + # new_vn = prefix(context, vn) + # return tilde_assume(childcontext(context), right, new_vn, vi) + # This is because `prefix` applies _all_ prefixes in a given context to a + # variable name. Thus, if we had two levels of nested prefixes e.g. + # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the + # first call would apply the prefix `a.b._`, and the recursive call + # would apply the prefix `b._`, resulting in `b.a.b._`. + # This is why we need a special function, `prefix_and_strip_contexts`. new_vn, new_context = prefix_and_strip_contexts(context, vn) return tilde_assume(new_context, right, new_vn, vi) end diff --git a/src/contexts.jl b/src/contexts.jl index c71a7467e..dbd6f9b23 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -281,6 +281,19 @@ end Same as `prefix`, but additionally returns a new context stack that has all the PrefixContexts removed. + +NOTE: This does _not_ modify any variables in any `ConditionContext` and +`FixedContext` that may be present in the context stack. This is because this +function is only used in `tilde_assume`, which is lower in the tilde-pipeline +than `contextual_isassumption` and `contextual_isfixed` (the functions which +actually use the `ConditionContext` and `FixedContext` values). Thus, by this +time, any `ConditionContext`s and `FixedContext`s present have already served +their purpose. + +If you call this function, you must therefore be careful to ensure that you _do +not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you +_do_ need to modify them, then you may need to use +`prefix_cond_and_fixed_variables` instead. """ function prefix_and_strip_contexts(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix} child_context = childcontext(ctx) From a3bc52ec8a96f52af2c51fa31fb9410fe0918d76 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 18 Apr 2025 01:22:46 +0100 Subject: [PATCH 12/17] Add a test for conditioning submodel via arguments --- test/submodels.jl | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/test/submodels.jl b/test/submodels.jl index da91b8b93..844e12e13 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -5,7 +5,7 @@ using Distributions using Test @testset "submodels.jl" begin - @testset "$op" for op in [condition, fix] + @testset "$op with AbstractPPL API" for op in [condition, fix] x_val = 1.0 x_logp = op == condition ? logpdf(Normal(), x_val) : 0.0 @@ -139,6 +139,22 @@ using Test @test Set(keys(VarInfo(h3()))) == Set([@varname(a.b.y)]) end end + + @testset "conditioning via model arguments" begin + @model function f(x) + x ~ Normal() + return y ~ Normal() + end + @model function g(inner_x) + return a ~ to_submodel(f(inner_x)) + end + + vi = VarInfo(g(1.0)) + @test Set(keys(vi)) == Set([@varname(a.y)]) + + vi = VarInfo(g(missing)) + @test Set(keys(vi)) == Set([@varname(a.x), @varname(a.y)]) + end end end From 8c3bff42d9e06e8c70104cbade24e7c842fff4ba Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 18 Apr 2025 01:31:19 +0100 Subject: [PATCH 13/17] Clean new tests up a bit --- test/submodels.jl | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/test/submodels.jl b/test/submodels.jl index 844e12e13..834216223 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -29,7 +29,8 @@ using Test @test Set(keys(VarInfo(outer()))) == Set([@varname(a.x), @varname(a.y)]) # With conditioning/fixing - @testset "$model" for model in [with_inner_op, with_outer_op] + models = [("inner", with_inner_op), ("outer", with_outer_op)] + @testset "$name" for (name, model) in models # Test that the value was correctly set @test model()[1] == x_val # Test that the logp was correctly set @@ -44,7 +45,7 @@ using Test @model function inner() x ~ Normal() y ~ Normal() - return x + return (x, y) end @model function outer() return a ~ to_submodel(inner(), false) @@ -60,7 +61,8 @@ using Test @test Set(keys(VarInfo(outer()))) == Set([@varname(x), @varname(y)]) # With conditioning/fixing - @testset "$model" for model in [with_inner_op, with_outer_op] + models = [("inner", with_inner_op), ("outer", with_outer_op)] + @testset "$name" for (name, model) in models # Test that the value was correctly set @test model()[1] == x_val # Test that the logp was correctly set @@ -75,7 +77,7 @@ using Test @model function inner() x ~ Normal() y ~ Normal() - return x + return (x, y) end @model function outer() return a ~ to_submodel(prefix(inner(), :b), false) @@ -91,7 +93,8 @@ using Test @test Set(keys(VarInfo(outer()))) == Set([@varname(b.x), @varname(b.y)]) # With conditioning/fixing - @testset "$model" for model in [with_inner_op, with_outer_op] + models = [("inner", with_inner_op), ("outer", with_outer_op)] + @testset "$name" for (name, model) in models # Test that the value was correctly set @test model()[1] == x_val # Test that the logp was correctly set @@ -115,18 +118,20 @@ using Test end # No conditioning - @test Set(keys(VarInfo(h()))) == Set([@varname(a.b.x), @varname(a.b.y)]) + vi = VarInfo(h()) + @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)]) + @test getlogp(vi) == + logpdf(Normal(), vi[@varname(a.b.x)]) + + logpdf(Normal(), vi[@varname(a.b.y)]) # Conditioning/fixing at the top level op_h = op(h(), (@varname(a.b.x) => x_val)) - @test Set(keys(VarInfo(op_h))) == Set([@varname(a.b.y)]) # Conditioning/fixing at the second level op_g = op(g(), (@varname(b.x) => x_val)) @model function h2() return a ~ to_submodel(op_g) end - @test Set(keys(VarInfo(h2()))) == Set([@varname(a.b.y)]) # Conditioning/fixing at the very bottom op_f = op(f(), (@varname(x) => x_val)) @@ -136,7 +141,13 @@ using Test @model function h3() return a ~ to_submodel(g2()) end - @test Set(keys(VarInfo(h3()))) == Set([@varname(a.b.y)]) + + models = [("top", op_h), ("middle", h2()), ("bottom", h3())] + @testset "$name" for (name, model) in models + vi = VarInfo(model) + @test Set(keys(vi)) == Set([@varname(a.b.y)]) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) + end end end From b545a93c2eec92e54ef826352d11f910feaafb52 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 21 Apr 2025 18:15:31 +0100 Subject: [PATCH 14/17] Fix for VarNames with non-identity lenses --- src/context_implementations.jl | 2 +- test/submodels.jl | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 876f8a616..e9fad54b6 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -131,7 +131,7 @@ function tilde_assume!!(context, right, vn, vi) # change in the future. if should_auto_prefix(right) dppl_model = right.model.model # This isa DynamicPPL.Model - prefixed_submodel_context = PrefixContext{getsym(vn)}(dppl_model.context) + prefixed_submodel_context = PrefixContext{Symbol(vn)}(dppl_model.context) new_dppl_model = contextualize(dppl_model, prefixed_submodel_context) right = to_submodel(new_dppl_model, true) end diff --git a/test/submodels.jl b/test/submodels.jl index 834216223..6a8a2c889 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -105,6 +105,36 @@ using Test end end + @testset "Complex prefixes" begin + mutable struct P + a::Float64 + b::Float64 + end + @model function f() + x = Vector{Float64}(undef, 1) + x[1] ~ Normal() + y ~ Normal() + return x[1] + end + @model function g() + p = P(1.0, 2.0) + p.a ~ to_submodel(f()) + p.b ~ Normal() + return (p.a, p.b) + end + expected_vns = Set([ + @varname(var"p.a".x[1]), @varname(var"p.a".y), @varname(p.b) + ]) + @test Set(keys(VarInfo(g()))) == expected_vns + + # Check that we can condition/fix on any of them from the outside + for vn in expected_vns + op_g = op(g(), (vn => 1.0)) + vi = VarInfo(op_g) + @test Set(keys(vi)) == symdiff(expected_vns, Set([vn])) + end + end + @testset "Nested submodels" begin @model function f() x ~ Normal() From 70d124a00594c064f55522868e690a048f1c54dd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 22 Apr 2025 17:58:06 +0100 Subject: [PATCH 15/17] Apply suggestions from code review Co-authored-by: Markus Hauru --- HISTORY.md | 2 +- docs/src/api.md | 4 ++-- test/contexts.jl | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 17b0b2611..aab758dae 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -19,7 +19,7 @@ end end ``` -and the `inner.x` variable will be correctly conditioned. +and the `a.x` variable will be correctly conditioned. (Previously, you would have to condition `inner()` with the variable `a.x`, meaning that you would need to know what prefix to use before you had actually prefixed it.) ### AD testing utilities diff --git a/docs/src/api.md b/docs/src/api.md index acca6e3af..08522e2ce 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -78,9 +78,9 @@ decondition ## Fixing and unfixing -We can also _fix_ a collection of variables in a [`Model`](@ref) to certain using [`DynamicPPL.fix`](@ref). +We can also _fix_ a collection of variables in a [`Model`](@ref) to certain values using [`DynamicPPL.fix`](@ref). -This might seem quite similar to the aforementioned [`condition`](@ref) and its siblings, +This is quite similar to the aforementioned [`condition`](@ref) and its siblings, but they are indeed different operations: - `condition`ed variables are considered to be _observations_, and are thus diff --git a/test/contexts.jl b/test/contexts.jl index 081e59775..f59c80082 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -105,7 +105,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() vcat, pairs(conditioned_values), ) - @show conditioned_vns # We can now loop over them to check which ones are missing. We use # `getvalue` to handle the awkward case where sometimes From be8108b33649b3b0151c1a9ccd8364d6d6213e80 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 22 Apr 2025 17:59:27 +0100 Subject: [PATCH 16/17] Apply suggestions from code review --- docs/src/internals/submodel_condition.md | 2 +- src/contexts.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/internals/submodel_condition.md b/docs/src/internals/submodel_condition.md index 042a0f77a..d35627289 100644 --- a/docs/src/internals/submodel_condition.md +++ b/docs/src/internals/submodel_condition.md @@ -306,7 +306,7 @@ myprefix(big_ctx, @varname(x)) ``` This is a much better result! -The implementation of related functions such as `hasconditioned_nested` and `getconditioned_nested`, under the hood, use a similar recursion scheme, so you will find that this is a common pattern when reading the source code of various prefixing-related functions, you will find that this is a common pattern +The implementation of related functions such as `hasconditioned_nested` and `getconditioned_nested`, under the hood, use a similar recursion scheme, so you will find that this is a common pattern when reading the source code of various prefixing-related functions. When editing this code, it is worth being mindful of this as a potential source of incorrectness. !!! info diff --git a/src/contexts.jl b/src/contexts.jl index dbd6f9b23..4b3baffd4 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -426,7 +426,7 @@ hasconditioned_nested(::IsLeaf, context, vn) = hasconditioned(context, vn) function hasconditioned_nested(::IsParent, context, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end -function hasconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix} +function hasconditioned_nested(context::PrefixContext, vn) return hasconditioned_nested(collapse_prefix_stack(context), vn) end @@ -444,7 +444,7 @@ end function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end -function getconditioned_nested(context::PrefixContext{Prefix}, vn) where {Prefix} +function getconditioned_nested(context::PrefixContext, vn) return getconditioned_nested(collapse_prefix_stack(context), vn) end function getconditioned_nested(::IsParent, context, vn) From fcb44e5e3c163ad0ef251093b2de48d6914c99ad Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 23 Apr 2025 12:08:52 +0100 Subject: [PATCH 17/17] Make PrefixContext contain a varname rather than symbol (#896) --- HISTORY.md | 90 +++++++++++++++--------- docs/src/internals/submodel_condition.md | 19 ++--- src/context_implementations.jl | 2 +- src/contexts.jl | 65 ++++++++++------- src/model.jl | 8 +-- src/submodel_macro.jl | 4 +- test/contexts.jl | 63 ++++++++++------- test/submodels.jl | 4 +- 8 files changed, 150 insertions(+), 105 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index aab758dae..ac3e40970 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,7 +4,7 @@ **Breaking changes** -### Submodels +### Submodels: conditioning Variables in a submodel can now be conditioned and fixed in a correct way. See https://github.com/TuringLang/DynamicPPL.jl/issues/857 for a full illustration, but essentially it means you can now do this: @@ -22,38 +22,7 @@ end and the `a.x` variable will be correctly conditioned. (Previously, you would have to condition `inner()` with the variable `a.x`, meaning that you would need to know what prefix to use before you had actually prefixed it.) -### AD testing utilities - -`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default. -To disable this, pass the `linked=false` keyword argument. -If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure. -This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information. -From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`. - -### SimpleVarInfo linking / invlinking - -Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error. - -### VarInfo constructors - -`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. - -The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed. -If you were not using this argument (most likely), then there is no change needed. -If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below). - -The `UntypedVarInfo` constructor and type is no longer exported. -If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead. - -The `TypedVarInfo` constructor and type is no longer exported. -The _type_ has been replaced with `DynamicPPL.NTVarInfo`. -The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`. - -Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail. -Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs. -Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface. - -### VarName prefixing behaviour +### Submodel prefixing The way in which VarNames in submodels are prefixed has been changed. This is best explained through an example. @@ -95,9 +64,62 @@ outer() | (@varname(var"a.x") => 1.0,) outer() | (a.x=1.0,) ``` -If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. +In a similar way, if the variable on the left-hand side of your tilde statement is not just a single identifier, any fields or indices it accesses are now properly respected. +Consider the following setup: + +```julia +using DynamicPPL, Distributions +@model inner() = x ~ Normal() +@model function outer() + a = Vector{Float64}(undef, 1) + a[1] ~ to_submodel(inner()) + return a +end +``` + +In this case, the variable sampled is actually the `x` field of the first element of `a`: + +```julia +julia> only(keys(VarInfo(outer()))) == @varname(a[1].x) +true +``` + +Before this version, it used to be a single variable called `var"a[1].x"`. + +Note that if you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. (This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.) +### AD testing utilities + +`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default. +To disable this, pass the `linked=false` keyword argument. +If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure. +This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information. +From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`. + +### SimpleVarInfo linking / invlinking + +Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error. + +### VarInfo constructors + +`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. + +The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed. +If you were not using this argument (most likely), then there is no change needed. +If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below). + +The `UntypedVarInfo` constructor and type is no longer exported. +If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead. + +The `TypedVarInfo` constructor and type is no longer exported. +The _type_ has been replaced with `DynamicPPL.NTVarInfo`. +The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`. + +Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail. +Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs. +Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface. + **Other changes** While these are technically breaking, they are only internal changes and do not affect the public API. diff --git a/docs/src/internals/submodel_condition.md b/docs/src/internals/submodel_condition.md index d35627289..ecb9d452b 100644 --- a/docs/src/internals/submodel_condition.md +++ b/docs/src/internals/submodel_condition.md @@ -181,10 +181,10 @@ Putting all of the information so far together, what it means is that if we have using DynamicPPL: PrefixContext, ConditionContext, DefaultContext inner_ctx_with_outer_cond = ConditionContext( - Dict(@varname(a.x) => 1.0), PrefixContext{:a}(DefaultContext()) + Dict(@varname(a.x) => 1.0), PrefixContext(@varname(a)) ) -inner_ctx_with_inner_cond = PrefixContext{:a}( - ConditionContext(Dict(@varname(x) => 1.0), DefaultContext()) +inner_ctx_with_inner_cond = PrefixContext( + @varname(a), ConditionContext(Dict(@varname(x) => 1.0)) ) ``` @@ -252,10 +252,11 @@ The general strategy that we adopt is similar to above. Following the principle that `PrefixContext` should be nested inside the outer context, but outside the inner submodel's context, we can infer that the correct context inside `charlie` should be: ```@example -big_ctx = PrefixContext{:a}( +big_ctx = PrefixContext( + @varname(a), ConditionContext( Dict(@varname(b.y) => 1.0), - PrefixContext{:b}(ConditionContext(Dict(@varname(x) => 1.0))), + PrefixContext(@varname(b), ConditionContext(Dict(@varname(x) => 1.0))), ), ) ``` @@ -280,9 +281,9 @@ end function myprefix(::IsParent, ctx::AbstractContext, vn::VarName) return myprefix(childcontext(ctx), vn) end -function myprefix(ctx::DynamicPPL.PrefixContext{Prefix}, vn::VarName) where {Prefix} +function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName) # The functionality to actually manipulate the VarNames is in AbstractPPL - new_vn = AbstractPPL.prefix(vn, VarName{Prefix}()) + new_vn = AbstractPPL.prefix(vn, ctx.vn_prefix) # Then pass to the child context return myprefix(childcontext(ctx), new_vn) end @@ -295,11 +296,11 @@ This implementation clearly is not correct, because it applies the _inner_ `Pref The right way to implement `myprefix` is to, essentially, reverse the order of two lines above: ```@example -function myprefix(ctx::DynamicPPL.PrefixContext{Prefix}, vn::VarName) where {Prefix} +function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName) # Pass to the child context first new_vn = myprefix(childcontext(ctx), vn) # Then apply this context's prefix - return AbstractPPL.prefix(new_vn, VarName{Prefix}()) + return AbstractPPL.prefix(new_vn, ctx.vn_prefix) end myprefix(big_ctx, @varname(x)) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e9fad54b6..eb025dec8 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -131,7 +131,7 @@ function tilde_assume!!(context, right, vn, vi) # change in the future. if should_auto_prefix(right) dppl_model = right.model.model # This isa DynamicPPL.Model - prefixed_submodel_context = PrefixContext{Symbol(vn)}(dppl_model.context) + prefixed_submodel_context = PrefixContext(vn, dppl_model.context) new_dppl_model = contextualize(dppl_model, prefixed_submodel_context) right = to_submodel(new_dppl_model, true) end diff --git a/src/contexts.jl b/src/contexts.jl index 4b3baffd4..8ac085663 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -237,27 +237,34 @@ function setchildcontext(parent::MiniBatchContext, child) end """ - PrefixContext{Prefix}(context) + PrefixContext(vn::VarName[, context::AbstractContext]) + PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym} Create a context that allows you to use the wrapped `context` when running the model and -adds the `Prefix` to all parameters. +prefixes all parameters with the VarName `vn`. + +`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`. +If `context` is not provided, it defaults to `DefaultContext()`. This context is useful in nested models to ensure that the names of the parameters are unique. See also: [`to_submodel`](@ref) """ -struct PrefixContext{Prefix,C} <: AbstractContext +struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext + vn_prefix::Tvn context::C end -function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(context)}(context) +PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext()) +function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} + return PrefixContext(VarName{sym}(), context) end +PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) NodeTrait(::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context -function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix} - return PrefixContext{Prefix}(child) +function setchildcontext(ctx::PrefixContext, child::AbstractContext) + return PrefixContext(ctx.vn_prefix, child) end """ @@ -265,8 +272,8 @@ end Apply the prefixes in the context `ctx` to the variable name `vn`. """ -function prefix(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix} - return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Prefix}()) +function prefix(ctx::PrefixContext, vn::VarName) + return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) end function prefix(ctx::AbstractContext, vn::VarName) return prefix(NodeTrait(ctx), ctx, vn) @@ -295,14 +302,13 @@ not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you _do_ need to modify them, then you may need to use `prefix_cond_and_fixed_variables` instead. """ -function prefix_and_strip_contexts(ctx::PrefixContext{Prefix}, vn::VarName) where {Prefix} +function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) child_context = childcontext(ctx) # vn_prefixed contains the prefixes from all lower levels vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( child_context, vn ) - return AbstractPPL.prefix(vn_prefixed, VarName{Prefix}()), - child_context_without_prefixes + return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes end function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) @@ -314,11 +320,16 @@ function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName end """ - prefix(model::Model, x) - -Return `model` but with all random variables prefixed by `x`. + prefix(model::Model, x::VarName) + prefix(model::Model, x::Val{sym}) + prefix(model::Model, x::Any) -If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing. +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. # Examples @@ -328,17 +339,19 @@ julia> using DynamicPPL: prefix julia> @model demo() = x ~ Dirac(1) demo (generic function with 2 methods) -julia> rand(prefix(demo(), :my_prefix)) +julia> rand(prefix(demo(), @varname(my_prefix))) (var"my_prefix.x" = 1,) -julia> # One can also use `Val` to avoid runtime overheads. - rand(prefix(demo(), Val(:my_prefix))) +julia> rand(prefix(demo(), Val(:my_prefix))) (var"my_prefix.x" = 1,) ``` """ -prefix(model::Model, x) = contextualize(model, PrefixContext{Symbol(x)}(model.context)) -function prefix(model::Model, ::Val{x}) where {x} - return contextualize(model, PrefixContext{Symbol(x)}(model.context)) +prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) +function prefix(model::Model, x::Val{sym}) where {sym} + return contextualize(model, PrefixContext(VarName{sym}(), model.context)) +end +function prefix(model::Model, x) + return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) end """ @@ -715,13 +728,13 @@ which explains this in much more detail. ```jldoctest julia> using DynamicPPL: collapse_prefix_stack -julia> c1 = PrefixContext{:a}(ConditionContext((x=1, ))); +julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); julia> collapse_prefix_stack(c1) ConditionContext(Dict(a.x => 1), DefaultContext()) julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. - c2 = PrefixContext{:a}(ConditionContext((x=1, ), PrefixContext{:b}(ConditionContext((y=2,))))); + c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,))))); julia> collapsed = collapse_prefix_stack(c2); @@ -733,14 +746,14 @@ julia> # `collapsed` really looks something like this: (1, 2) ``` """ -function collapse_prefix_stack(context::PrefixContext{Prefix}) where {Prefix} +function collapse_prefix_stack(context::PrefixContext) # Collapse the child context (thus applying any inner prefixes first) collapsed = collapse_prefix_stack(childcontext(context)) # Prefix any conditioned variables with the current prefix # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. # So is this function. In the worst case scenario, this is O(N^2) in the # depth of the context stack. - return prefix_cond_and_fixed_variables(collapsed, VarName{Prefix}()) + return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) end function collapse_prefix_stack(context::AbstractContext) return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) diff --git a/src/model.jl b/src/model.jl index cfe87ad44..c7c4bdf57 100644 --- a/src/model.jl +++ b/src/model.jl @@ -429,7 +429,7 @@ julia> # Nested ones also work. # (Note that `PrefixContext` also prefixes the variables of any # ConditionContext that is _inside_ it; because of this, the type of the # container has to be broadened to a `Dict`.) - cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0); + cm = condition(contextualize(m, PrefixContext(@varname(a), ConditionContext((m=1.0,)))), x=100.0); julia> Set(keys(conditioned(cm))) == Set([@varname(a.m), @varname(x)]) true @@ -441,7 +441,7 @@ julia> # Since we conditioned on `a.m`, it is not treated as a random variable. a.x julia> # We can also condition on `a.m` _outside_ of the PrefixContext: - cm = condition(contextualize(m, PrefixContext{:a}(DefaultContext())), (@varname(a.m) => 1.0)); + cm = condition(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); julia> conditioned(cm) Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: @@ -769,7 +769,7 @@ julia> # Returns all the variables we have fixed on + their values. (x = 100.0, m = 1.0) julia> # The rest of this is the same as the `condition` example above. - cm = fix(contextualize(m, PrefixContext{:a}(fix(m=1.0))), x=100.0); + cm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0); julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)]) true @@ -779,7 +779,7 @@ julia> keys(VarInfo(cm)) a.x julia> # We can also condition on `a.m` _outside_ of the PrefixContext: - cm = fix(contextualize(m, PrefixContext{:a}(DefaultContext())), (@varname(a.m) => 1.0)); + cm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); julia> fixed(cm) Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index f6b9c4479..5f1ec95ec 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -223,12 +223,12 @@ end prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx) function prefix_submodel_context(prefix, ctx) # E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated. - return :($(PrefixContext){$(Symbol)($(esc(prefix)))}($ctx)) + return :($(PrefixContext)($(Val)($(Symbol)($(esc(prefix)))), $ctx)) end function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx) # E.g. `prefix="asd"`. - return :($(PrefixContext){$(esc(Meta.quot(Symbol(prefix))))}($ctx)) + return :($(PrefixContext)($(esc(Meta.quot(Val(Symbol(prefix))))), $ctx)) end function prefix_submodel_context(prefix::Bool, ctx) diff --git a/test/contexts.jl b/test/contexts.jl index f59c80082..1ba099a37 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -57,14 +57,15 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), :sampling => SamplingContext(), :minibatch => MiniBatchContext(DefaultContext(), 0.0), - :prefix => PrefixContext{:x}(DefaultContext()), + :prefix => PrefixContext(@varname(x)), :pointwiselogdensity => PointwiseLogdensityContext(), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) ), :condition3 => ConditionContext( - (x=1.0,), PrefixContext{:a}(ConditionContext(Dict(@varname(y) => 2.0))) + (x=1.0,), + PrefixContext(@varname(a), ConditionContext(Dict(@varname(y) => 2.0))), ), :condition4 => ConditionContext((x=[1.0, missing],)), ) @@ -131,31 +132,37 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "PrefixContext" begin @testset "prefixing" begin - ctx = @inferred PrefixContext{:a}( - PrefixContext{:b}( - PrefixContext{:c}( - PrefixContext{:d}( - PrefixContext{:e}(PrefixContext{:f}(DefaultContext())) + ctx = @inferred PrefixContext( + @varname(a), + PrefixContext( + @varname(b), + PrefixContext( + @varname(c), + PrefixContext( + @varname(d), + PrefixContext( + @varname(e), PrefixContext(@varname(f), DefaultContext()) + ), ), ), ), ) - vn = VarName{:x}() + vn = @varname(x) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test vn_prefixed == @varname(a.b.c.d.e.f.x) - vn = VarName{:x}(((1,),)) + vn = @varname(x[1]) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test vn_prefixed == @varname(a.b.c.d.e.f.x[1]) end @testset "nested within arbitrary context stacks" begin vn = @varname(x[1]) - ctx1 = PrefixContext{:a}(DefaultContext()) + ctx1 = PrefixContext(@varname(a)) @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) ctx2 = SamplingContext(ctx1) @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) - ctx3 = PrefixContext{:b}(ctx2) + ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) @@ -163,30 +170,30 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "prefix_and_strip_contexts" begin vn = @varname(x[1]) - ctx1 = PrefixContext{:a}(DefaultContext()) + ctx1 = PrefixContext(@varname(a)) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx1, vn) @test new_vn == @varname(a.x[1]) @test new_ctx == DefaultContext() - ctx2 = SamplingContext(PrefixContext{:a}(DefaultContext())) + ctx2 = SamplingContext(PrefixContext(@varname(a))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) @test new_vn == @varname(a.x[1]) @test new_ctx == SamplingContext() - ctx3 = PrefixContext{:a}(ConditionContext((a=1,))) + ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) @test new_vn == @varname(a.x[1]) @test new_ctx == ConditionContext((a=1,)) - ctx4 = SamplingContext(PrefixContext{:a}(ConditionContext((a=1,)))) + ctx4 = SamplingContext(PrefixContext(@varname(a), ConditionContext((a=1,)))) new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) @test new_vn == @varname(a.x[1]) @test new_ctx == SamplingContext(ConditionContext((a=1,))) end @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - prefix = :my_prefix - context = DynamicPPL.PrefixContext{prefix}(SamplingContext()) + prefix_vn = @varname(my_prefix) + context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) # Sample with the context. varinfo = DynamicPPL.VarInfo() DynamicPPL.evaluate!!(model, varinfo, context) @@ -195,7 +202,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # Extract the ground truth varnames vns_expected = Set([ - AbstractPPL.prefix(vn, VarName{prefix}()) for + AbstractPPL.prefix(vn, prefix_vn) for vn in DynamicPPL.TestUtils.varnames(model) ]) @@ -373,7 +380,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end # Prefix -> Condition - c1 = PrefixContext{:a}(ConditionContext((c=1, d=2))) + c1 = PrefixContext(@varname(a), ConditionContext((c=1, d=2))) c1 = collapse_prefix_stack(c1) @test has_no_prefixcontexts(c1) c1_vals = conditioned(c1) @@ -382,7 +389,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test getvalue(c1_vals, @varname(a.d)) == 2 # Condition -> Prefix - c2 = (ConditionContext((c=1, d=2), PrefixContext{:a}(DefaultContext()))) + c2 = ConditionContext((c=1, d=2), PrefixContext(@varname(a))) c2 = collapse_prefix_stack(c2) @test has_no_prefixcontexts(c2) c2_vals = conditioned(c2) @@ -391,7 +398,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test getvalue(c2_vals, @varname(d)) == 2 # Prefix -> Fixed - c3 = PrefixContext{:a}(FixedContext((f=1, g=2))) + c3 = PrefixContext(@varname(a), FixedContext((f=1, g=2))) c3 = collapse_prefix_stack(c3) c3_vals = fixed(c3) @test length(c3_vals) == 2 @@ -400,7 +407,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test getvalue(c3_vals, @varname(a.g)) == 2 # Fixed -> Prefix - c4 = (FixedContext((f=1, g=2), PrefixContext{:a}(DefaultContext()))) + c4 = FixedContext((f=1, g=2), PrefixContext(@varname(a))) c4 = collapse_prefix_stack(c4) @test has_no_prefixcontexts(c4) c4_vals = fixed(c4) @@ -409,8 +416,11 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test getvalue(c4_vals, @varname(g)) == 2 # Prefix -> Condition -> Prefix -> Condition - c5 = PrefixContext{:a}( - ConditionContext((c=1,), PrefixContext{:b}(ConditionContext((d=2,)))) + c5 = PrefixContext( + @varname(a), + ConditionContext( + (c=1,), PrefixContext(@varname(b), ConditionContext((d=2,))) + ), ) c5 = collapse_prefix_stack(c5) @test has_no_prefixcontexts(c5) @@ -420,8 +430,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test getvalue(c5_vals, @varname(a.b.d)) == 2 # Prefix -> Condition -> Prefix -> Fixed - c6 = PrefixContext{:a}( - ConditionContext((c=1,), PrefixContext{:b}(FixedContext((d=2,)))) + c6 = PrefixContext( + @varname(a), + ConditionContext((c=1,), PrefixContext(@varname(b), FixedContext((d=2,)))), ) c6 = collapse_prefix_stack(c6) @test has_no_prefixcontexts(c6) diff --git a/test/submodels.jl b/test/submodels.jl index 6a8a2c889..e79eed2c3 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -122,9 +122,7 @@ using Test p.b ~ Normal() return (p.a, p.b) end - expected_vns = Set([ - @varname(var"p.a".x[1]), @varname(var"p.a".y), @varname(p.b) - ]) + expected_vns = Set([@varname(p.a.x[1]), @varname(p.a.y), @varname(p.b)]) @test Set(keys(VarInfo(g()))) == expected_vns # Check that we can condition/fix on any of them from the outside