Constraint Specification
Variational inference on factor graphs requires specifying the structure of the approximate posterior distribution q, namely how it factorizes and what functional forms individual marginals take. These choices determine the variational family over which an objective function, such as the Bethe Free Energy (BFE), is optimized, and in GraphPPL they are called constraints and are specified using the @constraints macro.
For more background on the Bethe Free Energy and its connection to message passing on factor graphs, see:
- Yedidia et al. (2005) on belief propagation and regional approximations to the variational free energy;
- Dauwels (2007) on variational message passing on Forney-style factor graphs;
- Senoz et al. (2021) on constraint manipulation and message passing on factor graphs.
There are two types of constraints:
- Factorization constraints define how the variational posterior factorizes around factor nodes (e.g. $q(x, y, z) = q(x, y)q(z)$).
- Functional form constraints specify the distributional family for a variable's posterior (e.g. $q(x) \sim \mathrm{Normal}$).
The constraints macro
The constraints macro accepts a high-level constraint specification and converts this to a structure that can be interpreted by GraphPPL models. For example, suppose we have the following toy model, that defines a Gaussian distribution over x with mean y and variance z:
using GraphPPL
using Distributions
import GraphPPL: @model
@model function toy_model(x, y, z)
x ~ Normal(y, z)
endSuppose we want to apply the following constraints over the variational posterior q:
\[q(x, y, z) = q(x, y)q(z) \\ q(x) \sim \mathrm{Normal}\]
We can write this in the constraints macro using the following code:
@constraints begin
q(x, y, z) = q(x, y)q(z)
q(x) :: Normal
endConstraints:
q(x, y, z) = q(x, y)q(z)
q(x) :: Distributions.Normal
We can reference variables in the constraints macro with their corresponding name in the model specification. Naturally, this raises the question on how we can specify constraints over variables in submodels, as these variables are not available in the scope of the model specification. To this extent, we can nest our constraints in the same way in which we have nested our models, and use the for q in submodel block to specify constraints over submodels. For example, suppose we have the following model:
@model function toy_model(x, y, z)
x ~ Normal(y, z)
y ~ Normal(0, 1)
end
@model function outer_toy_model(a, b, c)
a ~ toy_model(y = b, z = c)
endWe can specify constraints over the toy_model submodel using the following code:
@constraints begin
for q in toy_model
q(x, y, z) = q(x, y)q(z)
q(x) :: Normal
end
endConstraints:
q(toy_model) =
q(x, y, z) = q(x, y)q(z)
q(x) :: Distributions.Normal
The submodel constraint specification applies to all submodels with the same name. However, as a user you might want to specify constraints over a specific submodel. To this extent, we can use the for q in (submodel, index) syntax. This will only apply the constraints to the submodel with the corresponding index. For example, suppose we have the following model:
@model function toy_model(x, y, z)
x ~ Normal(y, z)
y ~ Normal(0, 1)
end
@model function outer_toy_model(a, b, c)
a ~ toy_model(y = b, z = c)
a ~ toy_model(y = b, z = c)
endWe can specify constraints over the first toy_model submodel using the following code:
@constraints begin
for q in (toy_model, 1)
q(x, y, z) = q(x, y)q(z)
q(x) :: Normal
end
endConstraints:
q((toy_model, 1)) =
q(x, y, z) = q(x, y)q(z)
q(x) :: Distributions.Normal
Constraints over vector variables
When a model contains vector (or array) latent variables, we can specify factorization constraints over individual elements using the begin and end indexing syntax. For example, consider a random walk model where latent states x are coupled through sequential dependencies:
@model function random_walk_model(y, n)
local x
x[1] ~ NormalMeanVariance(0.0, 1.0)
for i in 2:n
x[i] ~ Normal(x[i - 1], 1.0)
end
for i in 1:n
y[i] ~ Normal(x[i], 1.0)
end
endSince the latent states x are coupled through the random walk prior, they are not conditionally independent. To enforce a mean-field factorization over the elements of x, we can write:
@constraints begin
q(x) = q(x[begin])..q(x[end])
endConstraints:
q(x) = q(x[(begin)..(end)])
This specifies that the joint posterior over x factorizes into independent marginals for each element: $q(\mathbf{x}) = \prod_i q(x_i)$. The .. operator creates a factorization range from the first to the last element of the vector.
Alternatively, MeanField() can be used as a shorthand to factorize all variables into independent marginals:
@constraints begin
q(x) = MeanField()
endConstraints:
q(x) = MeanField()
For a full example using vector variable constraints in practice, see the Gamma Mixture example in the RxInfer documentation.
Stacked functional form constraints
In the constraints macro, we can specify multiple functional form constraints over the same variable. For example, suppose we have the following model:
@constraints begin
q(x) :: Normal :: Beta
endConstraints:
q(x) :: (Distributions.Normal, Distributions.Beta)
In this constraint the posterior over x will first be constrained to be a normal distribution, and then the result with be constrained to be a beta distribution. This might be useful to create a chain of constraints that are applied in order. The resulting constraint is a tuple of constraints.
The inference backend must support stacked constraints for this feature to work. Some combinations of stacked constraints might not be supported or theoretically sound.
Default constraints
While we can specify constraints over all instances of a submodel at a specific layer of the hierarchy, we're not guaranteed to have all instances of a submodel at a specific layer of the hierarchy. To this extent, we can specify default constraints that apply to all instances of a specific submodel. For example, we can define the following model, where we have a recursive_model instance at every layer of the hierarchy:
@model function recursive_model(n, x, y)
z ~ Gamma(1, 1)
if n > 0
y ~ Normal(recursive_model(n = n - 1, x = x), z)
else
y ~ Normal(0, z)
end
endWe can specify default constraints over the recursive_model submodel using the following code:
GraphPPL.default_constraints(::typeof(recursive_model)) = @constraints begin
q(x, y, z) = q(x)q(y)q(z)
endWhen a model of type recursive_model is now created, the default constraints will be applied to all instances of the recursive_model submodel. Note that default constraints are overwritten by constraints passed to the top-level model, if they concern the same instance of a submodel.
Prespecified constraints
GraphPPL provides a set of prespecified constraints that can be used to specify constraints over the variational posterior. These constraint sets are aliases for their corresponding equivalent constriant sets, and can be used for convenience. The following prespecified constraints are available:
GraphPPL.MeanField — Type
MeanFieldGeneric factorisation constraint used to specify a mean-field factorisation for recognition distribution q. This constraint ignores default_constraints from submodels and forces everything to be factorized.
See also: BetheFactorization
GraphPPL.BetheFactorization — Function
BetheFactorizationGeneric factorisation constraint used to specify the Bethe factorisation for recognition distribution q. An alias to UnspecifiedConstraints.
See also: MeanField
This means that we can write the following:
@constraints begin
q(x, y, z) = MeanField() # Equivalent to q(x, y, z) = q(x)q(y)q(z)
q(a, b, c) = BetheFactorization() # Equivalent to q(a, b, c) = q(a, b, c), can be used to overwrite default constraints.
endConstraints:
q(x, y, z) = MeanField()
q(a, b, c) = Constraints:
Plugin's internals
GraphPPL.@constraints — Macro
@constraints begin
q(x, y, z) = q(x, y)q(z)
q(x) :: Normal
endSpecify factorization and functional form constraints on the variational posterior for use with Bethe Free Energy-based inference.
Factorization constraints define how the variational posterior factorizes (e.g. q(x, y, z) = q(x, y)q(z)). Functional form constraints specify the distributional form of the variational posterior for a variable (e.g. q(x) :: Normal).
Use for q in submodel ... end blocks to apply constraints to variables within submodels.
See the Constraint Specification section for more details.
GraphPPL.VariationalConstraintsPlugin — Type
VariationalConstraintsPlugin(constraints)A plugin that adds a VI related properties to the factor node for the variational inference procedure.
GraphPPL.Constraints — Type
ConstraintsAn instance of Constraints represents a set of constraints to be applied to a variational posterior in a factor graph model.
GraphPPL.SpecificSubModelConstraints — Type
SpecificSubModelConstraintsA SpecificSubModelConstraints represents a set of constraints to be applied to a specific submodel. The submodel is specified by the tag field, which contains the identifier of the submodel.
See also: GraphPPL.GeneralSubModelConstraints
GraphPPL.GeneralSubModelConstraints — Type
GeneralSubModelConstraintsA GeneralSubModelConstraints represents a set of constraints to be applied to a set of submodels. The submodels are specified by the fform field, which contains the identifier of the submodel. The constraints field contains the constraints to be applied to all instances of this submodel on this level in the model hierarchy.
See also: GraphPPL.SpecificSubModelConstraints
GraphPPL.FactorizationConstraint — Type
FactorizationConstraint{V, F}A FactorizationConstraint represents a single factorization constraint in a variational posterior constraint specification. We use type parametrization to dispatch on different types of constraints, for example q(x, y) = MeanField() is treated different from q(x, y) = q(x)q(y).
The FactorizationConstraint constructor checks for obvious errors, such as duplicate variables in the constraint specification and checks if the left hand side and right hand side contain the same variables.
See also: [`GraphPPL.FactorizationConstraintEntry`](@ref)GraphPPL.FactorizationConstraintEntry — Type
FactorizationConstraintEntryA FactorizationConstraintEntry is a group of variables (represented as a Vector of IndexedVariable objects) that represents a factor group in a factorization constraint.
See also: GraphPPL.FactorizationConstraint
GraphPPL.MarginalFormConstraint — Type
A MarginalFormConstraint represents a single functional form constraint in a variational marginal constraint specification. We use type parametrization to dispatch on different types of constraints, for example q(x, y) :: MvNormal should be treated different from q(x) :: Normal.
GraphPPL.MessageFormConstraint — Type
A MessageConstraint represents a single constraint on the messages in a message passing schema. These constraints closely resemble the MarginalFormConstraint but are used to specify constraints on the messages in a message passing schema.
GraphPPL.materialize_constraints! — Function
materialize_constraints!(model::Model, node_label::NodeLabel, node_data::NodeData)Materializes the factorization constraint in node_data in model at node_label. This function converts the BitSet representation of a constraint in node_data to the tuple representation containing all interface names.
GraphPPL.factorization_split — Function
factorization_split(left, right)Creates a new FactorizationConstraintEntry that contains a SplittedRange splitting left and right. This function is used to convert two FactorizationConstraintEntrys (for example q(x[begin])..q(x[end])) into a single FactorizationConstraintEntry containing the SplittedRange.
See also: [`GraphPPL.SplittedRange`](@ref)GraphPPL.SplittedRange — Type
SplittedRange{L, R}SplittedRange represents a range of splitted variable in factorization specification language. Such variables specified to be not in the same factorization cluster.
See also: GraphPPL.CombinedRange
GraphPPL.CombinedRange — Type
CombinedRange{L, R}CombinedRange represents a range of combined variable in factorization specification language. Such variables specified to be in the same factorization cluster.
See also: GraphPPL.SplittedRange