Automatic Inference Specification

RxInfer provides the infer function for quickly running and testing your model with both static and streaming datasets. To enable streaming behavior, the infer function accepts an autoupdates argument, which specifies how to update your priors for future states based on newly updated posteriors.

It's important to note that while this function covers most capabilities of the inference engine, advanced use cases may require resorting to the Manual Inference Specification.

For details on manual inference specification, see the Manual Inference section.

RxInfer.inferFunction
infer(
    model; 
    data = nothing,
    datastream = nothing,
    autoupdates = nothing,
    initmarginals = nothing,
    initmessages = nothing,
    constraints = nothing,
    meta = nothing,
    options = nothing,
    returnvars = nothing, 
    predictvars = nothing, 
    historyvars = nothing,
    keephistory = nothing,
    iterations = nothing,
    free_energy = false,
    free_energy_diagnostics = BetheFreeEnergyDefaultChecks,
    showprogress = false,
    callbacks = nothing,
    addons = nothing,
    postprocess = DefaultPostprocess(),
    warn = true,
    events = nothing,
    uselock = false,
    autostart = true,
    catch_exception = false
)

This function provides a generic way to perform probabilistic inference for batch/static and streamline/online scenarios. Returns an InferenceResult (batch setting) or RxInferenceEngine (streamline setting) based on the parameters used.

Arguments

For more information about some of the arguments, please check below.

  • model: specifies a model generator, required
  • data: NamedTuple or Dict with data, required (or datastream or predictvars)
  • datastream: A stream of NamedTuple with data, required (or data)
  • autoupdates = nothing: auto-updates specification, required for streamline inference, see @autoupdates
  • initmarginals = nothing: NamedTuple or Dict with initial marginals, optional
  • initmessages = nothing: NamedTuple or Dict with initial messages, optional
  • constraints = nothing: constraints specification object, optional, see @constraints
  • meta = nothing: meta specification object, optional, may be required for some models, see @meta
  • options = nothing: model creation options, optional, see ModelInferenceOptions
  • returnvars = nothing: return structure info, optional, defaults to return everything at each iteration, see below for more information
  • predictvars = nothing: return structure info, optional, see below for more information (exclusive for batch inference)
  • historyvars = nothing: history structure info, optional, defaults to no history, see below for more information (exclusive for streamline inference)
  • keephistory = nothing: history buffer size, defaults to empty buffer, see below for more information (exclusive for streamline inference)
  • iterations = nothing: number of iterations, optional, defaults to nothing, the inference engine does not distinguish between variational message passing or Loopy belief propagation or expectation propagation iterations, see below for more information
  • free_energy = false: compute the Bethe free energy, optional, defaults to false. Can be passed a floating point type, e.g. Float64, for better efficiency, but disables automatic differentiation packages, such as ForwardDiff.jl
  • free_energy_diagnostics = BetheFreeEnergyDefaultChecks: free energy diagnostic checks, optional, by default checks for possible NaNs and Infs. nothing disables all checks.
  • showprogress = false: show progress module, optional, defaults to false (exclusive for batch inference)
  • catch_exception specifies whether exceptions during the inference procedure should be caught, optional, defaults to false (exclusive for batch inference)
  • callbacks = nothing: inference cycle callbacks, optional, see below for more info
  • addons = nothing: inject and send extra computation information along messages, see below for more info
  • postprocess = DefaultPostprocess(): inference results postprocessing step, optional, see below for more info
  • events = nothing: inference cycle events, optional, see below for more info (exclusive for streamline inference)
  • uselock = false: specifies either to use the lock structure for the inference or not, if set to true uses Base.Threads.SpinLock. Accepts custom AbstractLock. (exclusive for streamline inference)
  • autostart = true: specifies whether to call RxInfer.start on the created engine automatically or not (exclusive for streamline inference)
  • warn = true: enables/disables warnings

Note on NamedTuples

