Inference execution
The RxInfer
inference API supports different types of message-passing algorithms (including hybrid algorithms combining several different types). While RxInfer
implements several algorithms to cater to different computational needs and scenarios, the core message-passing algorithms that form the foundation of our inference capabilities are:
Whereas belief propagation computes exact inference for the random variables of interest, the variational message passing (VMP) is an approximation method that can be applied to a larger range of models.
The inference engine itself isn't aware of different algorithm types and simply does message passing between nodes. However, during the model specification stage user may specify different factorisation constraints around factor nodes with the help of the @constraints
macro. Different factorisation constraints lead to different message passing update rules. See more documentation about constraints specification in the corresponding section.
Automatic inference specification
RxInfer
exports the infer
function to quickly run and test your model with both static and asynchronous (real-time) datasets. See more information about the infer
function on the separate documentation section:
RxInfer.infer
— Functioninfer(
model;
data = nothing,
datastream = nothing,
autoupdates = nothing,
initialization = nothing,
constraints = nothing,
meta = nothing,
options = nothing,
returnvars = nothing,
predictvars = nothing,
historyvars = nothing,
keephistory = nothing,
iterations = nothing,
free_energy = false,
free_energy_diagnostics = DefaultObjectiveDiagnosticChecks,
showprogress = false,
callbacks = nothing,
addons = nothing,
postprocess = DefaultPostprocess(),
warn = true,
events = nothing,
uselock = false,
autostart = true,
catch_exception = false
)
This function provides a generic way to perform probabilistic inference for batch/static and streamline/online scenarios. Returns either an InferenceResult
(batch setting) or RxInferenceEngine
(streamline setting) based on the parameters used.
Arguments
Check the official documentation for more information about some of the arguments.
model
: specifies a model generator, requireddata
:NamedTuple
orDict
with data, required (ordatastream
orpredictvars
)datastream
: A stream ofNamedTuple
with data, required (ordata
)autoupdates = nothing
: auto-updates specification, required for streamline inference, see@autoupdates
initialization = nothing
: initialization specification object, optional, see@initialization
constraints = nothing
: constraints specification object, optional, see@constraints
meta = nothing
: meta specification object, optional, may be required for some models, see@meta
options = nothing
: model creation options, optional, seeReactiveMPInferenceOptions
returnvars = nothing
: return structure info, optional, defaults to return everything at each iterationpredictvars = nothing
: return structure info, optional (exclusive for batch inference)historyvars = nothing
: history structure info, optional, defaults to no history (exclusive for streamline inference)keephistory = nothing
: history buffer size, defaults to empty buffer (exclusive for streamline inference)iterations = nothing
: number of iterations, optional, defaults tonothing
, the inference engine does not distinguish between variational message passing or Loopy belief propagation or expectation propagation iterationsfree_energy = false
: compute the Bethe free energy, optional, defaults to false. Can be passed a floating point type, e.g.Float64
, for better efficiency, but disables automatic differentiation packages, such as ForwardDiff.jlfree_energy_diagnostics = DefaultObjectiveDiagnosticChecks
: free energy diagnostic checks, optional, by default checks for possibleNaN
s andInf
s.nothing
disables all checks.showprogress = false
: show progress module, optional, defaults to false (exclusive for batch inference)catch_exception
specifies whether exceptions during the inference procedure should be caught, optional, defaults to false (exclusive for batch inference)callbacks = nothing
: inference cycle callbacks, optionaladdons = nothing
: inject and send extra computation information along messagespostprocess = DefaultPostprocess()
: inference results postprocessing step, optionalevents = nothing
: inference cycle events, optional (exclusive for streamline inference)uselock = false
: specifies either to use the lock structure for the inference or not, if set to true usesBase.Threads.SpinLock
. Accepts customAbstractLock
. (exclusive for streamline inference)autostart = true
: specifies whether to callRxInfer.start
on the created engine automatically or not (exclusive for streamline inference)warn = true
: enables/disables warnings
Note on NamedTuples
When passing NamedTuple
as a value for some argument, make sure you use a trailing comma for NamedTuple
s with a single entry. The reason is that Julia treats returnvars = (x = KeepLast())
and returnvars = (x = KeepLast(), )
expressions differently. This first expression creates (or overwrites!) new local/global variable named x
with contents KeepLast()
. The second expression (note trailing comma) creates NamedTuple
with x
as a key and KeepLast()
as a value assigned for this key.
(x = KeepLast()) # defines a variable `x` with the value `KeepLast()`
KeepLast()
(x = KeepLast(), ) # defines a NamedTuple with `x` as one of the keys and value `KeepLast()`
(x = KeepLast(),)
model
Also read the Model Specification section.
The model
argument accepts a model specification as its input. The easiest way to create the model is to use the @model
macro. For example:
@model function beta_bernoulli(y, a, b)
x ~ Beta(a, b)
y .~ Bernoulli(x)
end
result = infer(
model = beta_bernoulli(a = 1, b = 1),
data = (y = [ true, false, false ], )
)
result.posteriors[:x]
Beta{Float64}(α=2.0, β=3.0)
The model
keyword argument does not accept a ProbabilisticModel
instance as a value, as it needs to inject constraints
and meta
during the inference procedure.
data
Either data
or datastream
keyword argument are required. Specifying both data
and datastream
is not supported and will result in an error.
The behavior of the data
keyword argument depends on the inference setting (batch or streamline).
The data
keyword argument must be a NamedTuple
(or Dict
) where keys (of Symbol
type) correspond to some arguments defined in the model specification. For example, if a model defines y
in its argument list
@model function beta_bernoulli(y, a, b)
x ~ Beta(a, b)
y .~ Bernoulli(x)
end
and you want to condition on this argument, then the data
field must have an :y
key (of Symbol
type) which holds the data. The values in the data
must have the exact same shape as its corresponding variable container. E.g. in the exampl above y
is being used in the broadcasting operation, thus it must be a collection of values. a
and b
arguments, however, could be just single numbers:
result = infer(
model = beta_bernoulli(),
data = (y = [ true, false, false ], a = 1, b = 1)
)
result.posteriors[:x]
Beta{Float64}(α=2.0, β=3.0)
datastream
Also read the Streamlined Inference section.
The datastream
keyword argument must be an observable that supports subscribe!
and unsubscribe!
functions (e.g., streams from the Rocket.jl
package). The elements of the observable must be of type NamedTuple
where keys (of Symbol
type) correspond to input arguments defined in the model specification, except for those which are listed in the @autoupdates
specification. For example, if a model defines y
as its argument (which is not part of the @autoupdates
specification) the named tuple from the observable must have an :y
key (of Symbol
type). The values in the named tuple must have the exact same shape as the corresponding variable container.
initialization
Also read the Initialization section.
For specific types of inference algorithms, such as variational message passing, it might be required to initialize (some of) the marginals before running the inference procedure in order to break the dependency loop. If this is not done, the inference algorithm will not be executed due to the lack of information and message and/or marginals will not be updated. In order to specify these initial marginals and messages, you can use the initialization
argument in combination with the @initialization
macro, such as
init = @initialization begin
# initialize the marginal distribution of x as a vague Normal distribution
# if x is a vector, then it simply uses the same value for all elements
# However, it is also possible to provide a vector of distributions to set each element individually
q(x) = vague(NormalMeanPrecision)
end
Initial state:
q(x) = NormalMeanPrecision{Float64}(μ=0.0, w=1.0e-12)
returnvars
returnvars
specifies latent variables of interest and their posterior updates. Its behavior depends on the inference type: streamline or batch.
Batch inference:
- Accepts a
NamedTuple
orDict
of return variable specifications. - Two specifications available:
KeepLast
(saves the last update) andKeepEach
(saves all updates). - When
iterations
is set, returns every update for each iteration (equivalent toKeepEach()
); ifnothing
, saves the last update (equivalent toKeepLast()
). - Use
iterations = 1
to forceKeepEach()
for a single iteration or setreturnvars = KeepEach()
manually.
result = infer(
...,
returnvars = (
x = KeepLast(),
τ = KeepEach()
)
)
Shortcut for setting the same option for all variables:
result = infer(
...,
returnvars = KeepLast() # or KeepEach()
)
Streamline inference:
- For each symbol in
returnvars
,infer
creates an observable stream of posterior updates. - Agents can subscribe to these updates using the
Rocket.jl
package.
engine = infer(
...,
autoupdates = my_autoupdates,
returnvars = (:x, :τ),
autostart = false
)
RxInfer.KeepLast
— TypeInstructs the inference engine to keep only the last marginal update and disregard intermediate updates.
RxInfer.KeepEach
— TypeInstructs the inference engine to keep each marginal update for all intermediate iterations.
predictvars
predictvars
specifies the variables which should be predicted. Similar to returnvars
, predictvars
accepts a NamedTuple
or Dict
. There are two specifications:
KeepLast
: saves the last update for a variable, ignoring any intermediate results during iterationsKeepEach
: saves all updates for a variable for all iterations
result = infer(
...,
predictvars = (
o = KeepLast(),
τ = KeepEach()
)
)
The predictvars
argument is exclusive for batch setting.
historyvars
Also read the Keeping the history of posteriors.
historyvars
specifies the variables of interests and the amount of information to keep in history about the posterior updates when performing streamline inference. The specification is similar to the returnvars
when applied in batch setting. The historyvars
requires keephistory
to be greater than zero.
historyvars
accepts a NamedTuple
or Dict
or return var specification. There are two specifications:
KeepLast
: saves the last update for a variable, ignoring any intermediate results during iterationsKeepEach
: saves all updates for a variable for all iterations
result = infer(
...,
autoupdates = my_autoupdates,
historyvars = (
x = KeepLast(),
τ = KeepEach()
),
keephistory = 10
)
It is also possible to set either historyvars = KeepLast()
or historyvars = KeepEach()
that acts as an alias and sets the given option for all random variables in the model.
result = infer(
...,
autoupdates = my_autoupdates,
historyvars = KeepLast(),
keephistory = 10
)
keep_history
Specifies the buffer size for the updates history both for the historyvars
and the free_energy
buffers in streamline inference.
The historyvars
and keep_history
arguments are exclusive for streamlined setting.
iterations
Specifies the number of variational (or loopy belief propagation) iterations. By default set to nothing
, which is equivalent of doing 1 iteration. However, if set explicitly to 1
the default setting for returnvars
changes from KeepLast
to KeepEach
.
free_energy
Batch inference:
Specifies if the infer
function should return Bethe Free Energy (BFE) values.
- Optionally accepts a floating-point type (e.g.,
Float64
) for improved BFE computation performance, but restricts the use of automatic differentiation packages.
Streamline inference:
Specifies if the infer
function should create an observable stream of Bethe Free Energy (BFE) values, computed at each VMP iteration.
- When
free_energy = true
andkeephistory > 0
, additional fields are exposed in the engine for accessing the history of BFE updates.engine.free_energy_history
: Averaged BFE history over VMP iterations.engine.free_energy_final_only_history
: BFE history of values computed in the last VMP iterations for each observation.engine.free_energy_raw_history
: Raw BFE history.
free_energy_diagnostics
This settings specifies either a single or a tuple of diagnostic checks for Bethe Free Energy values stream. By default checks for NaN
s and Inf
s. See also RxInfer.ObjectiveDiagnosticCheckNaNs
and RxInfer.ObjectiveDiagnosticCheckInfs
. Pass nothing
to disable any checks.
options
RxInfer.ReactiveMPInferenceOptions
— TypeReactiveMPInferenceOptions(; kwargs...)
Creates model inference options object. The list of available options is present below.
Options
limit_stack_depth
: limits the stack depth for computing messages, helps withStackOverflowError
for some huge models, but reduces the performance of inference backend. Accepts integer as an argument that specifies the maximum number of recursive depth. Lower is better for stack overflow error, but worse for performance.warn
: (optional) flag to suppress warnings. Warnings are not displayed if set tofalse
. Defaults totrue
.
Advanced options
scheduler
: changes the scheduler of reactive streams, see Rocket.jl for more info, defaults toAsapScheduler
.rulefallback
: specifies a global message update rule fallback for cases when a specific message update rule is not available. ConsultReactiveMP
documentation for the list of available callbacks.
See also: infer
catch_exception
The catch_exception
keyword argument specifies whether exceptions during the batch inference procedure should be caught in the error
field of the result. By default, if exception occurs during the inference procedure the result will be lost. Set catch_exception = true
to obtain partial result for the inference in case if an exception occurs. Use RxInfer.issuccess
and RxInfer.iserror
function to check if the inference completed successfully or failed. If an error occurs, the error
field will store a tuple, where first element is the exception itself and the second element is the caught backtrace
. Use the stacktrace
function with the backtrace
as an argument to recover the stacktrace of the error. Use Base.showerror
function to display the error.
RxInfer.issuccess
— FunctionChecks if the InferenceResult
object does not contain an error.
RxInfer.iserror
— FunctionChecks if the InferenceResult
object contains an error.
callbacks
The inference function has its own lifecycle. The user is free to provide some (or none) of the callbacks to inject some extra logging or other procedures in the inference function, e.g.
result = infer(
...,
callbacks = (
on_marginal_update = (model, name, update) -> println("\$(name) has been updated: \$(update)"),
after_inference = (args...) -> println("Inference has been completed")
)
)
The callbacks
keyword argument accepts a named-tuple of 'name = callback' pairs. The list of all possible callbacks for different inference setting (batch or streamline) and their arguments is present below:
before_model_creation()
after_model_creation(model::ProbabilisticModel)
Exlusive for batch inference
on_marginal_update(model::ProbabilisticModel, name::Symbol, update)
before_inference(model::ProbabilisticModel)
before_iteration(model::ProbabilisticModel, iteration::Int)::Bool
before_data_update(model::ProbabilisticModel, data)
after_data_update(model::ProbabilisticModel, data)
after_iteration(model::ProbabilisticModel, iteration::Int)::Bool
after_inference(model::ProbabilisticModel)
before_iteration
and after_iteration
callbacks are allowed to return true/false
value. true
indicates that iterations must be halted and no further inference should be made.
Exlusive for streamline inference
before_autostart(engine::RxInferenceEngine)
after_autostart(engine::RxInferenceEngine)
addons
The addons
field extends the default message computation rules with some extra information, e.g. computing log-scaling factors of messages or saving debug-information. Accepts a single addon or a tuple of addons. Automatically changes the default value of the postprocess
argument to NoopPostprocess
.
postprocess
Also read the Inference results postprocessing section.
The postprocess
keyword argument controls whether the inference results must be modified in some way before exiting the inference
function. By default, the inference function uses the DefaultPostprocess
strategy, which by default removes the Marginal
wrapper type from the results. Change this setting to NoopPostprocess
if you would like to keep the Marginal
wrapper type, which might be useful in the combination with the addons
argument. If the addons
argument has been used, automatically changes the default strategy value to NoopPostprocess
.
Where to go next?
Read more explanation about the other keyword arguments in the Streamlined (online) inferencesection or check out the Static Inference section or check some more advanced examples.