# Constraints Specification

`RxInfer.jl`

uses a macro called `@constraints`

from `GraphPPL`

to add extra constraints during the inference process. For details on using the `@constraints`

macro, you can check out the official documentation of GraphPPL.

## Background and example

Here we briefly cover the mathematical aspects of constraints specification. For additional information and relevant links, please refer to the Bethe Free Energy section. In essence, `RxInfer`

performs Variational Inference (via message passing) given specific constraints $\mathcal{Q}$:

\[q^* = \arg\min_{q(s) \in \mathcal{Q}}F[q](\hat{y}) = \mathbb{E}_{q(s)}\left[\log \frac{q(s)}{p(s, y=\hat{y})} \right]\,.\]

The `@model`

macro specifies generative model `p(s, y)`

where `s`

is a set of random variables and `y`

is a set of observations. In a nutshell the goal of probabilistic programming is to find `p(s|y)`

. `RxInfer`

approximates `p(s|y)`

with a proxy distribution `q(x)`

using KL divergence and Bethe Free Energy optimisation procedure. By default there are no extra factorization constraints on `q(s)`

and the optimal solution is `q(s) = p(s|y)`

.

For certain problems, it may be necessary to adjust the set of constraints $\mathcal{Q}$ (also known as the variational family of distributions) to either improve accuracy at the expense of computational resources or reduce accuracy to conserve computational resources. Sometimes, we are compelled to impose certain constraints because otherwise, the problem becomes too challenging to solve within a reasonable timeframe.

For instance, consider the following model:

```
using RxInfer
@model function iid_normal(y)
μ ~ Normal(mean = 0.0, variance = 1.0)
τ ~ Gamma(shape = 1.0, rate = 1.0)
y .~ Normal(mean = μ, precision = τ)
end
```

In this model, we characterize all observations in a dataset y as a `Normal`

distribution with mean `μ`

and precision `τ`

. It's reasonable to assume that the latent variables `μ`

and `τ`

are jointly independent, thereby rendering their joint posterior distribution as:

\[q(μ, τ) = q(μ)q(τ)\,.\]

If we would write the variational family of distribution for such an assumption, it would be expressed as:

\[\mathcal{Q} = \left\{ q : q(μ, τ) = q(μ)q(τ) \right\}\,.\]

We can express this constraint with the `@constraints`

macro:

```
constraints = @constraints begin
q(μ, τ) = q(μ)q(τ)
end
```

```
Constraints:
q(μ, τ) = q(μ)q(τ)
```

and use the created `constraints`

object to the `infer`

function:

```
# We need to specify initial marginals, since with the constraints
# the problem becomes inherently iterative (we could also specify initial for the `μ` instead)
init = @initialization begin
q(τ) = vague(Gamma)
end
result = infer(
model = iid_normal(),
# Sample data from mean `3.1415` and precision `2.7182`
data = (y = rand(NormalMeanPrecision(3.1415, 2.7182), 1000), ),
constraints = constraints,
initialization = init,
iterations = 25
)
```

```
Inference results:
Posteriors | available for (μ, τ)
```

```
println("Estimated mean of `μ` is ", mean(result.posteriors[:μ][end]), " with standard deviation ", std(result.posteriors[:μ][end]))
println("Estimated mean of `τ` is ", mean(result.posteriors[:τ][end]), " with standard deviation ", std(result.posteriors[:τ][end]))
```

```
Estimated mean of `μ` is 3.1282785803162447 with standard deviation 0.020092150335224317
Estimated mean of `τ` is 2.4761206627373396 with standard deviation 0.11062491279187155
```

We observe that the estimates tend to slightly deviate from what the real values are. This behavior is a known characteristic of inference with the aforementioned constraints, often referred to as *Mean Field* constraints.

## General syntax

You can use the `@constraints`

macro with either a regular Julia function or a single `begin ... end`

block. Both ways are valid, as shown below:

```
# `functional` style
@constraints function create_my_constraints()
q(μ, τ) = q(μ)q(τ)
end
# `block` style
myconstraints = @constraints begin
q(μ, τ) = q(μ)q(τ)
end
```

