Inference execution

The RxInfer inference API supports different types of message-passing algorithms (including hybrid algorithms combining several different types). While RxInfer implements several algorithms to cater to different computational needs and scenarios, the core message-passing algorithms that form the foundation of our inference capabilities are:

Whereas belief propagation computes exact inference for the random variables of interest, the variational message passing (VMP) is an approximation method that can be applied to a larger range of models.

The inference engine itself isn't aware of different algorithm types and simply does message passing between nodes. However, during the model specification stage user may specify different factorisation constraints around factor nodes with the help of the @constraints macro. Different factorisation constraints lead to different message passing update rules. See more documentation about constraints specification in the corresponding section.

Automatic inference specification

RxInfer exports the infer function to quickly run and test your model with both static and asynchronous (real-time) datasets. See more information about the infer function on the separate documentation section:

RxInfer.inferFunction
infer(
    model; 
    data = nothing,
    datastream = nothing,
    autoupdates = nothing,
    initialization = nothing,
    constraints = nothing,
    meta = nothing,
    options = nothing,
    returnvars = nothing, 
    predictvars = nothing, 
    historyvars = nothing,
    keephistory = nothing,
    iterations = nothing,
    free_energy = false,
    free_energy_diagnostics = DefaultObjectiveDiagnosticChecks,
    showprogress = false,
    callbacks = nothing,
    addons = nothing,
    postprocess = DefaultPostprocess(),
    warn = true,
    events = nothing,
    uselock = false,
    autostart = true,
    catch_exception = false
)

This function provides a generic way to perform probabilistic inference for batch/static and streamline/online scenarios. Returns either an InferenceResult (batch setting) or RxInferenceEngine (streamline setting) based on the parameters used.

Arguments

Check the official documentation for more information about some of the arguments.

  • model: specifies a model generator, required
  • data: NamedTuple or Dict with data, required (or datastream or predictvars)
  • datastream: A stream of NamedTuple with data, required (or data)
  • autoupdates = nothing: auto-updates specification, required for streamline inference, see @autoupdates
  • initialization = nothing: initialization specification object, optional, see @initialization
  • constraints = nothing: constraints specification object, optional, see @constraints
  • meta = nothing: meta specification object, optional, may be required for some models, see @meta
  • options = nothing: model creation options, optional, see ReactiveMPInferenceOptions
  • returnvars = nothing: return structure info, optional, defaults to return everything at each iteration
  • predictvars = nothing: return structure info, optional (exclusive for batch inference)
  • historyvars = nothing: history structure info, optional, defaults to no history (exclusive for streamline inference)
  • keephistory = nothing: history buffer size, defaults to empty buffer (exclusive for streamline inference)
  • iterations = nothing: number of iterations, optional, defaults to nothing, the inference engine does not distinguish between variational message passing or Loopy belief propagation or expectation propagation iterations
  • free_energy = false: compute the Bethe free energy, optional, defaults to false. Can be passed a floating point type, e.g. Float64, for better efficiency, but disables automatic differentiation packages, such as ForwardDiff.jl
  • free_energy_diagnostics = DefaultObjectiveDiagnosticChecks: free energy diagnostic checks, optional, by default checks for possible NaNs and Infs. nothing disables all checks.
  • showprogress = false: show progress module, optional, defaults to false (exclusive for batch inference)
  • catch_exception specifies whether exceptions during the inference procedure should be caught, optional, defaults to false (exclusive for batch inference)
  • callbacks = nothing: inference cycle callbacks, optional
  • addons = nothing: inject and send extra computation information along messages
  • postprocess = DefaultPostprocess(): inference results postprocessing step, optional
  • events = nothing: inference cycle events, optional (exclusive for streamline inference)
  • uselock = false: specifies either to use the lock structure for the inference or not, if set to true uses Base.Threads.SpinLock. Accepts custom AbstractLock. (exclusive for streamline inference)
  • autostart = true: specifies whether to call RxInfer.start on the created engine automatically or not (exclusive for streamline inference)
  • warn = true: enables/disables warnings
source

Note on NamedTuples

When passing NamedTuple as a value for some argument, make sure you use a trailing comma for NamedTuples with a single entry. The reason is that Julia treats returnvars = (x = KeepLast()) and returnvars = (x = KeepLast(), ) expressions differently. This first expression creates (or overwrites!) new local/global variable named x with contents KeepLast(). The second expression (note trailing comma) creates NamedTuple with x as a key and KeepLast() as a value assigned for this key.

(x = KeepLast()) # defines a variable `x` with the value `KeepLast()`
KeepLast()
(x = KeepLast(), ) # defines a NamedTuple with `x` as one of the keys and value `KeepLast()`
(x = KeepLast(),)
  • model

Also read the Model Specification section.

The model argument accepts a model specification as its input. The easiest way to create the model is to use the @model macro. For example:

@model function beta_bernoulli(y, a, b)
    x  ~ Beta(a, b)
    y .~ Bernoulli(x)
end

result = infer(
    model = beta_bernoulli(a = 1, b = 1),
    data  = (y = [ true, false, false ], )
)

