# Model Specification

`RxInfer`

largely depends on `GraphPPL`

for model specification. Read extensive documentation regarding the model specification in the corresponding section of `GraphPPL`

documentation. Here we outline only a small portion of model specification capabilities for beginners.

`@model`

macro

The `RxInfer.jl`

package exports the `@model`

macro for model specification. This `@model`

macro accepts the model specification itself in a form of regular Julia function.

For example:

```
@model function model_name(model_arguments...)
# model specification here
end
```

where `model_arguments...`

may include both hypeparameters and data.

`model_arguments`

are converted to keyword arguments. Positional arguments in the model specification are not supported. Thus it is not possible to use Julia's multiple dispatch for the model arguments.

The `@model`

macro returns a regular Julia function (in this example `model_name()`

) which can be executed as usual. The only difference here is that all arguments of the model function are treated as keyword arguments. Upon calling, the model function returns a so-called model generator object, e.g:

```
@model function my_model(observation, hyperparameter)
observations ~ Normal(mean = 0.0, var = hyperparameter)
end
```

`model = my_model(hyperparameter = 3)`

The model generator is not a real model (yet). For example, in the code above, we haven't specified anything for the `observation`

. The generator object allows us to iteratively add extra properties to the model, condition on data, and/or assign extra metadata information without actually materializing the entire graph structure. Read extra information about model generator here.

## A state space model example

Here we give an example of a probabilistic model before presenting the details of the model specification syntax. The model below is a simple state space model with latent random variables `x`

and noisy observations `y`

.

```
@model function state_space_model(y, trend, variance)
x[1] ~ Normal(mean = 0.0, variance = 100.0)
y[1] ~ Normal(mean = x[1], variance = variance)
for i in 2:length(y)
x[i] ~ Normal(mean = x[i - 1] + trend, variance = 1.0)
y[i] ~ Normal(mean = x[i], variance = variance)
end
end
```

In this model we assign a prior distribution over latent state `x[1]`

. All subsequent states `x[i]`

depend on `x[i - 1]`

and `trend`

and are modelled as a simple Gaussian random walk. Observations `y`

are modelled with the `Gaussian`

distribution as well with a prespecified `variance`

hyperparameter.

`length(y)`

can be called only if `y`

has an associated data with it. This is not always the case, for example it is possible to instantiate the model lazily before the data becomes available. In such situations, `length(y)`

will throw an error.

### Hyperparameters

Any constant passed to a model as a model argument will be automatically converted to a corresponding constant node in the model's graph.

`model = state_space_model(trend = 3.0, variance = 1.0)`

In this example we instantiate a model generator with `trend`

and `variance`

parameters *clamped* to `3.0`

and `1.0`

respectively. That means that no inference will be performed for those parameters and some of the expressions within the model structure might be simplified and compiled-out.

### Conditioning on data

To fully complete model specification we need to specify `y`

. In this example, `y`

is playing a role of observations. `RxInfer`

provides a convenient mechanism to pass data values to the model with the `|`

operator.

`conditioned = model | (y = [ 0.0, 1.0, 2.0 ], )`

```
state_space_model(trend = 3.0, variance = 1.0) conditioned on:
y = [0.0, 1.0, 2.0]
```

The conditioning on data is a feature of `RxInfer`

, not `GraphPPL`

.

In the example above we conditioned on data in a form of the `NamedTuple`

, but it is also possible to condition on a dictionary where keys represent names of the corresponding model arguments:

```
data = Dict(:y => [ 0.0, 1.0, 2.0 ])
conditioned = model | data
```

```
state_space_model(trend = 3.0, variance = 1.0) conditioned on:
y = [0.0, 1.0, 2.0]
```

Sometimes it might be useful to indicate that some arguments are data (thus condition on them) before the actual data becomes available. This situation may occur during reactive inference, when data becomes available *after* model creation. `RxInfer`

provides a special structure called `RxInfer.DeferredDataHandler`

, which can be used instead of the real data.

For the example above, however, we cannot simply do the following:

`model | (y = RxInfer.DeferredDataHandler(), )`

because we use `length(y)`

in the model and this is only possible if `y`

has an associated data. We could adjust the model specification a bit, by adding the extra `n`

parameter to the list of arguments:

```
@model function state_space_model_with_n(y, n, trend, variance)
x[1] ~ Normal(mean = 0.0, variance = 100.0)
y[1] ~ Normal(mean = x[1], variance = variance)
for i in 2:n
x[i] ~ Normal(mean = x[i - 1] + trend, variance = 1.0)
y[i] ~ Normal(mean = x[i], variance = variance)
end
end
```

For such model, we can safely condition on `y`

without providing actual data for it, but using the `RxInfer.DeferredDataHandler`

instead:

