Model Specification

RxInfer largely depends on GraphPPL for model specification. Read extensive documentation regarding the model specification in the corresponding section of GraphPPL documentation. Here we outline only a small portion of model specification capabilities for beginners.

@model macro

The RxInfer.jl package exports the @model macro for model specification. This @model macro accepts the model specification itself in a form of regular Julia function.

For example:

@model function model_name(model_arguments...)
    # model specification here
end

where model_arguments... may include both hypeparameters and data.

Note

model_arguments are converted to keyword arguments. Positional arguments in the model specification are not supported. Thus it is not possible to use Julia's multiple dispatch for the model arguments.

The @model macro returns a regular Julia function (in this example model_name()) which can be executed as usual. The only difference here is that all arguments of the model function are treated as keyword arguments. Upon calling, the model function returns a so-called model generator object, e.g:

@model function my_model(observation, hyperparameter)
    observations ~ Normal(mean = 0.0, var = hyperparameter)
end
model = my_model(hyperparameter = 3)

The model generator is not a real model (yet). For example, in the code above, we haven't specified anything for the observation. The generator object allows us to iteratively add extra properties to the model, condition on data, and/or assign extra metadata information without actually materializing the entire graph structure. Read extra information about model generator here.

A state space model example

Here we give an example of a probabilistic model before presenting the details of the model specification syntax. The model below is a simple state space model with latent random variables x and noisy observations y.

@model function state_space_model(y, trend, variance)
    x[1] ~ Normal(mean = 0.0, variance = 100.0)
    y[1] ~ Normal(mean = x[1], variance = variance)
    for i in 2:length(y)
       x[i] ~ Normal(mean = x[i - 1] + trend, variance = 1.0)
       y[i] ~ Normal(mean = x[i], variance = variance)
    end
end

In this model we assign a prior distribution over latent state x[1]. All subsequent states x[i] depend on x[i - 1] and trend and are modelled as a simple Gaussian random walk. Observations y are modelled with the Gaussian distribution as well with a prespecified variance hyperparameter.

Note

length(y) can be called only if y has an associated data with it. This is not always the case, for example it is possible to instantiate the model lazily before the data becomes available. In such situations, length(y) will throw an error.

Hyperparameters

Any constant passed to a model as a model argument will be automatically converted to a corresponding constant node in the model's graph.

model = state_space_model(trend = 3.0, variance = 1.0)

In this example we instantiate a model generator with trend and variance parameters clamped to 3.0 and 1.0 respectively. That means that no inference will be performed for those parameters and some of the expressions within the model structure might be simplified and compiled-out.

Conditioning on data

To fully complete model specification we need to specify y. In this example, y is playing a role of observations. RxInfer provides a convenient mechanism to pass data values to the model with the | operator.

conditioned = model | (y = [ 0.0, 1.0, 2.0 ], )
state_space_model(trend = 3.0, variance = 1.0) conditioned on: 
  y = [0.0, 1.0, 2.0]
Note

The conditioning on data is a feature of RxInfer, not GraphPPL.

In the example above we conditioned on data in a form of the NamedTuple, but it is also possible to condition on a dictionary where keys represent names of the corresponding model arguments:

data        = Dict(:y => [ 0.0, 1.0, 2.0 ])
conditioned = model | data
state_space_model(trend = 3.0, variance = 1.0) conditioned on: 
  y = [0.0, 1.0, 2.0]

Sometimes it might be useful to indicate that some arguments are data (thus condition on them) before the actual data becomes available. This situation may occur during reactive inference, when data becomes available after model creation. RxInfer provides a special structure called RxInfer.DeferredDataHandler, which can be used instead of the real data.

For the example above, however, we cannot simply do the following:

model | (y = RxInfer.DeferredDataHandler(), )

because we use length(y) in the model and this is only possible if y has an associated data. We could adjust the model specification a bit, by adding the extra n parameter to the list of arguments:

@model function state_space_model_with_n(y, n, trend, variance)
    x[1] ~ Normal(mean = 0.0, variance = 100.0)
    y[1] ~ Normal(mean = x[1], variance = variance)
    for i in 2:n
       x[i] ~ Normal(mean = x[i - 1] + trend, variance = 1.0)
       y[i] ~ Normal(mean = x[i], variance = variance)
    end
end

For such model, we can safely condition on y without providing actual data for it, but using the RxInfer.DeferredDataHandler instead:

state_space_model_with_n(trend = 3.0, variance = 1.0, n = 10) | (
    y = RxInfer.DeferredDataHandler(),
)
state_space_model_with_n(trend = 3.0, variance = 1.0, n = 10) conditioned on: 
  y = [ deffered data ]

