# Inference without explicit message update rules

`RxInfer`

utilizes the `ReactiveMP.jl`

package as its inference backend. Typically, running inference with `ReactiveMP.jl`

requires users to define a factor node using the `@node`

macro and specify corresponding message update rules with the `@rule`

macro. Detailed instructions on this can be found in this section of the documentation. However, in this tutorial, we will explore an alternative approach that allows inference with default message update rule for custom factor nodes by defining only `BayesBase.logpdf`

and `BayesBase.insupport`

for a factor node, without needing explicit `@rule`

specifications.

In the context of message-passing based Bayesian inference, custom message update rules enhance precision and efficiency. These rules leverage the specific mathematical properties of the model's distributions and relationships, leading to more accurate updates and faster convergence. By incorporating domain-specific knowledge, custom rules improve the robustness and reliability of the inference process, particularly in complex models where default rules may be inadequate or inefficient.

## A simple prior-likelihood model

We start a simple model with a hidden variable `p`

and observations `y`

. Later in the tutorial we explore more advanced use-cases. In this particular case we assume that `p`

follows a prior distribution and `y`

are drawn from a likelihood distribution. The model can be defined as follows:

```
using RxInfer
@model function simple_model(y, prior, likelihood)
p ~ prior
y .~ likelihood(p)
end
```

## Node specifications

Next, we define structures for both the `prior`

and the `likelihood`

. Let's start with the prior. Assume that the `p`

parameter is best described by a `Beta`

distribution. We can define it as follows:

The `Distributions.jl`

package already provides a fully-featured implementation of `Beta`

and `Bernoulli`

distributions, including functions like `logpdf`

and support checks. The example below redefines the `Beta`

distribution structure and related functions solely for illustrative purposes. In practice, you often won't need to define these distributions yourself, as many of them has already been included in `Distributions.jl`

.

```
using Distributions, BayesBase
struct BetaDistribution{A, B} <: ContinuousUnivariateDistribution
a::A
b::B
end
# Reuse `logpdf` from `Distributions.jl` for illustrative purposes
BayesBase.logpdf(d::BetaDistribution, x) = logpdf(Beta(d.a, d.b), x)
BayesBase.insupport(d::BetaDistribution, x::Real) = 0 <= x <= 1
```

Next, we assume that `y`

is a discrete dataset of `true`

and `false`

values. The logical choice for the `likelihood`

distribution is the `Bernoulli`

distribution.

```
struct BernoulliDistribution{P} <: DiscreteUnivariateDistribution
p::P
end
# Reuse `logpdf` from `Distributions.jl` for illustrative purposes
BayesBase.logpdf(d::BernoulliDistribution, x) = logpdf(Bernoulli(d.p), x)
BayesBase.insupport(d::BernoulliDistribution, x) = x === true || x === false
```

The next step is to register these structures as valid factor nodes:

```
@node BetaDistribution Stochastic [out, a, b]
@node BernoulliDistribution Stochastic [out, p]
```

When specifying a node for our custom distributions, we must follow a specific edge ordering. The first edge is always `out`

, which represents a sample in the `logpdf`

function. All remaining edges must match the parameters of the distribution in the exact same order. For example, for the `BetaDistribution`

, the node function is defined as `(out, a, b) -> logpdf(BetaDistribution(a, b), out)`

. This ensures that the node specification and the `logpdf`

function correctly maps the distribution parameters to the sample output.

Although `Beta`

is a conjugate prior for the parameter of the `Bernoulli`

distribution, `ReactiveMP`

and `RxInfer`

are unaware of this and cannot exploit this information. To utilize conjugacy, refer to the custom node creation section of the documentation.

## Generating a synthetic dataset

Previously, we assumed that our dataset consists of discrete values: `true`

and `false`

. We can generate a synthetic dataset with these values as follows:

```
using StableRNGs, Plots
hidden_p = 1 / 3.1415 # a value between `0` and `1`
ndatapoints = 1_000 # number of observarions
dataset = rand(StableRNG(42), Bernoulli(hidden_p), ndatapoints)
bar(["true", "false"], [ count(==(true), dataset), count(==(false), dataset) ], label = "dataset")
```

## Inference with a rule fallback

Now, we can run inference with `RxInfer`

. Since explicit rules for our nodes have not defined, we can instruct the `ReactiveMP`

backend to use fallback message update rules. Refer to the `ReactiveMP`

documentation for available fallbacks. In this example, we will use the `NodeFunctionRuleFallback`

structure, which uses the `logpdf`

of the stochastic node to approximate messages.

`NodeFunctionRuleFallback`

employs a simple approximation for outbound messages, which may significantly degrade inference accuracy. Whenever possible, it is recommended to define proper message update rules.

To complete the inference setup, we must define an approximation method for posteriors using the `@constraints`

macro. We will utilize the `ExponentialFamilyProjection`

library to project an arbitrary function onto a member of the exponential family. More information on `ExponentialFamilyProjection`

can be found in the Non-conjugate Inference section and in its official documentation.

```
using ExponentialFamilyProjection
@constraints function projection_constraints()
# Use `Beta` from `Distributions.jl` as it is compatible with the `ExponentialFamilyProjection` library
q(p) :: ProjectedTo(Beta)
end
```

