# Static Inference

This guide explains how to use the `infer`

function for static datasets. We'll show how `RxInfer`

can estimate posterior beliefs given a set of observations. We'll use a simple Beta-Bernoulli model as an example, which has been covered in the Getting Started section, but keep in mind that these techniques can apply to any model.

Also read about Streaming Inference or checkout more complex examples.

## Model specification

Also read the Model Specification section.

In static inference, we want to update our prior beliefs about certain hidden states given some dataset. To achieve this, we include data as an argument in our model specification:

```
using RxInfer
@model function beta_bernoulli(y, a, b)
θ ~ Beta(a, b)
for i in 1:length(y)
y[i] ~ Bernoulli(θ)
end
end
```

In this model, we assume that `y`

is a collection of data points, and `a`

and `b`

are just numbers. To run inference in this model, we have to call the `infer`

function with the `data`

argument provided.

## Dataset of observations

For demonstration purposes, we will use hand crafted dataset:

```
using Distributions, StableRNGs
hidden_θ = 1 / 3.1415
distribution = Bernoulli(hidden_θ)
rng = StableRNG(43)
n_observations = 1_000
dataset = rand(rng, distribution, n_observations)
```

## Calling the inference procedure

Everything is ready to run inference in our simple model. In order to run inference with static dataset using the `infer`

function, we need to use the `data`

argument. The `data`

argument expects a `NamedTuple`

where keys correspond to the names of the model arguments. In our case the model arguments were `a`

, `b`

and `y`

. We treat `a`

and `b`

as hyperparameters and pass them directly to the model constructor and we treat `y`

as our observations, thus we pass it to the `data`

argument as follows:

```
results = infer(
model = beta_bernoulli(a = 1.0, b = 1.0),
data = (y = dataset, )
)
```

```
Inference results:
Posteriors | available for (θ)
```

`y`

inside the `@model`

specification is not the same data collection as provided in the `data`

argument. Inside the `@model`

, `y`

is a collection of nodes in the corresponding factor graph, but it will have exactly the same shape as the collection provided in the `data`

argument, hence we can use some basic Julia function, e.g. `length`

.

Note, that we could also pass `a`

and `b`

as data:

```
results = infer(
model = beta_bernoulli(),
data = (y = dataset, a = 1.0, b = 1.0)
)
```

```
Inference results:
Posteriors | available for (θ)
```

The `infer`

function, however, requires at least one data argument to be present in the supplied `data`

. The difference between *hyperparameters* and *observations* is purely semantic and should not have real influence on the result of the inference procedure.

The inference procedure uses **reactive message passing** protocol and may decide to optimize and precompute certain messages that use fixed hyperparameters, hence changing the order of computed messages. The order of computations may change the convergence properties for some complex models.

In case of inference with static datasets, the `infer`

function will return the `InferenceResult`

structure. This structure has the `.posteriors`

field, which is a `Dict`

like structure that maps names of latent states to their corresponding posteriors. For example:

`results.posteriors[:θ]`

`Beta{Float64}(α=336.0, β=666.0)`

`RxInfer.InferenceResult`

— Type`InferenceResult`

This structure is used as a return value from the `infer`

function for static datasets.

**Public Fields**

`posteriors`

:`Dict`

or`NamedTuple`

of 'random variable' - 'posterior' pairs. See the`returnvars`

argument for`infer`

.`predictions`

: (optional)`Dict`

or`NamedTuple`

of 'data variable' - 'prediction' pairs. See the`predictvars`

argument for`infer`

.`free_energy`

: (optional) An array of Bethe Free Energy values per VMP iteration. See the`free_energy`

argument for`infer`

.`model`

:`FactorGraphModel`

object reference.`error`

: (optional) A reference to an exception, that might have occurred during the inference. See the`catch_exception`

argument for`infer`

.

See also: `infer`

We can also visualize our posterior results with the `Plots.jl`

package. We used `Beta(a = 1.0, b = 1.0)`

as a prior, lets compare our prior and posterior beliefs:

```
using Plots
rθ = range(0, 1, length = 1000)
p = plot()
p = plot!(p, rθ, (x) -> pdf(Beta(1.0, 1.0), x), title="Prior", fillalpha=0.3, fillrange = 0, label="P(θ)", c=1,)
p = plot!(p, rθ, (x) -> pdf(results.posteriors[:θ], x), title="Posterior", fillalpha=0.3, fillrange = 0, label="P(θ|y)", c=3)
p = vline!(p, [ hidden_θ ], label = "Real (hidden) θ")
```