```
state_space_model_with_n(trend = 3.0, variance = 1.0, n = 10) | (
y = RxInfer.DeferredDataHandler(),
)
```

```
state_space_model_with_n(trend = 3.0, variance = 1.0, n = 10) conditioned on:
y = [ deffered data ]
```

Read more information about condition on data in this section of the documentation.

### Latent variables

Latent variables are being created with the `~`

operator and can be read as *is distributed as*. For example, to create a latent variable `y`

which is modeled by a Normal distribution, where its mean and variance are controlled by the random variables `m`

and `v`

respectively, we define

`y ~ Normal(mean = m, variance = v)`

In the example above

`x[1] ~ Normal(mean = 0.0, variance = 100.0)`

indicates that `x₁`

is distributed as Normal distribution.

The `RxInfer.jl`

package uses the `~`

operator for modelling both stochastic and deterministic relationships between random variables. However, `GraphPPL.jl`

also allows to use `:=`

operator for deterministic relationships.

## Relationships between variables

In probabilistic models based on graphs, factor nodes are used to define a relationship between random variables and/or constants and data variables. A factor node defines a probability distribution over selected latent or data variables. The `~`

operator not only creates a latent variable but also defines a functional relatinship of it with other variables and creates a factor node as a result.

In the example above

`x[1] ~ Normal(mean = 0.0, variance = 100.0)`

not only creates a latent variable `x₁`

but also a factor node `Normal`

.

Generally it is not necessary to label all the arguments with their names, as `mean = ...`

or `variance = ...`

and many factor nodes do not require it explicitly. However, for nodes, which have many different useful parametrizations (e.g. `Normal`

) labeling the arguments is a requirement that helps to avoid any possible confusion. Read more about `Distributions`

compatibility here.

### Deterministic relationships

In contrast to other probabilistic programming languages in Julia, `RxInfer`

does not allow use of `=`

operator for creating deterministic relationships between (latent)variables. Instead, we can use `:=`

operator for this purpose. For example:

```
t ~ Normal(mean = 0.0, variance = 1.0)
x := exp(t) # x is linked deterministically to t
y ~ Normal(mean = x, variance = 1.0)
```

Using `x = exp(t)`

directly would be incorrect and most likely would result in an `MethodError`

because `t`

does not have a definitive value at the model creation time (remember that our models create a factor graph under the hood and latent states do not have a value until the inference is performed). At the model creation time, `t`

holds a reference to a node in the graph, instead of an actual value sample from the `Normal`

distribution.

### Control flow statements

In general, it is possible to use any Julia code within model specification function, including control flow statements, such as `for`

, `while`

and `if`

statements. However, it is not possible to use any latent states within such statements. This is due to the fact that it is necessary to know exactly the structure of the graph before the inference. Thus it is **not possible** to write statements like:

```
c ~ Categorical([ 1/2, 1/2 ])
# This is NOT possible in `RxInfer`'s model specification language
if c > 1
# ...
end
```

since `c`

must be statically known upon graph creation.

### Anonymous factor nodes and latent variables

The `@model`

macro automatically resolves any inner function calls into anonymous factor nodes and latent variables. For example the following:

```
y ~ Normal(
mean = Normal(mean = 0.0, variance = 1.0),
precision = Gamma(shape = 1.0, rate = 1.0)
)
```

is equivalent to

```
tmp1 ~ Normal(mean = 0.0, variance = 1.0)
tmp2 ~ Gamma(shape = 1.0, rate = 1.0)
y ~ Normal(mean = tmp1, precision = tmp2)
```

The inference backend still performs inference for anonymous latent variables, however, there it does not provide an easy way to obtain posteriors for them. Note that the inference backend will try to optimize deterministic function calls in the case where all arguments are known in advance. For example:

`y ~ Normal(mean = 0.0, variance = inv(2.0))`

should not create an extra factor node for the `inv`

, since `inv`

is a deterministic function and all arguments are known in advance. The same situation applies in case of complex initializations involving different types, as in:

`y ~ MvNormal(mean = zeros(3), covariance = Matrix(Diagonal(ones(3))))`

In this case, the expression `Matrix(Diagonal(ones(3)))`

can (and will) be precomputed upon model creation and does not require to perform probabilistic inference.

### Indexing operations

The `ref`

expressions, such as `x[i]`

, are handled in a special way. Technically, in Julia, the `x[i]`

call is translated to a function call `getindex(x, i)`

. Thus the `@model`

macro should create a factor node for the `getindex`

function, but this won't happen in practice because this case is treated separately. This means that the model parser will not create unnecessary nodes when only simple indexing is involved. That also means that all expressions inside `x[...]`

list are left untouched during model parsing.

It is not allowed to use latent variables within square brackets in the model specification or for control flow statements such as `if`

