Static Inference
This guide explains how to use the infer
function for static datasets. We'll show how RxInfer
can estimate posterior beliefs given a set of observations. We'll use a simple Beta-Bernoulli model as an example, which has been covered in the Getting Started section, but keep in mind that these techniques can apply to any model.
Also read about Streaming Inference or checkout more complex examples.
Model specification
Also read the Model Specification section.
In static inference, we want to update our prior beliefs about certain hidden states given some dataset. To achieve this, we include data as an argument in our model specification:
using RxInfer
@model function beta_bernoulli(y, a, b)
θ ~ Beta(a, b)
for i in 1:length(y)
y[i] ~ Bernoulli(θ)
end
end
In this model, we assume that y
is a collection of data points, and a
and b
are just numbers. To run inference in this model, we have to call the infer
function with the data
argument provided.
Dataset of observations
For demonstration purposes, we will use hand crafted dataset:
using Distributions, StableRNGs
hidden_θ = 1 / 3.1415
distribution = Bernoulli(hidden_θ)
rng = StableRNG(43)
n_observations = 1_000
dataset = rand(rng, distribution, n_observations)
Calling the inference procedure
Everything is ready to run inference in our simple model. In order to run inference with static dataset using the infer
function, we need to use the data
argument. The data
argument expects a NamedTuple
where keys correspond to the names of the model arguments. In our case the model arguments were a
, b
and y
. We treat a
and b
as hyperparameters and pass them directly to the model constructor and we treat y
as our observations, thus we pass it to the data
argument as follows:
results = infer(
model = beta_bernoulli(a = 1.0, b = 1.0),
data = (y = dataset, )
)
Inference results:
Posteriors | available for (θ)
y
inside the @model
specification is not the same data collection as provided in the data
argument. Inside the @model
, y
is a collection of nodes in the corresponding factor graph, but it will have exactly the same shape as the collection provided in the data
argument, hence we can use some basic Julia function, e.g. length
.
Note, that we could also pass a
and b
as data:
results = infer(
model = beta_bernoulli(),
data = (y = dataset, a = 1.0, b = 1.0)
)
Inference results:
Posteriors | available for (θ)
The infer
function, however, requires at least one data argument to be present in the supplied data
. The difference between hyperparameters and observations is purely semantic and should not have real influence on the result of the inference procedure.
The inference procedure uses reactive message passing protocol and may decide to optimize and precompute certain messages that use fixed hyperparameters, hence changing the order of computed messages. The order of computations may change the convergence properties for some complex models.
In case of inference with static datasets, the infer
function will return the InferenceResult
structure. This structure has the .posteriors
field, which is a Dict
like structure that maps names of latent states to their corresponding posteriors. For example:
results.posteriors[:θ]
Beta{Float64}(α=336.0, β=666.0)
RxInfer.InferenceResult
— TypeInferenceResult
This structure is used as a return value from the infer
function for static datasets.
Public Fields
posteriors
:Dict
orNamedTuple
of 'random variable' - 'posterior' pairs. See thereturnvars
argument forinfer
.predictions
: (optional)Dict
orNamedTuple
of 'data variable' - 'prediction' pairs. See thepredictvars
argument forinfer
.free_energy
: (optional) An array of Bethe Free Energy values per VMP iteration. See thefree_energy
argument forinfer
.model
:FactorGraphModel
object reference.error
: (optional) A reference to an exception, that might have occurred during the inference. See thecatch_exception
argument forinfer
.
See also: infer
We can also visualize our posterior results with the Plots.jl
package. We used Beta(a = 1.0, b = 1.0)
as a prior, lets compare our prior and posterior beliefs:
using Plots
rθ = range(0, 1, length = 1000)
p = plot()
p = plot!(p, rθ, (x) -> pdf(Beta(1.0, 1.0), x), title="Prior", fillalpha=0.3, fillrange = 0, label="P(θ)", c=1,)
p = plot!(p, rθ, (x) -> pdf(results.posteriors[:θ], x), title="Posterior", fillalpha=0.3, fillrange = 0, label="P(θ|y)", c=3)
p = vline!(p, [ hidden_θ ], label = "Real (hidden) θ")
Missing data points and predictions
result = infer(
model = beta_bernoulli(a = 1.0, b = 1.0),
data = (y = [ true, false, missing, true, false ], )
)
Inference results:
Posteriors | available for (θ)
Predictions | available for (y)
In principle, the entire dataset may consist of missing
entries:
result = infer(
model = beta_bernoulli(a = 1.0, b = 1.0),
data = (y = [ missing, missing, missing, missing, missing ], )
)
Inference results:
Posteriors | available for (θ)
Predictions | available for (y)
In this case, the resulting posterior is simply equal to the prior (as expected, since no extra information can be extracted from the observations):
result.posteriors[:θ]
Beta{Float64}(α=1.0, β=1.0)
In addition, in the presence of missing
data points RxInfer
will also attempt to compute predictive distributions:
result.predictions[:y]
5-element Vector{Bernoulli{Float64}}:
Bernoulli{Float64}(p=0.5)
Bernoulli{Float64}(p=0.5)
Bernoulli{Float64}(p=0.5)
Bernoulli{Float64}(p=0.5)
Bernoulli{Float64}(p=0.5)
# Sample y₃
rand(result.predictions[:y][3])
true
Variational Inference with static datasets
The example above is quite simple and performs exact Bayesian inference. However, for more complex model, we may need to specify variational constraints and perform variational inference. To demonstrate this, we will use a slightly more complex model, where we need to estimate mean and the precision of IID samples drawn from the Normal distribution:
@model function iid_estimation(y)
μ ~ Normal(mean = 0.0, precision = 0.1)
τ ~ Gamma(shape = 1.0, rate = 1.0)
y .~ Normal(mean = μ, precision = τ)
end
In this model, we have two latent variables μ
and τ
and a set of observations y
. Note that we used the broadcasting syntax, which is roughly equivalent to the manual for loop shown in the previous example. Let's try to run the inference in this model, but first, we need to create our observations:
# `ExponentialFamily` package expors different parametrizations
# for the Normal distribution
using ExponentialFamily
hidden_μ = 3.1415
hidden_τ = 2.7182
distribution = NormalMeanPrecision(hidden_μ, hidden_τ)
rng = StableRNG(42)
n_observations = 1_000
dataset = rand(rng, distribution, n_observations)
And finally we run the inference procedure:
results = infer(
model = iid_estimation(),
data = (y = dataset, )
)
ERROR: Variables [ μ, τ ] have not been updated after an update event.
Therefore, make sure to initialize all required marginals and messages. See `initialization` keyword argument for the inference function.
See the official documentation for detailed information regarding the initialization.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
Huh? We get an error saying that the inference could not update the latent variables. This is happened because our model contain loops in its structure, therefore it requires the initialization. Read more about the initialization in the corresponding section in the documentation.
We have two options here, either we initialize the messages and perform Loopy Belief Propagation in this model or we break the loops with variational constraints and perform variational inference. In this tutorial, we will choose the second option. For this we need to specify factorization constraints with the @constraints
macro.
# Specify mean-field constraint over the joint variational posterior
constraints = @constraints begin
q(μ, τ) = q(μ)q(τ)
end
# Specify initial posteriors for variational iterations
initialization = @initialization begin
q(μ) = vague(NormalMeanPrecision)
q(τ) = vague(GammaShapeRate)
end
With this, we can use the constraints
and initialization
keyword arguments in the infer
function. We also specify the number of variational iterations with the iterations
keyword argument:
results = infer(
model = iid_estimation(),
data = (y = dataset, ),
constraints = constraints,
iterations = 100,
initialization = initialization
)
Inference results:
Posteriors | available for (μ, τ)
Nice! Now, we have some result. Let's for example inspect the posterior results for μ
.
results.posteriors[:μ]
100-element Vector{NormalWeightedMeanPrecision{Float64}}:
NormalWeightedMeanPrecision{Float64}(xi=3.1276134803795e-9, w=0.10000000100200404)
NormalWeightedMeanPrecision{Float64}(xi=155.36497103500977, w=49.87459713359878)
NormalWeightedMeanPrecision{Float64}(xi=7667.879879783272, w=2456.6745376352883)
NormalWeightedMeanPrecision{Float64}(xi=8056.651469322076, w=2581.226095926443)
NormalWeightedMeanPrecision{Float64}(xi=8057.059156601909, w=2581.3567075436076)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563495237, w=2581.3568379008652)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901298, w=2581.356838030949)
NormalWeightedMeanPrecision{Float64}(xi=8057.0595639017065, w=2581.356838031086)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
⋮
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901715, w=2581.356838031092)
NormalWeightedMeanPrecision{Float64}(xi=8057.059563901707, w=2581.356838031086)
In constrast to the previous example, now we have an array of posteriors for μ
, not just a single value. Each posterior in the collection corresponds to the intermediate variational update for each variational iteration. Let's visualize how our posterior over μ
has been changing during the variational optimization:
@gif for (i, intermediate_posterior) in enumerate(results.posteriors[:μ])
rμ = range(0, 5, length = 1000)
plot(rμ, (x) -> pdf(intermediate_posterior, x), title="Posterior on iteration $(i)", fillalpha=0.3, fillrange = 0, label="P(μ|y)", c=3)
vline!([hidden_μ], label = "Real (hidden) μ")
end
It seems that the posterior has converged to a stable distribution pretty fast. We are going to verify the converge in the next section. If, for example, we are not interested in intermediate updates, but just in the final posterior, we could use the returnvars
option in the infer
function and use the KeepLast
option for μ
:
results_keep_last = infer(
model = iid_estimation(),
data = (y = dataset, ),
constraints = constraints,
iterations = 100,
returnvars = (μ = KeepLast(), ),
initialization = initialization
)
Inference results:
Posteriors | available for (μ)
We can also verify that the got exactly the same result:
results_keep_last.posteriors[:μ] == last(results.posteriors[:μ])
true
Let's also verify that the posteriors are consistent with the real hidden values used in the dataset generation:
println("Real (hidden) μ was ", hidden_μ)
println("Inferred mean for μ is ", mean(last(results.posteriors[:μ])), " with standard deviation ", std(last(results.posteriors[:μ])))
println("Real (hidden) τ was ", hidden_τ)
println("Inferred mean for τ is ", mean(last(results.posteriors[:τ])), " with standard deviation ", std(last(results.posteriors[:τ])))
Real (hidden) μ was 3.1415
Inferred mean for μ is 3.121249819163777 with standard deviation 0.019682305930741596
Real (hidden) τ was 2.7182
Inferred mean for τ is 2.5812568380310656 with standard deviation 0.11532205069721162
rμ = range(2, 4, length = 1000)
pμ = plot(rμ, (x) -> pdf(last(results.posteriors[:μ]), x), title="Posterior for μ", fillalpha=0.3, fillrange = 0, label="P(μ|y)", c=3)
pμ = vline!(pμ, [ hidden_μ ], label = "Real (hidden) μ")
rτ = range(2, 4, length = 1000)
pτ = plot(rτ, (x) -> pdf(last(results.posteriors[:τ]), x), title="Posterior for τ", fillalpha=0.3, fillrange = 0, label="P(τ|y)", c=3)
pτ = vline!(pτ, [ hidden_τ ], label = "Real (hidden) τ")
plot(pμ, pτ)
Nice result! Our posteriors are pretty close to the actual values of the parameters used for dataset generation.
Convergence and Bethe Free Energy
Read also the Bethe Free Energy section.
In contrast to Loopy Belief Propagation, the variational inference is set to converge to a stable point during variational inference. In order to verify the convergence for this particular model, we can check the convergence of the Bethe Free Enegrgy values. By default, infer
function does not compute the Bethe Free Energy values. In order to compute those, we must set the free_energy
flag explicitly to true
:
results = infer(
model = iid_estimation(),
data = (y = dataset, ),
constraints = constraints,
iterations = 100,
initialization = initialization,
free_energy = true
)
Inference results:
Posteriors | available for (μ, τ)
Free Energy: | Real[14763.3, 2437.5, 952.711, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108 … 952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108, 952.108]
Now, we can access the free_energy
field of the results
and verify if the inference procedure has converged or not:
plot(results.free_energy, label = "Bethe Free Energy")
Well, it seems that 100
iterations was too much for this simple problem and we could do much less iterations in order to converge to a stable point. The animation above also suggested that the posterior for μ
has converged pretty fast to a stable point.
# Let's try to use only 5 iterations
results = infer(
model = iid_estimation(),
data = (y = dataset, ),
constraints = constraints,
iterations = 5,
initialization = initialization,
free_energy = true
)
Inference results:
Posteriors | available for (μ, τ)
Free Energy: | Real[14763.3, 2437.5, 952.711, 952.108, 952.108]
plot(results.free_energy, label = "Bethe Free Energy")
Callbacks
The infer
function has its own lifecycle, consisting of multiple steps. A user is free to inject some extra logic during the inference procedure, e.g. for debugging purposes. By supplying callbacks, users can inject custom logic on specific moments during the inference procedure. Here are available callbacks that can be used together with the static datasets:
before_model_creation()
Calls before the model is going to be created, does not accept any arguments.
after_model_creation(model::ProbabilisticModel)
Calls right after the model has been created, accepts a single argument, the model
.
before_inference(model::ProbabilisticModel)
Calls before the inference procedure starts, accepts a single argument, the model
.
after_inference(model::ProbabilisticModel)
Calls after the inference procedure ends, accepts a single argument, the model
.
before_iteration(model::ProbabilisticModel, iteration::Int)
Calls before each iteration, accepts two arguments: the model
and the current iteration number.
after_iteration(model::ProbabilisticModel, iteration::Int)
Calls after each iteration, accepts two arguments: the model
and the current iteration number.
before_data_update(model::ProbabilisticModel, data)
Calls before each data update, accepts two arguments: the model
and the updated data.
after_data_update(model::ProbabilisticModel, data)
Calls after each data update, accepts two arguments: the model
and the updated data.
on_marginal_update(model::ProbabilisticModel, name, update)
Calls after each marginal update, accepts three arguments: the model
, the name of the updated marginal, and the updated marginal itself.
Here is an example usage of the outlined callbacks:
function before_model_creation()
println("The model is about to be created")
end
function after_model_creation(model::ProbabilisticModel)
println("The model has been created")
println(" The number of factor nodes is: ", length(RxInfer.getfactornodes(model)))
println(" The number of latent states is: ", length(RxInfer.getrandomvars(model)))
println(" The number of data points is: ", length(RxInfer.getdatavars(model)))
println(" The number of constants is: ", length(RxInfer.getconstantvars(model)))
end
function before_inference(model::ProbabilisticModel)
println("The inference procedure is about to start")
end
function after_inference(model::ProbabilisticModel)
println("The inference procedure has ended")
end
function before_iteration(model::ProbabilisticModel, iteration::Int)
println("The iteration ", iteration, " is about to start")
end
function after_iteration(model::ProbabilisticModel, iteration::Int)
println("The iteration ", iteration, " has ended")
end
function before_data_update(model::ProbabilisticModel, data)
println("The data is about to be processed")
end
function after_data_update(model::ProbabilisticModel, data)
println("The data has been processed")
end
function on_marginal_update(model::ProbabilisticModel, name, update)
println("New marginal update for ", name, " ", update)
end
on_marginal_update (generic function with 1 method)
results = infer(
model = iid_estimation(),
data = (y = dataset, ),
constraints = constraints,
iterations = 5,
initialization = initialization,
free_energy = true,
callbacks = (
before_model_creation = before_model_creation,
after_model_creation = after_model_creation,
before_inference = before_inference,
after_inference = after_inference,
before_iteration = before_iteration,
after_iteration = after_iteration,
before_data_update = before_data_update,
after_data_update = after_data_update,
on_marginal_update = on_marginal_update
)
)
The model is about to be created
The model has been created
The number of factor nodes is: 1002
The number of latent states is: 2
The number of data points is: 1000
The number of constants is: 4
The inference procedure is about to start
The iteration 1 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=5.000000000050656e14))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=3.1276134803795e-9, w=0.10000000100200404))
The data has been processed
The iteration 1 has ended
The iteration 2 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=10065.375288830106))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=155.36497103500977, w=49.87459713359878))
The data has been processed
The iteration 2 has ended
The iteration 3 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=203.94251927819624))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=7667.879879783272, w=2456.6745376352883))
The data has been processed
The iteration 3 has ended
The iteration 4 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=194.10132685523482))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=8056.651469322076, w=2581.226095926443))
The data has been processed
The iteration 4 has ended
The iteration 5 is about to start
The data is about to be processed
New marginal update for τ Marginal(GammaShapeRate{Float64}(a=501.0, b=194.09150532601225))
New marginal update for μ Marginal(NormalWeightedMeanPrecision{Float64}(xi=8057.059156601909, w=2581.3567075436076))
The data has been processed
The iteration 5 has ended
The inference procedure has ended
Where to go next?
This guide covered some fundamental usages of the infer
function in the context of inference with static datasets, but did not cover all the available keyword arguments of the function. Read more explanation about the other keyword arguments in the Overview section or check out the Streaming Inference section. Also check out more complex examples.