When passing NamedTuple as a value for some argument, make sure you use a trailing comma for NamedTuples with a single entry. The reason is that Julia treats returnvars = (x = KeepLast()) and returnvars = (x = KeepLast(), ) expressions differently. This first expression creates (or overwrites!) new local/global variable named x with contents KeepLast(). The second expression (note trailing comma) creates NamedTuple with x as a key and KeepLast() as a value assigned for this key.

The model argument accepts a ModelGenerator as its input. The easiest way to create the ModelGenerator is to use the @model macro. For example:

@model function coin_toss(some_argument, some_keyword_argument = 3)
    ...
end

result = infer(
    model = coin_toss(some_argument; some_keyword_argument = 3)
)

Note: The model keyword argument does not accept a FactorGraphModel instance as a value, as it needs to inject constraints and meta during the inference procedure.

  • data

Either data or datastream or predictvars keyword argument is required. Specifying both data and datastream is not supported and will result in an error. Specifying both datastream and predictvars is not supported and will result in an error.

Note: The behavior of the data keyword argument depends on the inference setting (batch or streamline).

The data keyword argument must be a NamedTuple (or Dict) where keys (of Symbol type) correspond to all datavars defined in the model specification. For example, if a model defines x = datavar(Float64) the data field must have an :x key (of Symbol type) which holds a value of type Float64. The values in the data must have the exact same shape as the datavar container. In other words, if a model defines x = datavar(Float64, n) then data[:x] must provide a container with length n and with elements of type Float64.

  • streamline setting

All entries in the data argument are zipped together with the Base.zip function to form one slice of the data chunck. This means all containers in the data argument must be of the same size (zip iterator finished as soon as one container has no remaining values). In order to use a fixed value for some specific datavar it is not necessary to create a container with that fixed value, but rather more efficient to use Iterators.repeated to create an infinite iterator.

  • datastream

The datastream keyword argument must be an observable that supports subscribe! and unsubscribe! functions (streams from the Rocket.jl package are also supported). The elements of the observable must be of type NamedTuple where keys (of Symbol type) correspond to all datavars defined in the model specification, except for those which are listed in the autoupdates specification. For example, if a model defines x = datavar(Float64) (which is not part of the autoupdates specification) the named tuple from the observable must have an :x key (of Symbol type) which holds a value of type Float64. The values in the named tuple must have the exact same shape as the datavar container. In other words, if a model defines x = datavar(Float64, n) then namedtuple[:x] must provide a container with length n and with elements of type Float64.

Note: The behavior of the individual named tuples from the datastream observable is similar to that which is used in the batch setting. In fact, you can see the streamline inference as an efficient version of the batch inference, which automatically updates some datavars with the autoupdates specification and listens to the datastream to update the rest of the datavars.

For specific types of inference algorithms, such as variational message passing, it might be required to initialize (some of) the marginals before running the inference procedure in order to break the dependency loop. If this is not done, the inference algorithm will not be executed due to the lack of information and message and/or marginals will not be updated. In order to specify these initial marginals, you can use the initmarginals argument, such as

infer(...
    initmarginals = (
        # initialize the marginal distribution of x as a vague Normal distribution
        # if x is a vector, then it simply uses the same value for all elements
        # However, it is also possible to provide a vector of distributions to set each element individually 
        x = vague(NormalMeanPrecision),  
    ),
)

This argument needs to be a named tuple, i.e. `initmarginals = (a = ..., )`, or dictionary.

- ### `initmessages`

For specific types of inference algorithms, such as loopy belief propagation or expectation propagation, it might be required to initialize (some of) the messages before running the inference procedure in order to break the dependency loop. If this is not done, the inference algorithm will not be executed due to the lack of information and message and/or marginals will not be updated. In order to specify these initial messages, you can use the `initmessages` argument, such as