Read more information about condition on data in this section of the documentation.

Latent variables

Latent variables are being created with the ~ operator and can be read as is distributed as. For example, to create a latent variable y which is modeled by a Normal distribution, where its mean and variance are controlled by the random variables m and v respectively, we define

y ~ Normal(mean = m, variance = v)

In the example above

x[1] ~ Normal(mean = 0.0, variance = 100.0)

indicates that x₁ is distributed as Normal distribution.

Note

The RxInfer.jl package uses the ~ operator for modelling both stochastic and deterministic relationships between random variables. However, GraphPPL.jl also allows to use := operator for deterministic relationships.

Relationships between variables

In probabilistic models based on graphs, factor nodes are used to define a relationship between random variables and/or constants and data variables. A factor node defines a probability distribution over selected latent or data variables. The ~ operator not only creates a latent variable but also defines a functional relatinship of it with other variables and creates a factor node as a result.

In the example above

x[1] ~ Normal(mean = 0.0, variance = 100.0)

not only creates a latent variable x₁ but also a factor node Normal.

Note

Generally it is not necessary to label all the arguments with their names, as mean = ... or variance = ... and many factor nodes do not require it explicitly. However, for nodes, which have many different useful parametrizations (e.g. Normal) labeling the arguments is a requirement that helps to avoid any possible confusion. Read more about Distributions compatibility here.

Deterministic relationships

In contrast to other probabilistic programming languages in Julia, RxInfer does not allow use of = operator for creating deterministic relationships between (latent)variables. Instead, we can use := operator for this purpose. For example:

t ~ Normal(mean = 0.0, variance = 1.0)
x := exp(t) # x is linked deterministically to t
y ~ Normal(mean = x, variance = 1.0)

Using x = exp(t) directly would be incorrect and most likely would result in an MethodError because t does not have a definitive value at the model creation time (remember that our models create a factor graph under the hood and latent states do not have a value until the inference is performed). At the model creation time, t holds a reference to a node in the graph, instead of an actual value sample from the Normal distribution.

Control flow statements

In general, it is possible to use any Julia code within model specification function, including control flow statements, such as for, while and if statements. However, it is not possible to use any latent states within such statements. This is due to the fact that it is necessary to know exactly the structure of the graph before the inference. Thus it is not possible to write statements like:

c ~ Categorical([ 1/2, 1/2 ])
# This is NOT possible in `RxInfer`'s model specification language
if c > 1
# ...
end

since c must be statically known upon graph creation.

Anonymous factor nodes and latent variables

The @model macro automatically resolves any inner function calls into anonymous factor nodes and latent variables. For example the following:

y ~ Normal(
    mean = Normal(mean = 0.0, variance = 1.0), 
    precision = Gamma(shape = 1.0, rate = 1.0)
)

is equivalent to

tmp1 ~ Normal(mean = 0.0, variance = 1.0)
tmp2 ~ Gamma(shape = 1.0, rate = 1.0)
y    ~ Normal(mean = tmp1, precision = tmp2)

The inference backend still performs inference for anonymous latent variables, however, there it does not provide an easy way to obtain posteriors for them. Note that the inference backend will try to optimize deterministic function calls in the case where all arguments are known in advance. For example:

y ~ Normal(mean = 0.0, variance = inv(2.0))

should not create an extra factor node for the inv, since inv is a deterministic function and all arguments are known in advance. The same situation applies in case of complex initializations involving different types, as in:

y ~ MvNormal(mean = zeros(3), covariance = Matrix(Diagonal(ones(3))))

In this case, the expression Matrix(Diagonal(ones(3))) can (and will) be precomputed upon model creation and does not require to perform probabilistic inference.

Indexing operations

The ref expressions, such as x[i], are handled in a special way. Technically, in Julia, the x[i] call is translated to a function call getindex(x, i). Thus the @model macro should create a factor node for the getindex function, but this won't happen in practice because this case is treated separately. This means that the model parser will not create unnecessary nodes when only simple indexing is involved. That also means that all expressions inside x[...] list are left untouched during model parsing.

Warning

It is not allowed to use latent variables within square brackets in the model specification or for control flow statements such as if, for or while.

Broadcasting syntax

GraphPPL support broadcasting for ~ operator in the exact same way as Julia itself. A user is free to write an expression of the following form:

m  ~ Normal(mean = 0.0, precision = 0.0001)
t  ~ Gamma(shape = 1.0, rate = 1.0)
y .~ Normal(mean = m, precision = t)

More complex expressions are also allowed:

