Static Inference

This guide explains how to use the infer function for static datasets. We'll show how RxInfer can estimate posterior beliefs given a set of observations. We'll use a simple Beta-Bernoulli model as an example, which has been covered in the Getting Started section, but keep in mind that these techniques can apply to any model.

Also read about Streaming Inference or checkout more complex examples.

Model specification

Also read the Model Specification section.

In static inference, we want to update our prior beliefs about certain hidden states given some dataset. To achieve this, we include data as an argument in our model specification:

using RxInfer

@model function beta_bernoulli(y, a, b)
    θ ~ Beta(a, b)
    for i in 1:length(y)
        y[i] ~ Bernoulli(θ)
    end
end

In this model, we assume that y is a collection of data points, and a and b are just numbers. To run inference in this model, we have to call the infer function with the data argument provided.

Dataset of observations

For demonstration purposes, we will use hand crafted dataset:

using Distributions, StableRNGs

hidden_θ       = 1 / 3.1415
distribution   = Bernoulli(hidden_θ)
rng            = StableRNG(43)
n_observations = 1_000
dataset        = rand(rng, distribution, n_observations)

Calling the inference procedure

Everything is ready to run inference in our simple model. In order to run inference with static dataset using the infer function, we need to use the data argument. The data argument expects a NamedTuple where keys correspond to the names of the model arguments. In our case the model arguments were a, b and y. We treat a and b as hyperparameters and pass them directly to the model constructor and we treat y as our observations, thus we pass it to the data argument as follows:

results = infer(
    model = beta_bernoulli(a = 1.0, b = 1.0),
    data  = (y = dataset, )
)
Inference results:
  Posteriors       | available for (θ)
Note

y inside the @model specification is not the same data collection as provided in the data argument. Inside the @model, y is a collection of nodes in the corresponding factor graph, but it will have exactly the same shape as the collection provided in the data argument, hence we can use some basic Julia function, e.g. length.

Note, that we could also pass a and b as data:

results = infer(
    model = beta_bernoulli(),
    data  = (y = dataset, a = 1.0, b = 1.0)
)
Inference results:
  Posteriors       | available for (θ)

The infer function, however, requires at least one data argument to be present in the supplied data. The difference between hyperparameters and observations is purely semantic and should not have real influence on the result of the inference procedure.

Note

The inference procedure uses reactive message passing protocol and may decide to optimize and precompute certain messages that use fixed hyperparameters, hence changing the order of computed messages. The order of computations may change the convergence properties for some complex models.

In case of inference with static datasets, the infer function will return the InferenceResult structure. This structure has the .posteriors field, which is a Dict like structure that maps names of latent states to their corresponding posteriors. For example:

results.posteriors[:θ]
Beta{Float64}(α=336.0, β=666.0)
RxInfer.InferenceResultType
InferenceResult

This structure is used as a return value from the infer function for static datasets.

Public Fields

  • posteriors: Dict or NamedTuple of 'random variable' - 'posterior' pairs. See the returnvars argument for infer.
  • predictions: (optional) Dict or NamedTuple of 'data variable' - 'prediction' pairs. See the predictvars argument for infer.
  • free_energy: (optional) An array of Bethe Free Energy values per VMP iteration. See the free_energy argument for infer.
  • model: FactorGraphModel object reference.
  • error: (optional) A reference to an exception, that might have occurred during the inference. See the catch_exception argument for infer.

See also: infer

source

We can also visualize our posterior results with the Plots.jl package. We used Beta(a = 1.0, b = 1.0) as a prior, lets compare our prior and posterior beliefs:

using Plots

rθ = range(0, 1, length = 1000)

p = plot()
p = plot!(p, rθ, (x) -> pdf(Beta(1.0, 1.0), x), title="Prior", fillalpha=0.3, fillrange = 0, label="P(θ)", c=1,)
p = plot!(p, rθ, (x) -> pdf(results.posteriors[:θ], x), title="Posterior", fillalpha=0.3, fillrange = 0, label="P(θ|y)", c=3)
p = vline!(p, [ hidden_θ ], label = "Real (hidden) θ")
Example block output

Missing data points and predictions

result = infer(
    model = beta_bernoulli(a = 1.0, b = 1.0),
    data  = (y = [ true, false, missing, true, false ], )
)
Inference results:
  Posteriors       | available for (θ)
  Predictions      | available for (y)

In principle, the entire dataset may consist of missing entries:

