Custom Functional Form Specification

In a nutshell, functional form constraints defines a function that approximates the product of colliding messages and computes posterior marginal that can be used later on during the inference procedure. An important part of the functional forms constraint implementation is the prod function in the BayesBase package. For example, if we refer to our CustomFunctionalForm as to f we can see the whole functional form constraints pipeline as:

\[q(x) = f\left(\frac{\overrightarrow{\mu}(x)\overleftarrow{\mu}(x)}{\int \overrightarrow{\mu}(x)\overleftarrow{\mu}(x) \mathrm{d}x}\right)\]

Interface

ReactiveMP.jl, however, uses some extra utility functions to define functional form constraint behaviour. Here we briefly describe all utility function. If you are only interested in the concrete example, you may directly head to the Custom Functional Form example at the end of this section.

Abstract super type

ReactiveMP.AbstractFormConstraintType
AbstractFormConstraint

Every functional form constraint is a subtype of AbstractFormConstraint abstract type.

Note: this is not strictly necessary, but it makes automatic dispatch easier and compatible with the CompositeFormConstraint.

source
ReactiveMP.UnspecifiedFormConstraintType
UnspecifiedFormConstraint

One of the form constraint objects. Does not imply any form constraints and simply returns the same object as receives. However it does not allow DistProduct to be a valid functional form in the inference backend.

source
ReactiveMP.CompositeFormConstraintType
CompositeFormConstraint

Creates a composite form constraint that applies form constraints in order. The composed form constraints must be compatible and have the exact same form_check_strategy.

source
ReactiveMP.preprocess_form_constraintsFunction
preprocess_form_constraints(constraints)

This function preprocesses form constraints and converts the provided objects into a form compatible with ReactiveMP inference backend (if possible). If a tuple of constraints is passed, it creates a CompositeFormConstraint object. Wraps unknown form constraints into a WrappedFormConstraint object.

source

Form check strategy

Every custom functional form must implement a new method for the default_form_check_strategy function that returns either FormConstraintCheckEach or FormConstraintCheckLast.

  • FormConstraintCheckLast: q(x) = f(μ1(x) * μ2(x) * μ3(x))
  • FormConstraintCheckEach: q(x) = f(f(μ1(x) * μ2(x)) * μ3(x))
ReactiveMP.default_form_check_strategyFunction
default_form_check_strategy(form_constraint)

Returns a default check strategy (e.g. FormConstraintCheckEach or FormConstraintCheckEach) for a given form constraint object.

source
ReactiveMP.FormConstraintCheckEachType
FormConstraintCheckEach

This form constraint check strategy checks functional form of the messages product after each product in an equality chain. Usually if a variable has been connected to multiple nodes we want to perform multiple prod to obtain a posterior marginal. With this form check strategy constrain_form function will be executed after each subsequent prod function.

source
ReactiveMP.FormConstraintCheckLastType
FormConstraintCheckEach

This form constraint check strategy checks functional form of the last messages product in the equality chain. Usually if a variable has been connected to multiple nodes we want to perform multiple prod to obtain a posterior marginal. With this form check strategy constrain_form function will be executed only once after all subsequenct prod functions have been executed.

source

Prod constraint

Every custom functional form must implement a new method for the default_prod_constraint function that returns a proper prod_constraint object.

ReactiveMP.default_prod_constraintFunction
default_prod_constraint(form_constraint)

Returns a default prod constraint needed to apply a given form_constraint. For most form constraints this function returns ProdGeneric.

source

Constrain form, a.k.a f

The main function that a custom functional form must implement, which we referred to as f in the beginning of this section, is the constrain_form function.

Custom Functional Form Example

In this demo, we show how to build a custom functional form constraint that is compatible with the ReactiveMP.jl inference backend. An important part of the functional form constraint implementation is the prod function in the BayesBase package. We present a relatively simple use case, which may not be very practical but serves as a straightforward step-by-step guide.

Assume that we want a specific posterior marginal of some random variable in our model to have a specific Gaussian parameterization, such as mean-precision. Here, how we can achieve this with our custom MeanPrecisionFormConstraint functional form constraint:

using ReactiveMP, ExponentialFamily, Distributions, BayesBase

