Bethe Free Energy implementation in RxInfer
The following text introduces the Bethe Free Energy. We start be defining a factorized model and move from the Variational Free Energy to a definition of the Bethe Free Energy.
Factorized model
Before we can define a model, we must identify all variables that are relevant to the problem at hand. We distinguish between variables that can be directly observed, and variables that can not be observed directly, also known as latent variables, We then define a model that factorizes over consituent smaller factors (functions), as
Individual factors may represent stochastic functions, such as conditional or prior distributions, but also potential functions or deterministic relationships. A factor may depend on multiple observed and/or latent variables (or none).
Variational Free Energy
The Variational Free Energy (VFE) then defines a functional objective that includes the model and a variational distribution over the latent variables, A functional defines a function of a function that returns a scalar. Here, the VFE is a function of the variational distribution (as indicated by square brackets) and returns a number.
The VFE is also a function of the observed data, as indicated by round brackets, where the data are substituted in the factorized model.
Variational inference
The goal of variational inference is to find a variational distibution that minimizes the VFE, This objective can be optimized (under specific constraints) with the use of variational calculus. Constraints are implied by the domain over which the variational distribution is optimized, and can be enforced by Lagrange multipliers.
For the VFE, constraints enforce e.g. the normalization of the variational distribution. The variational distribution that minimizes the VFE then approximates the true (but often unobtainable) posterior distribution.
Bethe approximation
Optimization of the VFE is still a daunting task, because the variational distribution is a joint distribution over possibly many latent variables. Instead of optimizing the joint variational distribution directly, a factorized variational distribution is often chosen. The factorized variational distribution is then optimized for its constituent factors.
A popular choice of factorization is the Bethe approximation, which is constructed from the factorization of the model itself, The numerator iterates over the factors in the model, and carves the joint variational distribution in smaller variational distributions that are more manageable to optimize.
The denominator of the Bethe approximation iterates over all individual latent variables and discounts them. The discounting factor is chosen as the degree of the variable minus one, where the degree counts the number of factors in which the variable appears.
The Bethe approximation thus constrains the variational distribution to a factorized form. However, the true posterior distribution might not factorize in this way, e.g. if the grapical representation of the model contains cycles. In these cases the Bethe approximation trades the exact solution for computational tractability.
Bethe Free Energy
The Bethe Free Energy (BFE) substitutes the Bethe approximation in the VFE, which then fragments over factors and variables, as The first term of the BFE specifies an average energy, which internalizes the factors of the model. The last two terms specify entropies.
Crucially, the BFE can be iteratively optimized for each individual variational distribution in turn. Optimization of the BFE is thus more manageable than direct optimization of the VFE.
For iterative optimization of the BFE, the variational distributions must first be initialized. The infer
function uses the initialization
keyword argument to initialize the variational distributions of the BFE.
For disambiguation, note that the initialization of the variational distribution is a different design consideration than the choice of priors. A prior specifies a factor in the model definition, while initialization concerns factors in the variational distribution.
Further reading
- Pearl (1986) on the original foundations of Bayesian networks and belief propagation;
- Yedidia et al. (2005) on the connections between belief propagation and regional approximations to the VFE;
- Dauwels (2007) on variational message passing on Forney-style factor graphs (FFGs);
- Senoz et al. (2021) on constraint manipulation and message passing on FFGs.
Implementation details
RxInfer
implements Bethe Free Energy optimization in an implicit way via the mesasge passing technique. That means that the inference engine does not compute BFE values explicitly, unless specified explicitly. The infer
function has free_energy
flag, which indicates whether BFE values must be computed explicitly or not. Note, however, that due to the reactive nature of the message passing implementation in RxInfer
the computed BFE value may not represent its actual state. This may happen when updates for certain posteriors arriving more often than updates for other posteriors and usually tend to happen in models with loops in its structure. To circumvent this, instead of checking if BFE value is being minimized it is advised to check if it converges.
RxInfer.BetheFreeEnergy
— TypeBetheFreeEnergy(skip_strategy, scheduler)
Implements a reactive stream for Bethe Free Energy values. Must be used in combination with the score
function of ReactiveMP.jl
.
Arguments
::Type{T}
: a type of the counting real number, e.g.Float64
. Set toReal
by default, otherwise the inference procedure is not automatically differentiable.skip_strategy
: a strategy that defines which posterior marginals to skip, e.g.SkipInitial()
.scheduler
: a scheduler for the underlying stream, e.g.AsapScheduler()
.
RxInfer.BetheFreeEnergyDefaultMarginalSkipStrategy
— ConstantDefault marginal skip strategy for the Bethe Free Energy objective.
RxInfer.BetheFreeEnergyDefaultScheduler
— ConstantDefault scheduler for the Bethe Free Energy objective.
RxInfer.ReactiveMPFreeEnergyPlugin
— TypeA plugin for GraphPPL graph engine that adds the Bethe Free Energy objective computation to the nodes of the model.
Extra diagnostic checks
RxInfer
verifies intermediate computations of BFE on each iteration. By default, RxInfer
will throw an exception, if local factor node or variable node computations result in either NaN
or Inf
. Note, that the verification happens only if the computation of BFE has been requested explicitly.
RxInfer.apply_diagnostic_check
— Functionapply_diagnostic_check(check, stream)
This function applies a check
to the stream
. Does nothing if check
is of type Nothing
.
RxInfer.ObjectiveDiagnosticCheckNaNs
— TypeObjectiveDiagnosticCheckNaNs
If enabled checks that both variable and factor bound score functions in the objective computation do not return NaN
s. Throws an error if finds NaN
.
RxInfer.ObjectiveDiagnosticCheckInfs
— TypeObjectiveDiagnosticCheckInfs
If enabled checks that both variable and factor bound score functions in the objective computation do not return Inf
s. Throws an error if finds Inf
.
RxInfer.DefaultObjectiveDiagnosticChecks
— Constantconst DefaultObjectiveDiagnosticChecks = (ObjectiveDiagnosticCheckNaNs(), ObjectiveDiagnosticCheckInfs())
A constant that defines the default objective diagnostic checks.