, `for`

or `while`

.

### Broadcasting syntax

`GraphPPL`

support broadcasting for `~`

operator in the exact same way as Julia itself. A user is free to write an expression of the following form:

```
m ~ Normal(mean = 0.0, precision = 0.0001)
t ~ Gamma(shape = 1.0, rate = 1.0)
y .~ Normal(mean = m, precision = t)
```

More complex expressions are also allowed:

```
w ~ Wishart(3, diageye(2))
x[1] ~ MvNormal(mean = zeros(2), precision = diageye(2))
x[2:end] .~ A .* x[1:end-1] # <- State-space model with transition matrix A
y .~ MvNormal(mean = x, precision = w) # <- Observations with unknown precision matrix
```

Note, however, that shapes of all variables that take part in the broadcasting operation must be defined in advance. That means that it is not possible to use broadcasting with deffered data. Read more about how broadcasting machinery works in Julia in the official documentation.

`Distributions.jl`

compatibility

For some factor nodes we rely on the syntax from `Distributions.jl`

to make it easy to adopt `RxInfer.jl`

for these users. These nodes include for example the `Beta`

and `Wishart`

distributions. These nodes can be created using the `~`

syntax with the arguments as specified in `Distributions.jl`

. Unfortunately, we `RxInfer.jl`

is not yet compatible with all possible distributions to be used as factor nodes. If you feel that you would like to see another node implemented, please file an issue.

To quickly check the list of all available factor nodes that can be used in the model specification language call `?ReactiveMP.is_predefined_node`

or `Base.doc(ReactiveMP.is_predefined_node)`

.

Specifically for the Gaussian/Normal case we have custom implementations that yield a higher computational efficiency and improved stability in comparison to `Distributions.jl`

as these are optimized for sampling operations. Our aliases for these distributions therefore do not correspond to the implementations from `Distributions.jl`

. However, our model specification language is compatible with syntax from `Distributions.jl`

for normal distributions, which will be automatically converted. `RxInfer`

has its own implementation because of the following 3 reasons:

`Distributions.jl`

constructs normal distributions by saving the corresponding covariance matrices in a`PDMat`

object from`PDMats.jl`

. This construction always computes the Cholesky decompositions of the covariance matrices, which is very convenient for sampling-based procedures. However, in`RxInfer.jl`

we mostly base our computations on analytical expressions which do not always need to compute the Cholesky decomposition. In order to reduce the overhead that`Distributions.jl`

introduces, we therefore have custom implementations.- Depending on the update rules, we might favor different parameterizations of the normal distributions.
`ReactiveMP.jl`

has quite a variety in parameterizations that allow us to efficient computations where we convert between parameterizations as little as possible. - In certain situations we value stability a lot, especially when inverting matrices.
`PDMats.jl`

, and hence`Distributions.jl`

, is not capable to fulfill all needs that we have here. Therefore we use`PositiveFactorizations.jl`

to cope with the corner-cases.

## Model structure visualisation

It is also possible to visualize the model structure after conditioning on data. For that we need two extra packages installed: `Cairo`

and `GraphPlot`

. Note, that those packages are not included in the `RxInfer`

package and must be installed separately.

```
using Cairo, GraphPlot
# `Create` the actual graph of the model conditioned on the data
model = RxInfer.create_model(conditioned)
# Call `gplot` function from `GraphPlot` to visualise the structure of the graph
GraphPlot.gplot(RxInfer.getmodel(model))
```

## Node Contraction

RxInfer's model specification extension for GraphPPL supports a feature called *node contraction*. This feature allows you to *contract* (or *replace*) a submodel with a corresponding factor node. Node contraction can be useful in several scenarios:

- When running inference in a submodel is computationally expensive
- When a submodel contains many variables whose inference results are not of primary importance
- When specialized message passing update rules can be derived for variables in the Markov blanket of the submodel

Let's illustrate this concept with a simple example. We'll first create a basic submodel and then allow the inference backend to replace it with a corresponding node that has well-defined message update rules.

```
using RxInfer, Plots
@model function ShiftedNormal(data, mean, precision, shift)
shifted_mean := mean + shift
data ~ Normal(mean = shifted_mean, precision = precision)
end
@model function Model(data, precision, shift)
mean ~ Normal(mean = 15.0, var = 1.0)
data ~ ShiftedNormal(mean = mean, precision = precision, shift = shift)
end
result = infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, )
)
plot(title = "Inference results over `mean`")
plot!(0:0.1:20.0, (x) -> pdf(NormalMeanVariance(15.0, 1.0), x), label = "prior", fill = 0, fillalpha = 0.2)
plot!(0:0.1:20.0, (x) -> pdf(result.posteriors[:mean], x), label = "posterior", fill = 0, fillalpha = 0.2)
vline!([ 10.0 ], label = "data point")
```