### Missing data points and predictions

```
result = infer(
model = beta_bernoulli(a = 1.0, b = 1.0),
data = (y = [ true, false, missing, true, false ], )
)
```

```
Inference results:
Posteriors | available for (θ)
Predictions | available for (y)
```

In principle, the entire dataset may consist of `missing`

entries:

```
result = infer(
model = beta_bernoulli(a = 1.0, b = 1.0),
data = (y = [ missing, missing, missing, missing, missing ], )
)
```

```
Inference results:
Posteriors | available for (θ)
Predictions | available for (y)
```

In this case, the resulting posterior is simply equal to the prior (as expected, since no extra information can be extracted from the observations):

`result.posteriors[:θ]`

`Beta{Float64}(α=1.0, β=1.0)`

In addition, in the presence of `missing`

data points `RxInfer`

will also attempt to compute predictive distributions:

`result.predictions[:y]`

```
5-element Vector{Bernoulli{Float64}}:
Bernoulli{Float64}(p=0.5)
Bernoulli{Float64}(p=0.5)
Bernoulli{Float64}(p=0.5)
Bernoulli{Float64}(p=0.5)
Bernoulli{Float64}(p=0.5)
```

```
# Sample y₃
rand(result.predictions[:y][3])
```

`false`

## Variational Inference with static datasets

The example above is quite simple and performs exact Bayesian inference. However, for more complex model, we may need to specify variational constraints and perform variational inference. To demonstrate this, we will use a slightly more complex model, where we need to estimate mean and the precision of IID samples drawn from the Normal distribution:

```
@model function iid_estimation(y)
μ ~ Normal(mean = 0.0, precision = 0.1)
τ ~ Gamma(shape = 1.0, rate = 1.0)
y .~ Normal(mean = μ, precision = τ)
end
```

In this model, we have two latent variables `μ`

and `τ`

and a set of observations `y`

. Note that we used the broadcasting syntax, which is roughly equivalent to the manual for loop shown in the previous example. Let's try to run the inference in this model, but first, we need to create our observations:

```
# `ExponentialFamily` package expors different parametrizations
# for the Normal distribution
using ExponentialFamily
hidden_μ = 3.1415
hidden_τ = 2.7182
distribution = NormalMeanPrecision(hidden_μ, hidden_τ)
rng = StableRNG(42)
n_observations = 1_000
dataset = rand(rng, distribution, n_observations)
```

And finally we run the inference procedure:

```
results = infer(
model = iid_estimation(),
data = (y = dataset, )
)
```

```
ERROR: Variables [ μ, τ ] have not been updated after an update event.
Therefore, make sure to initialize all required marginals and messages. See `initialization` keyword argument for the inference function.
See the official documentation for detailed information regarding the initialization.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
```

Huh? We get an error saying that the inference could not update the latent variables. This is happened because our model contain loops in its structure, therefore it requires the initialization. Read more about the initialization in the corresponding section in the documentation.

We have two options here, either we initialize the messages and perform Loopy Belief Propagation in this model or we break the loops with variational constraints and perform variational inference. In this tutorial, we will choose the second option. For this we need to specify factorization constraints with the `@constraints`

macro.

```
# Specify mean-field constraint over the joint variational posterior
constraints = @constraints begin
q(μ, τ) = q(μ)q(τ)
end
# Specify initial posteriors for variational iterations
initialization = @initialization begin
q(μ) = vague(NormalMeanPrecision)
q(τ) = vague(GammaShapeRate)
end
```

With this, we can use the `constraints`

and `initialization`

keyword arguments in the `infer`

function. We also specify the number of variational iterations with the `iterations`

keyword argument:

```
results = infer(
model = iid_estimation(),
data = (y = dataset, ),
constraints = constraints,
iterations = 100,
initialization = initialization
)
```

```
Inference results:
Posteriors | available for (μ, τ)
```

Nice! Now, we have some result. Let's for example inspect the posterior results for `μ`

.

`results.posteriors[:μ]`

```
100-element Vector{NormalWeightedMeanPrecision{Float64}}:
NormalWeightedMeanPrecision{Float64}(xi=3.1276134803795e-9, w=0.10000000100200404)
NormalWeightedMeanPrecision{Float64}(xi=155.36497103500977, w=49.87459713359878)
NormalWeightedMeanPrecision{Float64}(xi=7667.879879783272, w=2456.6745376352883)
NormalWeightedMeanPrecision{Float64}(xi=8056.651469322076, w=2581.226095926443)
NormalWeightedMeanPrecision{Float64}(xi=8057.059156601909, w=2581.3567075436076)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563495237, w=2581.3568379008652)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901298, w=2581.356838030949)
NormalWeightedMeanPrecision{Float64}(xi=8057.0595639017065, w=2581.356838031086)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
⋮
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
```