result.posteriors[:x]
Beta{Float64}(α=2.0, β=3.0)
Note

The model keyword argument does not accept a ProbabilisticModel instance as a value, as it needs to inject constraints and meta during the inference procedure.

  • data

Either data or datastream keyword argument are required. Specifying both data and datastream is not supported and will result in an error.

Note

The behavior of the data keyword argument depends on the inference setting (batch or streamline).

The data keyword argument must be a NamedTuple (or Dict) where keys (of Symbol type) correspond to some arguments defined in the model specification. For example, if a model defines y in its argument list

@model function beta_bernoulli(y, a, b)
    x  ~ Beta(a, b)
    y .~ Bernoulli(x)
end

and you want to condition on this argument, then the data field must have an :y key (of Symbol type) which holds the data. The values in the data must have the exact same shape as its corresponding variable container. E.g. in the exampl above y is being used in the broadcasting operation, thus it must be a collection of values. a and b arguments, however, could be just single numbers:

result = infer(
    model = beta_bernoulli(),
    data  = (y = [ true, false, false ], a = 1, b = 1)
)

result.posteriors[:x]
Beta{Float64}(α=2.0, β=3.0)
  • datastream

Also read the Streamlined Inference section.

The datastream keyword argument must be an observable that supports subscribe! and unsubscribe! functions (e.g., streams from the Rocket.jl package). The elements of the observable must be of type NamedTuple where keys (of Symbol type) correspond to input arguments defined in the model specification, except for those which are listed in the @autoupdates specification. For example, if a model defines y as its argument (which is not part of the @autoupdates specification) the named tuple from the observable must have an :y key (of Symbol type). The values in the named tuple must have the exact same shape as the corresponding variable container.

  • initialization

Also read the Initialization section.

For specific types of inference algorithms, such as variational message passing, it might be required to initialize (some of) the marginals before running the inference procedure in order to break the dependency loop. If this is not done, the inference algorithm will not be executed due to the lack of information and message and/or marginals will not be updated. In order to specify these initial marginals and messages, you can use the initialization argument in combination with the @initialization macro, such as

init = @initialization begin
    # initialize the marginal distribution of x as a vague Normal distribution
    # if x is a vector, then it simply uses the same value for all elements
    # However, it is also possible to provide a vector of distributions to set each element individually
    q(x) = vague(NormalMeanPrecision)
end
Initial state: 
  q(x) = NormalMeanPrecision{Float64}(μ=0.0, w=1.0e-12)
  • returnvars

returnvars specifies latent variables of interest and their posterior updates. Its behavior depends on the inference type: streamline or batch.

Batch inference:

  • Accepts a NamedTuple or Dict of return variable specifications.
  • Two specifications available: KeepLast (saves the last update) and KeepEach (saves all updates).
  • When iterations is set, returns every update for each iteration (equivalent to KeepEach()); if nothing, saves the last update (equivalent to KeepLast()).
  • Use iterations = 1 to force KeepEach() for a single iteration or set returnvars = KeepEach() manually.
result = infer(
    ...,
    returnvars = (
        x = KeepLast(),
        τ = KeepEach()
    )
)

Shortcut for setting the same option for all variables:

result = infer(
    ...,
    returnvars = KeepLast()  # or KeepEach()
)

Streamline inference:

  • For each symbol in returnvars, infer creates an observable stream of posterior updates.
  • Agents can subscribe to these updates using the Rocket.jl package.
engine = infer(
    ...,
    autoupdates = my_autoupdates,
    returnvars = (:x, :τ),
    autostart  = false
)
RxInfer.KeepLastType

Instructs the inference engine to keep only the last marginal update and disregard intermediate updates.

source
RxInfer.KeepEachType

Instructs the inference engine to keep each marginal update for all intermediate iterations.

source
  • predictvars

predictvars specifies the variables which should be predicted. Similar to returnvars, predictvars accepts a NamedTuple or Dict. There are two specifications:

  • KeepLast: saves the last update for a variable, ignoring any intermediate results during iterations
  • KeepEach: saves all updates for a variable for all iterations
result = infer(
    ...,
    predictvars = (
        o = KeepLast(),
        τ = KeepEach()
    )
)
Note

The predictvars argument is exclusive for batch setting.

  • historyvars

Also read the Keeping the history of posteriors.

historyvars specifies the variables of interests and the amount of information to keep in history about the posterior updates when performing streamline inference. The specification is similar to the returnvars when applied in batch setting. The historyvars requires keephistory to be greater than zero.

historyvars accepts a NamedTuple or Dict or return var specification. There are two specifications:

  • KeepLast: saves the last update for a variable, ignoring any intermediate results during iterations
  • KeepEach: saves all updates for a variable for all iterations
result = infer(
    ...,
    autoupdates = my_autoupdates,
    historyvars = (
        x = KeepLast(),
        τ = KeepEach()
    ),
    keephistory = 10
)

It is also possible to set either historyvars = KeepLast() or historyvars = KeepEach() that acts as an alias and sets the given option for all random variables in the model.

result = infer(
    ...,
    autoupdates = my_autoupdates,
    historyvars = KeepLast(),
    keephistory = 10
)
  • keep_history