julia infer(... initmessages = ( # initialize the messages distribution of x as a vague Normal distribution # if x is a vector, then it simply uses the same value for all elements # However, it is also possible to provide a vector of distributions to set each element individually x = vague(NormalMeanPrecision), ), )

  • options

  • limit_stack_depth: limits the stack depth for computing messages, helps with StackOverflowError for some huge models, but reduces the performance of inference backend. Accepts integer as an argument that specifies the maximum number of recursive depth. Lower is better for stack overflow error, but worse for performance.

  • pipeline: changes the default pipeline for each factor node in the graph

  • global_reactive_scheduler: changes the scheduler of reactive streams, see Rocket.jl for more info, defaults to no scheduler

  • returnvars

returnvars specifies latent variables of interest and their posterior updates. Its behavior depends on the inference type: streamline or batch.

Batch inference:

  • Accepts a NamedTuple or Dict of return variable specifications.
  • Two specifications available: KeepLast (saves the last update) and KeepEach (saves all updates).
  • When iterations is set, returns every update for each iteration (equivalent to KeepEach()); if nothing, saves the last update (equivalent to KeepLast()).
  • Use iterations = 1 to force KeepEach() for a single iteration or set returnvars = KeepEach() manually.

Example:

result = infer(
    ...,
    returnvars = (
        x = KeepLast(),
        τ = KeepEach()
    )
)

Shortcut for setting the same option for all variables:

result = infer(
    ...,
    returnvars = KeepLast()  # or KeepEach()
)

Streamline inference:

  • For each symbol in returnvars, infer creates an observable stream of posterior updates.
  • Agents can subscribe to these updates using the Rocket.jl package.

Example:

engine = infer(
    ...,
    autoupdates = my_autoupdates,
    returnvars = (:x, :τ),
    autostart  = false
)
  • predictvars

predictvars specifies the variables which should be predicted. In the model definition these variables are specified as datavars, although they should not be passed inside data argument.

Similar to returnvars, predictvars accepts a NamedTuple or Dict. There are two specifications:

  • KeepLast: saves the last update for a variable, ignoring any intermediate results during iterations
  • KeepEach: saves all updates for a variable for all iterations

Example:

result = infer(
    ...,
    predictvars = (
        o = KeepLast(),
        τ = KeepEach()
    )
)

Note: The predictvars argument is exclusive for batch setting.

  • historyvars

historyvars specifies the variables of interests and the amount of information to keep in history about the posterior updates when performing streamline inference. The specification is similar to the returnvars when applied in batch setting. The historyvars requires keephistory to be greater than zero.

historyvars accepts a NamedTuple or Dict or return var specification. There are two specifications:

  • KeepLast: saves the last update for a variable, ignoring any intermediate results during iterations
  • KeepEach: saves all updates for a variable for all iterations

Example:

result = infer(
    ...,
    autoupdates = my_autoupdates,
    historyvars = (
        x = KeepLast(),
        τ = KeepEach()
    ),
    keephistory = 10
)

It is also possible to set either historyvars = KeepLast() or historyvars = KeepEach() that acts as an alias and sets the given option for all random variables in the model.

Example:

result = infer(
    ...,
    autoupdates = my_autoupdates,
    historyvars = KeepLast(),
    keephistory = 10
)
  • keep_history

Specifies the buffer size for the updates history both for the historyvars and the free_energy buffers in streamline inference.

  • iterations

Specifies the number of variational (or loopy belief propagation) iterations. By default set to nothing, which is equivalent of doing 1 iteration.

  • free_energy

Streamline inference:

Specifies if the infer function should create an observable stream of Bethe Free Energy (BFE) values, computed at each VMP iteration.

  • When free_energy = true and keephistory > 0, additional fields are exposed in the engine for accessing the history of BFE updates.
    • engine.free_energy_history: Averaged BFE history over VMP iterations.
    • engine.free_energy_final_only_history: BFE history of values computed in the last VMP iterations for each observation.
    • engine.free_energy_raw_history: Raw BFE history.

Batch inference:

Specifies if the infer function should return Bethe Free Energy (BFE) values.

  • Optionally accepts a floating-point type (e.g., Float64) for improved BFE computation performance, but restricts the use of automatic differentiation packages.

  • free_energy_diagnostics

This settings specifies either a single or a tuple of diagnostic checks for Bethe Free Energy values stream. By default checks for NaNs and Infs. See also BetheFreeEnergyCheckNaNs and BetheFreeEnergyCheckInfs. Pass nothing to disable any checks.

  • catch_exception

The catch_exception keyword argument specifies whether exceptions during the batch inference procedure should be caught in the error field of the result. By default, if exception occurs during the inference procedure the result will be lost. Set catch_exception = true to obtain partial result for the inference in case if an exception occurs. Use RxInfer.issuccess and RxInfer.iserror function to check if the inference completed successfully or failed. If an error occurs, the error field will store a tuple, where first element is the exception itself and the second element is the caught backtrace. Use the stacktrace function with the backtrace as an argument to recover the stacktrace of the error. Use Base.showerror function to display the error.

  • callbacks

The inference function has its own lifecycle. The user is free to provide some (or none) of the callbacks to inject some extra logging or other procedures in the inference function, e.g.

result = infer(
    ...,
    callbacks = (
        on_marginal_update = (model, name, update) -> println("$(name) has been updated: $(update)"),
        after_inference    = (args...) -> println("Inference has been completed")
    )
)

The callbacks keyword argument accepts a named-tuple of 'name = callback' pairs. The list of all possible callbacks for different inference setting (batch or streamline) and their arguments is present below:

  • on_marginal_update: args: (model::FactorGraphModel, name::Symbol, update) (exlusive for batch inference)
  • before_model_creation: args: ()
  • after_model_creation: args: (model::FactorGraphModel, returnval)
  • before_inference: args: (model::FactorGraphModel) (exlusive for batch inference)
  • before_iteration: args: (model::FactorGraphModel, iteration::Int)::Bool (exlusive for batch inference)
  • before_data_update: args: (model::FactorGraphModel, data) (exlusive for batch inference)
  • after_data_update: args: (model::FactorGraphModel, data) (exlusive for batch inference)
  • after_iteration: args: (model::FactorGraphModel, iteration::Int)::Bool (exlusive for batch inference)
  • after_inference: args: (model::FactorGraphModel) (exlusive for batch inference)
  • before_autostart: args: (engine::RxInferenceEngine) (exlusive for streamline inference)
  • after_autostart: args: (engine::RxInferenceEngine) (exlusive for streamline inference)

before_iteration and after_iteration callbacks are allowed to return true/false value. true indicates that iterations must be halted and no further inference should be made.

  • addons

The addons field extends the default message computation rules with some extra information, e.g. computing log-scaling factors of messages or saving debug-information. Accepts a single addon or a tuple of addons. If set, replaces the corresponding setting in the options. Automatically changes the default value of the postprocess argument to NoopPostprocess.

  • postprocess

The postprocess keyword argument controls whether the inference results must be modified in some way before exiting the inference function. By default, the inference function uses the DefaultPostprocess strategy, which by default removes the Marginal wrapper type from the results. Change this setting to NoopPostprocess if you would like to keep the Marginal wrapper type, which might be useful in the combination with the addons argument. If the addons argument has been used, automatically changes the default strategy value to NoopPostprocess.

source
RxInfer.InferenceResultType
InferenceResult

This structure is used as a return value from the infer function.

Public Fields

  • posteriors: Dict or NamedTuple of 'random variable' - 'posterior' pairs. See the returnvars argument for infer.
  • free_energy: (optional) An array of Bethe Free Energy values per VMP iteration. See the free_energy argument for infer.
  • model: FactorGraphModel object reference.
  • returnval: Return value from executed @model.
  • error: (optional) A reference to an exception, that might have occurred during the inference. See the catch_exception argument for infer.

See also: infer

source
RxInfer.startFunction
start(engine::RxInferenceEngine)

Starts the RxInferenceEngine by subscribing to the data source, instantiating free energy (if enabled) and starting the event loop. Use RxInfer.stop to stop the RxInferenceEngine. Note that it is not always possible to stop/restart the engine and this depends on the data source type.

See also: RxInfer.stop

source
RxInfer.stopFunction
stop(engine::RxInferenceEngine)

Stops the RxInferenceEngine by unsubscribing to the data source, free energy (if enabled) and stopping the event loop. Use RxInfer.start to start the RxInferenceEngine again. Note that it is not always possible to stop/restart the engine and this depends on the data source type.

See also: RxInfer.start

source
RxInfer.@autoupdatesMacro
@autoupdates

Creates the auto-updates specification for the rxinference function. In the online-streaming Bayesian inference procedure it is important to update your priors for the future states based on the new updated posteriors. The @autoupdates structure simplify such a specification. It accepts a single block of code where each line defines how to update the datavar's in the probabilistic model specification.

Each line of code in the auto-update specification defines datavars, which need to be updated, on the left hand side of the equality expression and the update function on the right hand side of the expression. The update function operates on posterior marginals in the form of the q(symbol) expression.

For example:

@autoupdates begin 
    x = f(q(z))
end

This structure specifies to automatically update x = datavar(...) as soon as the inference engine computes new posterior over z variable. It then applies the f function to the new posterior and calls update!(x, ...) automatically.

As an example consider the following model and auto-update specification:

@model function kalman_filter()
    x_current_mean = datavar(Float64)
    x_current_var  = datavar(Float64)

    x_current ~ Normal(mean = x_current_mean, var = x_current_var)

    x_next ~ Normal(mean = x_current, var = 1.0)

    y = datavar(Float64)
    y ~ Normal(mean = x_next, var = 1.0)
end

This model has two datavars that represent our prior knowledge of the x_current state of the system. The x_next random variable represent the next state of the system that is connected to the observed variable y. The auto-update specification could look like:

autoupdates = @autoupdates begin
    x_current_mean, x_current_var = mean_cov(q(x_next))
end

This structure specifies to update our prior as soon as we have a new posterior q(x_next). It then applies the mean_cov function on the updated posteriors and updates datavars x_current_mean and x_current_var automatically.

See also: infer

source
RxInfer.RxInferenceEngineType
RxInferenceEngine

The return value of the rxinference function.

Public fields

  • posteriors: Dict or NamedTuple of 'random variable' - 'posterior stream' pairs. See the returnvars argument for the infer.
  • free_energy: (optional) A stream of Bethe Free Energy values per VMP iteration. See the free_energy argument for the infer.
  • history: (optional) Saves history of previous marginal updates. See the historyvars and keephistory arguments for the infer.
  • free_energy_history: (optional) Free energy history, average over variational iterations
  • free_energy_raw_history: (optional) Free energy history, returns returns computed values of all variational iterations for each data event (if available)
  • free_energy_final_only_history: (optional) Free energy history, returns computed values of final variational iteration for each data event (if available)
  • events: (optional) A stream of events send by the inference engine. See the events argument for the infer.
  • model: FactorGraphModel object reference.
  • returnval: Return value from executed @model.

Use the RxInfer.start(engine) function to subscribe on the data source and start the inference procedure. Use RxInfer.stop(engine) to unsubscribe from the data source and stop the inference procedure. Note, that it is not always possible to start/stop the inference procedure.

See also: infer, RxInferenceEvent, RxInfer.start, RxInfer.stop

source
RxInfer.RxInferenceEventType
RxInferenceEvent{T, D}

The RxInferenceEngine sends events in a form of the RxInferenceEvent structure. T represents the type of an event, D represents the type of a data associated with the event. The type of data depends on the type of an event, but usually represents a tuple, which can be unrolled automatically with the Julia's splitting syntax, e.g. model, iteration = event. See the documentation of the rxinference function for possible event types and their associated data types.

The events system itself uses the Rocket.jl library API. For example, one may create a custom event listener in the following way:

using Rocket

struct MyEventListener <: Rocket.Actor{RxInferenceEvent}
    # ... extra fields
end

function Rocket.on_next!(listener::MyEventListener, event::RxInferenceEvent{ :after_iteration })
    model, iteration = event
    println("Iteration $(iteration) has been finished.")
end

function Rocket.on_error!(listener::MyEventListener, err)
    # ...
end

function Rocket.on_complete!(listener::MyEventListener)
    # ...
end

and later on:

engine = infer(events = Val((:after_iteration, )), ...)

subscription = subscribe!(engine.events, MyEventListener(...))

See also: infer, RxInferenceEngine

source