Constraints Specification
RxInfer.jl
uses a macro called @constraints
from GraphPPL
to add extra constraints during the inference process. For details on using the @constraints
macro, you can check out the official documentation of GraphPPL.
Background and example
Here we briefly cover the mathematical aspects of constraints specification. For additional information and relevant links, please refer to the Bethe Free Energy section. In essence, RxInfer
performs Variational Inference (via message passing) given specific constraints $\mathcal{Q}$:
\[q^* = \arg\min_{q(s) \in \mathcal{Q}}F[q](\hat{y}) = \mathbb{E}_{q(s)}\left[\log \frac{q(s)}{p(s, y=\hat{y})} \right]\,.\]
The @model
macro specifies generative model p(s, y)
where s
is a set of random variables and y
is a set of observations. In a nutshell the goal of probabilistic programming is to find p(s|y)
. RxInfer
approximates p(s|y)
with a proxy distribution q(x)
using KL divergence and Bethe Free Energy optimisation procedure. By default there are no extra factorization constraints on q(s)
and the optimal solution is q(s) = p(s|y)
.
For certain problems, it may be necessary to adjust the set of constraints $\mathcal{Q}$ (also known as the variational family of distributions) to either improve accuracy at the expense of computational resources or reduce accuracy to conserve computational resources. Sometimes, we are compelled to impose certain constraints because otherwise, the problem becomes too challenging to solve within a reasonable timeframe.
For instance, consider the following model:
using RxInfer
@model function iid_normal(y)
μ ~ Normal(mean = 0.0, variance = 1.0)
τ ~ Gamma(shape = 1.0, rate = 1.0)
y .~ Normal(mean = μ, precision = τ)
end
In this model, we characterize all observations in a dataset y as a Normal
distribution with mean μ
and precision τ
. It's reasonable to assume that the latent variables μ
and τ
are jointly independent, thereby rendering their joint posterior distribution as:
\[q(μ, τ) = q(μ)q(τ)\,.\]
If we would write the variational family of distribution for such an assumption, it would be expressed as:
\[\mathcal{Q} = \left\{ q : q(μ, τ) = q(μ)q(τ) \right\}\,.\]
We can express this constraint with the @constraints
macro:
constraints = @constraints begin
q(μ, τ) = q(μ)q(τ)
end
Constraints:
q(μ, τ) = q(μ)q(τ)
and use the created constraints
object to the infer
function:
# We need to specify initial marginals, since with the constraints
# the problem becomes inherently iterative (we could also specify initial for the `μ` instead)
init = @initialization begin
q(τ) = vague(Gamma)
end
result = infer(
model = iid_normal(),
# Sample data from mean `3.1415` and precision `2.7182`
data = (y = rand(NormalMeanPrecision(3.1415, 2.7182), 1000), ),
constraints = constraints,
initialization = init,
iterations = 25
)
Inference results:
Posteriors | available for (μ, τ)
println("Estimated mean of `μ` is ", mean(result.posteriors[:μ][end]), " with standard deviation ", std(result.posteriors[:μ][end]))
println("Estimated mean of `τ` is ", mean(result.posteriors[:τ][end]), " with standard deviation ", std(result.posteriors[:τ][end]))
Estimated mean of `μ` is 3.126979569345421 with standard deviation 0.019526121605635525
Estimated mean of `τ` is 2.6218171852751597 with standard deviation 0.11713415337225881
We observe that the estimates tend to slightly deviate from what the real values are. This behavior is a known characteristic of inference with the aforementioned constraints, often referred to as Mean Field constraints.
General syntax
You can use the @constraints
macro with either a regular Julia function or a single begin ... end
block. Both ways are valid, as shown below:
# `functional` style
@constraints function create_my_constraints()
q(μ, τ) = q(μ)q(τ)
end
# `block` style
myconstraints = @constraints begin
q(μ, τ) = q(μ)q(τ)
end
The function-based syntax can also take arguments, like this:
@constraints function make_constraints(mean_field)
# Specify mean-field only if the flag is `true`
if mean_field
q(μ, τ) = q(μ)q(τ)
end
end
myconstraints = make_constraints(true)
Constraints:
q(μ, τ) = q(μ)q(τ)
Marginal and messages form constraints
To specify marginal or messages form constraints @constraints
macro uses ::
operator (in somewhat similar way as Julia uses it for multiple dispatch type specification). Read more about available functional form constraints in the Built-In Functional Forms section.
As an example, the following constraint:
@constraints begin
q(x) :: PointMassFormConstraint()
end
Constraints:
q(x) :: PointMassFormConstraint()
indicates that the resulting marginal of the variable (or array of variables) named x
must be approximated with a PointMass
object. Message passing based algorithms compute posterior marginals as a normalized product of two colliding messages on corresponding edges of a factor graph. In a few words q(x)::PointMassFormConstraint
reads as:
\[\mathrm{approximate~} q(x) = \frac{\overrightarrow{\mu}(x)\overleftarrow{\mu}(x)}{\int \overrightarrow{\mu}(x)\overleftarrow{\mu}(x) \mathrm{d}x}\mathrm{~as~PointMass}\]
Sometimes it might be useful to set a functional form constraint on messages too. For example if it is essential to keep a specific Gaussian parametrisation or if some messages are intractable and need approximation. To set messages form constraint @constraints
macro uses μ(...)
instead of q(...)
:
@constraints begin
q(x) :: PointMassFormConstraint()
μ(x) :: SampleListFormConstraint(1000)
# it is possible to assign different form constraints on the same variable
# both for the marginal and for the messages
end
Constraints:
q(x) :: PointMassFormConstraint()
μ(x) :: SampleListFormConstraint(Random._GLOBAL_RNG(), AutoProposal(), BayesBase.BootstrapImportanceSampling())
@constraints
macro understands "stacked" form constraints. For example the following form constraint
@constraints begin
q(x) :: SampleListFormConstraint(1000) :: PointMassFormConstraint()
end
Constraints:
q(x) :: (SampleListFormConstraint(Random._GLOBAL_RNG(), AutoProposal(), BayesBase.BootstrapImportanceSampling()), PointMassFormConstraint())
indicates that the q(x)
first must be approximated with a SampleList
and in addition the result of this approximation should be approximated as a PointMass
.
Not all combinations of "stacked" form constraints are compatible between each other.
You can find more information about built-in functional form constraint in the Built-in Functional Forms section. In addition, the ReactiveMP library documentation explains the functional form interfaces and shows how to build a custom functional form constraint that is compatible with RxInfer.jl
and ReactiveMP.jl
inference engine.
Factorization constraints on posterior distribution q
As has been mentioned above, inference may be not tractable for every model without extra factorization constraints. To circumvent this, RxInfer.jl
allows for extra factorization constraints, for example:
@constraints begin
q(x, y) = q(x)q(y)
end
Constraints:
q(x, y) = q(x)q(y)
specifies a so-called mean-field assumption on variables x
and y
in the model. Furthermore, if x
is an array of variables in our model we may induce extra mean-field assumption on x
in the following way.
@constraints begin
q(x) = q(x[begin])..q(x[end])
q(x, y) = q(x)q(y)
end
Constraints:
q(x) = q(x[(begin)..(end)])
q(x, y) = q(x)q(y)
These constraints specify a mean-field assumption between variables x
and y
(either single variable or collection of variables) and additionally specify mean-field assumption on variables $x_i$.
@constraints
macro does not support matrix-based collections of variables. E.g. it is not possible to write q(x[begin, begin])..q(x[end, end])
. Use q(x[begin])..q(x[end])
instead.
Read more about the @constraints
macro in the official documentation of GraphPPL
Constraints in submodels
RxInfer
allows you to define your generative model hierarchically, using previously defined @model
modules as submodels in larger models. Because of this, users need to specify their constraints hierarchically as well to avoid ambiguities. Consider the following example:
@model function inner_inner(τ, y)
y ~ Normal(mean = τ[1], var = τ[2])
end
@model function inner(θ, α)
β ~ Normal(mean = 0.0, var = 1.0)
α ~ Gamma(shape = β, rate = 1.0)
α ~ inner_inner(τ = θ)
end
@model function outer()
local w
for i = 1:5
w[i] ~ inner(θ = Gamma(shape = 1.0, rate = 1.0))
end
y ~ inner(θ = w[2:3])
end
To access the variables in the submodels, we use the for q in __submodel__
syntax, which will allow us to specify constraints over variables in the context of an inner submodel:
@constraints begin
for q in inner
q(α) :: PointMassFormConstraint()
q(α, β) = q(α)q(β)
end
end
Constraints:
q(inner) =
q(α, β) = q(α)q(β)
q(α) :: PointMassFormConstraint()
Similarly, we can specify constraints over variables in the context of the innermost submodel by using the for q in __submodel__
syntax twice:
@constraints begin
for q in inner
for q in inner_inner
q(y, τ) = q(y)q(τ[1])q(τ[2])
end
q(α) :: PointMassFormConstraint()
q(α, β) = q(α)q(β)
end
end
Constraints:
q(inner) =
q(α, β) = q(α)q(β)
q(α) :: PointMassFormConstraint()
q(inner_inner) =
q(y, τ) = q(y)q(τ[1])q(τ[2])
The for q in __submodel__
applies the constraints specified in this code block to all instances of __submodel__
in the current context. If we want to apply constraints to a specific instance of a submodel, we can use the for q in (__submodel__, __identifier__)
syntax, where __identifier__
is a counter integer. For example, if we want to specify constraints on the first instance of inner
in our outer
model, we can do so with the following syntax:
@constraints begin
for q in (inner, 1)
q(α) :: PointMassFormConstraint()
q(α, β) = q(α)q(β)
end
end
Constraints:
q((inner, 1)) =
q(α, β) = q(α)q(β)
q(α) :: PointMassFormConstraint()
Factorization constraints specified in a context propagate to their child submodels. This means that we can specify factorization constraints over variables where the factor node that connects the two are in a submodel, without having to specify the factorization constraint in the submodel itself. For example, if we want to specify a factorization constraint between w[2]
and w[3]
in our outer
model, we can specify it in the context of outer
, and RxInfer
will recognize that these variables are connected through the Normal
node in the inner_inner
submodel:
@constraints begin
q(w) = q(w[begin])..q(w[end])
end
Constraints:
q(w) = q(w[(begin)..(end)])
Default constraints
Sometimes, a submodel is used in multiple contexts, on multiple levels of hierarchy and in different submodels. In such cases, it becomes cumbersome to specify constraints for each instance of the submodel and track its usage throughout the model. To alleviate this, RxInfer
allows users to specify default constraints for a submodel. These constraints will be applied to all instances of the submodel unless overridden by specific constraints. To specify default constraints for a submodel, override the GraphPPL.default_constraints
function for the submodel:
RxInfer.GraphPPL.default_constraints(::typeof(inner)) = @constraints begin
q(α) :: PointMassFormConstraint()
q(α, β) = q(α)q(β)
end
More information can be found in the GraphPPL documentation.
Constraints on the data
By default, RxInfer
assumes that, since the data comes into the model as observed, the posterior marginal distribution of the data is independent from other marginals and is a Dirac-delta distribution. However, this assumption breaks when we pass missing data into our model. When the data is missing, we might have a joint dependency between the data and latent variables, as the missing data essentially behaves as a latent variable. In such cases, we can wrap the data in a UnfactorizedData
. This will notify the inference engine that the data should not be factorized out and we can specify a custom factorization constraint on these variables using the @constraints
macro.
RxInfer.UnfactorizedData
— TypeUnfactorizedData{D}
A wrapper struct to wrap data that should not be factorized out by default during inference. When performing Bayesian Inference with message passing, every factor node contains a local factorization constraint on the variational posterior distribution. For data, we usually regarding data as an independent component in the variational posterior distribution. However, in some cases, for example when we are predicting data, we do not want to factorize out the data. In such cases, we can wrap the data with UnfactorizedData
struct to prevent the factorization and craft a custom node-local factorization with the @constraints
macro.
unfactorized_example_constraints = @constraints begin
q(y[1:1000], μ, τ) = q(y[1:1000])q(μ)q(τ)
q(y[1001:1100], μ, τ) = q(y[1001:1100], μ)q(τ)
end
result = infer(
model = iid_normal(),
data = (y = UnfactorizedData(vcat(rand(NormalMeanPrecision(3.1415, 2.7182), 1000), [missing for _ in 1:100])),),
constraints = unfactorized_example_constraints,
initialization = init,
iterations = 25
)
Inference results:
Posteriors | available for (μ, τ)
Predictions | available for (y)
Prespecified constraints
GraphPPL
exports some prespecified constraints that can be used in the @constraints
macro, but these constraints can also be passed as top-level constraints in the infer
function. For example, to specify a mean-field assumption on all variables in the model, we can use the MeanField
constraint:
result = infer(
model = iid_normal(),
data = (y = rand(NormalMeanPrecision(3.1415, 2.7182), 1000), ),
constraints = MeanField(), # instead of using `@constraints` macro
initialization = init,
iterations = 25
)
Inference results:
Posteriors | available for (μ, τ)