The function-based syntax can also take arguments, like this:

```
@constraints function make_constraints(mean_field)
# Specify mean-field only if the flag is `true`
if mean_field
q(μ, τ) = q(μ)q(τ)
end
end
myconstraints = make_constraints(true)
```

```
Constraints:
q(μ, τ) = q(μ)q(τ)
```

## Marginal and messages form constraints

To specify marginal or messages form constraints `@constraints`

macro uses `::`

operator (in somewhat similar way as Julia uses it for multiple dispatch type specification). Read more about available functional form constraints in the Built-In Functional Forms section.

As an example, the following constraint:

```
@constraints begin
q(x) :: PointMassFormConstraint()
end
```

```
Constraints:
q(x) :: PointMassFormConstraint()
```

indicates that the resulting marginal of the variable (or array of variables) named `x`

must be approximated with a `PointMass`

object. Message passing based algorithms compute posterior marginals as a normalized product of two colliding messages on corresponding edges of a factor graph. In a few words `q(x)::PointMassFormConstraint`

reads as:

\[\mathrm{approximate~} q(x) = \frac{\overrightarrow{\mu}(x)\overleftarrow{\mu}(x)}{\int \overrightarrow{\mu}(x)\overleftarrow{\mu}(x) \mathrm{d}x}\mathrm{~as~PointMass}\]

Sometimes it might be useful to set a functional form constraint on messages too. For example if it is essential to keep a specific Gaussian parametrisation or if some messages are intractable and need approximation. To set messages form constraint `@constraints`

macro uses `μ(...)`

instead of `q(...)`

:

```
@constraints begin
q(x) :: PointMassFormConstraint()
μ(x) :: SampleListFormConstraint(1000)
# it is possible to assign different form constraints on the same variable
# both for the marginal and for the messages
end
```

```
Constraints:
q(x) :: PointMassFormConstraint()
μ(x) :: SampleListFormConstraint(Random._GLOBAL_RNG(), AutoProposal(), BayesBase.BootstrapImportanceSampling())
```

`@constraints`

macro understands "stacked" form constraints. For example the following form constraint

```
@constraints begin
q(x) :: SampleListFormConstraint(1000) :: PointMassFormConstraint()
end
```

```
Constraints:
q(x) :: (SampleListFormConstraint(Random._GLOBAL_RNG(), AutoProposal(), BayesBase.BootstrapImportanceSampling()), PointMassFormConstraint())
```

indicates that the `q(x)`

first must be approximated with a `SampleList`

and in addition the result of this approximation should be approximated as a `PointMass`

.

Not all combinations of "stacked" form constraints are compatible between each other.

You can find more information about built-in functional form constraint in the Built-in Functional Forms section. In addition, the ReactiveMP library documentation explains the functional form interfaces and shows how to build a custom functional form constraint that is compatible with `RxInfer.jl`

and `ReactiveMP.jl`

inference engine.

## Factorization constraints on posterior distribution `q`

As has been mentioned above, inference may be not tractable for every model without extra factorization constraints. To circumvent this, `RxInfer.jl`

allows for extra factorization constraints, for example:

```
@constraints begin
q(x, y) = q(x)q(y)
end
```

```
Constraints:
q(x, y) = q(x)q(y)
```

specifies a so-called mean-field assumption on variables `x`

and `y`

in the model. Furthermore, if `x`

is an array of variables in our model we may induce extra mean-field assumption on `x`

in the following way.

```
@constraints begin
q(x) = q(x[begin])..q(x[end])
q(x, y) = q(x)q(y)
end
```

```
Constraints:
q(x) = q(x[(begin)..(end)])
q(x, y) = q(x)q(y)
```

These constraints specify a mean-field assumption between variables `x`

and `y`

(either single variable or collection of variables) and additionally specify mean-field assumption on variables $x_i$.

`@constraints`

macro does not support matrix-based collections of variables. E.g. it is not possible to write `q(x[begin, begin])..q(x[end, end])`

. Use `q(x[begin])..q(x[end])`

instead.

Read more about the `@constraints`

macro in the official documentation of GraphPPL

## Constraints in submodels

`RxInfer`

