Inference without explicit message update rules
RxInfer
utilizes the ReactiveMP.jl
package as its inference backend. Typically, running inference with ReactiveMP.jl
requires users to define a factor node using the @node
macro and specify corresponding message update rules with the @rule
macro. Detailed instructions on this can be found in this section of the documentation. However, in this tutorial, we will explore an alternative approach that allows inference with default message update rule for custom factor nodes by defining only BayesBase.logpdf
and BayesBase.insupport
for a factor node, without needing explicit @rule
specifications.
In the context of message-passing based Bayesian inference, custom message update rules enhance precision and efficiency. These rules leverage the specific mathematical properties of the model's distributions and relationships, leading to more accurate updates and faster convergence. By incorporating domain-specific knowledge, custom rules improve the robustness and reliability of the inference process, particularly in complex models where default rules may be inadequate or inefficient.
A simple prior-likelihood model
We start a simple model with a hidden variable p
and observations y
. Later in the tutorial we explore more advanced use-cases. In this particular case we assume that p
follows a prior distribution and y
are drawn from a likelihood distribution. The model can be defined as follows:
using RxInfer
@model function simple_model(y, prior, likelihood)
p ~ prior
y .~ likelihood(p)
end
Node specifications
Next, we define structures for both the prior
and the likelihood
. Let's start with the prior. Assume that the p
parameter is best described by a Beta
distribution. We can define it as follows:
The Distributions.jl
package already provides a fully-featured implementation of Beta
and Bernoulli
distributions, including functions like logpdf
and support checks. The example below redefines the Beta
distribution structure and related functions solely for illustrative purposes. In practice, you often won't need to define these distributions yourself, as many of them has already been included in Distributions.jl
.
using Distributions, BayesBase
struct BetaDistribution{A, B} <: ContinuousUnivariateDistribution
a::A
b::B
end
# Reuse `logpdf` from `Distributions.jl` for illustrative purposes
BayesBase.logpdf(d::BetaDistribution, x) = logpdf(Beta(d.a, d.b), x)
BayesBase.insupport(d::BetaDistribution, x::Real) = 0 <= x <= 1
Next, we assume that y
is a discrete dataset of true
and false
values. The logical choice for the likelihood
distribution is the Bernoulli
distribution.
struct BernoulliDistribution{P} <: DiscreteUnivariateDistribution
p::P
end
# Reuse `logpdf` from `Distributions.jl` for illustrative purposes
BayesBase.logpdf(d::BernoulliDistribution, x) = logpdf(Bernoulli(d.p), x)
BayesBase.insupport(d::BernoulliDistribution, x) = x === true || x === false
The next step is to register these structures as valid factor nodes:
@node BetaDistribution Stochastic [out, a, b]
@node BernoulliDistribution Stochastic [out, p]
When specifying a node for our custom distributions, we must follow a specific edge ordering. The first edge is always out
, which represents a sample in the logpdf
function. All remaining edges must match the parameters of the distribution in the exact same order. For example, for the BetaDistribution
, the node function is defined as (out, a, b) -> logpdf(BetaDistribution(a, b), out)
. This ensures that the node specification and the logpdf
function correctly maps the distribution parameters to the sample output.
Although Beta
is a conjugate prior for the parameter of the Bernoulli
distribution, ReactiveMP
and RxInfer
are unaware of this and cannot exploit this information. To utilize conjugacy, refer to the custom node creation section of the documentation.
Generating a synthetic dataset
Previously, we assumed that our dataset consists of discrete values: true
and false
. We can generate a synthetic dataset with these values as follows:
using StableRNGs, Plots
hidden_p = 1 / 3.1415 # a value between `0` and `1`
ndatapoints = 1_000 # number of observarions
dataset = rand(StableRNG(42), Bernoulli(hidden_p), ndatapoints)
bar(["true", "false"], [ count(==(true), dataset), count(==(false), dataset) ], label = "dataset")
Inference with a rule fallback
Now, we can run inference with RxInfer
. Since explicit rules for our nodes have not defined, we can instruct the ReactiveMP
backend to use fallback message update rules. Refer to the ReactiveMP
documentation for available fallbacks. In this example, we will use the NodeFunctionRuleFallback
structure, which uses the logpdf
of the stochastic node to approximate messages.
NodeFunctionRuleFallback
employs a simple approximation for outbound messages, which may significantly degrade inference accuracy. Whenever possible, it is recommended to define proper message update rules.
To complete the inference setup, we must define an approximation method for posteriors using the @constraints
macro. We will utilize the ExponentialFamilyProjection
library to project an arbitrary function onto a member of the exponential family. More information on ExponentialFamilyProjection
can be found in the Non-conjugate Inference section and in its official documentation.
using ExponentialFamilyProjection
@constraints function projection_constraints()
# Use `Beta` from `Distributions.jl` as it is compatible with the `ExponentialFamilyProjection` library
q(p) :: ProjectedTo(Beta)
end
projection_constraints (generic function with 1 method)
With all components ready, we can proceed with the inference procedure:
result = infer(
model = simple_model(prior = BetaDistribution(1, 1), likelihood = BernoulliDistribution),
data = (y = dataset, ),
constraints = projection_constraints(),
options = (
rulefallback = NodeFunctionRuleFallback(),
)
)
Inference results:
Posteriors | available for (p)
For rulefallback = NodeFunctionRuleFallback()
to function correctly, the node must be defined as Stochastic
and the underlying object must be a subtype of Distribution
from Distributions.jl
.
Result analysis
We can perform a simple analysis and compare the inferred value with the hidden value used to generate the actual dataset:
using Plots, StatsPlots
plot(result.posteriors[:p], label = "posterior of p", fill = 0, fillalpha = 0.2)
vline!([ hidden_p ], label = "hidden p")
As shown, the estimated posterior is quite close to the actual hidden value of p
used during the inference procedure.
Fusing deterministic transformations with stochastic nodes
One of the limitations of the NodeFunctionRuleFallback
implementation is that it does not support Deterministic
or Delta
nodes. However, it is possible to combine a deterministic transformation with a stochastic node, such as Gaussian
. For instance, consider a dataset drawn from the Normal
distribution, where the mean
parameter has been transformed by a known function, and the true hidden variable is h
.
using ExponentialFamily, Distributions, Plots, StableRNGs
hidden_h = 2.3
hidden_t = 0.5
known_transformation(h) = exp(h)
hidden_mean = known_transformation(hidden_h)
ndatapoints = 50
dataset = rand(StableRNG(42), NormalMeanPrecision(hidden_mean, hidden_t), ndatapoints)
histogram(dataset; normalize = :pdf)
The model can be defined as follows:
using RxInfer
@model function mymodel(y, prior_h, prior_t)
h ~ prior_h
t ~ prior_t
y .~ Normal(mean = known_transformation(h), precision = t)
end
Inference in this model is challenging because the known_transformation
function is explicitly used as a factor node, requiring special approximation rules. These rules are covered in a separate section. Here, we demonstrate a different approach that modifies the model structure to run inference without needing to approximate messages around a deterministic node.
First, we define our custom transformed Normal
distribution:
using BayesBase
struct TransformedNormalDistribution{H, T} <: ContinuousUnivariateDistribution
h::H
t::T
end
# We integrate the `known_transformation` within the `logpdf` function
# This way, it won't be an explicit factor node but hidden within the `logpdf` of another node
BayesBase.logpdf(dist::TransformedNormalDistribution, x) = logpdf(NormalMeanPrecision(known_transformation(dist.h), dist.t), x)
BayesBase.insupport(dist::TransformedNormalDistribution, x) = true
@node TransformedNormalDistribution Stochastic [out, h, t]
Next, we tweak the model structure:
@model function mymodel(y, prior_h, prior_t)
h ~ prior_h
t ~ prior_t
y .~ TransformedNormalDistribution(h, t)
end
We use the following priors, constraints, and initialization:
using ExponentialFamilyProjection
prior_h = LogNormal(0, 1)
prior_t = Gamma(1, 1)
constraints = @constraints begin
q(h, t) = q(h)q(t)
q(h) :: ProjectedTo(LogNormal)
q(t) :: ProjectedTo(Gamma)
end
initialization = @initialization begin
q(t) = Gamma(1, 1)
end
Initial state:
q(t) = Gamma{Float64}(α=1.0, θ=1.0)
The ProjectedTo
macro has a parameters
field that allows for different hyperparameters, which may improve accuracy or convergence speed. Refer to the ExponentialFamilyProjection
documentation for more information.
Inference with a rule fallback
Now we are ready to run the inference procedure:
result = infer(
model = mymodel(prior_h = prior_h, prior_t = prior_t),
data = (y = dataset,),
constraints = constraints,
initialization = initialization,
iterations = 50,
options = (
rulefallback = NodeFunctionRuleFallback(),
)
)
Inference results:
Posteriors | available for (h, t)
Result analysis
Finally, let's plot the resulting posteriors for each VMP iteration:
@gif for (i, q) in enumerate(zip(result.posteriors[:h], result.posteriors[:t]))
p1 = plot(1:0.01:3, q[1], label = "q(h) iteration $i", fill = 0, fillalpha = 0.2)
p1 = vline!([hidden_h], label = "hidden h")
p2 = plot(0:0.01:1, q[2], label = "q(t) iteration $i", fill = 0, fillalpha = 0.2)
p2 = vline!([hidden_t], label = "hidden t")
plot(p1, p2)
end fps = 15
We can see that the inference results are able to recover the actual value of hidden h
that has been used to generate the synthetic dataset. In conclusion, this example demonstrates that by integrating deterministic transformations within the logpdf function of a stochastic node, we can bypass the limitations of NodeFunctionRuleFallback
in handling deterministic nodes.