result = infer(
    model = beta_bernoulli(a = 1.0, b = 1.0),
    data  = (y = [ missing, missing, missing, missing, missing ], )
)
Inference results:
  Posteriors       | available for (θ)
  Predictions      | available for (y)

In this case, the resulting posterior is simply equal to the prior (as expected, since no extra information can be extracted from the observations):

result.posteriors[:θ]
Beta{Float64}(α=1.0, β=1.0)

In addition, in the presence of missing data points RxInfer will also attempt to compute predictive distributions:

result.predictions[:y]
5-element Vector{Bernoulli{Float64}}:
 Bernoulli{Float64}(p=0.5)
 Bernoulli{Float64}(p=0.5)
 Bernoulli{Float64}(p=0.5)
 Bernoulli{Float64}(p=0.5)
 Bernoulli{Float64}(p=0.5)
# Sample y₃
rand(result.predictions[:y][3])
true

Variational Inference with static datasets

The example above is quite simple and performs exact Bayesian inference. However, for more complex model, we may need to specify variational constraints and perform variational inference. To demonstrate this, we will use a slightly more complex model, where we need to estimate mean and the precision of IID samples drawn from the Normal distribution:

@model function iid_estimation(y)
    μ  ~ Normal(mean = 0.0, precision = 0.1)
    τ  ~ Gamma(shape = 1.0, rate = 1.0)
    y .~ Normal(mean = μ, precision = τ)
end

In this model, we have two latent variables μ and τ and a set of observations y. Note that we used the broadcasting syntax, which is roughly equivalent to the manual for loop shown in the previous example. Let's try to run the inference in this model, but first, we need to create our observations:

# `ExponentialFamily` package expors different parametrizations
# for the Normal distribution
using ExponentialFamily

hidden_μ       = 3.1415
hidden_τ       = 2.7182
distribution   = NormalMeanPrecision(hidden_μ, hidden_τ)
rng            = StableRNG(42)
n_observations = 1_000
dataset        = rand(rng, distribution, n_observations)

And finally we run the inference procedure:

results = infer(
    model = iid_estimation(),
    data  = (y = dataset, )
)
ERROR: Variables [ μ, τ ] have not been updated after an update event. 
Therefore, make sure to initialize all required marginals and messages. See `initialization` keyword argument for the inference function. 
See the official documentation for detailed information regarding the initialization.

Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35

Huh? We get an error saying that the inference could not update the latent variables. This is happened because our model contain loops in its structure, therefore it requires the initialization. Read more about the initialization in the corresponding section in the documentation.

We have two options here, either we initialize the messages and perform Loopy Belief Propagation in this model or we break the loops with variational constraints and perform variational inference. In this tutorial, we will choose the second option. For this we need to specify factorization constraints with the @constraints macro.

# Specify mean-field constraint over the joint variational posterior
constraints = @constraints begin
    q(μ, τ) = q(μ)q(τ)
end
# Specify initial posteriors for variational iterations
initialization = @initialization begin
    q(μ) = vague(NormalMeanPrecision)
    q(τ) = vague(GammaShapeRate)
end

With this, we can use the constraints and initialization keyword arguments in the infer function. We also specify the number of variational iterations with the iterations keyword argument:

results = infer(
    model          = iid_estimation(),
    data           = (y = dataset, ),
    constraints    = constraints,
    iterations     = 100,
    initialization = initialization
)
Inference results:
  Posteriors       | available for (μ, τ)

Nice! Now, we have some result. Let's for example inspect the posterior results for μ.

results.posteriors[:μ]
100-element Vector{NormalWeightedMeanPrecision{Float64}}:
 NormalWeightedMeanPrecision{Float64}(xi=3.1276134803795e-9, w=0.10000000100200404)
 NormalWeightedMeanPrecision{Float64}(xi=155.36497103500977, w=49.87459713359878)
 NormalWeightedMeanPrecision{Float64}(xi=7667.879879783272, w=2456.6745376352883)
 NormalWeightedMeanPrecision{Float64}(xi=8056.651469322076, w=2581.226095926443)
 NormalWeightedMeanPrecision{Float64}(xi=8057.059156601909, w=2581.3567075436076)
 NormalWeightedMeanPrecision{Float64}(xi=8057.059563495237, w=2581.3568379008652)
 NormalWeightedMeanPrecision{Float64}(xi=8057.059563901298, w=2581.356838030949)
 NormalWeightedMeanPrecision{Float64}(xi=8057.0595639017065, w=2581.356838031086)
 NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
 NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
 ⋮
 NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
 NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
 NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
 NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
 NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
 NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
 NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
 NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
 NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)