In constrast to the previous example, now we have an array of posteriors for `μ`

, not just a single value. Each posterior in the collection corresponds to the intermediate variational update for each variational iteration. Let's visualize how our posterior over `μ`

has been changing during the variational optimization:

```
@gif for (i, intermediate_posterior) in enumerate(results.posteriors[:μ])
rμ = range(0, 5, length = 1000)
plot(rμ, (x) -> pdf(intermediate_posterior, x), title="Posterior on iteration $(i)", fillalpha=0.3, fillrange = 0, label="P(μ|y)", c=3)
vline!([hidden_μ], label = "Real (hidden) μ")
end
```

It seems that the posterior has converged to a stable distribution pretty fast. We are going to verify the converge in the next section. If, for example, we are not interested in intermediate updates, but just in the final posterior, we could use the `returnvars`

option in the `infer`

function and use the `KeepLast`

option for `μ`

:

```
results_keep_last = infer(
model = iid_estimation(),
data = (y = dataset, ),
constraints = constraints,
iterations = 100,
returnvars = (μ = KeepLast(), ),
initialization = initialization
)
```

```
Inference results:
Posteriors | available for (μ)
```

We can also verify that the got exactly the same result:

`results_keep_last.posteriors[:μ] == last(results.posteriors[:μ])`

`true`

Let's also verify that the posteriors are consistent with the real hidden values used in the dataset generation:

```
println("Real (hidden) μ was ", hidden_μ)
println("Inferred mean for μ is ", mean(last(results.posteriors[:μ])), " with standard deviation ", std(last(results.posteriors[:μ])))
println("Real (hidden) τ was ", hidden_τ)
println("Inferred mean for τ is ", mean(last(results.posteriors[:τ])), " with standard deviation ", std(last(results.posteriors[:τ])))
```

```
Real (hidden) μ was 3.1415
Inferred mean for μ is 3.121249819163777 with standard deviation 0.019682305930741596
Real (hidden) τ was 2.7182
Inferred mean for τ is 2.5812568380310656 with standard deviation 0.11532205069721162
```

```
rμ = range(2, 4, length = 1000)
pμ = plot(rμ, (x) -> pdf(last(results.posteriors[:μ]), x), title="Posterior for μ", fillalpha=0.3, fillrange = 0, label="P(μ|y)", c=3)
pμ = vline!(pμ, [ hidden_μ ], label = "Real (hidden) μ")
rτ = range(2, 4, length = 1000)
pτ = plot(rτ, (x) -> pdf(last(results.posteriors[:τ]), x), title="Posterior for τ", fillalpha=0.3, fillrange = 0, label="P(τ|y)", c=3)
pτ = vline!(pτ, [ hidden_τ ], label = "Real (hidden) τ")
plot(pμ, pτ)
```

Nice result! Our posteriors are pretty close to the actual values of the parameters used for dataset generation.

## Convergence and Bethe Free Energy

Read also the Bethe Free Energy section.

In contrast to Loopy Belief Propagation, the variational inference is set to converge to a stable point during variational inference. In order to verify the convergence for this particular model, we can check the convergence of the Bethe Free Enegrgy values. By default, `infer`

function does **not** compute the Bethe Free Energy values. In order to compute those, we must set the `free_energy`

flag explicitly to `true`

:

```
results = infer(
model = iid_estimation(),
data = (y = dataset, ),
constraints = constraints,
iterations = 100,
initialization = initialization,
free_energy = true
)
```

```
Inference results:
Posteriors | available for (μ, τ)
Free Energy: | Real[14763.3, 2437.5, 952.711, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108 … 952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108]
```

Now, we can access the `free_energy`

field of the `results`

and verify if the inference procedure has converged or not:

`plot(results.free_energy, label = "Bethe Free Energy")`

Well, it seems that `100`

iterations was too much for this simple problem and we could do much less iterations in order to converge to a stable point. The animation above also suggested that the posterior for `μ`

has converged pretty fast to a stable point.

```
# Let's try to use only 5 iterations
results = infer(
model = iid_estimation(),
data = (y = dataset, ),
constraints = constraints,
iterations = 5,
initialization = initialization,
free_energy = true
)
```

```
Inference results:
Posteriors | available for (μ, τ)
Free Energy: | Real[14763.3, 2437.5, 952.711, 952.108, 952.108]
```