As we can see, we can run inference on this model. We can also visualize the model's structure, as shown in the Model structure visualisation section.

```
using Cairo, GraphPlot
GraphPlot.gplot(getmodel(result.model))
```

Now, let's create an optimized version of the `ShiftedNormal`

submodel as a standalone node with its own message passing update rules.

Creating correct message passing update rules is beyond the scope of this section. For more information about custom message passing update rules, refer to the Custom Node section.

```
@node typeof(ShiftedNormal) Stochastic [ data, mean, precision, shift ]
@rule typeof(ShiftedNormal)(:mean, Marginalisation) (q_data::PointMass, q_precision::PointMass, q_shift::PointMass, ) = begin
return @call_rule NormalMeanPrecision(:μ, Marginalisation) (q_out = PointMass(mean(q_data) - mean(q_shift)), q_τ = q_precision)
end
result_with_contraction = infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, ),
allow_node_contraction = true
)
plot(title = "Inference results over `mean` with node contraction")
plot!(0:0.1:20.0, (x) -> pdf(NormalMeanVariance(15.0, 1.0), x), label = "prior", fill = 0, fillalpha = 0.2)
plot!(0:0.1:20.0, (x) -> pdf(result_with_contraction.posteriors[:mean], x), label = "posterior", fill = 0, fillalpha = 0.2)
vline!([ 10.0 ], label = "data point")
```

As you can see, the inference result is identical to the previous case. However, the structure of the model is different:

`GraphPlot.gplot(getmodel(result_with_contraction.model))`

With node contraction, we no longer have access to the variables defined inside the `ShiftedNormal`

submodel, as it has been contracted to a single factor node. It's worth noting that this feature heavily relies on existing message passing update rules for the submodel. However, it can also be combined with another useful inference technique where no explicit message passing update rules are required.

We can also verify that node contraction indeed improves the performance of the inference:

```
using BenchmarkTools
benchmark_without_contraction = @benchmark infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, )
)
benchmark_with_contraction = @benchmark infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, ),
allow_node_contraction = true
)
```

Let's examine the benchmark results:

`benchmark_without_contraction`

```
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 181.419 μs … 737.307 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 220.933 μs ┊ GC (median): 0.00%
Time (mean ± σ): 220.502 μs ± 19.573 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▁▁ ▁▄▅██▇█▆▄▂
▂▂▂▃▄▆████▅▅▄▄▃▃▂▃▃▄▅▆██████████▇▆▅▄▅▅▅▆▆▆▆▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▁▂▂ ▄
181 μs Histogram: frequency by time 270 μs <
Memory estimate: 93.66 KiB, allocs estimate: 1433.
```

`benchmark_with_contraction`

```
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 113.622 μs … 588.599 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 137.752 μs ┊ GC (median): 0.00%
Time (mean ± σ): 136.930 μs ± 14.889 μs ┊ GC (mean ± σ): 0.00% ± 0.00%
▄▅▂ ▃▄▆▇██▇▇▃▁
▂▆███▇▆▅▄▃▃▃▃▃▄▅▇▇██████████▇▅▄▃▃▃▂▃▃▃▃▃▃▃▃▃▃▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁ ▄
114 μs Histogram: frequency by time 181 μs <
Memory estimate: 68.27 KiB, allocs estimate: 977.
```

As we can see, the inference with node contraction runs faster due to the simplified model structure and optimized message update rules. This performance improvement is reflected in reduced execution time and fewer memory allocations.

### Node creation options

`GraphPPL`

allows to pass optional arguments to the node creation constructor with the `where { options... }`

options specification syntax.

Example:

`y ~ Normal(mean = y_mean, var = y_var) where { meta = ... }`

A list of the available options specific to the `ReactiveMP`

inference engine is presented below.

#### Metadata option

Is is possible to pass any extra metadata to a factor node with the `meta`

option. Metadata can be later accessed in message computation rules.

```
z ~ f(x, y) where { meta = Linearization() }
d ~ g(a, b) where { meta = Unscented() }
```

This option might be useful to change message passing rules around a specific factor node. Read more about this feature in Meta Specification section.

#### Dependencies option

A user can modify default computational pipeline of a node with the `dependencies`

options. Read more about different options in the `ReactiveMP.jl`

documentation.

```
y[k - 1] ~ Probit(x[k]) where {
# This specification indicates that in order to compute an outbound message from the `in` interface
# We need an inbound message from the same edge initialized to `NormalMeanPrecision(0.0, 1.0)`
dependencies = RequireMessageFunctionalDependencies(in = NormalMeanPrecision(0.0, 1.0))
}
```