w         ~ Wishart(3, diageye(2))
x[1]      ~ MvNormal(mean = zeros(2), precision = diageye(2))
x[2:end] .~ A .* x[1:end-1] # <- State-space model with transition matrix A
y        .~ MvNormal(mean = x, precision = w) # <- Observations with unknown precision matrix

Note, however, that shapes of all variables that take part in the broadcasting operation must be defined in advance. That means that it is not possible to use broadcasting with deffered data. Read more about how broadcasting machinery works in Julia in the official documentation.

Distributions.jl compatibility

For some factor nodes we rely on the syntax from Distributions.jl to make it easy to adopt RxInfer.jl for these users. These nodes include for example the Beta and Wishart distributions. These nodes can be created using the ~ syntax with the arguments as specified in Distributions.jl. Unfortunately, we RxInfer.jl is not yet compatible with all possible distributions to be used as factor nodes. If you feel that you would like to see another node implemented, please file an issue.

Note

To quickly check the list of all available factor nodes that can be used in the model specification language call ?ReactiveMP.is_predefined_node or Base.doc(ReactiveMP.is_predefined_node).

Specifically for the Gaussian/Normal case we have custom implementations that yield a higher computational efficiency and improved stability in comparison to Distributions.jl as these are optimized for sampling operations. Our aliases for these distributions therefore do not correspond to the implementations from Distributions.jl. However, our model specification language is compatible with syntax from Distributions.jl for normal distributions, which will be automatically converted. RxInfer has its own implementation because of the following 3 reasons:

  1. Distributions.jl constructs normal distributions by saving the corresponding covariance matrices in a PDMat object from PDMats.jl. This construction always computes the Cholesky decompositions of the covariance matrices, which is very convenient for sampling-based procedures. However, in RxInfer.jl we mostly base our computations on analytical expressions which do not always need to compute the Cholesky decomposition. In order to reduce the overhead that Distributions.jl introduces, we therefore have custom implementations.
  2. Depending on the update rules, we might favor different parameterizations of the normal distributions. ReactiveMP.jl has quite a variety in parameterizations that allow us to efficient computations where we convert between parameterizations as little as possible.
  3. In certain situations we value stability a lot, especially when inverting matrices. PDMats.jl, and hence Distributions.jl, is not capable to fulfill all needs that we have here. Therefore we use PositiveFactorizations.jl to cope with the corner-cases.

Model structure visualisation

It is also possible to visualize the model structure after conditioning on data. For that we need two extra packages installed: Cairo and GraphPlot. Note, that those packages are not included in the RxInfer package and must be installed separately.

using Cairo, GraphPlot

# `Create` the actual graph of the model conditioned on the data
model = RxInfer.create_model(conditioned)

# Call `gplot` function from `GraphPlot` to visualise the structure of the graph
GraphPlot.gplot(RxInfer.getmodel(model))
Example block output

Node Contraction

RxInfer's model specification extension for GraphPPL supports a feature called node contraction. This feature allows you to contract (or replace) a submodel with a corresponding factor node. Node contraction can be useful in several scenarios:

  • When running inference in a submodel is computationally expensive
  • When a submodel contains many variables whose inference results are not of primary importance
  • When specialized message passing update rules can be derived for variables in the Markov blanket of the submodel

Let's illustrate this concept with a simple example. We'll first create a basic submodel and then allow the inference backend to replace it with a corresponding node that has well-defined message update rules.

using RxInfer, Plots

@model function ShiftedNormal(data, mean, precision, shift)
    shifted_mean := mean + shift
    data ~ Normal(mean = shifted_mean, precision = precision)
end

@model function Model(data, precision, shift)
    mean ~ Normal(mean = 15.0, var = 1.0)
    data ~ ShiftedNormal(mean = mean, precision = precision, shift = shift)
end

result = infer(
    model = Model(precision = 1.0, shift = 1.0),
    data  = (data = 10.0, )
)

plot(title = "Inference results over `mean`")
plot!(0:0.1:20.0, (x) -> pdf(NormalMeanVariance(15.0, 1.0), x), label = "prior", fill = 0, fillalpha = 0.2)
plot!(0:0.1:20.0, (x) -> pdf(result.posteriors[:mean], x), label = "posterior", fill = 0, fillalpha = 0.2)
vline!([ 10.0 ], label = "data point")
Example block output

As we can see, we can run inference on this model. We can also visualize the model's structure, as shown in the Model structure visualisation section.

using Cairo, GraphPlot

GraphPlot.gplot(getmodel(result.model))

Now, let's create an optimized version of the ShiftedNormal submodel as a standalone node with its own message passing update rules.

Note

Creating correct message passing update rules is beyond the scope of this section. For more information about custom message passing update rules, refer to the Custom Node section.