In constrast to the previous example, now we have an array of posteriors for μ, not just a single value. Each posterior in the collection corresponds to the intermediate variational update for each variational iteration. Let's visualize how our posterior over μ has been changing during the variational optimization:

@gif for (i, intermediate_posterior) in enumerate(results.posteriors[:μ])
    rμ = range(0, 5, length = 1000)
    plot(rμ, (x) -> pdf(intermediate_posterior, x), title="Posterior on iteration $(i)", fillalpha=0.3, fillrange = 0, label="P(μ|y)", c=3)
    vline!([hidden_μ], label = "Real (hidden) μ")
end
Example block output

It seems that the posterior has converged to a stable distribution pretty fast. We are going to verify the converge in the next section. If, for example, we are not interested in intermediate updates, but just in the final posterior, we could use the returnvars option in the infer function and use the KeepLast option for μ:

results_keep_last = infer(
    model          = iid_estimation(),
    data           = (y = dataset, ),
    constraints    = constraints,
    iterations     = 100,
    returnvars     = (μ = KeepLast(), ),
    initialization = initialization
)
Inference results:
  Posteriors       | available for (μ)

We can also verify that the got exactly the same result:

results_keep_last.posteriors[:μ] == last(results.posteriors[:μ])
true

Let's also verify that the posteriors are consistent with the real hidden values used in the dataset generation:

println("Real (hidden) μ was ", hidden_μ)
println("Inferred mean for μ is ", mean(last(results.posteriors[:μ])), " with standard deviation ", std(last(results.posteriors[:μ])))

println("Real (hidden) τ was ", hidden_τ)
println("Inferred mean for τ is ", mean(last(results.posteriors[:τ])), " with standard deviation ", std(last(results.posteriors[:τ])))
Real (hidden) μ was 3.1415
Inferred mean for μ is 3.121249819163777 with standard deviation 0.019682305930741596
Real (hidden) τ was 2.7182
Inferred mean for τ is 2.5812568380310656 with standard deviation 0.11532205069721162
rμ = range(2, 4, length = 1000)
pμ = plot(rμ, (x) -> pdf(last(results.posteriors[:μ]), x), title="Posterior for μ", fillalpha=0.3, fillrange = 0, label="P(μ|y)", c=3)
pμ = vline!(pμ, [ hidden_μ ], label = "Real (hidden) μ")

rτ = range(2, 4, length = 1000)
pτ = plot(rτ, (x) -> pdf(last(results.posteriors[:τ]), x), title="Posterior for τ", fillalpha=0.3, fillrange = 0, label="P(τ|y)", c=3)
pτ = vline!(pτ, [ hidden_τ ], label = "Real (hidden) τ")

plot(pμ, pτ)
Example block output

Nice result! Our posteriors are pretty close to the actual values of the parameters used for dataset generation.

Convergence and Bethe Free Energy

Read also the Bethe Free Energy section.

In contrast to Loopy Belief Propagation, the variational inference is set to converge to a stable point during variational inference. In order to verify the convergence for this particular model, we can check the convergence of the Bethe Free Enegrgy values. By default, infer function does not compute the Bethe Free Energy values. In order to compute those, we must set the free_energy flag explicitly to true:

results = infer(
    model          = iid_estimation(),
    data           = (y = dataset, ),
    constraints    = constraints,
    iterations     = 100,
    initialization = initialization,
    free_energy    = true
)
Inference results:
  Posteriors       | available for (μ, τ)
  Free Energy:     | Real[14763.3, 2437.5, 952.711, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108  …  952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108]

Now, we can access the free_energy field of the results and verify if the inference procedure has converged or not:

plot(results.free_energy, label = "Bethe Free Energy")
Example block output

Well, it seems that 100 iterations was too much for this simple problem and we could do much less iterations in order to converge to a stable point. The animation above also suggested that the posterior for μ has converged pretty fast to a stable point.

# Let's try to use only 5 iterations
results = infer(
    model          = iid_estimation(),
    data           = (y = dataset, ),
    constraints    = constraints,
    iterations     = 5,
    initialization = initialization,
    free_energy    = true
)
Inference results:
  Posteriors       | available for (μ, τ)
  Free Energy:     | Real[14763.3, 2437.5, 952.711, 952.108, 952.108]