allows you to define your generative model hierarchically, using previously defined `@model`

modules as submodels in larger models. Because of this, users need to specify their constraints hierarchically as well to avoid ambiguities. Consider the following example:

```
@model function inner_inner(τ, y)
y ~ Normal(τ[1], τ[2])
end
@model function inner(θ, α)
β ~ Normal(0, 1)
α ~ Gamma(β, 1)
α ~ inner_inner(τ = θ)
end
@model function outer()
local w
for i = 1:5
w[i] ~ inner(θ = Gamma(1, 1))
end
y ~ inner(θ = w[2:3])
end
```

To access the variables in the submodels, we use the `for q in __submodel__`

syntax, which will allow us to specify constraints over variables in the context of an inner submodel:

```
@constraints begin
for q in inner
q(α) :: PointMassFormConstraint()
q(α, β) = q(α)q(β)
end
end
```

```
Constraints:
q(inner) =
q(α, β) = q(α)q(β)
q(α) :: PointMassFormConstraint()
```

Similarly, we can specify constraints over variables in the context of the innermost submodel by using the `for q in __submodel__`

syntax twice:

```
@constraints begin
for q in inner
for q in inner_inner
q(y, τ) = q(y)q(τ[1])q(τ[2])
end
q(α) :: PointMassFormConstraint()
q(α, β) = q(α)q(β)
end
end
```

```
Constraints:
q(inner) =
q(α, β) = q(α)q(β)
q(α) :: PointMassFormConstraint()
q(inner_inner) =
q(y, τ) = q(y)q(τ[1])q(τ[2])
```

The `for q in __submodel__`

applies the constraints specified in this code block to all instances of `__submodel__`

in the current context. If we want to apply constraints to a specific instance of a submodel, we can use the `for q in (__submodel__, __identifier__)`

syntax, where `__identifier__`

is a counter integer. For example, if we want to specify constraints on the first instance of `inner`

in our `outer`

model, we can do so with the following syntax:

```
@constraints begin
for q in (inner, 1)
q(α) :: PointMassFormConstraint()
q(α, β) = q(α)q(β)
end
end
```

```
Constraints:
q((inner, 1)) =
q(α, β) = q(α)q(β)
q(α) :: PointMassFormConstraint()
```

Factorization constraints specified in a context propagate to their child submodels. This means that we can specify factorization constraints over variables where the factor node that connects the two are in a submodel, without having to specify the factorization constraint in the submodel itself. For example, if we want to specify a factorization constraint between `w[2]`

and `w[3]`

in our `outer`

model, we can specify it in the context of `outer`

, and `RxInfer`

will recognize that these variables are connected through the `Normal`

node in the `inner_inner`

submodel:

```
@constraints begin
q(w) = q(w[begin])..q(w[end])
end
```

```
Constraints:
q(w) = q(w[(begin)..(end)])
```

## Default constraints

Sometimes, a submodel is used in multiple contexts, on multiple levels of hierarchy and in different submodels. In such cases, it becomes cumbersome to specify constraints for each instance of the submodel and track its usage throughout the model. To alleviate this, `RxInfer`

allows users to specify default constraints for a submodel. These constraints will be applied to all instances of the submodel unless overridden by specific constraints. To specify default constraints for a submodel, override the `GraphPPL.default_constraints`

function for the submodel:

```
RxInfer.GraphPPL.default_constraints(::typeof(inner)) = @constraints begin
q(α) :: PointMassFormConstraint()
q(α, β) = q(α)q(β)
end
```

More information can be found in the GraphPPL documentation.

## Prespecified constraints

`GraphPPL`

exports some prespecified constraints that can be used in the `@constraints`

macro, but these constraints can also be passed as top-level constraints in the `infer`

function. For example, to specify a mean-field assumption on all variables in the model, we can use the `MeanField`

constraint:

```
result = infer(
model = iid_normal(),
data = (y = rand(NormalMeanPrecision(3.1415, 2.7182), 1000), ),
constraints = MeanField(), # instead of using `@constraints` macro
initialization = init,
iterations = 25
)
```

```
Inference results:
Posteriors | available for (μ, τ)
```