Custom Functional Form Specification
In a nutshell, functional form constraints defines a function that approximates the product of colliding messages and computes posterior marginal that can be used later on during the inference procedure. An important part of the functional forms constraint implementation is the prod
function in the BayesBase
package. For example, if we refer to our CustomFunctionalForm
as to f
we can see the whole functional form constraints pipeline as:
\[q(x) = f\left(\frac{\overrightarrow{\mu}(x)\overleftarrow{\mu}(x)}{\int \overrightarrow{\mu}(x)\overleftarrow{\mu}(x) \mathrm{d}x}\right)\]
Interface
ReactiveMP.jl
, however, uses some extra utility functions to define functional form constraint behaviour. Here we briefly describe all utility function. If you are only interested in the concrete example, you may directly head to the Custom Functional Form example at the end of this section.
Abstract super type
ReactiveMP.AbstractFormConstraint
— TypeAbstractFormConstraint
Every functional form constraint is a subtype of AbstractFormConstraint
abstract type.
Note: this is not strictly necessary, but it makes automatic dispatch easier and compatible with the CompositeFormConstraint
.
ReactiveMP.UnspecifiedFormConstraint
— TypeUnspecifiedFormConstraint
One of the form constraint objects. Does not imply any form constraints and simply returns the same object as receives. However it does not allow DistProduct
to be a valid functional form in the inference backend.
ReactiveMP.CompositeFormConstraint
— TypeCompositeFormConstraint
Creates a composite form constraint that applies form constraints in order. The composed form constraints must be compatible and have the exact same form_check_strategy
.
ReactiveMP.preprocess_form_constraints
— Functionpreprocess_form_constraints(constraints)
This function preprocesses form constraints and converts the provided objects into a form compatible with ReactiveMP inference backend (if possible). If a tuple of constraints is passed, it creates a CompositeFormConstraint
object. Wraps unknown form constraints into a WrappedFormConstraint
object.
Form check strategy
Every custom functional form must implement a new method for the default_form_check_strategy
function that returns either FormConstraintCheckEach
or FormConstraintCheckLast
.
FormConstraintCheckLast
:q(x) = f(μ1(x) * μ2(x) * μ3(x))
FormConstraintCheckEach
:q(x) = f(f(μ1(x) * μ2(x)) * μ3(x))
ReactiveMP.default_form_check_strategy
— Functiondefault_form_check_strategy(form_constraint)
Returns a default check strategy (e.g. FormConstraintCheckEach
or FormConstraintCheckEach
) for a given form constraint object.
ReactiveMP.FormConstraintCheckEach
— TypeFormConstraintCheckEach
This form constraint check strategy checks functional form of the messages product after each product in an equality chain. Usually if a variable has been connected to multiple nodes we want to perform multiple prod
to obtain a posterior marginal. With this form check strategy constrain_form
function will be executed after each subsequent prod
function.
ReactiveMP.FormConstraintCheckLast
— TypeFormConstraintCheckEach
This form constraint check strategy checks functional form of the last messages product in the equality chain. Usually if a variable has been connected to multiple nodes we want to perform multiple prod
to obtain a posterior marginal. With this form check strategy constrain_form
function will be executed only once after all subsequenct prod
functions have been executed.
ReactiveMP.FormConstraintCheckPickDefault
— TypeFormConstraintCheckPickDefault
This form constraint check strategy simply fallbacks to a default check strategy for a given form constraint.
Prod constraint
Every custom functional form must implement a new method for the default_prod_constraint
function that returns a proper prod_constraint
object.
ReactiveMP.default_prod_constraint
— Functiondefault_prod_constraint(form_constraint)
Returns a default prod constraint needed to apply a given form_constraint
. For most form constraints this function returns ProdGeneric
.
Constrain form, a.k.a f
The main function that a custom functional form must implement, which we referred to as f
in the beginning of this section, is the constrain_form
function.
ReactiveMP.constrain_form
— Functionconstrain_form(constraint, something)
This function applies a given form constraint to a given object.
Custom Functional Form Example
In this demo, we show how to build a custom functional form constraint that is compatible with the ReactiveMP.jl
inference backend. An important part of the functional form constraint implementation is the prod
function in the BayesBase
package. We present a relatively simple use case, which may not be very practical but serves as a straightforward step-by-step guide.
Assume that we want a specific posterior marginal of some random variable in our model to have a specific Gaussian parameterization, such as mean-precision. Here, how we can achieve this with our custom MeanPrecisionFormConstraint
functional form constraint:
using ReactiveMP, ExponentialFamily, Distributions, BayesBase
# First, we define our functional form structure with no fields
struct MeanPrecisionFormConstraint <: AbstractFormConstraint end
ReactiveMP.default_form_check_strategy(::MeanPrecisionFormConstraint) = FormConstraintCheckLast()
ReactiveMP.default_prod_constraint(::MeanPrecisionFormConstraint) = GenericProd()
function ReactiveMP.constrain_form(::MeanPrecisionFormConstraint, distribution)
# This assumes that the given `distribution` object has `mean` and `precision` defined.
# These quantities might be approximated using other methods, such as Laplace approximation.
m = mean(distribution) # or approximate with some other method
p = precision(distribution) # or approximate with some other method
return NormalMeanPrecision(m, p)
end
function ReactiveMP.constrain_form(::MeanPrecisionFormConstraint, distribution::BayesBase.ProductOf)
# `ProductOf` is a special case. Read more about this type in the corresponding
# documentation section of the `BayesBase` package.
# ...
end
constraint = ReactiveMP.preprocess_form_constraints(MeanPrecisionFormConstraint())
constrain_form(constraint, NormalMeanVariance(0, 2))
ExponentialFamily.NormalMeanPrecision{Float64}(μ=0.0, w=0.5)
Wrapped Form Constraints
Some constraint objects might not be subtypes of AbstractFormConstraint
. This can occur, for instance, if the object is defined in a different package or needs to subtype a different abstract type. In such cases, ReactiveMP
expects users to pass a WrappedFormConstraint
object, which wraps the original object and makes it compatible with the ReactiveMP
inference backend. Note that the ReactiveMP.preprocess_form_constraints
function automatically wraps all objects that are not subtypes of AbstractFormConstraint
.
Additionally, objects wrapped by WrappedFormConstraints
may implement the ReactiveMP.prepare_context
function. This function's output will be stored in the WrappedFormConstraints
along with the original object. If prepare_context
is implemented, the constrain_form
function will take three arguments: the original constraint, the context, and the object that needs to be constrained.
ReactiveMP.WrappedFormConstraint
— TypeWrappedFormConstraint(constraint, context)
This is a wrapper for a form constraint object. It allows to pass additional context to the constrain_form
function. By default all objects that are not sub-typed from AbstractFormConstraint
are wrapped into this object. Use ReactiveMP.prepare_context
to provide an extra context for a given form constraint, that can be reused between multiple constrain_form
calls.
ReactiveMP.prepare_context
— Functionprepare_context(constraint)
This function prepares a context for a given form constraint. Returns WrappedFormConstraintNoContext
if no context is needed (the default behaviour).
ReactiveMP.constrain_form
— Methodconstrain_form(wrapped::WrappedFormConstraint, something)
This function unwraps the wrapped
object and calls constrain_form
function with the provided context. If the context is not provided, simply calls constrain_form
with the wrapped constraint. Otherwise passes the context to the constrain_form
function as the second argument.
using ReactiveMP, Distributions, BayesBase, Random
# First, we define our custom form constraint that creates a set of samples
# Note that this is not a subtype of `AbstractFormConstraint`
struct MyCustomSampleListFormConstraint end
# Note that we still need to implement `default_form_check_strategy` and `default_prod_constraint` functions
# which are necessary for the `ReactiveMP` inference backend
ReactiveMP.default_form_check_strategy(::MyCustomSampleListFormConstraint) = FormConstraintCheckLast()
ReactiveMP.default_prod_constraint(::MyCustomSampleListFormConstraint) = GenericProd()
# We implement the `prepare_context` function, which returns a random number generator
function ReactiveMP.prepare_context(constraint::MyCustomSampleListFormConstraint)
return Random.default_rng()
end
# We implement the `constrain_form` function, which returns a set of samples
function ReactiveMP.constrain_form(constraint::MyCustomSampleListFormConstraint, context, distribution)
return rand(context, distribution, 10)
end
constraint = ReactiveMP.preprocess_form_constraints(MyCustomSampleListFormConstraint())
constrain_form(constraint, Normal(0, 10))
10-element Vector{Float64}:
10.10069465442664
19.419371597098397
13.00177609273929
-10.758450030955542
7.047409702090875
-19.506827135698387
7.3727202607858375
-13.857520455891638
11.055239843880283
4.0793311359172755