`plot(results.free_energy, label = "Bethe Free Energy")`

## Callbacks

The `infer`

function has its own lifecycle, consisting of multiple steps. A user is free to inject some extra logic during the inference procedure, e.g. for debugging purposes. By supplying callbacks, users can inject custom logic on specific moments during the inference procedure. Here are available callbacks that can be used together with the static datasets:

`before_model_creation()`

Calls before the model is going to be created, does not accept any arguments.

`after_model_creation(model::ProbabilisticModel)`

Calls right after the model has been created, accepts a single argument, the `model`

.

`before_inference(model::ProbabilisticModel)`

Calls before the inference procedure starts, accepts a single argument, the `model`

.

`after_inference(model::ProbabilisticModel)`

Calls after the inference procedure ends, accepts a single argument, the `model`

.

`before_iteration(model::ProbabilisticModel, iteration::Int)`

Calls before each iteration, accepts two arguments: the `model`

and the current iteration number.

`after_iteration(model::ProbabilisticModel, iteration::Int)`

Calls after each iteration, accepts two arguments: the `model`

and the current iteration number.

`before_data_update(model::ProbabilisticModel, data)`

Calls before each data update, accepts two arguments: the `model`

and the updated data.

`after_data_update(model::ProbabilisticModel, data)`

Calls after each data update, accepts two arguments: the `model`

and the updated data.

`on_marginal_update(model::ProbabilisticModel, name, update)`

Calls after each marginal update, accepts three arguments: the `model`

, the name of the updated marginal, and the updated marginal itself.

Here is an example usage of the outlined callbacks:

```
function before_model_creation()
println("The model is about to be created")
end
function after_model_creation(model::ProbabilisticModel)
println("The model has been created")
println(" The number of factor nodes is: ", length(RxInfer.getfactornodes(model)))
println(" The number of latent states is: ", length(RxInfer.getrandomvars(model)))
println(" The number of data points is: ", length(RxInfer.getdatavars(model)))
println(" The number of constants is: ", length(RxInfer.getconstantvars(model)))
end
function before_inference(model::ProbabilisticModel)
println("The inference procedure is about to start")
end
function after_inference(model::ProbabilisticModel)
println("The inference procedure has ended")
end
function before_iteration(model::ProbabilisticModel, iteration::Int)
println("The iteration ", iteration, " is about to start")
end
function after_iteration(model::ProbabilisticModel, iteration::Int)
println("The iteration ", iteration, " has ended")
end
function before_data_update(model::ProbabilisticModel, data)
println("The data is about to be processed")
end
function after_data_update(model::ProbabilisticModel, data)
println("The data has been processed")
end
function on_marginal_update(model::ProbabilisticModel, name, update)
println("New marginal update for ", name, " ", update)
end
```

`on_marginal_update (generic function with 1 method)`

```
results = infer(
model = iid_estimation(),
data = (y = dataset, ),
constraints = constraints,
iterations = 5,
initialization = initialization,
free_energy = true,
callbacks = (
before_model_creation = before_model_creation,
after_model_creation = after_model_creation,
before_inference = before_inference,
after_inference = after_inference,
before_iteration = before_iteration,
after_iteration = after_iteration,
before_data_update = before_data_update,
after_data_update = after_data_update,
on_marginal_update = on_marginal_update
)
)
```

```
The model is about to be created
The model has been created
The number of factor nodes is: 1002
The number of latent states is: 2
The number of data points is: 1000
The number of constants is: 4
The inference procedure is about to start
The iteration 1 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=5.000000000050656e14))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=3.1276134803795e-9, w=0.10000000100200404))
The data has been processed
The iteration 1 has ended
The iteration 2 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=10065.375288830106))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=155.36497103500977, w=49.87459713359878))
The data has been processed
The iteration 2 has ended
The iteration 3 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=203.94251927819624))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=7667.879879783272, w=2456.6745376352883))
The data has been processed
The iteration 3 has ended
The iteration 4 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=194.10132685523482))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=8056.651469322076, w=2581.226095926443))
The data has been processed
The iteration 4 has ended
The iteration 5 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=194.09150532601225))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=8057.059156601909, w=2581.3567075436076))
The data has been processed
The iteration 5 has ended
The inference procedure has ended
```

## Where to go next?

This guide covered some fundamental usages of the `infer`

function in the context of inference with static datasets, but did not cover all the available keyword arguments of the function. Read more explanation about the other keyword arguments in the Overview section or check out the Streaming Inference section. Also check out more complex examples.