`projection_constraints (generic function with 1 method)`

With all components ready, we can proceed with the inference procedure:

```
result = infer(
model = simple_model(prior = BetaDistribution(1, 1), likelihood = BernoulliDistribution),
data = (y = dataset, ),
constraints = projection_constraints(),
options = (
rulefallback = NodeFunctionRuleFallback(),
)
)
```

```
Inference results:
Posteriors | available for (p)
```

For `rulefallback = NodeFunctionRuleFallback()`

to function correctly, the node must be defined as `Stochastic`

and the underlying object must be a subtype of `Distribution`

from `Distributions.jl`

.

### Result analysis

We can perform a simple analysis and compare the inferred value with the hidden value used to generate the actual dataset:

```
using Plots, StatsPlots
plot(result.posteriors[:p], label = "posterior of p", fill = 0, fillalpha = 0.2)
vline!([ hidden_p ], label = "hidden p")
```

As shown, the estimated posterior is quite close to the actual hidden value of `p`

used during the inference procedure.

## Fusing deterministic transformations with stochastic nodes

One of the limitations of the `NodeFunctionRuleFallback`

implementation is that it does not support `Deterministic`

or `Delta`

nodes. However, it is possible to combine a deterministic transformation with a stochastic node, such as `Gaussian`

. For instance, consider a dataset drawn from the `Normal`

distribution, where the `mean`

parameter has been transformed by a known function, and the true hidden variable is `h`

.

```
using ExponentialFamily, Distributions, Plots, StableRNGs
hidden_h = 2.3
hidden_t = 0.5
known_transformation(h) = exp(h)
hidden_mean = known_transformation(hidden_h)
ndatapoints = 50
dataset = rand(StableRNG(42), NormalMeanPrecision(hidden_mean, hidden_t), ndatapoints)
histogram(dataset; normalize = :pdf)
```

The model can be defined as follows:

```
using RxInfer
@model function mymodel(y, prior_h, prior_t)
h ~ prior_h
t ~ prior_t
y .~ Normal(mean = known_transformation(h), precision = t)
end
```

Inference in this model is challenging because the `known_transformation`

function is explicitly used as a factor node, requiring special approximation rules. These rules are covered in a separate section. Here, we demonstrate a different approach that modifies the model structure to run inference without needing to approximate messages around a deterministic node.

First, we define our custom transformed `Normal`

distribution:

```
using BayesBase
struct TransformedNormalDistribution{H, T} <: ContinuousUnivariateDistribution
h::H
t::T
end
# We integrate the `known_transformation` within the `logpdf` function
# This way, it won't be an explicit factor node but hidden within the `logpdf` of another node
BayesBase.logpdf(dist::TransformedNormalDistribution, x) = logpdf(NormalMeanPrecision(known_transformation(dist.h), dist.t), x)
BayesBase.insupport(dist::TransformedNormalDistribution, x) = true
@node TransformedNormalDistribution Stochastic [out, h, t]
```

Next, we tweak the model structure:

```
@model function mymodel(y, prior_h, prior_t)
h ~ prior_h
t ~ prior_t
y .~ TransformedNormalDistribution(h, t)
end
```

We use the following priors, constraints, and initialization:

```
using ExponentialFamilyProjection
prior_h = LogNormal(0, 1)
prior_t = Gamma(1, 1)
constraints = @constraints begin
q(h, t) = q(h)q(t)
q(h) :: ProjectedTo(LogNormal)
q(t) :: ProjectedTo(Gamma)
end
initialization = @initialization begin
q(t) = Gamma(1, 1)
end
```

```
Initial state:
q(t) = Gamma{Float64}(α=1.0, θ=1.0)
```

The `ProjectedTo`

macro has a `parameters`

field that allows for different hyperparameters, which may improve accuracy or convergence speed. Refer to the `ExponentialFamilyProjection`

documentation for more information.

## Inference with a rule fallback

Now we are ready to run the inference procedure:

```
result = infer(
model = mymodel(prior_h = prior_h, prior_t = prior_t),
data = (y = dataset,),
constraints = constraints,
initialization = initialization,
iterations = 50,
options = (
rulefallback = NodeFunctionRuleFallback(),
)
)
```

```
Inference results:
Posteriors | available for (h, t)
```

### Result analysis

Finally, let's plot the resulting posteriors for each VMP iteration:

```
@gif for (i, q) in enumerate(zip(result.posteriors[:h], result.posteriors[:t]))
p1 = plot(1:0.01:3, q[1], label = "q(h) iteration $i", fill = 0, fillalpha = 0.2)
p1 = vline!([hidden_h], label = "hidden h")
p2 = plot(0:0.01:1, q[2], label = "q(t) iteration $i", fill = 0, fillalpha = 0.2)
p2 = vline!([hidden_t], label = "hidden t")
plot(p1, p2)
end fps = 15
```

We can see that the inference results are able to recover the actual value of hidden `h`

that has been used to generate the synthetic dataset. In conclusion, this example demonstrates that by integrating deterministic transformations within the logpdf function of a stochastic node, we can bypass the limitations of `NodeFunctionRuleFallback`

in handling deterministic nodes.