Model construction in RxInfer

Model creation in RxInfer largely depends on GraphPPL package. RxInfer re-exports the @model macro from GraphPPL and defines extra plugins and data structures on top of the default functionality.

Note

The model creation and construction were largely refactored in GraphPPL v4. Read Migration Guide for more details.

Also read the Model Specification guide.

@model macro

RxInfer operates with so-called graphical probabilistic models, more specifically factor graphs. Working with graphs directly is, however, tedious and error-prone, especially for large models. To simplify the process, RxInfer exports the @model macro, which translates a textual description of a probabilistic model into a corresponding factor graph representation.

RxInfer.@modelMacro
@model function model_name(model_arguments...)
    # model description
end

@model macro generates a function that returns an equivalent graph-representation of the given probabilistic model description. See the documentation to GraphPPL.@model for more information.

Supported aliases in the model specification specifically for RxInfer.jl and ReactiveMP.jl

  • a || b: alias for ReactiveMP.OR(a, b) node (operator precedence between ||, &&, -> and ! is the same as in Julia).
  • a && b: alias for ReactiveMP.AND(a, b) node (operator precedence ||, &&, -> and ! is the same as in Julia).
  • a -> b: alias for ReactiveMP.IMPLY(a, b) node (operator precedence ||, &&, -> and ! is the same as in Julia).
  • ¬a and !a: alias for ReactiveMP.NOT(a) node (Unicode \neg, operator precedence ||, &&, -> and ! is the same as in Julia).
source

Note, that GraphPPL also implements @model macro, but does not export it by default. This was a deliberate choice to allow inference backends (such as RxInfer) to implement custom functionality on top of the default GraphPPL.@model macro. This is done with a custom backend for GraphPPL.@model macro. Read more about backends in the corresponding section of GraphPPL documentation.

Conditioning on data

After model creation RxInfer uses RxInfer.condition_on function to condition on data. As an alias it is also possible to use the | operator for the same purpose, but with a nicer syntax.

RxInfer.condition_onFunction
condition_on(generator::ModelGenerator; kwargs...)

A function that creates a ConditionedModelGenerator object from GraphPPL.ModelGenerator. The | operator can be used as a shorthand for this function.

julia> using RxInfer

julia> @model function beta_bernoulli(y, a, b)
           θ ~ Beta(a, b)
           y .~ Bernoulli(θ)
       end

julia> conditioned_model = beta_bernoulli(a = 1.0, b = 2.0) | (y = [ 1.0, 0.0, 1.0 ], )
beta_bernoulli(a = 1.0, b = 2.0) conditioned on: 
  y = [1.0, 0.0, 1.0]

julia> RxInfer.create_model(conditioned_model) isa RxInfer.ProbabilisticModel
true
source
RxInfer.ConditionedModelGeneratorType
ConditionedModelGenerator(generator, conditioned_on)

Accepts a model generator and data to condition on. The generator must be GraphPPL.ModelGenerator object. The conditioned_on must be named tuple or a dictionary with keys corresponding to the names of the input arguments in the model.

source

Sometimes it might be useful to condition on data, which is not available at model creation time. This might be especially useful in reactive inference setting, where data, e.g. might be available later on from some asynchronous sensor input. For this reason, RxInfer implements a special deferred data handler, that does mark model argument as data, but does not specify any particular value for this data nor its shape.

RxInfer.DeferredDataHandlerType

An object that is used to condition on unknown data. That may be necessary to create a model from a ModelGenerator object for which data is not known at the time of the model creation.

source

After the model has been conditioned it can be materialized with the RxInfer.create_model function. This function takes the RxInfer.ConditionedModelGenerator object and materializes it into a RxInfer.ProbabilisticModel.

GraphPPL.create_modelMethod
create_model(generator::ConditionedModelGenerator)

Materializes the model specification conditioned on some data into a corresponding factor graph representation. Returns ProbabilisticModel.

source

Additional GraphPPL pipeline stages

RxInfer implements several additional pipeline stages for default parsing stages in GraphPPL. A notable distinction of the RxInfer model specification language is the fact that RxInfer "folds" some mathematical expressions and adds extra brackets to ensure the correct number of arguments for factor nodes. For example an expression x ~ x1 + x2 + x3 + x4 becomes x ~ ((x1 + x2) + x3) + x4 to ensure that the + function has exactly two arguments.

RxInfer.error_datavar_constvar_randomvarFunction
warn_datavar_constvar_randomvar(expr::Expr)

An additional pipeline stage for the @model macro from GraphPPL. Notify the user that the datavar, constvar and randomvar syntax has been removed and is not be supported in the current version.

source
RxInfer.compose_simple_operators_with_bracketsFunction
compose_simple_operators_with_brackets(expr::Expr)

An additional pipeline stage for the @model macro from GraphPPL. This pipeline converts simple multi-argument operators to their corresponding bracketed expression. E.g. the expression x ~ x1 + x2 + x3 + x4 becomes x ~ ((x1 + x2) + x3) + x4). The operators to compose are + and *.

source
RxInfer.inject_tilderhs_aliasesFunction
inject_tilderhs_aliases(e::Expr)

A pipeline stage for the @model macro from GraphPPL. This pipeline applies the aliases defined in ReactiveMPNodeAliases to the expression.

source
RxInfer.ReactiveMPNodeAliasesConstant

Syntaxic sugar for ReactiveMP nodes. Replaces a || b with ReactiveMP.OR(a, b), a && b with ReactiveMP.AND(a, b), a -> b with ReactiveMP.IMPLY(a, b) and ¬a with ReactiveMP.NOT(a).

source

Getting access to an internal variable data structures

To get an access to an internal ReactiveMP data structure of a variable in RxInfer model, it is possible to return a so called label of the variable from the model macro, and access it later on as the following:

using RxInfer

@model function beta_bernoulli(y)
    θ ~ Beta(1, 1)
    y ~ Bernoulli(θ)
    return θ
end

result = infer(
    model = beta_bernoulli(),
    data  = (y = 0.0, )
)
Inference results:
  Posteriors       | available for (θ)
graph     = RxInfer.getmodel(result.model)
returnval = RxInfer.getreturnval(graph)
θ         = returnval
variable  = RxInfer.getvariable(RxInfer.getvarref(graph, θ))
ReactiveMP.israndom(variable)
true