Creating your own custom nodes
Welcome to the RxInfer
documentation on creating custom factor graph nodes. In RxInfer
, factor nodes represent functional relationships between variables, also known as factors. Together, these factors define your probabilistic model. Quite often these factors represent distributions, denoting how a certain parameter affects another. However, other factors are also possible, such as ones specifying linear or non-linear relationships. RxInfer
already supports a lot of factor nodes, however, depending on the problem that you are trying to solve, you may need to create a custom node that better fits the specific requirements of your model. This tutorial will guide you through the process of defining a custom node in RxInfer
, step by step. By the end of this tutorial, you will be able to create your own custom node and integrate it into your model.
In addition, read another section on a different way of running inference with custom stochastic nodes without explicit rule specification here.
To create a custom node in RxInfer
, 4 steps are required:
- Create your custom node in
RxInfer
using the@node
macro. - Define the corresponding message passing update rules with the
@rule
macro. These rules specify how the node processes information in the form of messages, and how it communicates the results to adjacent parts of the model. - Specify computations for marginal distributions of the relevant variables with the
@marginalrule
macro. - Implement the computation of the Free Energy in a node with the
@average_energy
macro.
Throughout this tutorial, we will create a node for the Bernoulli
distribution. The Bernoulli
distribution is a commonly used distribution in statistical modeling that is often used to model a binary outcome, such as a coin flip. By recreating this node, we will be able to demonstrate the process of creating a custom node, from notifying RxInfer
of the nodes existence to implementing the required methods. While this tutorial focuses on the Bernoulli
distribution, the principles can be applied to creating custom nodes for other distributions as well. So let's get started!
Problem statement
Jane wants to determine whether a coin is a fair coin, meaning that is equally likely to land on heads or tails. In order to determine this, she will throw the coin $K=20$ times and write down how often it lands on heads and tails. The result of this experiment is a realization of the underlying stochastic process. Jane models the outcome of the experiment $x_k\in\{0,1\}$ using the Bernoulli distribution as
\[p(x_k \mid \pi) = \mathrm{Ber}(x_k \mid \pi) = \pi^{x_k} (1-\pi)^{1-x_k},\]
where $\pi \in[0,1]$ denotes the probability that she throws heads, also known as the success probability. Jane also has a prior belief (initial guess) about the value of $\pi$ which she models using the Beta distribution as
\[p(\pi) = \mathrm{Beta}(\pi \mid 4, 8).\]
With this prior belief, the total probabilistic model that she has for this experiment is given by
\[p(x_{1:K}, \pi) = p(\pi) \prod_{k=1}^K p(x_k \mid \pi).\]
Jane is interested in determining the fairness of the coin. Therefore she aims to infer (calculate) the posterior belief of $\pi$, $p(\pi \mid x_{1:K})$, denoting how $\pi$ is distributed after we have seen the data.
Step 1: Creating the custom node
In this example we will assume that the Bernoulli
node and distribution do not yet exist. The RxInfer
already defines the node for the Bernoulli
distribution from the Distributions.jl
package.
First things first, let's import RxInfer
:
using RxInfer
In order to define a custom node using the @node
macro from ReactiveMP
, we need the following three arguments:
- The name of the node.
- Whether the node is
Deterministic
orStochastic
. - The interfaces of the node and any potential aliases.
For the name of the node we wish to use MyBernoulli
in this tutorial (Bernoulli
already exists). However, the corresponding distribution does not yet exist. Therefore we need to specify it first as
# struct for Bernoulli distribution with success probability π
struct MyBernoulli{T <: Real} <: ContinuousUnivariateDistribution
π :: T
end
# for simplicity, let's also specify the mean of the distribution
Distributions.mean(d::MyBernoulli) = d.π
You can use regular functions, e.g +
as a node type. Their Julia type, however, is written with the typeof(_)
specification, e.g. typeof(+)
For our node we are dealing with a stochastic node, because the node forms a probabilistic relationship. This means that for a given value of $\pi$, we do know the corresponding value of the output, but we do have some belief about this. Deterministic nodes include for example linear and non-linear transformation, such as +
or *
.
The interfaces specify what variables are connected to the node. The first argument is its output by convention. The ordering is important for both the model specification as the rule definition. As an example consider the NormalMeanVariance
factor node. This factor node has interfaces [out, μ, v]
and can be called in the model specification language as x ~ NormalMeanVariance(μ, v)
. It is also possible to use aliases for the interfaces, which can be specified in a tuple as you will see below.
Concluding, we can create the MyBernoulli
factor node as
@node MyBernoulli Stochastic [out, (π, aliases = [p])]
Cool! Step 1 is done, we have created a custom node.
Step 2: Defining rules for our node
In order for RxInfer
to perform probabilistic inference and compute posterior distributions, such as $p(\pi\mid x_{1:K})$, we need to tell it how to perform inference locally around our node. This localization is what makes RxInfer
achieve high performance. In our message passing-based paradigm, we need to describe how the node processes incoming information in the form of messages (or marginals). Here we will highlight two different message passing strategies: sum-product message passing and variational message passing.
Sum-product message passing update rules
In sum-product message passing we compute outgoing messages to our node as
\[\vec{\mu}(x) \propto \int \mathrm{Ber}(x\mid \pi) \vec{\mu}(\pi) \mathrm{d}x\]
\[\overleftarrow{\mu}(\pi) \propto \sum_{x \in \{0,1\}} \mathrm{Ber}(x\mid \pi) \overleftarrow{\mu}(x)\]
This integral does not always have nice tractable solutions. However, for some forms of the incoming messages, it does yield a tractable solution.
For the case of a Beta
message coming into our node, the outgoing message will be the predictive posterior of the Bernoulli
distribution with a Beta
prior. Here we obtain $\pi = \frac{\alpha}{\alpha + \beta}$, which coincides with the mean of the Beta
distribution. Hence, we can write down the first update rule using the @rule
macro as
@rule MyBernoulli(:out, Marginalisation) (m_π :: Beta,) = MyBernoulli(mean(m_π))
Here, :out
refers to the interface of the outgoing message. The second argument denotes the incoming messages (which can be typed) as a tuple. Therefore make sure that it has a trailing ,
when there is a single message coming in. m_π
is shorthand for the incoming message on interface π
. As we will see later, the structured approximation update rule for incoming message from π
will have q_π
as parameter.
The second rule is also straightforward; if π
is a PointMass
and therefore fixed, the outgoing message will be MyBernoulli(π)
:
@rule MyBernoulli(:out, Marginalisation) (m_π :: PointMass,) = MyBernoulli(mean(m_π))
Continuing with the sum-product update rules, we now have to define the update rules towards the π
interface. We can only do exact inference if the incoming message is known, which in the case of the Bernoulli
distribution, means that the out
message is a PointMass
distribution that is either 0
or 1
. The updated Beta distribution for π
will be:
\[\overleftarrow{\mu}(π) \propto \mathrm{Beta}(1 + x, 2 - x)\]
Which gives us the following update rule:
@rule MyBernoulli(:π, Marginalisation) (m_out :: PointMass,) = begin
p = mean(m_out)
return Beta(one(p) + p, 2one(p) - p)
end
Variational message passing update rules
We will now cover our second set of update rules. The sum-product messages are not always tractable and therefore we may need to resort to approximations. Here we highlight the variational approximation. In variational message passing we compute outgoing messages to our node as
\[\vec{\nu}(x) \propto \exp \int q(\pi) \ln \mathrm{Ber}(x\mid \pi) \mathrm{d}x\]
\[\overleftarrow{\nu}(\pi) \propto \exp \sum_{x \in \{0,1\}} q(x) \ln \mathrm{Ber}(x\mid \pi)\]
These messages depend on the marginals on the adjacent edges and not on the incoming messages as was the case with sum-product message passing. Update rules that operate on the marginals instead of the incoming messages are specified with the q_{interface}
argument names. With these update rules, we can often support a wider family of distributions. Below we directly give the variational update rules. Deriving them yourself will be a nice challenge.
#rules towards out
@rule MyBernoulli(:out, Marginalisation) (q_π :: PointMass,) = MyBernoulli(mean(q_π))
@rule Bernoulli(:out, Marginalisation) (q_π::Any,) = begin
rho_1 = mean(log, q_π) # E[ln(x)]
rho_2 = mean(mirrorlog, q_π) # E[log(1-x)]
m = max(rho_1, rho_2)
tmp = exp(rho_1 - m)
p = clamp(tmp / (tmp + exp(rho_2 - m)), tiny, one(m))
return Bernoulli(p)
end
#rules towards π
@rule MyBernoulli(:π, Marginalisation) (q_out :: Any,) = begin
p = mean(q_out)
return Beta(one(p) + p, 2one(p) - p)
end
Typically, the type of the variational distributions q_
does not matter in the real computations, but only their statistics, e.g mean
or var
. Thus, in this case, we may safely use ::Any
.
In the example that we will show later on, we solely use sum-product message passing. Variational message passing requires us to set the local constraints in our model, something which is out of scope of this tutorial.
Step 3: Defining joint marginals for our node
The entire probabilistic model can be scored using the Bethe free energy, which bounds the log-evidence for acyclic graphs. This Bethe free energy consists out of the sum of node-local entropies, negative node-local average energies and edge specific entropies. Formally we can denote this by
\[F[q,f] = - \sum_{a\in\mathcal{V}} \mathrm{H}[q_a(s_a)] - \sum_{a\in\mathcal{V}}\mathrm{E}_{q_a(s_a)}[\ln f_a(s_a)] + \sum_{i\in\mathcal{E}}\mathrm{H}[q_i(s_i)]\]
Here we call $q_a(s_a)$ the joint marginals around a node and $-\mathrm{E}_{q_a(s_a)}[\ln f_a(s_a)]$ we term the average energy.
In order to be able to compute the Bethe free energy, we need to first describe how to compute $q_a(s_a)$, defined in our case as
\[q(x_k, \pi) = \vec{\mu}(\pi) \overleftarrow{\mu}(x_k) \mathrm{Ber}(x_k \mid \pi)\]
To calculate the updated posterior marginal for our custom distribution, we need to return joint posterior marginals for the interfaces of our node. In our case, the posterior marginal for the observation is still the same PointMass
distribution. However, to calculate the posterior marginal over π
, we use RxInfer
's built-in prod
functionality to multiply the Beta
prior with the Beta
likelihood. This gives us the updated posterior distribution, which is also a Beta
distribution. We use PreserveTypeProd(Distribution)
parameter to ensure that we multiply the two distributions analytically. This is done as follows:
@marginalrule MyBernoulli(:out_π) (m_out::PointMass, m_π::Beta) = begin
r = mean(m_out)
p = prod(PreserveTypeProd(Distribution), Beta(one(r) + r, 2one(r) - r), m_π)
return (out = m_out, p = p)
end
In this code :out_π
describes the arguments of the joint marginal distribution. The second argument contains the incoming messages. Here we know from the model specification that we observe out
and therefore this has to be a PointMass
. Because it is a PointMass
, the joint marginal automatically factorizes as $q(x_k, \pi) = q(x_k)q(\pi)$. These are the distributions that we return in a form of the NamedTuple
. NamedTuple
is used only in cases where we know that the joint marginal factorizes further, but typically it should be a full distribution. For computing $q(\pi)$ we need to compute the product $\vec{\mu}(\pi)\overleftarrow{\mu}(\pi)$. We already know how $\overleftarrow{\mu}(\pi)$ looks like from the previous step, so we can just use the prod
function.
Step 4: Defining the average energy for our node
To complete the computation of the Bethe free energy, we also need to compute the average energy term. The average energy in our MyBernoulli
example can be computed as $-\mathrm{E}_{q(x_k, \pi)}[\ln p(x_k \mid \pi)]$, however, because we know that we observe $x_k$ and therefore $q(x_k, \pi)$ factorizes, we can instead compute $\begin{aligned} -\mathrm{E}_{q(x_k)q(\pi)}[\ln p(x_k \mid \pi)] &= -\mathrm{E}_{q(x_k)q(\pi)} [\ln (\pi^{x_k} (1-\pi)^{1 - x_k})] \\
&= -\mathrm{E}_{q(x_k)q(\pi)} [x_k \ln(\pi) + (1-x_k) \ln(1-\pi)] \\
&= -\mathrm{E}_{q(x_k)}[x_k] \mathrm{E}_{q(\pi)} [\ln(\pi)] - (1-\mathrm{E}_{q(x_k)}[x_k]) \mathrm{E}_{q(\pi)}[\ln(1-\pi)] \end{aligned}$
Which is what we implemented below. Note that mean(mirrorlog, q(x))
is equal to $\mathrm{E}_{q(x)}[1-\log{x}]$.
@average_energy Bernoulli (q_out::Any, q_π::Any) = -mean(q_out) * mean(log, q_π) - (1.0 - mean(q_out)) * mean(mirrorlog, q_π)
In the case that the interfaces do not factorize, we would get something like @average_energy MyBernoulli (q_out_π::Any,) = begin ... end
.
Using our node in a model
With all the necessary functions defined, we can proceed to test our custom node in an experiment. For this experiment, we will generate a dataset from a Bernoulli
distribution with a fixed success probability of 0.75
. Next, we will define a probabilistic model that has a Beta
prior and a MyBernoulli
likelihood. The Beta
prior will be used to model our prior belief about the probability of success. The MyBernoulli
likelihood will be used to model the generative process of the observed data. We start by generating the dataset:
using Random
rng = MersenneTwister(42)
n = 500
π_real = 0.75
distribution = Bernoulli(π_real)
dataset = float.(rand(rng, distribution, n))
Next, we define our model. Note that we use the MyBernoulli
node in the model. The model consists of a single latent variable π
, which has a Beta
prior and is the parameter of the MyBernoulli
likelihood. The MyBernoulli
node takes the value of π
as its parameter and returns a binary observation. We set the hyperparameters of the Beta
prior to be 4 and 8, respectively, which correspond to a distribution slightly biased towards higher values of π
. The model is defined as follows:
@model function coin_model_mybernoulli(y)
# We endow θ parameter of our model with some prior
π ~ Beta(4.0, 8.0)
# We assume that outcome of each coin flip is governed by the MyBernoulli distribution
for i in eachindex(y)
y[i] ~ MyBernoulli(π)
end
end
Finally, we can run inference with this model and the generated dataset:
result_mybernoulli = infer(
model = coin_model_mybernoulli(),
data = (y = dataset, ),
)
Inference results:
Posteriors | available for (π)
We have now completed our experiment and obtained the posterior marginal distribution for p through inference. To evaluate the performance of our inference, we can compare the estimated posterior to the true value. In our experiment, the true value for p is 0.75
, and we can see that the estimated posterior has a mean close to this value, which shows that our custom node was able to successfully pass messages towards the π
variable in order to learn the true value of the parameter.
using Plots
rθ = range(0, 1, length = 1000)
p = plot(title = "Inference results")
plot!(rθ, (x) -> pdf(result_mybernoulli.posteriors[:π], x), fillalpha=0.3, fillrange = 0, label="p(π|x)", c=3)
vline!([π_real], label="Real π")
As a sanity check, we can create the same model with the RxInfer
built-in node Bernoulli
and compare the resulting posterior distribution with the one obtained using our custom MyBernoulli
node. This will give us confidence that our custom node is working correctly. We use the Bernoulli
node with the same Beta
prior and the observed data, and then run inference. We can compare the two posterior distributions and observe that they are exactly the same, which indicates that our custom node is performing as expected.
@model function coin_model(y)
p ~ Beta(4.0, 8.0)
for i in eachindex(y)
y[i] ~ Bernoulli(p)
end
end
result_bernoulli = infer(
model = coin_model(),
data = (y = dataset, ),
)
if !(result_bernoulli.posteriors[:p] == result_mybernoulli.posteriors[:π])
error("Results are not identical")
else
println("Results are identical 🎉🎉🎉")
end
Results are identical 🎉🎉🎉
Congratulations! You have successfully implemented your own custom node in RxInfer
. We went through the definition of a node to the implementation of the update rules and marginal posterior calculations. Finally we tested our custom node in a model and checked if we implemented everything correctly.
Custom node experimental functionality
The functionality described below is experimental and subject to change in future releases. Use it with caution in production code.
Rules that require a reference to a node object
In some advanced scenarios, you might need access to the node object itself within a message passing rule. This can be useful when:
- You need to inspect the current state of other variables in the model
- You want to implement complex message passing schemes that depend on the global model state
- You're experimenting with custom inference algorithms that require access to the factor graph structure
Here's how to implement a rule with node access. First we define a custom node and a simple model that uses this node:
using RxInfer
struct MyExperimentalNode end
@node MyExperimentalNode Stochastic [ out, θ ]
@model function my_experimental_model(y)
θ ~ Normal(mean = 0.0, variance = 1.0)
y ~ MyExperimentalNode(θ)
end
Second, we enable instruction to the inference backend to pass node reference to the rule.
# Enable node reference passing for this node type
ReactiveMP.call_rule_is_node_required(::Type{<:MyExperimentalNode}) = ReactiveMP.CallRuleNodeRequired()
Enabling node reference passing can negatively impact performance as it requires additional bookkeeping during inference.
Setting call_rule_is_node_required
for existing nodes (like NormalMeanVariance
) affects all models globally and will affect code that depends on your package. Only safe to use this for your custom nodes.
The call_rule_is_node_required
function is used to instruct the inference backend to pass the node object to the rule. After this is set, we can use the getnode()
function to access the node object within the rule.
@rule MyExperimentalNode(:θ, Marginalisation) (q_out::Any, ) = begin
node = getnode()
# Access interface index
ii = ReactiveMP.interfaceindex(node, :θ)
# Get interface object
θi = ReactiveMP.getinterfaces(node)[ii]
# Get variable object
θv = ReactiveMP.getvariable(θi)
# By default, `germarginal` ignores marginals set in the @initialization block
# `IncludeAll` overrides this behavior and includes all marginals
qθ = Rocket.getrecent(ReactiveMP.getmarginal(θv, IncludeAll()))
# This is a simple rule that returns a NormalMeanVariance distribution
# It could be replaced with any other rule that returns a distribution
return NormalMeanVariance(mean(qθ) + mean(q_out), var(qθ))
end
Running inference with the custom node and rule
Here's a full example showing how to use this functionality:
initialization = @initialization begin
q(θ) = NormalMeanVariance(3.14, 2.71)
end
result = infer(
model = my_experimental_model(),
data = (y = 1.0, ),
initialization = initialization
)
As we can see, the print statement in the rule is executed, which means that the node reference passing is working as expected. This feature opens up possibilities for advanced inference scenarios, but should be used judiciously. Consider whether your use case truly requires access to the node object, as simpler solutions using standard message passing rules are often sufficient and more maintainable.