Specifies the buffer size for the updates history both for the historyvars and the free_energy buffers in streamline inference.

Note

The historyvars and keep_history arguments are exclusive for streamlined setting.

  • iterations

Specifies the number of variational (or loopy belief propagation) iterations. By default set to nothing, which is equivalent of doing 1 iteration. However, if set explicitly to 1 the default setting for returnvars changes from KeepLast to KeepEach.

  • free_energy

Batch inference:

Specifies if the infer function should return Bethe Free Energy (BFE) values.

  • Optionally accepts a floating-point type (e.g., Float64) for improved BFE computation performance, but restricts the use of automatic differentiation packages.

Streamline inference:

Specifies if the infer function should create an observable stream of Bethe Free Energy (BFE) values, computed at each VMP iteration.

  • When free_energy = true and keephistory > 0, additional fields are exposed in the engine for accessing the history of BFE updates.
    • engine.free_energy_history: Averaged BFE history over VMP iterations.
    • engine.free_energy_final_only_history: BFE history of values computed in the last VMP iterations for each observation.
    • engine.free_energy_raw_history: Raw BFE history.
  • free_energy_diagnostics

This settings specifies either a single or a tuple of diagnostic checks for Bethe Free Energy values stream. By default checks for NaNs and Infs. See also RxInfer.ObjectiveDiagnosticCheckNaNs and RxInfer.ObjectiveDiagnosticCheckInfs. Pass nothing to disable any checks.

  • options

RxInfer.ReactiveMPInferenceOptionsType
ReactiveMPInferenceOptions(; kwargs...)

Creates model inference options object. The list of available options is present below.

Options

  • limit_stack_depth: limits the stack depth for computing messages, helps with StackOverflowError for some huge models, but reduces the performance of inference backend. Accepts integer as an argument that specifies the maximum number of recursive depth. Lower is better for stack overflow error, but worse for performance.
  • warn: (optional) flag to suppress warnings. Warnings are not displayed if set to false. Defaults to true.

Advanced options

  • scheduler: changes the scheduler of reactive streams, see Rocket.jl for more info, defaults to AsapScheduler.
  • rulefallback: specifies a global message update rule fallback for cases when a specific message update rule is not available. Consult ReactiveMP documentation for the list of available callbacks.

See also: infer

source
  • catch_exception

The catch_exception keyword argument specifies whether exceptions during the batch inference procedure should be caught in the error field of the result. By default, if exception occurs during the inference procedure the result will be lost. Set catch_exception = true to obtain partial result for the inference in case if an exception occurs. Use RxInfer.issuccess and RxInfer.iserror function to check if the inference completed successfully or failed. If an error occurs, the error field will store a tuple, where first element is the exception itself and the second element is the caught backtrace. Use the stacktrace function with the backtrace as an argument to recover the stacktrace of the error. Use Base.showerror function to display the error.

  • callbacks

The inference function has its own lifecycle. The user is free to provide some (or none) of the callbacks to inject some extra logging or other procedures in the inference function, e.g.

result = infer(
    ...,
    callbacks = (
        on_marginal_update = (model, name, update) -> println("\$(name) has been updated: \$(update)"),
        after_inference    = (args...) -> println("Inference has been completed")
    )
)

The callbacks keyword argument accepts a named-tuple of 'name = callback' pairs. The list of all possible callbacks for different inference setting (batch or streamline) and their arguments is present below:

  • before_model_creation()
  • after_model_creation(model::ProbabilisticModel)

Exlusive for batch inference

  • on_marginal_update(model::ProbabilisticModel, name::Symbol, update)
  • before_inference(model::ProbabilisticModel)
  • before_iteration(model::ProbabilisticModel, iteration::Int)::Bool
  • before_data_update(model::ProbabilisticModel, data)
  • after_data_update(model::ProbabilisticModel, data)
  • after_iteration(model::ProbabilisticModel, iteration::Int)::Bool
  • after_inference(model::ProbabilisticModel)
Note

before_iteration and after_iteration callbacks are allowed to return true/false value. true indicates that iterations must be halted and no further inference should be made.

Exlusive for streamline inference

  • before_autostart(engine::RxInferenceEngine)
  • after_autostart(engine::RxInferenceEngine)
  • addons

The addons field extends the default message computation rules with some extra information, e.g. computing log-scaling factors of messages or saving debug-information. Accepts a single addon or a tuple of addons. Automatically changes the default value of the postprocess argument to NoopPostprocess.

  • postprocess

Also read the Inference results postprocessing section.

The postprocess keyword argument controls whether the inference results must be modified in some way before exiting the inference function. By default, the inference function uses the DefaultPostprocess strategy, which by default removes the Marginal wrapper type from the results. Change this setting to NoopPostprocess if you would like to keep the Marginal wrapper type, which might be useful in the combination with the addons argument. If the addons argument has been used, automatically changes the default strategy value to NoopPostprocess.

Where to go next?

Read more explanation about the other keyword arguments in the Streamlined (online) inferencesection or check out the Static Inference section or check some more advanced examples.