Custom Functional Form Specification
In a nutshell, functional form constraints defines a function that approximates the product of colliding messages and computes posterior marginal that can be used later on during the inference procedure. An important part of the functional forms constraint implementation is the prod function in the BayesBase package. For example, if we refer to our CustomFunctionalForm as to f we can see the whole functional form constraints pipeline as:
\[q(x) = f\left(\frac{\overrightarrow{\mu}(x)\overleftarrow{\mu}(x)}{\int \overrightarrow{\mu}(x)\overleftarrow{\mu}(x) \mathrm{d}x}\right)\]
Interface
ReactiveMP.jl, however, uses some extra utility functions to define functional form constraint behaviour. Here we briefly describe all utility function. If you are only interested in the concrete example, you may directly head to the Custom Functional Form example at the end of this section.
Abstract super type
ReactiveMP.AbstractFormConstraint — Type
AbstractFormConstraintAbstract supertype for all form constraints. Subtype this to create custom form constraints that can be used with constrain_form and ReactiveMP.MessageProductContext.
Not strictly required (any object works via ReactiveMP.WrappedFormConstraint), but makes dispatch easier and is needed for ReactiveMP.CompositeFormConstraint composition via +.
ReactiveMP.UnspecifiedFormConstraint — Type
UnspecifiedFormConstraintThe default form constraint — does nothing and returns the distribution as-is. Used when no form constraint has been specified in the ReactiveMP.MessageProductContext.
ReactiveMP.CompositeFormConstraint — Type
CompositeFormConstraintA form constraint that chains multiple constraints together, applying them in order via constrain_form. Create one by combining constraints with + (e.g. constraint_a + constraint_b). All composed constraints must share the same default_form_check_strategy.
ReactiveMP.preprocess_form_constraints — Function
preprocess_form_constraints(constraints)Converts form constraints into a form compatible with the ReactiveMP inference backend. A tuple of constraints becomes a ReactiveMP.CompositeFormConstraint. Objects that are not subtypes of AbstractFormConstraint get wrapped into a ReactiveMP.WrappedFormConstraint.
Form check strategy
Every custom functional form must implement a new method for the default_form_check_strategy function that returns either FormConstraintCheckEach or FormConstraintCheckLast.
FormConstraintCheckLast:q(x) = f(μ1(x) * μ2(x) * μ3(x))FormConstraintCheckEach:q(x) = f(f(μ1(x) * μ2(x)) * μ3(x))
ReactiveMP.default_form_check_strategy — Function
default_form_check_strategy(form_constraint)Returns the default check strategy (either FormConstraintCheckEach or FormConstraintCheckLast) for a given form constraint. Override this for custom constraints to control when they are applied.
ReactiveMP.FormConstraintCheckEach — Type
FormConstraintCheckEachForm constraint check strategy that applies constrain_form after each pairwise product inside ReactiveMP.compute_product_of_two_messages. Use this when intermediate results need to stay in a specific functional form (e.g. to prevent numerical issues during long product chains).
See also: FormConstraintCheckLast, ReactiveMP.MessageProductContext
ReactiveMP.FormConstraintCheckLast — Type
FormConstraintCheckLastForm constraint check strategy that applies constrain_form once at the very end of ReactiveMP.compute_product_of_messages, after all pairwise products have been folded. This is the default strategy and is more efficient when intermediate form doesn't matter.
See also: FormConstraintCheckEach, ReactiveMP.MessageProductContext
ReactiveMP.FormConstraintCheckPickDefault — Type
FormConstraintCheckPickDefaultA meta-strategy that defers to the default check strategy of the given form constraint, as defined by default_form_check_strategy.
Prod constraint
Every custom functional form must implement a new method for the default_prod_constraint function that returns a proper prod_constraint object.
ReactiveMP.default_prod_constraint — Function
default_prod_constraint(form_constraint)Returns the default product strategy needed to apply a given form_constraint. For most form constraints this returns BayesBase.GenericProd().
Constrain form, a.k.a f
The main function that a custom functional form must implement, which we referred to as f in the beginning of this section, is the constrain_form function.
ReactiveMP.constrain_form — Function
constrain_form(constraint, distribution)Applies the form constraint to distribution and returns the constrained result. This is the main extension point for custom form constraints — implement a method of this function for your constraint type and the distribution types you want to support.
See also: AbstractFormConstraint, ReactiveMP.MessageProductContext
Custom Functional Form Example
In this demo, we show how to build a custom functional form constraint that is compatible with the ReactiveMP.jl inference backend. An important part of the functional form constraint implementation is the prod function in the BayesBase package. We present a relatively simple use case, which may not be very practical but serves as a straightforward step-by-step guide.
Assume that we want a specific posterior marginal of some random variable in our model to have a specific Gaussian parameterization, such as mean-precision. Here, how we can achieve this with our custom MeanPrecisionFormConstraint functional form constraint:
using ReactiveMP, ExponentialFamily, Distributions, BayesBase
# First, we define our functional form structure with no fields
struct MeanPrecisionFormConstraint <: AbstractFormConstraint end
ReactiveMP.default_form_check_strategy(::MeanPrecisionFormConstraint) = FormConstraintCheckLast()
ReactiveMP.default_prod_constraint(::MeanPrecisionFormConstraint) = GenericProd()
function ReactiveMP.constrain_form(::MeanPrecisionFormConstraint, distribution)
# This assumes that the given `distribution` object has `mean` and `precision` defined.
# These quantities might be approximated using other methods, such as Laplace approximation.
m = mean(distribution) # or approximate with some other method
p = precision(distribution) # or approximate with some other method
return NormalMeanPrecision(m, p)
end
function ReactiveMP.constrain_form(::MeanPrecisionFormConstraint, distribution::BayesBase.ProductOf)
# `ProductOf` is a special case. Read more about this type in the corresponding
# documentation section of the `BayesBase` package.
# ...
end
constraint = ReactiveMP.preprocess_form_constraints(MeanPrecisionFormConstraint())
constrain_form(constraint, NormalMeanVariance(0, 2))ExponentialFamily.NormalMeanPrecision{Float64}(μ=0.0, w=0.5)Wrapped Form Constraints
Some constraint objects might not be subtypes of AbstractFormConstraint. This can occur, for instance, if the object is defined in a different package or needs to subtype a different abstract type. In such cases, ReactiveMP expects users to pass a WrappedFormConstraint object, which wraps the original object and makes it compatible with the ReactiveMP inference backend. Note that the ReactiveMP.preprocess_form_constraints function automatically wraps all objects that are not subtypes of AbstractFormConstraint.
Additionally, objects wrapped by WrappedFormConstraints may implement the ReactiveMP.prepare_context function. This function's output will be stored in the WrappedFormConstraints along with the original object. If prepare_context is implemented, the constrain_form function will take three arguments: the original constraint, the context, and the object that needs to be constrained.
ReactiveMP.WrappedFormConstraint — Type
WrappedFormConstraint(constraint, context)A wrapper that pairs a form constraint with an optional precomputed context. Any object that is not a subtype of AbstractFormConstraint gets automatically wrapped into this during ReactiveMP.preprocess_form_constraints. Use ReactiveMP.prepare_context to provide extra context that can be reused across multiple constrain_form calls.
ReactiveMP.prepare_context — Function
prepare_context(constraint)Prepares a reusable context for a given form constraint. Returns WrappedFormConstraintNoContext by default (i.e. no context needed). Override this to precompute things that should be shared across multiple constrain_form calls.
ReactiveMP.constrain_form — Method
constrain_form(wrapped::WrappedFormConstraint, something)Unwraps the constraint and delegates to constrain_form with the inner constraint. If a context was provided via ReactiveMP.prepare_context, it is passed as the second argument.
using ReactiveMP, Distributions, BayesBase, Random
# First, we define our custom form constraint that creates a set of samples
# Note that this is not a subtype of `AbstractFormConstraint`
struct MyCustomSampleListFormConstraint end
# Note that we still need to implement `default_form_check_strategy` and `default_prod_constraint` functions
# which are necessary for the `ReactiveMP` inference backend
ReactiveMP.default_form_check_strategy(::MyCustomSampleListFormConstraint) = FormConstraintCheckLast()
ReactiveMP.default_prod_constraint(::MyCustomSampleListFormConstraint) = GenericProd()
# We implement the `prepare_context` function, which returns a random number generator
function ReactiveMP.prepare_context(constraint::MyCustomSampleListFormConstraint)
return Random.default_rng()
end
# We implement the `constrain_form` function, which returns a set of samples
function ReactiveMP.constrain_form(constraint::MyCustomSampleListFormConstraint, context, distribution)
return rand(context, distribution, 10)
end
constraint = ReactiveMP.preprocess_form_constraints(MyCustomSampleListFormConstraint())
constrain_form(constraint, Normal(0, 10))10-element Vector{Float64}:
-5.401945250970769
22.93740772659237
5.846274361393279
-14.142396288777714
16.77663463181217
4.557072965397123
3.5947554099251136
-18.35485488635651
-10.670262428046996
-2.271186514984308