Model Specification
RxInfer
largely depends on GraphPPL
for model specification. Read extensive documentation regarding the model specification in the corresponding section of GraphPPL
documentation. Here we outline only a small portion of model specification capabilities for beginners.
@model
macro
The RxInfer.jl
package exports the @model
macro for model specification. This @model
macro accepts the model specification itself in a form of regular Julia function.
For example:
@model function model_name(model_arguments...)
# model specification here
end
where model_arguments...
may include both hypeparameters and data.
model_arguments
are converted to keyword arguments. Positional arguments in the model specification are not supported. Thus it is not possible to use Julia's multiple dispatch for the model arguments.
The @model
macro returns a regular Julia function (in this example model_name()
) which can be executed as usual. The only difference here is that all arguments of the model function are treated as keyword arguments. Upon calling, the model function returns a so-called model generator object, e.g:
@model function my_model(observation, hyperparameter)
observations ~ Normal(mean = 0.0, var = hyperparameter)
end
model = my_model(hyperparameter = 3)
The model generator is not a real model (yet). For example, in the code above, we haven't specified anything for the observation
. The generator object allows us to iteratively add extra properties to the model, condition on data, and/or assign extra metadata information without actually materializing the entire graph structure. Read extra information about model generator here.
A state space model example
Here we give an example of a probabilistic model before presenting the details of the model specification syntax. The model below is a simple state space model with latent random variables x
and noisy observations y
.
@model function state_space_model(y, trend, variance)
x[1] ~ Normal(mean = 0.0, variance = 100.0)
y[1] ~ Normal(mean = x[1], variance = variance)
for i in 2:length(y)
x[i] ~ Normal(mean = x[i - 1] + trend, variance = 1.0)
y[i] ~ Normal(mean = x[i], variance = variance)
end
end
In this model we assign a prior distribution over latent state x[1]
. All subsequent states x[i]
depend on x[i - 1]
and trend
and are modelled as a simple Gaussian random walk. Observations y
are modelled with the Gaussian
distribution as well with a prespecified variance
hyperparameter.
length(y)
can be called only if y
has an associated data with it. This is not always the case, for example it is possible to instantiate the model lazily before the data becomes available. In such situations, length(y)
will throw an error.
Hyperparameters
Any constant passed to a model as a model argument will be automatically converted to a corresponding constant node in the model's graph.
model = state_space_model(trend = 3.0, variance = 1.0)
In this example we instantiate a model generator with trend
and variance
parameters clamped to 3.0
and 1.0
respectively. That means that no inference will be performed for those parameters and some of the expressions within the model structure might be simplified and compiled-out.
Conditioning on data
To fully complete model specification we need to specify y
. In this example, y
is playing a role of observations. RxInfer
provides a convenient mechanism to pass data values to the model with the |
operator.
conditioned = model | (y = [ 0.0, 1.0, 2.0 ], )
state_space_model(trend = 3.0, variance = 1.0) conditioned on:
y = [0.0, 1.0, 2.0]
The conditioning on data is a feature of RxInfer
, not GraphPPL
.
In the example above we conditioned on data in a form of the NamedTuple
, but it is also possible to condition on a dictionary where keys represent names of the corresponding model arguments:
data = Dict(:y => [ 0.0, 1.0, 2.0 ])
conditioned = model | data
state_space_model(trend = 3.0, variance = 1.0) conditioned on:
y = [0.0, 1.0, 2.0]
Sometimes it might be useful to indicate that some arguments are data (thus condition on them) before the actual data becomes available. This situation may occur during reactive inference, when data becomes available after model creation. RxInfer
provides a special structure called RxInfer.DeferredDataHandler
, which can be used instead of the real data.
For the example above, however, we cannot simply do the following:
model | (y = RxInfer.DeferredDataHandler(), )
because we use length(y)
in the model and this is only possible if y
has an associated data. We could adjust the model specification a bit, by adding the extra n
parameter to the list of arguments:
@model function state_space_model_with_n(y, n, trend, variance)
x[1] ~ Normal(mean = 0.0, variance = 100.0)
y[1] ~ Normal(mean = x[1], variance = variance)
for i in 2:n
x[i] ~ Normal(mean = x[i - 1] + trend, variance = 1.0)
y[i] ~ Normal(mean = x[i], variance = variance)
end
end
For such model, we can safely condition on y
without providing actual data for it, but using the RxInfer.DeferredDataHandler
instead:
state_space_model_with_n(trend = 3.0, variance = 1.0, n = 10) | (
y = RxInfer.DeferredDataHandler(),
)
state_space_model_with_n(trend = 3.0, variance = 1.0, n = 10) conditioned on:
y = [ deffered data ]
Read more information about condition on data in this section of the documentation.
Latent variables
Latent variables are being created with the ~
operator and can be read as is distributed as. For example, to create a latent variable y
which is modeled by a Normal distribution, where its mean and variance are controlled by the random variables m
and v
respectively, we define
y ~ Normal(mean = m, variance = v)
In the example above
x[1] ~ Normal(mean = 0.0, variance = 100.0)
indicates that x₁
is distributed as Normal distribution.
The RxInfer.jl
package uses the ~
operator for modelling both stochastic and deterministic relationships between random variables. However, GraphPPL.jl
also allows to use :=
operator for deterministic relationships.
Relationships between variables
In probabilistic models based on graphs, factor nodes are used to define a relationship between random variables and/or constants and data variables. A factor node defines a probability distribution over selected latent or data variables. The ~
operator not only creates a latent variable but also defines a functional relatinship of it with other variables and creates a factor node as a result.
In the example above
x[1] ~ Normal(mean = 0.0, variance = 100.0)
not only creates a latent variable x₁
but also a factor node Normal
.
Generally it is not necessary to label all the arguments with their names, as mean = ...
or variance = ...
and many factor nodes do not require it explicitly. However, for nodes, which have many different useful parametrizations (e.g. Normal
) labeling the arguments is a requirement that helps to avoid any possible confusion. Read more about Distributions
compatibility here.
Deterministic relationships
In contrast to other probabilistic programming languages in Julia, RxInfer
does not allow use of =
operator for creating deterministic relationships between (latent)variables. Instead, we can use :=
operator for this purpose. For example:
t ~ Normal(mean = 0.0, variance = 1.0)
x := exp(t) # x is linked deterministically to t
y ~ Normal(mean = x, variance = 1.0)
Using x = exp(t)
directly would be incorrect and most likely would result in an MethodError
because t
does not have a definitive value at the model creation time (remember that our models create a factor graph under the hood and latent states do not have a value until the inference is performed).
See Using =
instead of :=
for deterministic nodes for a detailed explanation of this design choice.
Control flow statements
In general, it is possible to use any Julia code within model specification function, including control flow statements, such as for
, while
and if
statements. However, it is not possible to use any latent states within such statements. This is due to the fact that it is necessary to know exactly the structure of the graph before the inference. Thus it is not possible to write statements like:
c ~ Categorical([ 1/2, 1/2 ])
# This is NOT possible in `RxInfer`'s model specification language
if c > 1
# ...
end
since c
must be statically known upon graph creation.
Anonymous factor nodes and latent variables
The @model
macro automatically resolves any inner function calls into anonymous factor nodes and latent variables. For example the following:
y ~ Normal(
mean = Normal(mean = 0.0, variance = 1.0),
precision = Gamma(shape = 1.0, rate = 1.0)
)
is equivalent to
tmp1 ~ Normal(mean = 0.0, variance = 1.0)
tmp2 ~ Gamma(shape = 1.0, rate = 1.0)
y ~ Normal(mean = tmp1, precision = tmp2)
The inference backend still performs inference for anonymous latent variables, however, there it does not provide an easy way to obtain posteriors for them. Note that the inference backend will try to optimize deterministic function calls in the case where all arguments are known in advance. For example:
y ~ Normal(mean = 0.0, variance = inv(2.0))
should not create an extra factor node for the inv
, since inv
is a deterministic function and all arguments are known in advance. The same situation applies in case of complex initializations involving different types, as in:
y ~ MvNormal(mean = zeros(3), covariance = Matrix(Diagonal(ones(3))))
In this case, the expression Matrix(Diagonal(ones(3)))
can (and will) be precomputed upon model creation and does not require to perform probabilistic inference.
Indexing operations
The ref
expressions, such as x[i]
, are handled in a special way. Technically, in Julia, the x[i]
call is translated to a function call getindex(x, i)
. Thus the @model
macro should create a factor node for the getindex
function, but this won't happen in practice because this case is treated separately. This means that the model parser will not create unnecessary nodes when only simple indexing is involved. That also means that all expressions inside x[...]
list are left untouched during model parsing.
It is not allowed to use latent variables within square brackets in the model specification or for control flow statements such as if
, for
or while
.
Broadcasting syntax
GraphPPL
support broadcasting for ~
operator in the exact same way as Julia itself. A user is free to write an expression of the following form:
m ~ Normal(mean = 0.0, precision = 0.0001)
t ~ Gamma(shape = 1.0, rate = 1.0)
y .~ Normal(mean = m, precision = t)
More complex expressions are also allowed:
w ~ Wishart(3, diageye(2))
x[1] ~ MvNormal(mean = zeros(2), precision = diageye(2))
x[2:end] .~ A .* x[1:end-1] # <- State-space model with transition matrix A
y .~ MvNormal(mean = x, precision = w) # <- Observations with unknown precision matrix
Note, however, that shapes of all variables that take part in the broadcasting operation must be defined in advance. That means that it is not possible to use broadcasting with deffered data. Read more about how broadcasting machinery works in Julia in the official documentation.
Distributions.jl
compatibility
For some factor nodes we rely on the syntax from Distributions.jl
to make it easy to adopt RxInfer.jl
for these users. These nodes include for example the Beta
and Wishart
distributions. These nodes can be created using the ~
syntax with the arguments as specified in Distributions.jl
. Unfortunately, we RxInfer.jl
is not yet compatible with all possible distributions to be used as factor nodes. If you feel that you would like to see another node implemented, please file an issue.
To quickly check the list of all available factor nodes that can be used in the model specification language call ?ReactiveMP.is_predefined_node
or Base.doc(ReactiveMP.is_predefined_node)
.
Specifically for the Gaussian/Normal case we have custom implementations that yield a higher computational efficiency and improved stability in comparison to Distributions.jl
as these are optimized for sampling operations. Our aliases for these distributions therefore do not correspond to the implementations from Distributions.jl
. However, our model specification language is compatible with syntax from Distributions.jl
for normal distributions, which will be automatically converted. RxInfer
has its own implementation because of the following 3 reasons:
Distributions.jl
constructs normal distributions by saving the corresponding covariance matrices in aPDMat
object fromPDMats.jl
. This construction always computes the Cholesky decompositions of the covariance matrices, which is very convenient for sampling-based procedures. However, inRxInfer.jl
we mostly base our computations on analytical expressions which do not always need to compute the Cholesky decomposition. In order to reduce the overhead thatDistributions.jl
introduces, we therefore have custom implementations.- Depending on the update rules, we might favor different parameterizations of the normal distributions.
ReactiveMP.jl
has quite a variety in parameterizations that allow us to efficient computations where we convert between parameterizations as little as possible. - In certain situations we value stability a lot, especially when inverting matrices.
PDMats.jl
, and henceDistributions.jl
, is not capable to fulfill all needs that we have here. Therefore we usePositiveFactorizations.jl
to cope with the corner-cases.
Model structure visualisation
Models specified using GraphPPL.jl in RxInfer.jl can be visualized in several ways to help understand their structure and relationships between variables. Let's create a simple model and visualize it.
using RxInfer
@model function coin_toss(y)
t ~ Beta(1, 1)
for i in eachindex(y)
y[i] ~ Bernoulli(t)
end
end
model_generator = coin_toss() | (y = [ true, false, true ], )
model_to_plot = RxInfer.getmodel(RxInfer.create_model(model_generator))
GraphViz.jl
It is possible to visualize the model structure after conditioning on data with the GraphViz.jl
package. Note that this package is not included in the RxInfer
package and must be installed separately.
using GraphViz
# Call `load` function from `GraphViz` to visualise the structure of the graph
GraphViz.load(model_to_plot, strategy = :simple)
Cairo
There is an alternative way to visuzalise the model structure with Cairo
and GraphPlot
Note, that those packages are also not included in the RxInfer
package and must be installed separately.
using Cairo, GraphPlot
# Call `gplot` function from `GraphPlot` to visualise the structure of the graph
GraphPlot.gplot(model_to_plot)
Node Contraction
RxInfer's model specification extension for GraphPPL supports a feature called node contraction. This feature allows you to contract (or replace) a submodel with a corresponding factor node. Node contraction can be useful in several scenarios:
- When running inference in a submodel is computationally expensive
- When a submodel contains many variables whose inference results are not of primary importance
- When specialized message passing update rules can be derived for variables in the Markov blanket of the submodel
Let's illustrate this concept with a simple example. We'll first create a basic submodel and then allow the inference backend to replace it with a corresponding node that has well-defined message update rules.
using RxInfer, Plots
@model function ShiftedNormal(data, mean, precision, shift)
shifted_mean := mean + shift
data ~ Normal(mean = shifted_mean, precision = precision)
end
@model function Model(data, precision, shift)
mean ~ Normal(mean = 15.0, var = 1.0)
data ~ ShiftedNormal(mean = mean, precision = precision, shift = shift)
end
result = infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, )
)
plot(title = "Inference results over `mean`")
plot!(0:0.1:20.0, (x) -> pdf(NormalMeanVariance(15.0, 1.0), x), label = "prior", fill = 0, fillalpha = 0.2)
plot!(0:0.1:20.0, (x) -> pdf(result.posteriors[:mean], x), label = "posterior", fill = 0, fillalpha = 0.2)
vline!([ 10.0 ], label = "data point")
As we can see, we can run inference on this model. We can also visualize the model's structure, as shown in the Model structure visualisation section.
using Cairo, GraphPlot
GraphPlot.gplot(getmodel(result.model))
Now, let's create an optimized version of the ShiftedNormal
submodel as a standalone node with its own message passing update rules.
Creating correct message passing update rules is beyond the scope of this section. For more information about custom message passing update rules, refer to the Custom Node section.
@node typeof(ShiftedNormal) Stochastic [ data, mean, precision, shift ]
@rule typeof(ShiftedNormal)(:mean, Marginalisation) (q_data::PointMass, q_precision::PointMass, q_shift::PointMass, ) = begin
return @call_rule NormalMeanPrecision(:μ, Marginalisation) (q_out = PointMass(mean(q_data) - mean(q_shift)), q_τ = q_precision)
end
result_with_contraction = infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, ),
allow_node_contraction = true
)
plot(title = "Inference results over `mean` with node contraction")
plot!(0:0.1:20.0, (x) -> pdf(NormalMeanVariance(15.0, 1.0), x), label = "prior", fill = 0, fillalpha = 0.2)
plot!(0:0.1:20.0, (x) -> pdf(result_with_contraction.posteriors[:mean], x), label = "posterior", fill = 0, fillalpha = 0.2)
vline!([ 10.0 ], label = "data point")
As you can see, the inference result is identical to the previous case. However, the structure of the model is different:
GraphPlot.gplot(getmodel(result_with_contraction.model))
With node contraction, we no longer have access to the variables defined inside the ShiftedNormal
submodel, as it has been contracted to a single factor node. It's worth noting that this feature heavily relies on existing message passing update rules for the submodel. However, it can also be combined with another useful inference technique where no explicit message passing update rules are required.
We can also verify that node contraction indeed improves the performance of the inference:
using BenchmarkTools
benchmark_without_contraction = @benchmark infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, )
)
benchmark_with_contraction = @benchmark infer(
model = Model(precision = 1.0, shift = 1.0),
data = (data = 10.0, ),
allow_node_contraction = true
)
Let's examine the benchmark results:
benchmark_without_contraction
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
Range (min … max): 177.362 μs … 18.887 ms ┊ GC (min … max): 0.00% … 97.80%
Time (median): 186.619 μs ┊ GC (median): 0.00%
Time (mean ± σ): 198.623 μs ± 334.898 μs ┊ GC (mean ± σ): 3.58% ± 2.12%
▃▅█▇▅▄▁
▁▁▂▄█████████▇▆▅▅▄▄▃▃▂▂▂▃▃▃▃▃▄▄▄▄▄▄▃▃▃▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
177 μs Histogram: frequency by time 227 μs <
Memory estimate: 93.66 KiB, allocs estimate: 1433.
benchmark_with_contraction
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
Range (min … max): 111.018 μs … 15.103 ms ┊ GC (min … max): 0.00% … 96.26%
Time (median): 117.390 μs ┊ GC (median): 0.00%
Time (mean ± σ): 126.966 μs ± 263.812 μs ┊ GC (mean ± σ): 4.02% ± 1.94%
▄█▇▄▁
▁▁▃███████▆▅▄▄▃▃▂▂▂▂▂▂▂▂▂▂▃▂▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
111 μs Histogram: frequency by time 157 μs <
Memory estimate: 68.27 KiB, allocs estimate: 977.
As we can see, the inference with node contraction runs faster due to the simplified model structure and optimized message update rules. This performance improvement is reflected in reduced execution time and fewer memory allocations.
Node creation options
GraphPPL
allows to pass optional arguments to the node creation constructor with the where { options... }
options specification syntax.
Example:
y ~ Normal(mean = y_mean, var = y_var) where { meta = ... }
A list of the available options specific to the ReactiveMP
inference engine is presented below.
Metadata option
Is is possible to pass any extra metadata to a factor node with the meta
option. Metadata can be later accessed in message computation rules.
z ~ f(x, y) where { meta = Linearization() }
d ~ g(a, b) where { meta = Unscented() }
This option might be useful to change message passing rules around a specific factor node. Read more about this feature in Meta Specification section.
Dependencies option
A user can modify default computational pipeline of a node with the dependencies
options. Read more about different options in the ReactiveMP.jl
documentation.
y[k - 1] ~ Probit(x[k]) where {
# This specification indicates that in order to compute an outbound message from the `in` interface
# We need an inbound message from the same edge initialized to `NormalMeanPrecision(0.0, 1.0)`
dependencies = RequireMessageFunctionalDependencies(in = NormalMeanPrecision(0.0, 1.0))
}