Model construction in RxInfer
Model creation in RxInfer
largely depends on GraphPPL
package. RxInfer
re-exports the @model
macro from GraphPPL
and defines extra plugins and data structures on top of the default functionality.
The model creation and construction were largely refactored in GraphPPL
v4. Read Migration Guide for more details.
Also read the Model Specification guide.
@model
macro
RxInfer
operates with so-called graphical probabilistic models, more specifically factor graphs. Working with graphs directly is, however, tedious and error-prone, especially for large models. To simplify the process, RxInfer
exports the @model
macro, which translates a textual description of a probabilistic model into a corresponding factor graph representation.
RxInfer.@model
— Macro@model function model_name(model_arguments...)
# model description
end
@model
macro generates a function that returns an equivalent graph-representation of the given probabilistic model description. See the documentation to GraphPPL.@model
for more information.
Supported aliases in the model specification specifically for RxInfer.jl and ReactiveMP.jl
a || b
: alias forReactiveMP.OR(a, b)
node (operator precedence between||
,&&
,->
and!
is the same as in Julia).a && b
: alias forReactiveMP.AND(a, b)
node (operator precedence||
,&&
,->
and!
is the same as in Julia).a -> b
: alias forReactiveMP.IMPLY(a, b)
node (operator precedence||
,&&
,->
and!
is the same as in Julia).¬a
and!a
: alias forReactiveMP.NOT(a)
node (Unicode\neg
, operator precedence||
,&&
,->
and!
is the same as in Julia).
Note, that GraphPPL
also implements @model
macro, but does not export it by default. This was a deliberate choice to allow inference backends (such as RxInfer
) to implement custom functionality on top of the default GraphPPL.@model
macro. This is done with a custom backend for GraphPPL.@model
macro. Read more about backends in the corresponding section of GraphPPL
documentation.
RxInfer.ReactiveMPGraphPPLBackend
— TypeA backend for GraphPPL that uses ReactiveMP for inference.
Conditioning on data
After model creation RxInfer
uses RxInfer.condition_on
function to condition on data. As an alias it is also possible to use the |
operator for the same purpose, but with a nicer syntax.
RxInfer.condition_on
— Functioncondition_on(generator::ModelGenerator; kwargs...)
A function that creates a ConditionedModelGenerator
object from GraphPPL.ModelGenerator
. The |
operator can be used as a shorthand for this function.
julia> using RxInfer
julia> @model function beta_bernoulli(y, a, b)
θ ~ Beta(a, b)
y .~ Bernoulli(θ)
end
julia> conditioned_model = beta_bernoulli(a = 1.0, b = 2.0) | (y = [ 1.0, 0.0, 1.0 ], )
beta_bernoulli(a = 1.0, b = 2.0) conditioned on:
y = [1.0, 0.0, 1.0]
julia> RxInfer.create_model(conditioned_model) isa RxInfer.ProbabilisticModel
true
Base.:|
— MethodAn alias for RxInfer.condition_on
.
RxInfer.ConditionedModelGenerator
— TypeConditionedModelGenerator(generator, conditioned_on)
Accepts a model generator and data to condition on. The generator
must be GraphPPL.ModelGenerator
object. The conditioned_on
must be named tuple or a dictionary with keys corresponding to the names of the input arguments in the model.
Sometimes it might be useful to condition on data, which is not available at model creation time. This might be especially useful in reactive inference setting, where data, e.g. might be available later on from some asynchronous sensor input. For this reason, RxInfer
implements a special deferred data handler, that does mark model argument as data, but does not specify any particular value for this data nor its shape.
RxInfer.DeferredDataHandler
— TypeAn object that is used to condition on unknown data. That may be necessary to create a model from a ModelGenerator
object for which data is not known at the time of the model creation.
After the model has been conditioned it can be materialized with the RxInfer.create_model
function. This function takes the RxInfer.ConditionedModelGenerator
object and materializes it into a RxInfer.ProbabilisticModel
.
GraphPPL.create_model
— Methodcreate_model(generator::ConditionedModelGenerator)
Materializes the model specification conditioned on some data into a corresponding factor graph representation. Returns ProbabilisticModel
.
RxInfer.ProbabilisticModel
— TypeA structure that holds the factor graph representation of a probabilistic model.
GraphPPL.getmodel
— MethodReturns the underlying factor graph model.
RxInfer.getreturnval
— MethodReturns the value from the return ...
operator inside the model specification.
RxInfer.getvardict
— MethodReturns the (nested) dictionary of random variables from the model specification.
RxInfer.getrandomvars
— MethodReturns the random variables from the model specification.
RxInfer.getdatavars
— MethodReturns the data variables from the model specification.
RxInfer.getconstantvars
— MethodReturns the constant variables from the model specification.
RxInfer.getfactornodes
— MethodReturns the factor nodes from the model specification.
Additional GraphPPL
pipeline stages
RxInfer
implements several additional pipeline stages for default parsing stages in GraphPPL
. A notable distinction of the RxInfer
model specification language is the fact that RxInfer
"folds" some mathematical expressions and adds extra brackets to ensure the correct number of arguments for factor nodes. For example an expression x ~ x1 + x2 + x3 + x4
becomes x ~ ((x1 + x2) + x3) + x4
to ensure that the +
function has exactly two arguments.
RxInfer.error_datavar_constvar_randomvar
— Functionwarn_datavar_constvar_randomvar(expr::Expr)
An additional pipeline stage for the @model
macro from GraphPPL
. Notify the user that the datavar
, constvar
and randomvar
syntax has been removed and is not be supported in the current version.
RxInfer.compose_simple_operators_with_brackets
— Functioncompose_simple_operators_with_brackets(expr::Expr)
An additional pipeline stage for the @model
macro from GraphPPL
. This pipeline converts simple multi-argument operators to their corresponding bracketed expression. E.g. the expression x ~ x1 + x2 + x3 + x4
becomes x ~ ((x1 + x2) + x3) + x4)
. The operators to compose are +
and *
.
RxInfer.inject_tilderhs_aliases
— Functioninject_tilderhs_aliases(e::Expr)
A pipeline stage for the @model
macro from GraphPPL
. This pipeline applies the aliases defined in ReactiveMPNodeAliases
to the expression.
RxInfer.ReactiveMPNodeAliases
— ConstantSyntaxic sugar for ReactiveMP
nodes. Replaces a || b
with ReactiveMP.OR(a, b)
, a && b
with ReactiveMP.AND(a, b)
, a -> b
with ReactiveMP.IMPLY(a, b)
and ¬a
with ReactiveMP.NOT(a)
.
Getting access to an internal variable data structures
To get an access to an internal ReactiveMP
data structure of a variable in RxInfer
model, it is possible to return a so called label of the variable from the model macro, and access it later on as the following:
using RxInfer
@model function beta_bernoulli(y)
θ ~ Beta(1, 1)
y ~ Bernoulli(θ)
return θ
end
result = infer(
model = beta_bernoulli(),
data = (y = 0.0, )
)
Inference results:
Posteriors | available for (θ)
graph = RxInfer.getmodel(result.model)
returnval = RxInfer.getreturnval(graph)
θ = returnval
variable = RxInfer.getvariable(RxInfer.getvarref(graph, θ))
ReactiveMP.israndom(variable)
true