plot(results.free_energy, label = "Bethe Free Energy")
Example block output

Callbacks

The infer function has its own lifecycle, consisting of multiple steps. A user is free to inject some extra logic during the inference procedure, e.g. for debugging purposes. By supplying callbacks, users can inject custom logic on specific moments during the inference procedure. Here are available callbacks that can be used together with the static datasets:


before_model_creation()

Calls before the model is going to be created, does not accept any arguments.

after_model_creation(model::ProbabilisticModel)

Calls right after the model has been created, accepts a single argument, the model.

before_inference(model::ProbabilisticModel)

Calls before the inference procedure starts, accepts a single argument, the model.

after_inference(model::ProbabilisticModel)

Calls after the inference procedure ends, accepts a single argument, the model.

before_iteration(model::ProbabilisticModel, iteration::Int)

Calls before each iteration, accepts two arguments: the model and the current iteration number.

after_iteration(model::ProbabilisticModel, iteration::Int)

Calls after each iteration, accepts two arguments: the model and the current iteration number.

before_data_update(model::ProbabilisticModel, data)

Calls before each data update, accepts two arguments: the model and the updated data.

after_data_update(model::ProbabilisticModel, data)

Calls after each data update, accepts two arguments: the model and the updated data.

on_marginal_update(model::ProbabilisticModel, name, update)

Calls after each marginal update, accepts three arguments: the model, the name of the updated marginal, and the updated marginal itself.


Here is an example usage of the outlined callbacks:

function before_model_creation()
    println("The model is about to be created")
end

function after_model_creation(model::ProbabilisticModel)
    println("The model has been created")
    println("  The number of factor nodes is: ", length(RxInfer.getfactornodes(model)))
    println("  The number of latent states is: ", length(RxInfer.getrandomvars(model)))
    println("  The number of data points is: ", length(RxInfer.getdatavars(model)))
    println("  The number of constants is: ", length(RxInfer.getconstantvars(model)))
end

function before_inference(model::ProbabilisticModel)
    println("The inference procedure is about to start")
end

function after_inference(model::ProbabilisticModel)
    println("The inference procedure has ended")
end

function before_iteration(model::ProbabilisticModel, iteration::Int)
    println("The iteration ", iteration, " is about to start")
end

function after_iteration(model::ProbabilisticModel, iteration::Int)
    println("The iteration ", iteration, " has ended")
end

function before_data_update(model::ProbabilisticModel, data)
    println("The data is about to be processed")
end

function after_data_update(model::ProbabilisticModel, data)
    println("The data has been processed")
end

function on_marginal_update(model::ProbabilisticModel, name, update)
    println("New marginal update for ", name, " ", update)
end
on_marginal_update (generic function with 1 method)
results = infer(
    model          = iid_estimation(),
    data           = (y = dataset, ),
    constraints    = constraints,
    iterations     = 5,
    initialization = initialization,
    free_energy    = true,
    callbacks      = (
        before_model_creation = before_model_creation,
        after_model_creation = after_model_creation,
        before_inference = before_inference,
        after_inference = after_inference,
        before_iteration = before_iteration,
        after_iteration = after_iteration,
        before_data_update = before_data_update,
        after_data_update = after_data_update,
        on_marginal_update = on_marginal_update
    )
)
The model is about to be created
The model has been created
  The number of factor nodes is: 1002
  The number of latent states is: 2
  The number of data points is: 1000
  The number of constants is: 4
The inference procedure is about to start
The iteration 1 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=5.000000000050656e14))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=3.1276134803795e-9, w=0.10000000100200404))
The data has been processed
The iteration 1 has ended
The iteration 2 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=10065.375288830106))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=155.36497103500977, w=49.87459713359878))
The data has been processed
The iteration 2 has ended
The iteration 3 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=203.94251927819624))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=7667.879879783272, w=2456.6745376352883))
The data has been processed
The iteration 3 has ended
The iteration 4 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=194.10132685523482))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=8056.651469322076, w=2581.226095926443))
The data has been processed
The iteration 4 has ended
The iteration 5 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=194.09150532601225))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=8057.059156601909, w=2581.3567075436076))
The data has been processed
The iteration 5 has ended
The inference procedure has ended

Where to go next?

This guide covered some fundamental usages of the infer function in the context of inference with static datasets, but did not cover all the available keyword arguments of the function. Read more explanation about the other keyword arguments in the Overview section or check out the Streaming Inference section. Also check out more complex examples.