# First, we define our functional form structure with no fields
struct MeanPrecisionFormConstraint <: AbstractFormConstraint end

ReactiveMP.default_form_check_strategy(::MeanPrecisionFormConstraint) = FormConstraintCheckLast()
ReactiveMP.default_prod_constraint(::MeanPrecisionFormConstraint) = GenericProd()

function ReactiveMP.constrain_form(::MeanPrecisionFormConstraint, distribution)
    # This assumes that the given `distribution` object has `mean` and `precision` defined.
    # These quantities might be approximated using other methods, such as Laplace approximation.
    m = mean(distribution)      # or approximate with some other method
    p = precision(distribution) # or approximate with some other method
    return NormalMeanPrecision(m, p)
end

function ReactiveMP.constrain_form(::MeanPrecisionFormConstraint, distribution::BayesBase.ProductOf)
    # `ProductOf` is a special case. Read more about this type in the corresponding
    # documentation section of the `BayesBase` package.
    # ...
end

constraint = ReactiveMP.preprocess_form_constraints(MeanPrecisionFormConstraint())

constrain_form(constraint, NormalMeanVariance(0, 2))
ExponentialFamily.NormalMeanPrecision{Float64}(μ=0.0, w=0.5)

Wrapped Form Constraints

Some constraint objects might not be subtypes of AbstractFormConstraint. This can occur, for instance, if the object is defined in a different package or needs to subtype a different abstract type. In such cases, ReactiveMP expects users to pass a WrappedFormConstraint object, which wraps the original object and makes it compatible with the ReactiveMP inference backend. Note that the ReactiveMP.preprocess_form_constraints function automatically wraps all objects that are not subtypes of AbstractFormConstraint.

Additionally, objects wrapped by WrappedFormConstraints may implement the ReactiveMP.prepare_context function. This function's output will be stored in the WrappedFormConstraints along with the original object. If prepare_context is implemented, the constrain_form function will take three arguments: the original constraint, the context, and the object that needs to be constrained.

ReactiveMP.WrappedFormConstraintType
WrappedFormConstraint(constraint, context)

This is a wrapper for a form constraint object. It allows to pass additional context to the constrain_form function. By default all objects that are not sub-typed from AbstractFormConstraint are wrapped into this object. Use ReactiveMP.prepare_context to provide an extra context for a given form constraint, that can be reused between multiple constrain_form calls.

source
ReactiveMP.prepare_contextFunction
prepare_context(constraint)

This function prepares a context for a given form constraint. Returns WrappedFormConstraintNoContext if no context is needed (the default behaviour).

source
ReactiveMP.constrain_formMethod
constrain_form(wrapped::WrappedFormConstraint, something)

This function unwraps the wrapped object and calls constrain_form function with the provided context. If the context is not provided, simply calls constrain_form with the wrapped constraint. Otherwise passes the context to the constrain_form function as the second argument.

source
using ReactiveMP, Distributions, BayesBase, Random

# First, we define our custom form constraint that creates a set of samples
# Note that this is not a subtype of `AbstractFormConstraint`
struct MyCustomSampleListFormConstraint end

# Note that we still need to implement `default_form_check_strategy` and `default_prod_constraint` functions
#  which are necessary for the `ReactiveMP` inference backend
ReactiveMP.default_form_check_strategy(::MyCustomSampleListFormConstraint) = FormConstraintCheckLast()
ReactiveMP.default_prod_constraint(::MyCustomSampleListFormConstraint) = GenericProd()

# We implement the `prepare_context` function, which returns a random number generator
function ReactiveMP.prepare_context(constraint::MyCustomSampleListFormConstraint)
    return Random.default_rng()
end

# We implement the `constrain_form` function, which returns a set of samples
function ReactiveMP.constrain_form(constraint::MyCustomSampleListFormConstraint, context, distribution)
    return rand(context, distribution, 10)
end

constraint = ReactiveMP.preprocess_form_constraints(MyCustomSampleListFormConstraint())

constrain_form(constraint, Normal(0, 10))
10-element Vector{Float64}:
   6.592059788438393
  -7.410036417952506
  -9.658366929105775
  -6.0881098203348465
  16.817399501524616
 -11.790660552816885
  -3.013034381799034
 -16.267075903943592
  -6.954432992702639
  11.324257601586446