diff --git a/HISTORY.md b/HISTORY.md index a45644a64..ac3e40970 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,38 +4,25 @@ **Breaking changes** -### AD testing utilities - -`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default. -To disable this, pass the `linked=false` keyword argument. -If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure. -This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information. -From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`. +### Submodels: conditioning -### SimpleVarInfo linking / invlinking - -Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error. +Variables in a submodel can now be conditioned and fixed in a correct way. +See https://github.com/TuringLang/DynamicPPL.jl/issues/857 for a full illustration, but essentially it means you can now do this: -### VarInfo constructors - -`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. - -The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed. -If you were not using this argument (most likely), then there is no change needed. -If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below). - -The `UntypedVarInfo` constructor and type is no longer exported. -If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead. - -The `TypedVarInfo` constructor and type is no longer exported. -The _type_ has been replaced with `DynamicPPL.NTVarInfo`. -The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`. +```julia +@model function inner() + x ~ Normal() + return y ~ Normal() +end +@model function outer() + return a ~ to_submodel(inner() | (x=1.0,)) +end +``` -Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail. -Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs. -Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface. +and the `a.x` variable will be correctly conditioned. +(Previously, you would have to condition `inner()` with the variable `a.x`, meaning that you would need to know what prefix to use before you had actually prefixed it.) -### VarName prefixing behaviour +### Submodel prefixing The way in which VarNames in submodels are prefixed has been changed. This is best explained through an example. @@ -77,9 +64,62 @@ outer() | (@varname(var"a.x") => 1.0,) outer() | (a.x=1.0,) ``` -If you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. +In a similar way, if the variable on the left-hand side of your tilde statement is not just a single identifier, any fields or indices it accesses are now properly respected. +Consider the following setup: + +```julia +using DynamicPPL, Distributions +@model inner() = x ~ Normal() +@model function outer() + a = Vector{Float64}(undef, 1) + a[1] ~ to_submodel(inner()) + return a +end +``` + +In this case, the variable sampled is actually the `x` field of the first element of `a`: + +```julia +julia> only(keys(VarInfo(outer()))) == @varname(a[1].x) +true +``` + +Before this version, it used to be a single variable called `var"a[1].x"`. + +Note that if you are sampling from a model with submodels, this doesn't affect the way you interact with the `MCMCChains.Chains` object, because VarNames are converted into Symbols when stored in the chain. (This behaviour will likely be changed in the future, in that Chains should be indexable by VarNames and not just Symbols, but that has not been implemented yet.) +### AD testing utilities + +`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default. +To disable this, pass the `linked=false` keyword argument. +If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure. +This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information. +From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`. + +### SimpleVarInfo linking / invlinking + +Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error. + +### VarInfo constructors + +`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. + +The `metadata` argument to `VarInfo([rng, ]model[, sampler, context, metadata])` has been removed. +If you were not using this argument (most likely), then there is no change needed. +If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `DynamicPPL.typed_vector_varinfo` instead (see 'Other changes' below). + +The `UntypedVarInfo` constructor and type is no longer exported. +If you needed to construct one, you should now use `DynamicPPL.untyped_varinfo` instead. + +The `TypedVarInfo` constructor and type is no longer exported. +The _type_ has been replaced with `DynamicPPL.NTVarInfo`. +The _constructor_ has been replaced with `DynamicPPL.typed_varinfo`. + +Note that the exact kind of VarInfo returned by `VarInfo(rng, model, ...)` is an implementation detail. +Previously, it was guaranteed that this would always be a VarInfo whose metadata was a `NamedTuple` containing `Metadata` structs. +Going forward, this is no longer the case, and you should only assume that the returned object obeys the `AbstractVarInfo` interface. + **Other changes** While these are technically breaking, they are only internal changes and do not affect the public API. diff --git a/docs/Project.toml b/docs/Project.toml index 40a719e03..93f449308 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/docs/make.jl b/docs/make.jl index c69b72fb8..7984fa1d1 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -24,7 +24,9 @@ makedocs(; format=Documenter.HTML(; size_threshold=2^10 * 400), modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)], pages=[ - "Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"] + "Home" => "index.md", + "API" => "api.md", + "Internals" => ["internals/varinfo.md", "internals/submodel_condition.md"], ], checkdocs=:exports, doctest=false, diff --git a/docs/src/api.md b/docs/src/api.md index ec741c9ad..08522e2ce 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -78,9 +78,9 @@ decondition ## Fixing and unfixing -We can also _fix_ a collection of variables in a [`Model`](@ref) to certain using [`fix`](@ref). +We can also _fix_ a collection of variables in a [`Model`](@ref) to certain values using [`DynamicPPL.fix`](@ref). -This might seem quite similar to the aforementioned [`condition`](@ref) and its siblings, +This is quite similar to the aforementioned [`condition`](@ref) and its siblings, but they are indeed different operations: - `condition`ed variables are considered to be _observations_, and are thus @@ -89,19 +89,19 @@ but they are indeed different operations: - `fix`ed variables are considered to be _constant_, and are thus not included in any log-probability computations. -The differences are more clearly spelled out in the docstring of [`fix`](@ref) below. +The differences are more clearly spelled out in the docstring of [`DynamicPPL.fix`](@ref) below. ```@docs -fix +DynamicPPL.fix DynamicPPL.fixed ``` -The difference between [`fix`](@ref) and [`condition`](@ref) is described in the docstring of [`fix`](@ref) above. +The difference between [`DynamicPPL.fix`](@ref) and [`DynamicPPL.condition`](@ref) is described in the docstring of [`DynamicPPL.fix`](@ref) above. -Similarly, we can [`unfix`](@ref) variables, i.e. return them to their original meaning: +Similarly, we can revert this with [`DynamicPPL.unfix`](@ref), i.e. return the variables to their original meaning: ```@docs -unfix +DynamicPPL.unfix ``` ## Predicting diff --git a/docs/src/internals/submodel_condition.md b/docs/src/internals/submodel_condition.md new file mode 100644 index 000000000..ecb9d452b --- /dev/null +++ b/docs/src/internals/submodel_condition.md @@ -0,0 +1,356 @@ +# How `PrefixContext` and `ConditionContext` interact + +```@meta +ShareDefaultModule = true +``` + +## PrefixContext + +`PrefixContext` is a context that, as the name suggests, prefixes all variables inside a model with a given symbol. +Thus, for example: + +```@example +using DynamicPPL, Distributions + +@model function f() + x ~ Normal() + return y ~ Normal() +end + +@model function g() + return a ~ to_submodel(f()) +end +``` + +inside the submodel `f`, the variables `x` and `y` become `a.x` and `a.y` respectively. +This is easiest to observe by running the model: + +```@example +vi = VarInfo(g()) +keys(vi) +``` + +!!! note + + In this case, where `to_submodel` is called without any other arguments, the prefix to be used is automatically inferred from the name of the variable on the left-hand side of the tilde. + We will return to the 'manual prefixing' case later. + +The phrase 'becoming' a different variable is a little underspecified: it is useful to pinpoint the exact location where the prefixing occurs, which is `tilde_assume`. +The method responsible for it is `tilde_assume(::PrefixContext, right, vn, vi)`: this attaches the prefix in the context to the `VarName` argument, before recursively calling `tilde_assume` with the new prefixed `VarName`. +This means that even though a statement `x ~ dist` still enters the tilde pipeline at the top level as `x`, if the model evaluation context contains a `PrefixContext`, any function from `tilde_assume` onwards will see `a.x` instead. + +## ConditionContext + +`ConditionContext` is a context which stores values of variables that are to be conditioned on. +These values may be stored as a `Dict` which maps `VarName`s to values, or alternatively as a `NamedTuple`. +The latter only works correctly if all `VarName`s are 'basic', in that they have an identity optic (i.e., something like `a.x` or `a[1]` is forbidden). +Because of this limitation, we will only use `Dict` in this example. + +!!! note + + If a `ConditionContext` with a `NamedTuple` encounters anything to do with a prefix, its internal `NamedTuple` is converted to a `Dict` anyway, so it is quite reasonable to ignore the `NamedTuple` case in this exposition. + +One can inspect the conditioning values with, for example: + +```@example +@model function d() + x ~ Normal() + return y ~ Normal() +end + +cond_model = d() | (@varname(x) => 1.0) +cond_ctx = cond_model.context +``` + +There are several internal functions that are used to determine whether a variable is conditioned, and if so, what its value is. + +```@example +DynamicPPL.hasconditioned_nested(cond_ctx, @varname(x)) +``` + +```@example +DynamicPPL.getconditioned_nested(cond_ctx, @varname(x)) +``` + +These functions are in turn used by the function `DynamicPPL.contextual_isassumption`, which is largely the same as `hasconditioned_nested`, but also checks whether the value is `missing` (in which case it isn't really conditioned). + +```@example +DynamicPPL.contextual_isassumption(cond_ctx, @varname(x)) +``` + +!!! note + + Notice that (neglecting `missing` values) the return value of `contextual_isassumption` is the _opposite_ of `hasconditioned_nested`, i.e. for a variable that _is_ conditioned on, `contextual_isassumption` returns `false`. + +If a variable `x` is conditioned on, then the effect of this is to set the value of `x` to the given value (while still including its contribution to the log probability density). +Since `x` is no longer a random variable, if we were to evaluate the model, we would find only one key in the `VarInfo`: + +```@example +keys(VarInfo(cond_model)) +``` + +## Joint behaviour: desiderata at the model level + +When paired together, these two contexts have the potential to cause substantial confusion: `PrefixContext` modifies the variable names that are seen, which may cause them to be out of sync with the values contained inside the `ConditionContext`. + +We begin by mentioning some high-level desiderata for their joint behaviour. +Take these models, for example: + +```@example +# We define a helper function to unwrap a layer of SamplingContext, to +# avoid cluttering the print statements. +unwrap_sampling_context(ctx::DynamicPPL.SamplingContext) = ctx.context +unwrap_sampling_context(ctx::DynamicPPL.AbstractContext) = ctx +@model function inner() + println("inner context: $(unwrap_sampling_context(__context__))") + x ~ Normal() + return y ~ Normal() +end + +@model function outer() + println("outer context: $(unwrap_sampling_context(__context__))") + return a ~ to_submodel(inner()) +end + +# 'Outer conditioning' +with_outer_cond = outer() | (@varname(a.x) => 1.0) + +# 'Inner conditioning' +inner_cond = inner() | (@varname(x) => 1.0) +@model function outer2() + println("outer context: $(unwrap_sampling_context(__context__))") + return a ~ to_submodel(inner_cond) +end +with_inner_cond = outer2() +``` + +We want that: + + 1. `keys(VarInfo(outer()))` should return `[a.x, a.y]`; + 2. `keys(VarInfo(with_outer_cond))` should return `[a.y]`; + 3. `keys(VarInfo(with_inner_cond))` should return `[a.y]`, + +**In other words, we can condition submodels either from the outside (point (2)) or from the inside (point (3)), and the variable name we use to specify the conditioning should match the level at which we perform the conditioning.** + +This is an incredibly salient point because it means that submodels can be treated as individual, opaque objects, and we can condition them without needing to know what it will be prefixed with, or the context in which that submodel is being used. +For example, this means we can reuse `inner_cond` in another model with a different prefix, and it will _still_ have its inner `x` value be conditioned, despite the prefix differing. + +!!! info + + In the current version of DynamicPPL, these criteria are all fulfilled. However, this was not the case in the past: in particular, point (3) was not fulfilled, and users had to condition the internal submodel with the prefixes that were used outside. (See [this GitHub issue](https://github.com/TuringLang/DynamicPPL.jl/issues/857) for more information; this issue was the direct motivation for this documentation page.) + +## Desiderata at the context level + +The above section describes how we expect conditioning and prefixing to behave from a user's perpective. +We now turn to the question of how we implement this in terms of DynamicPPL contexts. +We do not specify the implementation details here, but we will sketch out something resembling an API that will allow us to achieve the target behaviour. + +**Point (1)** does not involve any conditioning, only prefixing; it is therefore already satisfied by virtue of the `tilde_assume` method shown above. + +**Points (2) and (3)** are more tricky. +As the reader may surmise, the difference between them is the order in which the contexts are stacked. + +For the _outer_ conditioning case (point (2)), the `ConditionContext` will contain a `VarName` that is already prefixed. +When we enter the inner submodel, this `ConditionContext` has to be passed down and somehow combined with the `PrefixContext` that is created when we enter the submodel. +We make the claim here that the best way to do this is to nest the `PrefixContext` _inside_ the `ConditionContext`. +This is indeed what happens, as can be demonstrated by running the model. + +```@example +with_outer_cond(); +nothing; +``` + +!!! info + + The `; nothing` at the end is purely to circumvent a Documenter.jl quirk where stdout is only shown if the return value of the final statement is `nothing`. + If these documentation pages are moved to Quarto, it will be possible to remove this. + +For the _inner_ conditioning case (point (3)), the outer model is not run with any special context. +The inner model will itself contain a `ConditionContext` will contain a `VarName` that is not prefixed. +When we run the model, this `ConditionContext` should be then nested _inside_ a `PrefixContext` to form the final evaluation context. +Again, we can run the model to see this in action: + +```@example +with_inner_cond(); +nothing; +``` + +Putting all of the information so far together, what it means is that if we have these two inner contexts (taken from above): + +```@example +using DynamicPPL: PrefixContext, ConditionContext, DefaultContext + +inner_ctx_with_outer_cond = ConditionContext( + Dict(@varname(a.x) => 1.0), PrefixContext(@varname(a)) +) +inner_ctx_with_inner_cond = PrefixContext( + @varname(a), ConditionContext(Dict(@varname(x) => 1.0)) +) +``` + +then we want both of these to be `true` (and thankfully, they are!): + +```@example +DynamicPPL.hasconditioned_nested(inner_ctx_with_outer_cond, @varname(a.x)) +``` + +```@example +DynamicPPL.hasconditioned_nested(inner_ctx_with_inner_cond, @varname(a.x)) +``` + +This allows us to finally specify our task as follows: + +(1) Given the correct arguments, we need to make sure that `hasconditioned_nested` and `getconditioned_nested` behave correctly. + +(2) We need to make sure that both the correct arguments are supplied. In order to do so: + + - (2a) We need to make sure that when evaluating a submodel, the context stack is arranged such that `PrefixContext` is applied _inside_ the parent model's context, but _outside_ the submodel's own context. + + - (2b) We also need to make sure that the `VarName` passed to it is prefixed correctly. + +## How do we do it? + +(1) `hasconditioned_nested` and `getconditioned_nested` accomplish this by first 'collapsing' the context stack, i.e. they go through the context stack, remove all `PrefixContext`s, and apply those prefixes to any conditioned variables below it in the stack. +Once the `PrefixContext`s have been removed, one can then iterate through the context stack and check if any of the `ConditionContext`s contain the variable, or get the value itself. +For more details the reader is encouraged to read the source code. + +(2a) We ensure that the context stack is correctly arranged by relying on the behaviour of `make_evaluate_args_and_kwargs`. +This function is called whenever a model (which itself contains a context) is evaluated with a separate ('external') context, and makes sure to arrange both of these contexts such that _the model's context is nested inside the external context_. +Thus, as long as prefixing is implemented by applying a `PrefixContext` on the outermost layer of the _inner_ model context, this will be correctly combined with an external context to give the behaviour seen above. + +(2b) At first glance, it seems like `tilde_assume` can take care of the `VarName` prefixing for us (as described in the first section). +However, this is not actually the case: `contextual_isassumption`, which is the function that calls `hasconditioned_nested`, is much higher in the call stack than `tilde_assume` is. +So, we need to explicitly prefix it before passing it to `contextual_isassumption`. +This is done inside the `@model` macro, or technically, its subsidiary function `isassumption`. + +## Nested submodels + +Just in case the above wasn't complicated enough, we need to also be very careful when dealing with nested submodels, which have multiple layers of `PrefixContext`s which may be interspersed with `ConditionContext`s. +For example, in this series of nested submodels, + +```@example +@model function charlie() + x ~ Normal() + y ~ Normal() + return z ~ Normal() +end +@model function bravo() + return b ~ to_submodel(charlie() | (@varname(x) => 1.0)) +end +@model function alpha() + return a ~ to_submodel(bravo() | (@varname(b.y) => 1.0)) +end +``` + +we expect that the only variable to be sampled should be `z` inside `charlie`, or rather, `a.b.z` once it has been through the prefixes. + +```@example +keys(VarInfo(alpha())) +``` + +The general strategy that we adopt is similar to above. +Following the principle that `PrefixContext` should be nested inside the outer context, but outside the inner submodel's context, we can infer that the correct context inside `charlie` should be: + +```@example +big_ctx = PrefixContext( + @varname(a), + ConditionContext( + Dict(@varname(b.y) => 1.0), + PrefixContext(@varname(b), ConditionContext(Dict(@varname(x) => 1.0))), + ), +) +``` + +We need several things to work correctly here: we need the `VarName` prefixing to behave correctly, and then we need to implement `hasconditioned_nested` and `getconditioned_nested` on the resulting prefixed `VarName`. +It turns out that the prefixing itself is enough to illustrate the most important point in this section, namely, the need to traverse the context stack in a _different direction_ to what most of DynamicPPL does. + +Let's work with a function called `myprefix(::AbstractContext, ::VarName)` (to avoid confusion with any existing DynamicPPL function). +We should like `myprefix(big_ctx, @varname(x))` to return `@varname(a.b.x)`. +Consider the following naive implementation, which mirrors a lot of code in the tilde-pipeline: + +```@example +using DynamicPPL: NodeTrait, IsLeaf, IsParent, childcontext, AbstractContext +using AbstractPPL: AbstractPPL + +function myprefix(ctx::DynamicPPL.AbstractContext, vn::VarName) + return myprefix(NodeTrait(ctx), ctx, vn) +end +function myprefix(::IsLeaf, ::AbstractContext, vn::VarName) + return vn +end +function myprefix(::IsParent, ctx::AbstractContext, vn::VarName) + return myprefix(childcontext(ctx), vn) +end +function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName) + # The functionality to actually manipulate the VarNames is in AbstractPPL + new_vn = AbstractPPL.prefix(vn, ctx.vn_prefix) + # Then pass to the child context + return myprefix(childcontext(ctx), new_vn) +end + +myprefix(big_ctx, @varname(x)) +``` + +This implementation clearly is not correct, because it applies the _inner_ `PrefixContext` before the outer one. + +The right way to implement `myprefix` is to, essentially, reverse the order of two lines above: + +```@example +function myprefix(ctx::DynamicPPL.PrefixContext, vn::VarName) + # Pass to the child context first + new_vn = myprefix(childcontext(ctx), vn) + # Then apply this context's prefix + return AbstractPPL.prefix(new_vn, ctx.vn_prefix) +end + +myprefix(big_ctx, @varname(x)) +``` + +This is a much better result! +The implementation of related functions such as `hasconditioned_nested` and `getconditioned_nested`, under the hood, use a similar recursion scheme, so you will find that this is a common pattern when reading the source code of various prefixing-related functions. +When editing this code, it is worth being mindful of this as a potential source of incorrectness. + +!!! info + + If you have encountered left and right folds, the above discussion illustrates the difference between them: the wrong implementation of `myprefix` uses a left fold (which collects prefixes in the opposite order from which they are encountered), while the correct implementation uses a right fold. + +## Loose ends 1: Manual prefixing + +Sometimes users may want to manually prefix a model, for example: + +```@example +@model function inner_manual() + x ~ Normal() + return y ~ Normal() +end + +@model function outer_manual() + return _unused ~ to_submodel(prefix(inner_manual(), :a), false) +end +``` + +In this case, the `VarName` on the left-hand side of the tilde is not used, and the prefix is instead specified using the `prefix` function. + +The way to deal with this follows on from the previous discussion. +Specifically, we said that: + +> [...] as long as prefixing is implemented by applying a `PrefixContext` on the outermost layer of the _inner_ model context, this will be correctly combined [...] + +When automatic prefixing is used, this application of `PrefixContext` occurs inside the `tilde_assume!!` method. +In the manual prefixing case, we need to make sure that `prefix(submodel::Model, ::Symbol)` does the same thing, i.e. it inserts a `PrefixContext` at the outermost layer of `submodel`'s context. +We can see that this is precisely what happens: + +```@example +@model f() = x ~ Normal() + +model = f() +prefixed_model = prefix(model, :a) + +(model.context, prefixed_model.context) +``` + +## Loose ends 2: FixedContext + +Finally, note that all of the above also applies to the interaction between `PrefixContext` and `FixedContext`, except that the functions have different names. +(`FixedContext` behaves the same way as `ConditionContext`, except that unlike conditioned variables, fixed variables do not contribute to the log probability density.) +This generally results in a large amount of code duplication, but the concepts that underlie both contexts are exactly the same. diff --git a/src/compiler.jl b/src/compiler.jl index 4771b0171..6f7489b8e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -53,7 +53,9 @@ function isassumption( vn=AbstractPPL.drop_escape(varname(expr, need_concretize(expr))), ) return quote - if $(DynamicPPL.contextual_isassumption)(__context__, $vn) + if $(DynamicPPL.contextual_isassumption)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + ) # Considered an assumption by `__context__` which means either: # 1. We hit the default implementation, e.g. using `DefaultContext`, # which in turn means that we haven't considered if it's one of @@ -87,67 +89,45 @@ isassumption(expr) = :(false) contextual_isassumption(context, vn) Return `true` if `vn` is considered an assumption by `context`. - -The default implementation for `AbstractContext` always returns `true`. """ -contextual_isassumption(::IsLeaf, context, vn) = true -function contextual_isassumption(::IsParent, context, vn) - return contextual_isassumption(childcontext(context), vn) -end function contextual_isassumption(context::AbstractContext, vn) - return contextual_isassumption(NodeTrait(context), context, vn) -end -function contextual_isassumption(context::ConditionContext, vn) - if hasconditioned(context, vn) - val = getconditioned(context, vn) + if hasconditioned_nested(context, vn) + val = getconditioned_nested(context, vn) # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? if eltype(val) >: Missing && val === missing return true else return false end + else + return true end - - # We might have nested contexts, e.g. `ConditionContext{.., <:PrefixContext{..., <:ConditionContext}}` - # so we defer to `childcontext` if we haven't concluded that anything yet. - return contextual_isassumption(childcontext(context), vn) -end -function contextual_isassumption(context::PrefixContext, vn) - return contextual_isassumption(childcontext(context), prefix(context, vn)) end isfixed(expr, vn) = false -isfixed(::Union{Symbol,Expr}, vn) = :($(DynamicPPL.contextual_isfixed)(__context__, $vn)) +function isfixed(::Union{Symbol,Expr}, vn) + return :($(DynamicPPL.contextual_isfixed)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + )) +end """ contextual_isfixed(context, vn) Return `true` if `vn` is considered fixed by `context`. """ -contextual_isfixed(::IsLeaf, context, vn) = false -function contextual_isfixed(::IsParent, context, vn) - return contextual_isfixed(childcontext(context), vn) -end function contextual_isfixed(context::AbstractContext, vn) - return contextual_isfixed(NodeTrait(context), context, vn) -end -function contextual_isfixed(context::PrefixContext, vn) - return contextual_isfixed(childcontext(context), prefix(context, vn)) -end -function contextual_isfixed(context::FixedContext, vn) - if hasfixed(context, vn) - val = getfixed(context, vn) + if hasfixed_nested(context, vn) + val = getfixed_nested(context, vn) # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? if eltype(val) >: Missing && val === missing return false else return true end + else + return false end - - # We might have nested contexts, e.g. `FixedContext{.., <:PrefixContext{..., <:FixedContext}}` - # so we defer to `childcontext` if we haven't concluded that anything yet. - return contextual_isfixed(childcontext(context), vn) end # If we're working with, say, a `Symbol`, then we're not going to `view`. @@ -467,13 +447,17 @@ function generate_tilde(left, right) ) $isassumption = $(DynamicPPL.isassumption(left, vn)) if $(DynamicPPL.isfixed(left, vn)) - $left = $(DynamicPPL.getfixed_nested)(__context__, $vn) + $left = $(DynamicPPL.getfixed_nested)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + ) elseif $isassumption $(generate_tilde_assume(left, dist, vn)) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) - $left = $(DynamicPPL.getconditioned_nested)(__context__, $vn) + $left = $(DynamicPPL.getconditioned_nested)( + __context__, $(DynamicPPL.prefix)(__context__, $vn) + ) end $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e4ba5d252..eb025dec8 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -85,12 +85,23 @@ function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, rig end function tilde_assume(context::PrefixContext, right, vn, vi) - return tilde_assume(context.context, right, prefix(context, vn), vi) + # Note that we can't use something like this here: + # new_vn = prefix(context, vn) + # return tilde_assume(childcontext(context), right, new_vn, vi) + # This is because `prefix` applies _all_ prefixes in a given context to a + # variable name. Thus, if we had two levels of nested prefixes e.g. + # `PrefixContext{:a}(PrefixContext{:b}(DefaultContext()))`, then the + # first call would apply the prefix `a.b._`, and the recursive call + # would apply the prefix `b._`, resulting in `b.a.b._`. + # This is why we need a special function, `prefix_and_strip_contexts`. + new_vn, new_context = prefix_and_strip_contexts(context, vn) + return tilde_assume(new_context, right, new_vn, vi) end function tilde_assume( rng::Random.AbstractRNG, context::PrefixContext, sampler, right, vn, vi ) - return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), vi) + new_vn, new_context = prefix_and_strip_contexts(context, vn) + return tilde_assume(rng, new_context, sampler, right, new_vn, vi) end """ @@ -104,12 +115,27 @@ probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, vi) return if is_rhs_model(right) - # Prefix the variables using the `vn`. - rand_like!!( - right, - should_auto_prefix(right) ? PrefixContext{Symbol(vn)}(context) : context, - vi, - ) + # Here, we apply the PrefixContext _not_ to the parent `context`, but + # to the context of the submodel being evaluated. This means that later= + # on in `make_evaluate_args_and_kwargs`, the context stack will be + # correctly arranged such that it goes like this: + # parent_context[1] -> parent_context[2] -> ... -> PrefixContext -> + # submodel_context[1] -> submodel_context[2] -> ... -> leafcontext + # See the docstring of `make_evaluate_args_and_kwargs`, and the internal + # DynamicPPL documentation on submodel conditioning, for more details. + # + # NOTE: This relies on the existence of `right.model.model`. Right now, + # the only thing that can return true for `is_rhs_model` is something + # (a `Sampleable`) that has a `model` field that itself (a + # `ReturnedModelWrapper`) has a `model` field. This may or may not + # change in the future. + if should_auto_prefix(right) + dppl_model = right.model.model # This isa DynamicPPL.Model + prefixed_submodel_context = PrefixContext(vn, dppl_model.context) + new_dppl_model = contextualize(dppl_model, prefixed_submodel_context) + right = to_submodel(new_dppl_model, true) + end + rand_like!!(right, context, vi) else value, logp, vi = tilde_assume(context, right, vn, vi) value, acclogp_assume!!(context, vi, logp) diff --git a/src/contexts.jl b/src/contexts.jl index 58ac612b8..8ac085663 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -237,27 +237,34 @@ function setchildcontext(parent::MiniBatchContext, child) end """ - PrefixContext{Prefix}(context) + PrefixContext(vn::VarName[, context::AbstractContext]) + PrefixContext(vn::Val{sym}[, context::AbstractContext]) where {sym} Create a context that allows you to use the wrapped `context` when running the model and -adds the `Prefix` to all parameters. +prefixes all parameters with the VarName `vn`. + +`PrefixContext(Val(:a), context)` is equivalent to `PrefixContext(@varname(a), context)`. +If `context` is not provided, it defaults to `DefaultContext()`. This context is useful in nested models to ensure that the names of the parameters are unique. See also: [`to_submodel`](@ref) """ -struct PrefixContext{Prefix,C} <: AbstractContext +struct PrefixContext{Tvn<:VarName,C<:AbstractContext} <: AbstractContext + vn_prefix::Tvn context::C end -function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(context)}(context) +PrefixContext(vn::VarName) = PrefixContext(vn, DefaultContext()) +function PrefixContext(::Val{sym}, context::AbstractContext) where {sym} + return PrefixContext(VarName{sym}(), context) end +PrefixContext(::Val{sym}) where {sym} = PrefixContext(VarName{sym}()) NodeTrait(::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context -function setchildcontext(::PrefixContext{Prefix}, child) where {Prefix} - return PrefixContext{Prefix}(child) +function setchildcontext(ctx::PrefixContext, child::AbstractContext) + return PrefixContext(ctx.vn_prefix, child) end """ @@ -265,8 +272,8 @@ end Apply the prefixes in the context `ctx` to the variable name `vn`. """ -function prefix(ctx::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} - return AbstractPPL.prefix(prefix(childcontext(ctx), vn), VarName{Symbol(Prefix)}()) +function prefix(ctx::PrefixContext, vn::VarName) + return AbstractPPL.prefix(prefix(childcontext(ctx), vn), ctx.vn_prefix) end function prefix(ctx::AbstractContext, vn::VarName) return prefix(NodeTrait(ctx), ctx, vn) @@ -277,11 +284,52 @@ function prefix(::IsParent, ctx::AbstractContext, vn::VarName) end """ - prefix(model::Model, x) + prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) + +Same as `prefix`, but additionally returns a new context stack that has all the +PrefixContexts removed. -Return `model` but with all random variables prefixed by `x`. +NOTE: This does _not_ modify any variables in any `ConditionContext` and +`FixedContext` that may be present in the context stack. This is because this +function is only used in `tilde_assume`, which is lower in the tilde-pipeline +than `contextual_isassumption` and `contextual_isfixed` (the functions which +actually use the `ConditionContext` and `FixedContext` values). Thus, by this +time, any `ConditionContext`s and `FixedContext`s present have already served +their purpose. -If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing. +If you call this function, you must therefore be careful to ensure that you _do +not_ need to modify any inner `ConditionContext`s and `FixedContext`s. If you +_do_ need to modify them, then you may need to use +`prefix_cond_and_fixed_variables` instead. +""" +function prefix_and_strip_contexts(ctx::PrefixContext, vn::VarName) + child_context = childcontext(ctx) + # vn_prefixed contains the prefixes from all lower levels + vn_prefixed, child_context_without_prefixes = prefix_and_strip_contexts( + child_context, vn + ) + return AbstractPPL.prefix(vn_prefixed, ctx.vn_prefix), child_context_without_prefixes +end +function prefix_and_strip_contexts(ctx::AbstractContext, vn::VarName) + return prefix_and_strip_contexts(NodeTrait(ctx), ctx, vn) +end +prefix_and_strip_contexts(::IsLeaf, ctx::AbstractContext, vn::VarName) = (vn, ctx) +function prefix_and_strip_contexts(::IsParent, ctx::AbstractContext, vn::VarName) + vn, new_ctx = prefix_and_strip_contexts(childcontext(ctx), vn) + return vn, setchildcontext(ctx, new_ctx) +end + +""" + prefix(model::Model, x::VarName) + prefix(model::Model, x::Val{sym}) + prefix(model::Model, x::Any) + +Return `model` but with all random variables prefixed by `x`, where `x` is either: +- a `VarName` (e.g. `@varname(a)`), +- a `Val{sym}` (e.g. `Val(:a)`), or +- for any other type, `x` is converted to a Symbol and then to a `VarName`. Note that + this will introduce runtime overheads so is not recommended unless absolutely + necessary. # Examples @@ -291,17 +339,19 @@ julia> using DynamicPPL: prefix julia> @model demo() = x ~ Dirac(1) demo (generic function with 2 methods) -julia> rand(prefix(demo(), :my_prefix)) +julia> rand(prefix(demo(), @varname(my_prefix))) (var"my_prefix.x" = 1,) -julia> # One can also use `Val` to avoid runtime overheads. - rand(prefix(demo(), Val(:my_prefix))) +julia> rand(prefix(demo(), Val(:my_prefix))) (var"my_prefix.x" = 1,) ``` """ -prefix(model::Model, x) = contextualize(model, PrefixContext{Symbol(x)}(model.context)) -function prefix(model::Model, ::Val{x}) where {x} - return contextualize(model, PrefixContext{Symbol(x)}(model.context)) +prefix(model::Model, x::VarName) = contextualize(model, PrefixContext(x, model.context)) +function prefix(model::Model, x::Val{sym}) where {sym} + return contextualize(model, PrefixContext(VarName{sym}(), model.context)) +end +function prefix(model::Model, x) + return contextualize(model, PrefixContext(VarName{Symbol(x)}(), model.context)) end """ @@ -370,7 +420,9 @@ Return value of `vn` in `context`. function getconditioned(context::AbstractContext, vn::VarName) return error("context $(context) does not contain value for $vn") end -getconditioned(context::ConditionContext, vn::VarName) = getvalue(context.values, vn) +function getconditioned(context::ConditionContext, vn::VarName) + return getvalue(context.values, vn) +end """ hasconditioned_nested(context, vn) @@ -388,7 +440,7 @@ function hasconditioned_nested(::IsParent, context, vn) return hasconditioned(context, vn) || hasconditioned_nested(childcontext(context), vn) end function hasconditioned_nested(context::PrefixContext, vn) - return hasconditioned_nested(childcontext(context), prefix(context, vn)) + return hasconditioned_nested(collapse_prefix_stack(context), vn) end """ @@ -406,7 +458,7 @@ function getconditioned_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getconditioned_nested(context::PrefixContext, vn) - return getconditioned_nested(childcontext(context), prefix(context, vn)) + return getconditioned_nested(collapse_prefix_stack(context), vn) end function getconditioned_nested(::IsParent, context, vn) return if hasconditioned(context, vn) @@ -476,6 +528,9 @@ function conditioned(context::ConditionContext) # precedence over decendants of `context`. return _merge(context.values, conditioned(childcontext(context))) end +function conditioned(context::PrefixContext) + return conditioned(collapse_prefix_stack(context)) +end struct FixedContext{Values,Ctx<:AbstractContext} <: AbstractContext values::Values @@ -539,7 +594,7 @@ function hasfixed_nested(::IsParent, context, vn) return hasfixed(context, vn) || hasfixed_nested(childcontext(context), vn) end function hasfixed_nested(context::PrefixContext, vn) - return hasfixed_nested(childcontext(context), prefix(context, vn)) + return hasfixed_nested(collapse_prefix_stack(context), vn) end """ @@ -557,7 +612,7 @@ function getfixed_nested(::IsLeaf, context, vn) return error("context $(context) does not contain value for $vn") end function getfixed_nested(context::PrefixContext, vn) - return getfixed_nested(childcontext(context), prefix(context, vn)) + return getfixed_nested(collapse_prefix_stack(context), vn) end function getfixed_nested(::IsParent, context, vn) return if hasfixed(context, vn) @@ -652,3 +707,113 @@ function fixed(context::FixedContext) # precedence over decendants of `context`. return _merge(context.values, fixed(childcontext(context))) end +function fixed(context::PrefixContext) + return fixed(collapse_prefix_stack(context)) +end + +""" + collapse_prefix_stack(context::AbstractContext) + +Apply `PrefixContext`s to any conditioned or fixed values inside them, and remove +the `PrefixContext`s from the context stack. + +!!! note + If you are reading this docstring, you might probably be interested in a more +thorough explanation of how PrefixContext and ConditionContext / FixedContext +interact with one another, especially in the context of submodels. + The DynamicPPL documentation contains [a separate page on this +topic](https://turinglang.org/DynamicPPL.jl/previews/PR892/internals/submodel_condition/) +which explains this in much more detail. + +```jldoctest +julia> using DynamicPPL: collapse_prefix_stack + +julia> c1 = PrefixContext(@varname(a), ConditionContext((x=1, ))); + +julia> collapse_prefix_stack(c1) +ConditionContext(Dict(a.x => 1), DefaultContext()) + +julia> # Here, `x` gets prefixed only with `a`, whereas `y` is prefixed with both. + c2 = PrefixContext(@varname(a), ConditionContext((x=1, ), PrefixContext(@varname(b), ConditionContext((y=2,))))); + +julia> collapsed = collapse_prefix_stack(c2); + +julia> # `collapsed` really looks something like this: + # ConditionContext(Dict{VarName{:a}, Int64}(a.b.y => 2, a.x => 1), DefaultContext()) + # To avoid fragility arising from the order of the keys in the doctest, we test + # this indirectly: + collapsed.values[@varname(a.x)], collapsed.values[@varname(a.b.y)] +(1, 2) +``` +""" +function collapse_prefix_stack(context::PrefixContext) + # Collapse the child context (thus applying any inner prefixes first) + collapsed = collapse_prefix_stack(childcontext(context)) + # Prefix any conditioned variables with the current prefix + # Note: prefix_conditioned_variables is O(N) in the depth of the context stack. + # So is this function. In the worst case scenario, this is O(N^2) in the + # depth of the context stack. + return prefix_cond_and_fixed_variables(collapsed, context.vn_prefix) +end +function collapse_prefix_stack(context::AbstractContext) + return collapse_prefix_stack(NodeTrait(collapse_prefix_stack, context), context) +end +collapse_prefix_stack(::IsLeaf, context) = context +function collapse_prefix_stack(::IsParent, context) + new_child_context = collapse_prefix_stack(childcontext(context)) + return setchildcontext(context, new_child_context) +end + +""" + prefix_cond_and_fixed_variables(context::AbstractContext, prefix::VarName) + +Prefix all the conditioned and fixed variables in a given context with a single +`prefix`. + +```jldoctest +julia> using DynamicPPL: prefix_cond_and_fixed_variables, ConditionContext + +julia> c1 = ConditionContext((a=1, )) +ConditionContext((a = 1,), DefaultContext()) + +julia> prefix_cond_and_fixed_variables(c1, @varname(y)) +ConditionContext(Dict(y.a => 1), DefaultContext()) +``` +""" +function prefix_cond_and_fixed_variables(ctx::ConditionContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return ConditionContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(ctx::FixedContext, prefix::VarName) + # Replace the prefix of the conditioned variables + vn_dict = to_varname_dict(ctx.values) + prefixed_vn_dict = Dict( + AbstractPPL.prefix(vn, prefix) => value for (vn, value) in vn_dict + ) + # Prefix the child context as well + prefixed_child_ctx = prefix_cond_and_fixed_variables(childcontext(ctx), prefix) + return FixedContext(prefixed_vn_dict, prefixed_child_ctx) +end +function prefix_cond_and_fixed_variables(c::AbstractContext, prefix::VarName) + return prefix_cond_and_fixed_variables( + NodeTrait(prefix_cond_and_fixed_variables, c), c, prefix + ) +end +function prefix_cond_and_fixed_variables( + ::IsLeaf, context::AbstractContext, prefix::VarName +) + return context +end +function prefix_cond_and_fixed_variables( + ::IsParent, context::AbstractContext, prefix::VarName +) + return setchildcontext( + context, prefix_cond_and_fixed_variables(childcontext(context), prefix) + ) +end diff --git a/src/model.jl b/src/model.jl index b4d5f6bb7..c7c4bdf57 100644 --- a/src/model.jl +++ b/src/model.jl @@ -425,29 +425,32 @@ julia> # Returns all the variables we have conditioned on + their values. conditioned(condition(m, x=100.0, m=1.0)) (x = 100.0, m = 1.0) -julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). - cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext((m=1.0,)))), x=100.0); +julia> # Nested ones also work. + # (Note that `PrefixContext` also prefixes the variables of any + # ConditionContext that is _inside_ it; because of this, the type of the + # container has to be broadened to a `Dict`.) + cm = condition(contextualize(m, PrefixContext(@varname(a), ConditionContext((m=1.0,)))), x=100.0); -julia> conditioned(cm) -(x = 100.0, m = 1.0) +julia> Set(keys(conditioned(cm))) == Set([@varname(a.m), @varname(x)]) +true -julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, - # `a.m` is treated as a random variable. +julia> # Since we conditioned on `a.m`, it is not treated as a random variable. + # However, `a.x` will still be a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: - a.m - -julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. - cm = condition(contextualize(m, PrefixContext{:a}(ConditionContext(Dict(@varname(a.m) => 1.0)))), x=100.0); +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x -julia> conditioned(cm)[@varname(x)] -100.0 +julia> # We can also condition on `a.m` _outside_ of the PrefixContext: + cm = condition(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); -julia> conditioned(cm)[@varname(a.m)] -1.0 +julia> conditioned(cm) +Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: + a.m => 1.0 -julia> keys(VarInfo(cm)) # No variables are sampled -VarName[] +julia> # Now `a.x` will be sampled. + keys(VarInfo(cm)) +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x ``` """ conditioned(model::Model) = conditioned(model.context) @@ -765,29 +768,27 @@ julia> # Returns all the variables we have fixed on + their values. fixed(fix(m, x=100.0, m=1.0)) (x = 100.0, m = 1.0) -julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). - cm = fix(contextualize(m, PrefixContext{:a}(fix(m=1.0))), x=100.0); +julia> # The rest of this is the same as the `condition` example above. + cm = fix(contextualize(m, PrefixContext(@varname(a), fix(m=1.0))), x=100.0); -julia> fixed(cm) -(x = 100.0, m = 1.0) - -julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed, - # `a.m` is treated as a random variable. - keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}: - a.m +julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)]) +true -julia> # If we instead fix on `a.m`, `m` in the model will be considered an observation. - cm = fix(contextualize(m, PrefixContext{:a}(fix(@varname(a.m) => 1.0,))), x=100.0); +julia> keys(VarInfo(cm)) +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x -julia> fixed(cm)[@varname(x)] -100.0 +julia> # We can also condition on `a.m` _outside_ of the PrefixContext: + cm = fix(contextualize(m, PrefixContext(@varname(a))), (@varname(a.m) => 1.0)); -julia> fixed(cm)[@varname(a.m)] -1.0 +julia> fixed(cm) +Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: + a.m => 1.0 -julia> keys(VarInfo(cm)) # <= no variables are sampled -VarName[] +julia> # Now `a.x` will be sampled. + keys(VarInfo(cm)) +1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: + a.x ``` """ fixed(model::Model) = fixed(model.context) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index f6b9c4479..5f1ec95ec 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -223,12 +223,12 @@ end prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx) function prefix_submodel_context(prefix, ctx) # E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated. - return :($(PrefixContext){$(Symbol)($(esc(prefix)))}($ctx)) + return :($(PrefixContext)($(Val)($(Symbol)($(esc(prefix)))), $ctx)) end function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx) # E.g. `prefix="asd"`. - return :($(PrefixContext){$(esc(Meta.quot(Symbol(prefix))))}($ctx)) + return :($(PrefixContext)($(esc(Meta.quot(Val(Symbol(prefix))))), $ctx)) end function prefix_submodel_context(prefix::Bool, ctx) diff --git a/src/utils.jl b/src/utils.jl index 56c3d70af..71919480c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1286,7 +1286,10 @@ broadcast_safe(x::Distribution) = (x,) broadcast_safe(x::AbstractContext) = (x,) # Convert (x=1,) to Dict(@varname(x) => 1) -_nt_to_varname_dict(nt) = Dict(VarName{k}() => v for (k, v) in pairs(nt)) +function to_varname_dict(nt::NamedTuple) + return Dict{VarName,Any}(VarName{k}() => v for (k, v) in pairs(nt)) +end +to_varname_dict(d::AbstractDict) = d # Version of `merge` used by `conditioned` and `fixed` to handle # the scenario where we might try to merge a dict with an empty # tuple. @@ -1294,9 +1297,9 @@ _nt_to_varname_dict(nt) = Dict(VarName{k}() => v for (k, v) in pairs(nt)) _merge(left::NamedTuple, right::NamedTuple) = merge(left, right) _merge(left::AbstractDict, right::AbstractDict) = merge(left, right) _merge(left::AbstractDict, ::NamedTuple{()}) = left -_merge(left::AbstractDict, right::NamedTuple) = merge(left, _nt_to_varname_dict(right)) +_merge(left::AbstractDict, right::NamedTuple) = merge(left, to_varname_dict(right)) _merge(::NamedTuple{()}, right::AbstractDict) = right -_merge(left::NamedTuple, right::AbstractDict) = merge(_nt_to_varname_dict(left), right) +_merge(left::NamedTuple, right::AbstractDict) = merge(to_varname_dict(left), right) """ unique_syms(vns::T) where {T<:NTuple{N,VarName}} diff --git a/test/contexts.jl b/test/contexts.jl index 11e591f8f..1ba099a37 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,4 +1,5 @@ using Test, DynamicPPL, Accessors +using AbstractPPL: getoptic using DynamicPPL: leafcontext, setleafcontext, @@ -10,12 +11,18 @@ using DynamicPPL: IsParent, PointwiseLogdensityContext, contextual_isassumption, + FixedContext, ConditionContext, decondition_context, hasconditioned, getconditioned, + conditioned, + fixed, hasconditioned_nested, - getconditioned_nested + getconditioned_nested, + collapse_prefix_stack, + prefix_cond_and_fixed_variables, + getvalue using EnzymeCore @@ -50,14 +57,15 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() :testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()), :sampling => SamplingContext(), :minibatch => MiniBatchContext(DefaultContext(), 0.0), - :prefix => PrefixContext{:x}(DefaultContext()), + :prefix => PrefixContext(@varname(x)), :pointwiselogdensity => PointwiseLogdensityContext(), :condition1 => ConditionContext((x=1.0,)), :condition2 => ConditionContext( (x=1.0,), DynamicPPL.TestUtils.TestParentContext(ConditionContext((y=2.0,))) ), :condition3 => ConditionContext( - (x=1.0,), PrefixContext{:a}(ConditionContext(Dict(@varname(a.y) => 2.0))) + (x=1.0,), + PrefixContext(@varname(a), ConditionContext(Dict(@varname(y) => 2.0))), ), :condition4 => ConditionContext((x=[1.0, missing],)), ) @@ -70,91 +78,52 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end end - @testset "contextual_isassumption" begin - @testset "$(name)" for (name, context) in contexts - # Any `context` should return `true` by default. - @test contextual_isassumption(context, VarName{gensym(:x)}()) - - if any(Base.Fix2(isa, ConditionContext), context) - # We have a `ConditionContext` among us. - # Let's first extract the conditioned variables. - conditioned_values = DynamicPPL.conditioned(context) + @testset "extracting conditioned values" begin + # This testset tests `contextual_isassumption`, `getconditioned_nested`, and + # `hasconditioned_nested`. - # The conditioned values might be a NamedTuple, or a Dict. - # We convert to a Dict for consistency - if conditioned_values isa NamedTuple - conditioned_values = Dict( - VarName{sym}() => val for (sym, val) in pairs(conditioned_values) - ) - end - - for (vn, val) in pairs(conditioned_values) - # We need to drop the prefix of `var` since in `contextual_isassumption` - # it will be threaded through the `PrefixContext` before it reaches - # `ConditionContext` with the conditioned variable. - vn_without_prefix = if getoptic(vn) isa PropertyLens - # Hacky: This assumes that there is exactly one level of prefixing - # that we need to undo. This is appropriate for the :condition3 - # test case above, but is not generally correct. - AbstractPPL.unprefix(vn, VarName{getsym(vn)}()) - else - vn - end - - @show DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - # Let's check elementwise. - for vn_child in - DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - if getoptic(vn_child)(val) === missing - @test contextual_isassumption(context, vn_child) - else - @test !contextual_isassumption(context, vn_child) - end - end - end - end - end - end - - @testset "getconditioned_nested & hasconditioned_nested" begin - @testset "$name" for (name, context) in contexts + @testset "$(name)" for (name, context) in contexts + # If the varname doesn't exist, it should always be an assumption. fake_vn = VarName{gensym(:x)}() + @test contextual_isassumption(context, fake_vn) @test !hasconditioned_nested(context, fake_vn) @test_throws ErrorException getconditioned_nested(context, fake_vn) if any(Base.Fix2(isa, ConditionContext), context) - # `ConditionContext` specific. - + # We have a `ConditionContext` among us. # Let's first extract the conditioned variables. conditioned_values = DynamicPPL.conditioned(context) + # The conditioned values might be a NamedTuple, or a Dict. # We convert to a Dict for consistency - if conditioned_values isa NamedTuple - conditioned_values = Dict( - VarName{sym}() => val for (sym, val) in pairs(conditioned_values) - ) - end + conditioned_values = DynamicPPL.to_varname_dict(conditioned_values) + + # Extract all conditioned variables. We also use varname_leaves + # here to split up arrays which could potentially have some, + # but not all, elements being `missing`. + conditioned_vns = mapreduce( + p -> DynamicPPL.TestUtils.varname_leaves(p.first, p.second), + vcat, + pairs(conditioned_values), + ) - for (vn, val) in pairs(conditioned_values) - # We need to drop the prefix of `var` since in `contextual_isassumption` - # it will be threaded through the `PrefixContext` before it reaches - # `ConditionContext` with the conditioned variable. - vn_without_prefix = if getoptic(vn) isa PropertyLens - # Hacky: This assumes that there is exactly one level of prefixing - # that we need to undo. This is appropriate for the :condition3 - # test case above, but is not generally correct. - AbstractPPL.unprefix(vn, VarName{getsym(vn)}()) + # We can now loop over them to check which ones are missing. We use + # `getvalue` to handle the awkward case where sometimes + # `conditioned_values` contains the full Varname (e.g. `a.x`) and + # sometimes only the main symbol (e.g. it contains `x` when + # `vn` is `x[1]`) + for vn in conditioned_vns + val = DynamicPPL.getvalue(conditioned_values, vn) + # These VarNames are present in the conditioning values, so + # we should always be able to extract the value. + @test hasconditioned_nested(context, vn) + @test getconditioned_nested(context, vn) === val + # However, the return value of contextual_isassumption depends on + # whether the value is missing or not. + if ismissing(val) + @test contextual_isassumption(context, vn) else - vn - end - - for vn_child in - DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) - # `vn_child` should be in `context`. - @test hasconditioned_nested(context, vn_child) - # Value should be the same as extracted above. - @test getconditioned_nested(context, vn_child) === - getoptic(vn_child)(val) + @test !contextual_isassumption(context, vn) end end end @@ -163,39 +132,68 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "PrefixContext" begin @testset "prefixing" begin - ctx = @inferred PrefixContext{:a}( - PrefixContext{:b}( - PrefixContext{:c}( - PrefixContext{:d}( - PrefixContext{:e}(PrefixContext{:f}(DefaultContext())) + ctx = @inferred PrefixContext( + @varname(a), + PrefixContext( + @varname(b), + PrefixContext( + @varname(c), + PrefixContext( + @varname(d), + PrefixContext( + @varname(e), PrefixContext(@varname(f), DefaultContext()) + ), ), ), ), ) - vn = VarName{:x}() + vn = @varname(x) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test vn_prefixed == @varname(a.b.c.d.e.f.x) - vn = VarName{:x}(((1,),)) + vn = @varname(x[1]) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test vn_prefixed == @varname(a.b.c.d.e.f.x[1]) end @testset "nested within arbitrary context stacks" begin vn = @varname(x[1]) - ctx1 = PrefixContext{:a}(DefaultContext()) + ctx1 = PrefixContext(@varname(a)) @test DynamicPPL.prefix(ctx1, vn) == @varname(a.x[1]) ctx2 = SamplingContext(ctx1) @test DynamicPPL.prefix(ctx2, vn) == @varname(a.x[1]) - ctx3 = PrefixContext{:b}(ctx2) + ctx3 = PrefixContext(@varname(b), ctx2) @test DynamicPPL.prefix(ctx3, vn) == @varname(b.a.x[1]) ctx4 = DynamicPPL.ValuesAsInModelContext(OrderedDict(), false, ctx3) @test DynamicPPL.prefix(ctx4, vn) == @varname(b.a.x[1]) end + @testset "prefix_and_strip_contexts" begin + vn = @varname(x[1]) + ctx1 = PrefixContext(@varname(a)) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx1, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == DefaultContext() + + ctx2 = SamplingContext(PrefixContext(@varname(a))) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx2, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == SamplingContext() + + ctx3 = PrefixContext(@varname(a), ConditionContext((a=1,))) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx3, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == ConditionContext((a=1,)) + + ctx4 = SamplingContext(PrefixContext(@varname(a), ConditionContext((a=1,)))) + new_vn, new_ctx = DynamicPPL.prefix_and_strip_contexts(ctx4, vn) + @test new_vn == @varname(a.x[1]) + @test new_ctx == SamplingContext(ConditionContext((a=1,))) + end + @testset "evaluation: $(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - prefix = :my_prefix - context = DynamicPPL.PrefixContext{prefix}(SamplingContext()) + prefix_vn = @varname(my_prefix) + context = DynamicPPL.PrefixContext(prefix_vn, SamplingContext()) # Sample with the context. varinfo = DynamicPPL.VarInfo() DynamicPPL.evaluate!!(model, varinfo, context) @@ -204,7 +202,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # Extract the ground truth varnames vns_expected = Set([ - AbstractPPL.prefix(vn, VarName{prefix}()) for + AbstractPPL.prefix(vn, prefix_vn) for vn in DynamicPPL.TestUtils.varnames(model) ]) @@ -343,4 +341,103 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @test logprior(model_fixed, (; m)) == logprior(condition(model; s=s), (; m)) end end + + @testset "PrefixContext + Condition/FixedContext interactions" begin + @testset "prefix_cond_and_fixed_variables" begin + c1 = ConditionContext((c=1, d=2)) + c1_prefixed = prefix_cond_and_fixed_variables(c1, @varname(a)) + @test c1_prefixed isa ConditionContext + @test childcontext(c1_prefixed) isa DefaultContext + @test c1_prefixed.values[@varname(a.c)] == 1 + @test c1_prefixed.values[@varname(a.d)] == 2 + + c2 = FixedContext((f=1, g=2)) + c2_prefixed = prefix_cond_and_fixed_variables(c2, @varname(a)) + @test c2_prefixed isa FixedContext + @test childcontext(c2_prefixed) isa DefaultContext + @test c2_prefixed.values[@varname(a.f)] == 1 + @test c2_prefixed.values[@varname(a.g)] == 2 + + c3 = ConditionContext((c=1, d=2), FixedContext((f=1, g=2))) + c3_prefixed = prefix_cond_and_fixed_variables(c3, @varname(a)) + c3_prefixed_child = childcontext(c3_prefixed) + @test c3_prefixed isa ConditionContext + @test c3_prefixed.values[@varname(a.c)] == 1 + @test c3_prefixed.values[@varname(a.d)] == 2 + @test c3_prefixed_child isa FixedContext + @test c3_prefixed_child.values[@varname(a.f)] == 1 + @test c3_prefixed_child.values[@varname(a.g)] == 2 + @test childcontext(c3_prefixed_child) isa DefaultContext + end + + @testset "collapse_prefix_stack" begin + # Utility function to make sure that there are no PrefixContexts in + # the context stack. + function has_no_prefixcontexts(ctx::AbstractContext) + return !(ctx isa PrefixContext) && ( + NodeTrait(ctx) isa IsLeaf || has_no_prefixcontexts(childcontext(ctx)) + ) + end + + # Prefix -> Condition + c1 = PrefixContext(@varname(a), ConditionContext((c=1, d=2))) + c1 = collapse_prefix_stack(c1) + @test has_no_prefixcontexts(c1) + c1_vals = conditioned(c1) + @test length(c1_vals) == 2 + @test getvalue(c1_vals, @varname(a.c)) == 1 + @test getvalue(c1_vals, @varname(a.d)) == 2 + + # Condition -> Prefix + c2 = ConditionContext((c=1, d=2), PrefixContext(@varname(a))) + c2 = collapse_prefix_stack(c2) + @test has_no_prefixcontexts(c2) + c2_vals = conditioned(c2) + @test length(c2_vals) == 2 + @test getvalue(c2_vals, @varname(c)) == 1 + @test getvalue(c2_vals, @varname(d)) == 2 + + # Prefix -> Fixed + c3 = PrefixContext(@varname(a), FixedContext((f=1, g=2))) + c3 = collapse_prefix_stack(c3) + c3_vals = fixed(c3) + @test length(c3_vals) == 2 + @test length(c3_vals) == 2 + @test getvalue(c3_vals, @varname(a.f)) == 1 + @test getvalue(c3_vals, @varname(a.g)) == 2 + + # Fixed -> Prefix + c4 = FixedContext((f=1, g=2), PrefixContext(@varname(a))) + c4 = collapse_prefix_stack(c4) + @test has_no_prefixcontexts(c4) + c4_vals = fixed(c4) + @test length(c4_vals) == 2 + @test getvalue(c4_vals, @varname(f)) == 1 + @test getvalue(c4_vals, @varname(g)) == 2 + + # Prefix -> Condition -> Prefix -> Condition + c5 = PrefixContext( + @varname(a), + ConditionContext( + (c=1,), PrefixContext(@varname(b), ConditionContext((d=2,))) + ), + ) + c5 = collapse_prefix_stack(c5) + @test has_no_prefixcontexts(c5) + c5_vals = conditioned(c5) + @test length(c5_vals) == 2 + @test getvalue(c5_vals, @varname(a.c)) == 1 + @test getvalue(c5_vals, @varname(a.b.d)) == 2 + + # Prefix -> Condition -> Prefix -> Fixed + c6 = PrefixContext( + @varname(a), + ConditionContext((c=1,), PrefixContext(@varname(b), FixedContext((d=2,)))), + ) + c6 = collapse_prefix_stack(c6) + @test has_no_prefixcontexts(c6) + @test conditioned(c6) == Dict(@varname(a.c) => 1) + @test fixed(c6) == Dict(@varname(a.b.d) => 2) + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 3473d5594..72f33f2d0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -67,6 +67,7 @@ include("test_util.jl") include("threadsafe.jl") include("debug_utils.jl") include("deprecated.jl") + include("submodels.jl") end if GROUP == "All" || GROUP == "Group2" diff --git a/test/submodels.jl b/test/submodels.jl new file mode 100644 index 000000000..e79eed2c3 --- /dev/null +++ b/test/submodels.jl @@ -0,0 +1,199 @@ +module DPPLSubmodelTests + +using DynamicPPL +using Distributions +using Test + +@testset "submodels.jl" begin + @testset "$op with AbstractPPL API" for op in [condition, fix] + x_val = 1.0 + x_logp = op == condition ? logpdf(Normal(), x_val) : 0.0 + + @testset "Auto prefix" begin + @model function inner() + x ~ Normal() + y ~ Normal() + return (x, y) + end + @model function outer() + return a ~ to_submodel(inner()) + end + inner_op = op(inner(), (@varname(x) => x_val)) + @model function outer2() + return a ~ to_submodel(inner_op) + end + with_inner_op = outer2() + with_outer_op = op(outer(), (@varname(a.x) => x_val)) + + # No conditioning/fixing + @test Set(keys(VarInfo(outer()))) == Set([@varname(a.x), @varname(a.y)]) + + # With conditioning/fixing + models = [("inner", with_inner_op), ("outer", with_outer_op)] + @testset "$name" for (name, model) in models + # Test that the value was correctly set + @test model()[1] == x_val + # Test that the logp was correctly set + vi = VarInfo(model) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) + # Check the keys + @test Set(keys(VarInfo(model))) == Set([@varname(a.y)]) + end + end + + @testset "No prefix" begin + @model function inner() + x ~ Normal() + y ~ Normal() + return (x, y) + end + @model function outer() + return a ~ to_submodel(inner(), false) + end + @model function outer2() + return a ~ to_submodel(inner_op, false) + end + with_inner_op = outer2() + inner_op = op(inner(), (@varname(x) => x_val)) + with_outer_op = op(outer(), (@varname(x) => x_val)) + + # No conditioning/fixing + @test Set(keys(VarInfo(outer()))) == Set([@varname(x), @varname(y)]) + + # With conditioning/fixing + models = [("inner", with_inner_op), ("outer", with_outer_op)] + @testset "$name" for (name, model) in models + # Test that the value was correctly set + @test model()[1] == x_val + # Test that the logp was correctly set + vi = VarInfo(model) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) + # Check the keys + @test Set(keys(VarInfo(model))) == Set([@varname(y)]) + end + end + + @testset "Manual prefix" begin + @model function inner() + x ~ Normal() + y ~ Normal() + return (x, y) + end + @model function outer() + return a ~ to_submodel(prefix(inner(), :b), false) + end + inner_op = op(inner(), (@varname(x) => x_val)) + @model function outer2() + return a ~ to_submodel(prefix(inner_op, :b), false) + end + with_inner_op = outer2() + with_outer_op = op(outer(), (@varname(b.x) => x_val)) + + # No conditioning/fixing + @test Set(keys(VarInfo(outer()))) == Set([@varname(b.x), @varname(b.y)]) + + # With conditioning/fixing + models = [("inner", with_inner_op), ("outer", with_outer_op)] + @testset "$name" for (name, model) in models + # Test that the value was correctly set + @test model()[1] == x_val + # Test that the logp was correctly set + vi = VarInfo(model) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) + # Check the keys + @test Set(keys(VarInfo(model))) == Set([@varname(b.y)]) + end + end + + @testset "Complex prefixes" begin + mutable struct P + a::Float64 + b::Float64 + end + @model function f() + x = Vector{Float64}(undef, 1) + x[1] ~ Normal() + y ~ Normal() + return x[1] + end + @model function g() + p = P(1.0, 2.0) + p.a ~ to_submodel(f()) + p.b ~ Normal() + return (p.a, p.b) + end + expected_vns = Set([@varname(p.a.x[1]), @varname(p.a.y), @varname(p.b)]) + @test Set(keys(VarInfo(g()))) == expected_vns + + # Check that we can condition/fix on any of them from the outside + for vn in expected_vns + op_g = op(g(), (vn => 1.0)) + vi = VarInfo(op_g) + @test Set(keys(vi)) == symdiff(expected_vns, Set([vn])) + end + end + + @testset "Nested submodels" begin + @model function f() + x ~ Normal() + return y ~ Normal() + end + @model function g() + return _unused ~ to_submodel(prefix(f(), :b), false) + end + @model function h() + return a ~ to_submodel(g()) + end + + # No conditioning + vi = VarInfo(h()) + @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)]) + @test getlogp(vi) == + logpdf(Normal(), vi[@varname(a.b.x)]) + + logpdf(Normal(), vi[@varname(a.b.y)]) + + # Conditioning/fixing at the top level + op_h = op(h(), (@varname(a.b.x) => x_val)) + + # Conditioning/fixing at the second level + op_g = op(g(), (@varname(b.x) => x_val)) + @model function h2() + return a ~ to_submodel(op_g) + end + + # Conditioning/fixing at the very bottom + op_f = op(f(), (@varname(x) => x_val)) + @model function g2() + return _unused ~ to_submodel(prefix(op_f, :b), false) + end + @model function h3() + return a ~ to_submodel(g2()) + end + + models = [("top", op_h), ("middle", h2()), ("bottom", h3())] + @testset "$name" for (name, model) in models + vi = VarInfo(model) + @test Set(keys(vi)) == Set([@varname(a.b.y)]) + @test getlogp(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) + end + end + end + + @testset "conditioning via model arguments" begin + @model function f(x) + x ~ Normal() + return y ~ Normal() + end + @model function g(inner_x) + return a ~ to_submodel(f(inner_x)) + end + + vi = VarInfo(g(1.0)) + @test Set(keys(vi)) == Set([@varname(a.y)]) + + vi = VarInfo(g(missing)) + @test Set(keys(vi)) == Set([@varname(a.x), @varname(a.y)]) + end +end + +end