@node typeof(ShiftedNormal) Stochastic [ data, mean, precision, shift ]

@rule typeof(ShiftedNormal)(:mean, Marginalisation) (q_data::PointMass, q_precision::PointMass, q_shift::PointMass, ) = begin
    return @call_rule NormalMeanPrecision(:μ, Marginalisation) (q_out = PointMass(mean(q_data) - mean(q_shift)), q_τ = q_precision)
end

result_with_contraction = infer(
    model = Model(precision = 1.0, shift = 1.0),
    data  = (data = 10.0, ),
    allow_node_contraction = true
)

plot(title = "Inference results over `mean` with node contraction")
plot!(0:0.1:20.0, (x) -> pdf(NormalMeanVariance(15.0, 1.0), x), label = "prior", fill = 0, fillalpha = 0.2)
plot!(0:0.1:20.0, (x) -> pdf(result_with_contraction.posteriors[:mean], x), label = "posterior", fill = 0, fillalpha = 0.2)
vline!([ 10.0 ], label = "data point")
Example block output

As you can see, the inference result is identical to the previous case. However, the structure of the model is different:

GraphPlot.gplot(getmodel(result_with_contraction.model))

With node contraction, we no longer have access to the variables defined inside the ShiftedNormal submodel, as it has been contracted to a single factor node. It's worth noting that this feature heavily relies on existing message passing update rules for the submodel. However, it can also be combined with another useful inference technique where no explicit message passing update rules are required.

We can also verify that node contraction indeed improves the performance of the inference:

using BenchmarkTools

benchmark_without_contraction = @benchmark infer(
    model = Model(precision = 1.0, shift = 1.0),
    data  = (data = 10.0, )
)

benchmark_with_contraction = @benchmark infer(
    model = Model(precision = 1.0, shift = 1.0),
    data  = (data = 10.0, ),
    allow_node_contraction = true
)

Let's examine the benchmark results:

benchmark_without_contraction
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (minmax):  181.419 μs737.307 μs   GC (min … max): 0.00% … 0.00%
 Time  (median):     220.933 μs                GC (median):    0.00%
 Time  (mean ± σ):   220.502 μs ±  19.573 μs   GC (mean ± σ):  0.00% ± 0.00%

         ▁▁             ▁▄▅███▆▄▂                               
  ▂▂▂▃▄▆████▅▅▄▄▃▃▂▃▃▄▅▆█████████▇▆▅▄▅▅▅▆▆▆▆▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▁▂▂ ▄
  181 μs           Histogram: frequency by time          270 μs <

 Memory estimate: 93.66 KiB, allocs estimate: 1433.
benchmark_with_contraction
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (minmax):  113.622 μs588.599 μs   GC (min … max): 0.00% … 0.00%
 Time  (median):     137.752 μs                GC (median):    0.00%
 Time  (mean ± σ):   136.930 μs ±  14.889 μs   GC (mean ± σ):  0.00% ± 0.00%

    ▄▅▂             ▃▄▆█▇▇▃▁                                   
  ▂▆███▇▆▅▄▃▃▃▃▃▄▅▇▇████████▇▅▄▃▃▃▂▃▃▃▃▃▃▃▃▃▃▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁ ▄
  114 μs           Histogram: frequency by time          181 μs <

 Memory estimate: 68.27 KiB, allocs estimate: 977.

As we can see, the inference with node contraction runs faster due to the simplified model structure and optimized message update rules. This performance improvement is reflected in reduced execution time and fewer memory allocations.

Node creation options

GraphPPL allows to pass optional arguments to the node creation constructor with the where { options... } options specification syntax.

Example:

y ~ Normal(mean = y_mean, var = y_var) where { meta = ... }

A list of the available options specific to the ReactiveMP inference engine is presented below.

Metadata option

Is is possible to pass any extra metadata to a factor node with the meta option. Metadata can be later accessed in message computation rules.

z ~ f(x, y) where { meta = Linearization() }
d ~ g(a, b) where { meta = Unscented() }

This option might be useful to change message passing rules around a specific factor node. Read more about this feature in Meta Specification section.

Dependencies option

A user can modify default computational pipeline of a node with the dependencies options. Read more about different options in the ReactiveMP.jl documentation.

y[k - 1] ~ Probit(x[k]) where {
    # This specification indicates that in order to compute an outbound message from the `in` interface
    # We need an inbound message from the same edge initialized to `NormalMeanPrecision(0.0, 1.0)`
    dependencies = RequireMessageFunctionalDependencies(in = NormalMeanPrecision(0.0, 1.0))
}

Read also